Files
letta-server/letta/services/context_window_calculator/token_counter.py
2025-12-15 12:02:19 -08:00

318 lines
12 KiB
Python

import hashlib
import json
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from letta.helpers.decorators import async_redis_cache
from letta.llm_api.anthropic_client import AnthropicClient
from letta.llm_api.google_vertex_client import GoogleVertexClient
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.enums import ProviderType
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
if TYPE_CHECKING:
from letta.schemas.llm_config import LLMConfig
from letta.schemas.user import User
logger = get_logger(__name__)
class TokenCounter(ABC):
"""Abstract base class for token counting strategies"""
@abstractmethod
async def count_text_tokens(self, text: str) -> int:
"""Count tokens in a text string"""
@abstractmethod
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
"""Count tokens in a list of messages"""
@abstractmethod
async def count_tool_tokens(self, tools: List[Any]) -> int:
"""Count tokens in tool definitions"""
@abstractmethod
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
"""Convert messages to the appropriate format for this counter"""
class AnthropicTokenCounter(TokenCounter):
"""Token counter using Anthropic's API"""
def __init__(self, anthropic_client: AnthropicClient, model: str):
self.client = anthropic_client
self.model = model
@trace_method
@async_redis_cache(
key_func=lambda self, text: f"anthropic_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_text_tokens(self, text: str) -> int:
if not text:
return 0
return await self.client.count_tokens(model=self.model, messages=[{"role": "user", "content": text}])
@trace_method
@async_redis_cache(
key_func=lambda self,
messages: f"anthropic_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
if not messages:
return 0
return await self.client.count_tokens(model=self.model, messages=messages)
@trace_method
@async_redis_cache(
key_func=lambda self,
tools: f"anthropic_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
if not tools:
return 0
return await self.client.count_tokens(model=self.model, tools=tools)
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
return Message.to_anthropic_dicts_from_list(messages, current_model=self.model)
class ApproxTokenCounter(TokenCounter):
"""Fast approximate token counter using byte-based heuristic (bytes / 4).
This is the same approach codex-cli uses - a simple approximation that assumes
~4 bytes per token on average for English text. Much faster than tiktoken
and doesn't require loading tokenizer models into memory.
Just serializes the input to JSON and divides byte length by 4.
"""
APPROX_BYTES_PER_TOKEN = 4
def __init__(self, model: str | None = None):
# Model is optional since we don't actually use a tokenizer
self.model = model
def _approx_token_count(self, text: str) -> int:
"""Approximate token count: ceil(byte_len / 4)"""
if not text:
return 0
byte_len = len(text.encode("utf-8"))
return (byte_len + self.APPROX_BYTES_PER_TOKEN - 1) // self.APPROX_BYTES_PER_TOKEN
async def count_text_tokens(self, text: str) -> int:
if not text:
return 0
return self._approx_token_count(text)
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
if not messages:
return 0
return self._approx_token_count(json.dumps(messages))
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
if not tools:
return 0
functions = [t.model_dump() for t in tools]
return self._approx_token_count(json.dumps(functions))
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
return Message.to_openai_dicts_from_list(messages)
class GeminiTokenCounter(TokenCounter):
"""Token counter using Google's Gemini token counting API"""
def __init__(self, gemini_client: GoogleVertexClient, model: str):
self.client = gemini_client
self.model = model
@trace_method
@async_redis_cache(
key_func=lambda self, text: f"gemini_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_text_tokens(self, text: str) -> int:
if not text:
return 0
# For text counting, wrap in a simple user message format for Google
return await self.client.count_tokens(model=self.model, messages=[{"role": "user", "parts": [{"text": text}]}])
@trace_method
@async_redis_cache(
key_func=lambda self,
messages: f"gemini_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
if not messages:
return 0
return await self.client.count_tokens(model=self.model, messages=messages)
@trace_method
@async_redis_cache(
key_func=lambda self,
tools: f"gemini_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
if not tools:
return 0
return await self.client.count_tokens(model=self.model, tools=tools)
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
google_messages = Message.to_google_dicts_from_list(messages, current_model=self.model)
return google_messages
class TiktokenCounter(TokenCounter):
"""Token counter using tiktoken"""
def __init__(self, model: str):
self.model = model
@trace_method
@async_redis_cache(
key_func=lambda self, text: f"tiktoken_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_text_tokens(self, text: str) -> int:
from letta.log import get_logger
logger = get_logger(__name__)
if not text:
return 0
text_length = len(text)
text_preview = text[:100] + "..." if len(text) > 100 else text
logger.debug(f"TiktokenCounter.count_text_tokens: model={self.model}, text_length={text_length}, preview={repr(text_preview)}")
try:
import tiktoken
try:
encoding = tiktoken.encoding_for_model(self.model)
except KeyError:
logger.debug(f"Model {self.model} not found in tiktoken. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
result = len(encoding.encode(text))
logger.debug(f"TiktokenCounter.count_text_tokens: completed successfully, tokens={result}")
return result
except Exception as e:
logger.error(f"TiktokenCounter.count_text_tokens: FAILED with {type(e).__name__}: {e}, text_length={text_length}")
raise
@trace_method
@async_redis_cache(
key_func=lambda self,
messages: f"tiktoken_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
from letta.log import get_logger
logger = get_logger(__name__)
if not messages:
return 0
num_messages = len(messages)
total_content_length = sum(len(str(m.get("content", ""))) for m in messages)
logger.debug(
f"TiktokenCounter.count_message_tokens: model={self.model}, num_messages={num_messages}, total_content_length={total_content_length}"
)
try:
from letta.local_llm.utils import num_tokens_from_messages
result = num_tokens_from_messages(messages=messages, model=self.model)
logger.debug(f"TiktokenCounter.count_message_tokens: completed successfully, tokens={result}")
return result
except Exception as e:
logger.error(f"TiktokenCounter.count_message_tokens: FAILED with {type(e).__name__}: {e}, num_messages={num_messages}")
raise
@trace_method
@async_redis_cache(
key_func=lambda self,
tools: f"tiktoken_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
if not tools:
return 0
from letta.local_llm.utils import num_tokens_from_functions
# Extract function definitions from OpenAITool objects
functions = [t.function.model_dump() for t in tools]
return num_tokens_from_functions(functions=functions, model=self.model)
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
return Message.to_openai_dicts_from_list(messages)
def create_token_counter(
model_endpoint_type: ProviderType,
model: Optional[str] = None,
actor: "User" = None,
agent_id: Optional[str] = None,
) -> "TokenCounter":
"""
Factory function to create the appropriate token counter based on model configuration.
Returns:
The appropriate TokenCounter instance
"""
from letta.llm_api.llm_client import LLMClient
from letta.settings import model_settings, settings
# Use Gemini token counter for Google Vertex and Google AI
use_gemini = model_endpoint_type in ("google_vertex", "google_ai")
# Use Anthropic token counter if:
# 1. The model endpoint type is anthropic, OR
# 2. We're in PRODUCTION and anthropic_api_key is available (and not using Gemini)
use_anthropic = model_endpoint_type == "anthropic"
if use_gemini:
client = LLMClient.create(provider_type=model_endpoint_type, actor=actor)
token_counter = GeminiTokenCounter(client, model)
logger.info(
f"Using GeminiTokenCounter for agent_id={agent_id}, model={model}, "
f"model_endpoint_type={model_endpoint_type}, "
f"environment={settings.environment}"
)
elif use_anthropic:
anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor)
counter_model = model if model_endpoint_type == "anthropic" else None
token_counter = AnthropicTokenCounter(anthropic_client, counter_model)
logger.info(
f"Using AnthropicTokenCounter for agent_id={agent_id}, model={counter_model}, "
f"model_endpoint_type={model_endpoint_type}, "
f"environment={settings.environment}"
)
else:
token_counter = ApproxTokenCounter()
logger.info(
f"Using ApproxTokenCounter for agent_id={agent_id}, model={model}, "
f"model_endpoint_type={model_endpoint_type}, "
f"environment={settings.environment}"
)
return token_counter