fix: Add additional testing for anthropic token counting (#2619)
This commit is contained in:
@@ -30,7 +30,7 @@ from letta.log import get_logger
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import Tool
|
||||
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall
|
||||
from letta.schemas.openai.chat_completion_response import Message as ChoiceMessage
|
||||
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
||||
@@ -199,10 +199,10 @@ class AnthropicClient(LLMClientBase):
|
||||
elif llm_config.enable_reasoner:
|
||||
# NOTE: reasoning models currently do not allow for `any`
|
||||
tool_choice = {"type": "auto", "disable_parallel_tool_use": True}
|
||||
tools_for_request = [Tool(function=f) for f in tools]
|
||||
tools_for_request = [OpenAITool(function=f) for f in tools]
|
||||
elif force_tool_call is not None:
|
||||
tool_choice = {"type": "tool", "name": force_tool_call}
|
||||
tools_for_request = [Tool(function=f) for f in tools if f["name"] == force_tool_call]
|
||||
tools_for_request = [OpenAITool(function=f) for f in tools if f["name"] == force_tool_call]
|
||||
|
||||
# need to have this setting to be able to put inner thoughts in kwargs
|
||||
if not llm_config.put_inner_thoughts_in_kwargs:
|
||||
@@ -216,7 +216,7 @@ class AnthropicClient(LLMClientBase):
|
||||
tool_choice = {"type": "any", "disable_parallel_tool_use": True}
|
||||
else:
|
||||
tool_choice = {"type": "auto", "disable_parallel_tool_use": True}
|
||||
tools_for_request = [Tool(function=f) for f in tools] if tools is not None else None
|
||||
tools_for_request = [OpenAITool(function=f) for f in tools] if tools is not None else None
|
||||
|
||||
# Add tool choice
|
||||
if tool_choice:
|
||||
@@ -230,7 +230,7 @@ class AnthropicClient(LLMClientBase):
|
||||
inner_thoughts_key=INNER_THOUGHTS_KWARG,
|
||||
inner_thoughts_description=INNER_THOUGHTS_KWARG_DESCRIPTION,
|
||||
)
|
||||
tools_for_request = [Tool(function=f) for f in tools_with_inner_thoughts]
|
||||
tools_for_request = [OpenAITool(function=f) for f in tools_with_inner_thoughts]
|
||||
|
||||
if tools_for_request and len(tools_for_request) > 0:
|
||||
# TODO eventually enable parallel tool use
|
||||
@@ -270,7 +270,7 @@ class AnthropicClient(LLMClientBase):
|
||||
|
||||
return data
|
||||
|
||||
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[Tool] = None) -> int:
|
||||
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int:
|
||||
client = anthropic.AsyncAnthropic()
|
||||
if messages and len(messages) == 0:
|
||||
messages = None
|
||||
@@ -278,11 +278,19 @@ class AnthropicClient(LLMClientBase):
|
||||
anthropic_tools = convert_tools_to_anthropic_format(tools)
|
||||
else:
|
||||
anthropic_tools = None
|
||||
result = await client.beta.messages.count_tokens(
|
||||
model=model or "claude-3-7-sonnet-20250219",
|
||||
messages=messages or [{"role": "user", "content": "hi"}],
|
||||
tools=anthropic_tools or [],
|
||||
)
|
||||
|
||||
try:
|
||||
result = await client.beta.messages.count_tokens(
|
||||
model=model or "claude-3-7-sonnet-20250219",
|
||||
messages=messages or [{"role": "user", "content": "hi"}],
|
||||
tools=anthropic_tools or [],
|
||||
)
|
||||
except:
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
raise
|
||||
|
||||
token_count = result.input_tokens
|
||||
if messages is None:
|
||||
token_count -= 8
|
||||
@@ -477,7 +485,7 @@ class AnthropicClient(LLMClientBase):
|
||||
return chat_completion_response
|
||||
|
||||
|
||||
def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]:
|
||||
def convert_tools_to_anthropic_format(tools: List[OpenAITool]) -> List[dict]:
|
||||
"""See: https://docs.anthropic.com/claude/docs/tool-use
|
||||
|
||||
OpenAI style:
|
||||
@@ -527,7 +535,7 @@ def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]:
|
||||
for tool in tools:
|
||||
formatted_tool = {
|
||||
"name": tool.function.name,
|
||||
"description": tool.function.description,
|
||||
"description": tool.function.description if tool.function.description else "",
|
||||
"input_schema": tool.function.parameters or {"type": "object", "properties": {}, "required": []},
|
||||
}
|
||||
formatted_tools.append(formatted_tool)
|
||||
|
||||
@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from letta.llm_api.anthropic_client import AnthropicClient
|
||||
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
|
||||
from letta.utils import count_tokens
|
||||
|
||||
|
||||
@@ -42,7 +43,7 @@ class AnthropicTokenCounter(TokenCounter):
|
||||
return 0
|
||||
return await self.client.count_tokens(model=self.model, messages=messages)
|
||||
|
||||
async def count_tool_tokens(self, tools: List[Any]) -> int:
|
||||
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
|
||||
if not tools:
|
||||
return 0
|
||||
return await self.client.count_tokens(model=self.model, tools=tools)
|
||||
@@ -69,7 +70,7 @@ class TiktokenCounter(TokenCounter):
|
||||
|
||||
return num_tokens_from_messages(messages=messages, model=self.model)
|
||||
|
||||
async def count_tool_tokens(self, tools: List[Any]) -> int:
|
||||
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
|
||||
if not tools:
|
||||
return 0
|
||||
from letta.local_llm.utils import num_tokens_from_functions
|
||||
|
||||
@@ -164,7 +164,9 @@ def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, Up
|
||||
assert agent.message_buffer_autoclear == request.message_buffer_autoclear
|
||||
|
||||
|
||||
def validate_context_window_overview(overview: ContextWindowOverview, attached_file: Optional[FileAgent] = None) -> None:
|
||||
def validate_context_window_overview(
|
||||
agent_state: AgentState, overview: ContextWindowOverview, attached_file: Optional[FileAgent] = None
|
||||
) -> None:
|
||||
"""Validate common sense assertions for ContextWindowOverview"""
|
||||
|
||||
# 1. Current context size should not exceed maximum
|
||||
@@ -238,3 +240,7 @@ def validate_context_window_overview(overview: ContextWindowOverview, attached_f
|
||||
assert attached_file.visible_content in overview.core_memory
|
||||
assert '<file status="open">' in overview.core_memory
|
||||
assert "</file>" in overview.core_memory
|
||||
|
||||
# Check for tools
|
||||
assert overview.num_tokens_functions_definitions > 0
|
||||
assert len(overview.functions_definitions) > 0
|
||||
|
||||
@@ -706,11 +706,24 @@ async def test_create_get_list_agent(server: SyncServer, comprehensive_test_agen
|
||||
assert len(list_agents) == 0
|
||||
|
||||
|
||||
@pytest.fixture(params=["", "PRODUCTION"])
|
||||
def set_letta_environment(request):
|
||||
original = os.environ.get("LETTA_ENVIRONMENT")
|
||||
os.environ["LETTA_ENVIRONMENT"] = request.param
|
||||
yield request.param
|
||||
# Restore original environment variable
|
||||
if original is not None:
|
||||
os.environ["LETTA_ENVIRONMENT"] = original
|
||||
else:
|
||||
os.environ.pop("LETTA_ENVIRONMENT", None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_window_basic(server: SyncServer, comprehensive_test_agent_fixture, default_user, default_file, event_loop):
|
||||
async def test_get_context_window_basic(
|
||||
server: SyncServer, comprehensive_test_agent_fixture, default_user, default_file, event_loop, set_letta_environment
|
||||
):
|
||||
# Test agent creation
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
comprehensive_agent_checks(created_agent, create_agent_request, actor=default_user)
|
||||
|
||||
# Attach a file
|
||||
assoc = await server.file_agent_manager.attach_file(
|
||||
@@ -723,7 +736,7 @@ async def test_get_context_window_basic(server: SyncServer, comprehensive_test_a
|
||||
|
||||
# Get context window and check for basic appearances
|
||||
context_window_overview = await server.agent_manager.get_context_window(agent_id=created_agent.id, actor=default_user)
|
||||
validate_context_window_overview(context_window_overview, assoc)
|
||||
validate_context_window_overview(created_agent, context_window_overview, assoc)
|
||||
|
||||
# Test deleting the agent
|
||||
server.agent_manager.delete_agent(created_agent.id, default_user)
|
||||
@@ -731,6 +744,24 @@ async def test_get_context_window_basic(server: SyncServer, comprehensive_test_a
|
||||
assert len(list_agents) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_window_composio_tool(
|
||||
server: SyncServer, comprehensive_test_agent_fixture, default_user, default_file, event_loop, set_letta_environment
|
||||
):
|
||||
# Test agent creation
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# Attach a composio tool
|
||||
tool_create = ToolCreate.from_composio(action_name="GITHUB_GET_EMOJIS")
|
||||
tool = server.tool_manager.create_or_update_composio_tool(tool_create=tool_create, actor=default_user)
|
||||
|
||||
created_agent = server.agent_manager.attach_tool(agent_id=created_agent.id, tool_id=tool.id, actor=default_user)
|
||||
|
||||
# Get context window and check for basic appearances
|
||||
context_window_overview = await server.agent_manager.get_context_window(agent_id=created_agent.id, actor=default_user)
|
||||
validate_context_window_overview(created_agent, context_window_overview)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_passed_in_initial_messages(server: SyncServer, default_user, default_block, event_loop):
|
||||
memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")]
|
||||
|
||||
Reference in New Issue
Block a user