fix: Add additional testing for anthropic token counting (#2619)

This commit is contained in:
Matthew Zhou
2025-06-03 20:56:39 -07:00
committed by GitHub
parent 3d8704395b
commit ebccd8176a
4 changed files with 65 additions and 19 deletions

View File

@@ -30,7 +30,7 @@ from letta.log import get_logger
from letta.schemas.enums import ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import Tool
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall
from letta.schemas.openai.chat_completion_response import Message as ChoiceMessage
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
@@ -199,10 +199,10 @@ class AnthropicClient(LLMClientBase):
elif llm_config.enable_reasoner:
# NOTE: reasoning models currently do not allow for `any`
tool_choice = {"type": "auto", "disable_parallel_tool_use": True}
tools_for_request = [Tool(function=f) for f in tools]
tools_for_request = [OpenAITool(function=f) for f in tools]
elif force_tool_call is not None:
tool_choice = {"type": "tool", "name": force_tool_call}
tools_for_request = [Tool(function=f) for f in tools if f["name"] == force_tool_call]
tools_for_request = [OpenAITool(function=f) for f in tools if f["name"] == force_tool_call]
# need to have this setting to be able to put inner thoughts in kwargs
if not llm_config.put_inner_thoughts_in_kwargs:
@@ -216,7 +216,7 @@ class AnthropicClient(LLMClientBase):
tool_choice = {"type": "any", "disable_parallel_tool_use": True}
else:
tool_choice = {"type": "auto", "disable_parallel_tool_use": True}
tools_for_request = [Tool(function=f) for f in tools] if tools is not None else None
tools_for_request = [OpenAITool(function=f) for f in tools] if tools is not None else None
# Add tool choice
if tool_choice:
@@ -230,7 +230,7 @@ class AnthropicClient(LLMClientBase):
inner_thoughts_key=INNER_THOUGHTS_KWARG,
inner_thoughts_description=INNER_THOUGHTS_KWARG_DESCRIPTION,
)
tools_for_request = [Tool(function=f) for f in tools_with_inner_thoughts]
tools_for_request = [OpenAITool(function=f) for f in tools_with_inner_thoughts]
if tools_for_request and len(tools_for_request) > 0:
# TODO eventually enable parallel tool use
@@ -270,7 +270,7 @@ class AnthropicClient(LLMClientBase):
return data
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[Tool] = None) -> int:
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int:
client = anthropic.AsyncAnthropic()
if messages and len(messages) == 0:
messages = None
@@ -278,11 +278,19 @@ class AnthropicClient(LLMClientBase):
anthropic_tools = convert_tools_to_anthropic_format(tools)
else:
anthropic_tools = None
result = await client.beta.messages.count_tokens(
model=model or "claude-3-7-sonnet-20250219",
messages=messages or [{"role": "user", "content": "hi"}],
tools=anthropic_tools or [],
)
try:
result = await client.beta.messages.count_tokens(
model=model or "claude-3-7-sonnet-20250219",
messages=messages or [{"role": "user", "content": "hi"}],
tools=anthropic_tools or [],
)
except:
import ipdb
ipdb.set_trace()
raise
token_count = result.input_tokens
if messages is None:
token_count -= 8
@@ -477,7 +485,7 @@ class AnthropicClient(LLMClientBase):
return chat_completion_response
def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]:
def convert_tools_to_anthropic_format(tools: List[OpenAITool]) -> List[dict]:
"""See: https://docs.anthropic.com/claude/docs/tool-use
OpenAI style:
@@ -527,7 +535,7 @@ def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]:
for tool in tools:
formatted_tool = {
"name": tool.function.name,
"description": tool.function.description,
"description": tool.function.description if tool.function.description else "",
"input_schema": tool.function.parameters or {"type": "object", "properties": {}, "required": []},
}
formatted_tools.append(formatted_tool)

View File

@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, List
from letta.llm_api.anthropic_client import AnthropicClient
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
from letta.utils import count_tokens
@@ -42,7 +43,7 @@ class AnthropicTokenCounter(TokenCounter):
return 0
return await self.client.count_tokens(model=self.model, messages=messages)
async def count_tool_tokens(self, tools: List[Any]) -> int:
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
if not tools:
return 0
return await self.client.count_tokens(model=self.model, tools=tools)
@@ -69,7 +70,7 @@ class TiktokenCounter(TokenCounter):
return num_tokens_from_messages(messages=messages, model=self.model)
async def count_tool_tokens(self, tools: List[Any]) -> int:
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
if not tools:
return 0
from letta.local_llm.utils import num_tokens_from_functions

View File

@@ -164,7 +164,9 @@ def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, Up
assert agent.message_buffer_autoclear == request.message_buffer_autoclear
def validate_context_window_overview(overview: ContextWindowOverview, attached_file: Optional[FileAgent] = None) -> None:
def validate_context_window_overview(
agent_state: AgentState, overview: ContextWindowOverview, attached_file: Optional[FileAgent] = None
) -> None:
"""Validate common sense assertions for ContextWindowOverview"""
# 1. Current context size should not exceed maximum
@@ -238,3 +240,7 @@ def validate_context_window_overview(overview: ContextWindowOverview, attached_f
assert attached_file.visible_content in overview.core_memory
assert '<file status="open">' in overview.core_memory
assert "</file>" in overview.core_memory
# Check for tools
assert overview.num_tokens_functions_definitions > 0
assert len(overview.functions_definitions) > 0

View File

@@ -706,11 +706,24 @@ async def test_create_get_list_agent(server: SyncServer, comprehensive_test_agen
assert len(list_agents) == 0
@pytest.fixture(params=["", "PRODUCTION"])
def set_letta_environment(request):
original = os.environ.get("LETTA_ENVIRONMENT")
os.environ["LETTA_ENVIRONMENT"] = request.param
yield request.param
# Restore original environment variable
if original is not None:
os.environ["LETTA_ENVIRONMENT"] = original
else:
os.environ.pop("LETTA_ENVIRONMENT", None)
@pytest.mark.asyncio
async def test_get_context_window_basic(server: SyncServer, comprehensive_test_agent_fixture, default_user, default_file, event_loop):
async def test_get_context_window_basic(
server: SyncServer, comprehensive_test_agent_fixture, default_user, default_file, event_loop, set_letta_environment
):
# Test agent creation
created_agent, create_agent_request = comprehensive_test_agent_fixture
comprehensive_agent_checks(created_agent, create_agent_request, actor=default_user)
# Attach a file
assoc = await server.file_agent_manager.attach_file(
@@ -723,7 +736,7 @@ async def test_get_context_window_basic(server: SyncServer, comprehensive_test_a
# Get context window and check for basic appearances
context_window_overview = await server.agent_manager.get_context_window(agent_id=created_agent.id, actor=default_user)
validate_context_window_overview(context_window_overview, assoc)
validate_context_window_overview(created_agent, context_window_overview, assoc)
# Test deleting the agent
server.agent_manager.delete_agent(created_agent.id, default_user)
@@ -731,6 +744,24 @@ async def test_get_context_window_basic(server: SyncServer, comprehensive_test_a
assert len(list_agents) == 0
@pytest.mark.asyncio
async def test_get_context_window_composio_tool(
server: SyncServer, comprehensive_test_agent_fixture, default_user, default_file, event_loop, set_letta_environment
):
# Test agent creation
created_agent, create_agent_request = comprehensive_test_agent_fixture
# Attach a composio tool
tool_create = ToolCreate.from_composio(action_name="GITHUB_GET_EMOJIS")
tool = server.tool_manager.create_or_update_composio_tool(tool_create=tool_create, actor=default_user)
created_agent = server.agent_manager.attach_tool(agent_id=created_agent.id, tool_id=tool.id, actor=default_user)
# Get context window and check for basic appearances
context_window_overview = await server.agent_manager.get_context_window(agent_id=created_agent.id, actor=default_user)
validate_context_window_overview(created_agent, context_window_overview)
@pytest.mark.asyncio
async def test_create_agent_passed_in_initial_messages(server: SyncServer, default_user, default_block, event_loop):
memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")]