diff --git a/letta/services/context_window_calculator/context_window_calculator.py b/letta/services/context_window_calculator/context_window_calculator.py index 47a9aacd..b3a89028 100644 --- a/letta/services/context_window_calculator/context_window_calculator.py +++ b/letta/services/context_window_calculator/context_window_calculator.py @@ -4,6 +4,7 @@ from typing import Any, List, Optional, Tuple from openai.types.beta.function_tool import FunctionTool as OpenAITool from letta.log import get_logger +from letta.schemas.agent import AgentState from letta.schemas.enums import MessageRole from letta.schemas.letta_message_content import TextContent from letta.schemas.memory import ContextWindowOverview @@ -56,7 +57,7 @@ class ContextWindowCalculator: return None, 1 async def calculate_context_window( - self, agent_state: Any, actor: PydanticUser, token_counter: TokenCounter, message_manager: Any, passage_manager: Any + self, agent_state: AgentState, actor: PydanticUser, token_counter: TokenCounter, message_manager: Any, passage_manager: Any ) -> ContextWindowOverview: """Calculate context window information using the provided token counter""" diff --git a/letta/services/context_window_calculator/token_counter.py b/letta/services/context_window_calculator/token_counter.py index 3e1de4f7..1ec4a3fc 100644 --- a/letta/services/context_window_calculator/token_counter.py +++ b/letta/services/context_window_calculator/token_counter.py @@ -1,6 +1,9 @@ +import hashlib +import json from abc import ABC, abstractmethod from typing import Any, Dict, List +from letta.helpers.decorators import async_redis_cache from letta.llm_api.anthropic_client import AnthropicClient from letta.schemas.openai.chat_completion_request import Tool as OpenAITool from letta.utils import count_tokens @@ -33,16 +36,31 @@ class AnthropicTokenCounter(TokenCounter): self.client = anthropic_client self.model = model + @async_redis_cache( + key_func=lambda self, text: f"anthropic_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}", + prefix="token_counter", + ttl_s=3600, # cache for 1 hour + ) async def count_text_tokens(self, text: str) -> int: if not text: return 0 return await self.client.count_tokens(model=self.model, messages=[{"role": "user", "content": text}]) + @async_redis_cache( + key_func=lambda self, messages: f"anthropic_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}", + prefix="token_counter", + ttl_s=3600, # cache for 1 hour + ) async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int: if not messages: return 0 return await self.client.count_tokens(model=self.model, messages=messages) + @async_redis_cache( + key_func=lambda self, tools: f"anthropic_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}", + prefix="token_counter", + ttl_s=3600, # cache for 1 hour + ) async def count_tool_tokens(self, tools: List[OpenAITool]) -> int: if not tools: return 0 @@ -58,11 +76,21 @@ class TiktokenCounter(TokenCounter): def __init__(self, model: str): self.model = model + @async_redis_cache( + key_func=lambda self, text: f"tiktoken_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}", + prefix="token_counter", + ttl_s=3600, # cache for 1 hour + ) async def count_text_tokens(self, text: str) -> int: if not text: return 0 return count_tokens(text) + @async_redis_cache( + key_func=lambda self, messages: f"tiktoken_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}", + prefix="token_counter", + ttl_s=3600, # cache for 1 hour + ) async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int: if not messages: return 0 @@ -70,6 +98,11 @@ class TiktokenCounter(TokenCounter): return num_tokens_from_messages(messages=messages, model=self.model) + @async_redis_cache( + key_func=lambda self, tools: f"tiktoken_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}", + prefix="token_counter", + ttl_s=3600, # cache for 1 hour + ) async def count_tool_tokens(self, tools: List[OpenAITool]) -> int: if not tools: return 0