feat: Cache anthropic tokenizer (#3286)

This commit is contained in:
Matthew Zhou
2025-07-10 18:12:23 -07:00
committed by GitHub
parent c94b227a32
commit fc98ff3cc6
2 changed files with 35 additions and 1 deletions

View File

@@ -4,6 +4,7 @@ from typing import Any, List, Optional, Tuple
from openai.types.beta.function_tool import FunctionTool as OpenAITool
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.memory import ContextWindowOverview
@@ -56,7 +57,7 @@ class ContextWindowCalculator:
return None, 1
async def calculate_context_window(
self, agent_state: Any, actor: PydanticUser, token_counter: TokenCounter, message_manager: Any, passage_manager: Any
self, agent_state: AgentState, actor: PydanticUser, token_counter: TokenCounter, message_manager: Any, passage_manager: Any
) -> ContextWindowOverview:
"""Calculate context window information using the provided token counter"""

View File

@@ -1,6 +1,9 @@
import hashlib
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from letta.helpers.decorators import async_redis_cache
from letta.llm_api.anthropic_client import AnthropicClient
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
from letta.utils import count_tokens
@@ -33,16 +36,31 @@ class AnthropicTokenCounter(TokenCounter):
self.client = anthropic_client
self.model = model
@async_redis_cache(
key_func=lambda self, text: f"anthropic_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_text_tokens(self, text: str) -> int:
if not text:
return 0
return await self.client.count_tokens(model=self.model, messages=[{"role": "user", "content": text}])
@async_redis_cache(
key_func=lambda self, messages: f"anthropic_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
if not messages:
return 0
return await self.client.count_tokens(model=self.model, messages=messages)
@async_redis_cache(
key_func=lambda self, tools: f"anthropic_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
if not tools:
return 0
@@ -58,11 +76,21 @@ class TiktokenCounter(TokenCounter):
def __init__(self, model: str):
self.model = model
@async_redis_cache(
key_func=lambda self, text: f"tiktoken_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_text_tokens(self, text: str) -> int:
if not text:
return 0
return count_tokens(text)
@async_redis_cache(
key_func=lambda self, messages: f"tiktoken_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
if not messages:
return 0
@@ -70,6 +98,11 @@ class TiktokenCounter(TokenCounter):
return num_tokens_from_messages(messages=messages, model=self.model)
@async_redis_cache(
key_func=lambda self, tools: f"tiktoken_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
prefix="token_counter",
ttl_s=3600, # cache for 1 hour
)
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
if not tools:
return 0