From ebccd8176a957bb0a807fed13d174f4f6196565a Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 3 Jun 2025 20:56:39 -0700 Subject: [PATCH] fix: Add additional testing for anthropic token counting (#2619) --- letta/llm_api/anthropic_client.py | 34 ++++++++++------- .../token_counter.py | 5 ++- tests/helpers/utils.py | 8 +++- tests/test_managers.py | 37 +++++++++++++++++-- 4 files changed, 65 insertions(+), 19 deletions(-) diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index eefee965..3a60199a 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -30,7 +30,7 @@ from letta.log import get_logger from letta.schemas.enums import ProviderCategory from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage -from letta.schemas.openai.chat_completion_request import Tool +from letta.schemas.openai.chat_completion_request import Tool as OpenAITool from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall from letta.schemas.openai.chat_completion_response import Message as ChoiceMessage from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics @@ -199,10 +199,10 @@ class AnthropicClient(LLMClientBase): elif llm_config.enable_reasoner: # NOTE: reasoning models currently do not allow for `any` tool_choice = {"type": "auto", "disable_parallel_tool_use": True} - tools_for_request = [Tool(function=f) for f in tools] + tools_for_request = [OpenAITool(function=f) for f in tools] elif force_tool_call is not None: tool_choice = {"type": "tool", "name": force_tool_call} - tools_for_request = [Tool(function=f) for f in tools if f["name"] == force_tool_call] + tools_for_request = [OpenAITool(function=f) for f in tools if f["name"] == force_tool_call] # need to have this setting to be able to put inner thoughts in kwargs if not llm_config.put_inner_thoughts_in_kwargs: @@ -216,7 +216,7 @@ class AnthropicClient(LLMClientBase): tool_choice = {"type": "any", "disable_parallel_tool_use": True} else: tool_choice = {"type": "auto", "disable_parallel_tool_use": True} - tools_for_request = [Tool(function=f) for f in tools] if tools is not None else None + tools_for_request = [OpenAITool(function=f) for f in tools] if tools is not None else None # Add tool choice if tool_choice: @@ -230,7 +230,7 @@ class AnthropicClient(LLMClientBase): inner_thoughts_key=INNER_THOUGHTS_KWARG, inner_thoughts_description=INNER_THOUGHTS_KWARG_DESCRIPTION, ) - tools_for_request = [Tool(function=f) for f in tools_with_inner_thoughts] + tools_for_request = [OpenAITool(function=f) for f in tools_with_inner_thoughts] if tools_for_request and len(tools_for_request) > 0: # TODO eventually enable parallel tool use @@ -270,7 +270,7 @@ class AnthropicClient(LLMClientBase): return data - async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[Tool] = None) -> int: + async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int: client = anthropic.AsyncAnthropic() if messages and len(messages) == 0: messages = None @@ -278,11 +278,19 @@ class AnthropicClient(LLMClientBase): anthropic_tools = convert_tools_to_anthropic_format(tools) else: anthropic_tools = None - result = await client.beta.messages.count_tokens( - model=model or "claude-3-7-sonnet-20250219", - messages=messages or [{"role": "user", "content": "hi"}], - tools=anthropic_tools or [], - ) + + try: + result = await client.beta.messages.count_tokens( + model=model or "claude-3-7-sonnet-20250219", + messages=messages or [{"role": "user", "content": "hi"}], + tools=anthropic_tools or [], + ) + except: + import ipdb + + ipdb.set_trace() + raise + token_count = result.input_tokens if messages is None: token_count -= 8 @@ -477,7 +485,7 @@ class AnthropicClient(LLMClientBase): return chat_completion_response -def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]: +def convert_tools_to_anthropic_format(tools: List[OpenAITool]) -> List[dict]: """See: https://docs.anthropic.com/claude/docs/tool-use OpenAI style: @@ -527,7 +535,7 @@ def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]: for tool in tools: formatted_tool = { "name": tool.function.name, - "description": tool.function.description, + "description": tool.function.description if tool.function.description else "", "input_schema": tool.function.parameters or {"type": "object", "properties": {}, "required": []}, } formatted_tools.append(formatted_tool) diff --git a/letta/services/context_window_calculator/token_counter.py b/letta/services/context_window_calculator/token_counter.py index 764b71c3..3e1de4f7 100644 --- a/letta/services/context_window_calculator/token_counter.py +++ b/letta/services/context_window_calculator/token_counter.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List from letta.llm_api.anthropic_client import AnthropicClient +from letta.schemas.openai.chat_completion_request import Tool as OpenAITool from letta.utils import count_tokens @@ -42,7 +43,7 @@ class AnthropicTokenCounter(TokenCounter): return 0 return await self.client.count_tokens(model=self.model, messages=messages) - async def count_tool_tokens(self, tools: List[Any]) -> int: + async def count_tool_tokens(self, tools: List[OpenAITool]) -> int: if not tools: return 0 return await self.client.count_tokens(model=self.model, tools=tools) @@ -69,7 +70,7 @@ class TiktokenCounter(TokenCounter): return num_tokens_from_messages(messages=messages, model=self.model) - async def count_tool_tokens(self, tools: List[Any]) -> int: + async def count_tool_tokens(self, tools: List[OpenAITool]) -> int: if not tools: return 0 from letta.local_llm.utils import num_tokens_from_functions diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 2dfa5cca..2a7ec229 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -164,7 +164,9 @@ def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, Up assert agent.message_buffer_autoclear == request.message_buffer_autoclear -def validate_context_window_overview(overview: ContextWindowOverview, attached_file: Optional[FileAgent] = None) -> None: +def validate_context_window_overview( + agent_state: AgentState, overview: ContextWindowOverview, attached_file: Optional[FileAgent] = None +) -> None: """Validate common sense assertions for ContextWindowOverview""" # 1. Current context size should not exceed maximum @@ -238,3 +240,7 @@ def validate_context_window_overview(overview: ContextWindowOverview, attached_f assert attached_file.visible_content in overview.core_memory assert '' in overview.core_memory assert "" in overview.core_memory + + # Check for tools + assert overview.num_tokens_functions_definitions > 0 + assert len(overview.functions_definitions) > 0 diff --git a/tests/test_managers.py b/tests/test_managers.py index 007f5993..6331e266 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -706,11 +706,24 @@ async def test_create_get_list_agent(server: SyncServer, comprehensive_test_agen assert len(list_agents) == 0 +@pytest.fixture(params=["", "PRODUCTION"]) +def set_letta_environment(request): + original = os.environ.get("LETTA_ENVIRONMENT") + os.environ["LETTA_ENVIRONMENT"] = request.param + yield request.param + # Restore original environment variable + if original is not None: + os.environ["LETTA_ENVIRONMENT"] = original + else: + os.environ.pop("LETTA_ENVIRONMENT", None) + + @pytest.mark.asyncio -async def test_get_context_window_basic(server: SyncServer, comprehensive_test_agent_fixture, default_user, default_file, event_loop): +async def test_get_context_window_basic( + server: SyncServer, comprehensive_test_agent_fixture, default_user, default_file, event_loop, set_letta_environment +): # Test agent creation created_agent, create_agent_request = comprehensive_test_agent_fixture - comprehensive_agent_checks(created_agent, create_agent_request, actor=default_user) # Attach a file assoc = await server.file_agent_manager.attach_file( @@ -723,7 +736,7 @@ async def test_get_context_window_basic(server: SyncServer, comprehensive_test_a # Get context window and check for basic appearances context_window_overview = await server.agent_manager.get_context_window(agent_id=created_agent.id, actor=default_user) - validate_context_window_overview(context_window_overview, assoc) + validate_context_window_overview(created_agent, context_window_overview, assoc) # Test deleting the agent server.agent_manager.delete_agent(created_agent.id, default_user) @@ -731,6 +744,24 @@ async def test_get_context_window_basic(server: SyncServer, comprehensive_test_a assert len(list_agents) == 0 +@pytest.mark.asyncio +async def test_get_context_window_composio_tool( + server: SyncServer, comprehensive_test_agent_fixture, default_user, default_file, event_loop, set_letta_environment +): + # Test agent creation + created_agent, create_agent_request = comprehensive_test_agent_fixture + + # Attach a composio tool + tool_create = ToolCreate.from_composio(action_name="GITHUB_GET_EMOJIS") + tool = server.tool_manager.create_or_update_composio_tool(tool_create=tool_create, actor=default_user) + + created_agent = server.agent_manager.attach_tool(agent_id=created_agent.id, tool_id=tool.id, actor=default_user) + + # Get context window and check for basic appearances + context_window_overview = await server.agent_manager.get_context_window(agent_id=created_agent.id, actor=default_user) + validate_context_window_overview(created_agent, context_window_overview) + + @pytest.mark.asyncio async def test_create_agent_passed_in_initial_messages(server: SyncServer, default_user, default_block, event_loop): memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")]