318 lines
12 KiB
Python
318 lines
12 KiB
Python
import hashlib
|
|
import json
|
|
from abc import ABC, abstractmethod
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
|
|
from letta.helpers.decorators import async_redis_cache
|
|
from letta.llm_api.anthropic_client import AnthropicClient
|
|
from letta.llm_api.google_vertex_client import GoogleVertexClient
|
|
from letta.log import get_logger
|
|
from letta.otel.tracing import trace_method
|
|
from letta.schemas.enums import ProviderType
|
|
from letta.schemas.message import Message
|
|
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
|
|
|
|
if TYPE_CHECKING:
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.user import User
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class TokenCounter(ABC):
|
|
"""Abstract base class for token counting strategies"""
|
|
|
|
@abstractmethod
|
|
async def count_text_tokens(self, text: str) -> int:
|
|
"""Count tokens in a text string"""
|
|
|
|
@abstractmethod
|
|
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
|
"""Count tokens in a list of messages"""
|
|
|
|
@abstractmethod
|
|
async def count_tool_tokens(self, tools: List[Any]) -> int:
|
|
"""Count tokens in tool definitions"""
|
|
|
|
@abstractmethod
|
|
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
|
|
"""Convert messages to the appropriate format for this counter"""
|
|
|
|
|
|
class AnthropicTokenCounter(TokenCounter):
|
|
"""Token counter using Anthropic's API"""
|
|
|
|
def __init__(self, anthropic_client: AnthropicClient, model: str):
|
|
self.client = anthropic_client
|
|
self.model = model
|
|
|
|
@trace_method
|
|
@async_redis_cache(
|
|
key_func=lambda self, text: f"anthropic_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}",
|
|
prefix="token_counter",
|
|
ttl_s=3600, # cache for 1 hour
|
|
)
|
|
async def count_text_tokens(self, text: str) -> int:
|
|
if not text:
|
|
return 0
|
|
return await self.client.count_tokens(model=self.model, messages=[{"role": "user", "content": text}])
|
|
|
|
@trace_method
|
|
@async_redis_cache(
|
|
key_func=lambda self,
|
|
messages: f"anthropic_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
|
|
prefix="token_counter",
|
|
ttl_s=3600, # cache for 1 hour
|
|
)
|
|
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
|
if not messages:
|
|
return 0
|
|
return await self.client.count_tokens(model=self.model, messages=messages)
|
|
|
|
@trace_method
|
|
@async_redis_cache(
|
|
key_func=lambda self,
|
|
tools: f"anthropic_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
|
|
prefix="token_counter",
|
|
ttl_s=3600, # cache for 1 hour
|
|
)
|
|
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
|
|
if not tools:
|
|
return 0
|
|
return await self.client.count_tokens(model=self.model, tools=tools)
|
|
|
|
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
|
|
return Message.to_anthropic_dicts_from_list(messages, current_model=self.model)
|
|
|
|
|
|
class ApproxTokenCounter(TokenCounter):
|
|
"""Fast approximate token counter using byte-based heuristic (bytes / 4).
|
|
|
|
This is the same approach codex-cli uses - a simple approximation that assumes
|
|
~4 bytes per token on average for English text. Much faster than tiktoken
|
|
and doesn't require loading tokenizer models into memory.
|
|
|
|
Just serializes the input to JSON and divides byte length by 4.
|
|
"""
|
|
|
|
APPROX_BYTES_PER_TOKEN = 4
|
|
|
|
def __init__(self, model: str | None = None):
|
|
# Model is optional since we don't actually use a tokenizer
|
|
self.model = model
|
|
|
|
def _approx_token_count(self, text: str) -> int:
|
|
"""Approximate token count: ceil(byte_len / 4)"""
|
|
if not text:
|
|
return 0
|
|
byte_len = len(text.encode("utf-8"))
|
|
return (byte_len + self.APPROX_BYTES_PER_TOKEN - 1) // self.APPROX_BYTES_PER_TOKEN
|
|
|
|
async def count_text_tokens(self, text: str) -> int:
|
|
if not text:
|
|
return 0
|
|
return self._approx_token_count(text)
|
|
|
|
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
|
if not messages:
|
|
return 0
|
|
return self._approx_token_count(json.dumps(messages))
|
|
|
|
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
|
|
if not tools:
|
|
return 0
|
|
functions = [t.model_dump() for t in tools]
|
|
return self._approx_token_count(json.dumps(functions))
|
|
|
|
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
|
|
return Message.to_openai_dicts_from_list(messages)
|
|
|
|
|
|
class GeminiTokenCounter(TokenCounter):
|
|
"""Token counter using Google's Gemini token counting API"""
|
|
|
|
def __init__(self, gemini_client: GoogleVertexClient, model: str):
|
|
self.client = gemini_client
|
|
self.model = model
|
|
|
|
@trace_method
|
|
@async_redis_cache(
|
|
key_func=lambda self, text: f"gemini_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}",
|
|
prefix="token_counter",
|
|
ttl_s=3600, # cache for 1 hour
|
|
)
|
|
async def count_text_tokens(self, text: str) -> int:
|
|
if not text:
|
|
return 0
|
|
# For text counting, wrap in a simple user message format for Google
|
|
return await self.client.count_tokens(model=self.model, messages=[{"role": "user", "parts": [{"text": text}]}])
|
|
|
|
@trace_method
|
|
@async_redis_cache(
|
|
key_func=lambda self,
|
|
messages: f"gemini_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
|
|
prefix="token_counter",
|
|
ttl_s=3600, # cache for 1 hour
|
|
)
|
|
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
|
if not messages:
|
|
return 0
|
|
return await self.client.count_tokens(model=self.model, messages=messages)
|
|
|
|
@trace_method
|
|
@async_redis_cache(
|
|
key_func=lambda self,
|
|
tools: f"gemini_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
|
|
prefix="token_counter",
|
|
ttl_s=3600, # cache for 1 hour
|
|
)
|
|
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
|
|
if not tools:
|
|
return 0
|
|
return await self.client.count_tokens(model=self.model, tools=tools)
|
|
|
|
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
|
|
google_messages = Message.to_google_dicts_from_list(messages, current_model=self.model)
|
|
return google_messages
|
|
|
|
|
|
class TiktokenCounter(TokenCounter):
|
|
"""Token counter using tiktoken"""
|
|
|
|
def __init__(self, model: str):
|
|
self.model = model
|
|
|
|
@trace_method
|
|
@async_redis_cache(
|
|
key_func=lambda self, text: f"tiktoken_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}",
|
|
prefix="token_counter",
|
|
ttl_s=3600, # cache for 1 hour
|
|
)
|
|
async def count_text_tokens(self, text: str) -> int:
|
|
from letta.log import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
if not text:
|
|
return 0
|
|
|
|
text_length = len(text)
|
|
text_preview = text[:100] + "..." if len(text) > 100 else text
|
|
logger.debug(f"TiktokenCounter.count_text_tokens: model={self.model}, text_length={text_length}, preview={repr(text_preview)}")
|
|
|
|
try:
|
|
import tiktoken
|
|
|
|
try:
|
|
encoding = tiktoken.encoding_for_model(self.model)
|
|
except KeyError:
|
|
logger.debug(f"Model {self.model} not found in tiktoken. Using cl100k_base encoding.")
|
|
encoding = tiktoken.get_encoding("cl100k_base")
|
|
result = len(encoding.encode(text))
|
|
logger.debug(f"TiktokenCounter.count_text_tokens: completed successfully, tokens={result}")
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"TiktokenCounter.count_text_tokens: FAILED with {type(e).__name__}: {e}, text_length={text_length}")
|
|
raise
|
|
|
|
@trace_method
|
|
@async_redis_cache(
|
|
key_func=lambda self,
|
|
messages: f"tiktoken_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
|
|
prefix="token_counter",
|
|
ttl_s=3600, # cache for 1 hour
|
|
)
|
|
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
|
from letta.log import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
if not messages:
|
|
return 0
|
|
|
|
num_messages = len(messages)
|
|
total_content_length = sum(len(str(m.get("content", ""))) for m in messages)
|
|
logger.debug(
|
|
f"TiktokenCounter.count_message_tokens: model={self.model}, num_messages={num_messages}, total_content_length={total_content_length}"
|
|
)
|
|
|
|
try:
|
|
from letta.local_llm.utils import num_tokens_from_messages
|
|
|
|
result = num_tokens_from_messages(messages=messages, model=self.model)
|
|
logger.debug(f"TiktokenCounter.count_message_tokens: completed successfully, tokens={result}")
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"TiktokenCounter.count_message_tokens: FAILED with {type(e).__name__}: {e}, num_messages={num_messages}")
|
|
raise
|
|
|
|
@trace_method
|
|
@async_redis_cache(
|
|
key_func=lambda self,
|
|
tools: f"tiktoken_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
|
|
prefix="token_counter",
|
|
ttl_s=3600, # cache for 1 hour
|
|
)
|
|
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
|
|
if not tools:
|
|
return 0
|
|
from letta.local_llm.utils import num_tokens_from_functions
|
|
|
|
# Extract function definitions from OpenAITool objects
|
|
functions = [t.function.model_dump() for t in tools]
|
|
return num_tokens_from_functions(functions=functions, model=self.model)
|
|
|
|
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
|
|
return Message.to_openai_dicts_from_list(messages)
|
|
|
|
|
|
def create_token_counter(
|
|
model_endpoint_type: ProviderType,
|
|
model: Optional[str] = None,
|
|
actor: "User" = None,
|
|
agent_id: Optional[str] = None,
|
|
) -> "TokenCounter":
|
|
"""
|
|
Factory function to create the appropriate token counter based on model configuration.
|
|
|
|
Returns:
|
|
The appropriate TokenCounter instance
|
|
"""
|
|
from letta.llm_api.llm_client import LLMClient
|
|
from letta.settings import model_settings, settings
|
|
|
|
# Use Gemini token counter for Google Vertex and Google AI
|
|
use_gemini = model_endpoint_type in ("google_vertex", "google_ai")
|
|
|
|
# Use Anthropic token counter if:
|
|
# 1. The model endpoint type is anthropic, OR
|
|
# 2. We're in PRODUCTION and anthropic_api_key is available (and not using Gemini)
|
|
use_anthropic = model_endpoint_type == "anthropic"
|
|
|
|
if use_gemini:
|
|
client = LLMClient.create(provider_type=model_endpoint_type, actor=actor)
|
|
token_counter = GeminiTokenCounter(client, model)
|
|
logger.info(
|
|
f"Using GeminiTokenCounter for agent_id={agent_id}, model={model}, "
|
|
f"model_endpoint_type={model_endpoint_type}, "
|
|
f"environment={settings.environment}"
|
|
)
|
|
elif use_anthropic:
|
|
anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor)
|
|
counter_model = model if model_endpoint_type == "anthropic" else None
|
|
token_counter = AnthropicTokenCounter(anthropic_client, counter_model)
|
|
logger.info(
|
|
f"Using AnthropicTokenCounter for agent_id={agent_id}, model={counter_model}, "
|
|
f"model_endpoint_type={model_endpoint_type}, "
|
|
f"environment={settings.environment}"
|
|
)
|
|
else:
|
|
token_counter = ApproxTokenCounter()
|
|
logger.info(
|
|
f"Using ApproxTokenCounter for agent_id={agent_id}, model={model}, "
|
|
f"model_endpoint_type={model_endpoint_type}, "
|
|
f"environment={settings.environment}"
|
|
)
|
|
|
|
return token_counter
|