feat: Cache anthropic tokenizer (#3286)
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import Any, List, Optional, Tuple
|
||||
from openai.types.beta.function_tool import FunctionTool as OpenAITool
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.memory import ContextWindowOverview
|
||||
@@ -56,7 +57,7 @@ class ContextWindowCalculator:
|
||||
return None, 1
|
||||
|
||||
async def calculate_context_window(
|
||||
self, agent_state: Any, actor: PydanticUser, token_counter: TokenCounter, message_manager: Any, passage_manager: Any
|
||||
self, agent_state: AgentState, actor: PydanticUser, token_counter: TokenCounter, message_manager: Any, passage_manager: Any
|
||||
) -> ContextWindowOverview:
|
||||
"""Calculate context window information using the provided token counter"""
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import hashlib
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from letta.helpers.decorators import async_redis_cache
|
||||
from letta.llm_api.anthropic_client import AnthropicClient
|
||||
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
|
||||
from letta.utils import count_tokens
|
||||
@@ -33,16 +36,31 @@ class AnthropicTokenCounter(TokenCounter):
|
||||
self.client = anthropic_client
|
||||
self.model = model
|
||||
|
||||
@async_redis_cache(
|
||||
key_func=lambda self, text: f"anthropic_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}",
|
||||
prefix="token_counter",
|
||||
ttl_s=3600, # cache for 1 hour
|
||||
)
|
||||
async def count_text_tokens(self, text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
return await self.client.count_tokens(model=self.model, messages=[{"role": "user", "content": text}])
|
||||
|
||||
@async_redis_cache(
|
||||
key_func=lambda self, messages: f"anthropic_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
|
||||
prefix="token_counter",
|
||||
ttl_s=3600, # cache for 1 hour
|
||||
)
|
||||
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
||||
if not messages:
|
||||
return 0
|
||||
return await self.client.count_tokens(model=self.model, messages=messages)
|
||||
|
||||
@async_redis_cache(
|
||||
key_func=lambda self, tools: f"anthropic_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
|
||||
prefix="token_counter",
|
||||
ttl_s=3600, # cache for 1 hour
|
||||
)
|
||||
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
|
||||
if not tools:
|
||||
return 0
|
||||
@@ -58,11 +76,21 @@ class TiktokenCounter(TokenCounter):
|
||||
def __init__(self, model: str):
|
||||
self.model = model
|
||||
|
||||
@async_redis_cache(
|
||||
key_func=lambda self, text: f"tiktoken_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}",
|
||||
prefix="token_counter",
|
||||
ttl_s=3600, # cache for 1 hour
|
||||
)
|
||||
async def count_text_tokens(self, text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
return count_tokens(text)
|
||||
|
||||
@async_redis_cache(
|
||||
key_func=lambda self, messages: f"tiktoken_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
|
||||
prefix="token_counter",
|
||||
ttl_s=3600, # cache for 1 hour
|
||||
)
|
||||
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
||||
if not messages:
|
||||
return 0
|
||||
@@ -70,6 +98,11 @@ class TiktokenCounter(TokenCounter):
|
||||
|
||||
return num_tokens_from_messages(messages=messages, model=self.model)
|
||||
|
||||
@async_redis_cache(
|
||||
key_func=lambda self, tools: f"tiktoken_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
|
||||
prefix="token_counter",
|
||||
ttl_s=3600, # cache for 1 hour
|
||||
)
|
||||
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
|
||||
if not tools:
|
||||
return 0
|
||||
|
||||
Reference in New Issue
Block a user