Files
letta-server/tests/integration_test_token_counters.py
2025-12-15 12:02:19 -08:00

236 lines
8.9 KiB
Python

"""
Integration tests for token counting APIs.
These tests verify that the token counting implementations actually hit the real APIs
for Anthropic, Google Gemini, and OpenAI (tiktoken) by calling get_context_window
on an imported agent.
"""
import json
import os
import pytest
from letta.config import LettaConfig
from letta.orm import Base
from letta.schemas.agent import UpdateAgent
from letta.schemas.agent_file import AgentFileSchema
from letta.schemas.llm_config import LLMConfig
from letta.schemas.organization import Organization
from letta.schemas.user import User
from letta.server.server import SyncServer
# ============================================================================
# LLM Configs to test
# ============================================================================
def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig:
"""Load LLM configuration from JSON file."""
filename = os.path.join(llm_config_dir, filename)
with open(filename, "r") as f:
config_data = json.load(f)
return LLMConfig(**config_data)
LLM_CONFIG_FILES = [
"openai-gpt-4o-mini.json",
"claude-4-5-sonnet.json",
"gemini-2.5-pro.json",
]
LLM_CONFIGS = [pytest.param(get_llm_config(f), id=f.replace(".json", "")) for f in LLM_CONFIG_FILES]
# ============================================================================
# Fixtures
# ============================================================================
async def _clear_tables():
from letta.server.db import db_registry
async with db_registry.async_session() as session:
for table in reversed(Base.metadata.sorted_tables):
await session.execute(table.delete())
await session.commit()
@pytest.fixture(autouse=True)
async def clear_tables():
await _clear_tables()
@pytest.fixture
async def server():
config = LettaConfig.load()
config.save()
server = SyncServer(init_with_default_org_and_user=True)
await server.init_async()
await server.tool_manager.upsert_base_tools_async(actor=server.default_user)
yield server
@pytest.fixture
async def default_organization(server: SyncServer):
"""Fixture to create and return the default organization."""
org = await server.organization_manager.create_default_organization_async()
yield org
@pytest.fixture
async def default_user(server: SyncServer, default_organization):
"""Fixture to create and return the default user within the default organization."""
user = await server.user_manager.create_default_actor_async(org_id=default_organization.id)
yield user
@pytest.fixture
async def other_organization(server: SyncServer):
"""Fixture to create and return another organization."""
org = await server.organization_manager.create_organization_async(pydantic_org=Organization(name="test_org"))
yield org
@pytest.fixture
async def other_user(server: SyncServer, other_organization):
"""Fixture to create and return another user within the other organization."""
user = await server.user_manager.create_actor_async(pydantic_user=User(organization_id=other_organization.id, name="test_user"))
yield user
@pytest.fixture
async def imported_agent_id(server: SyncServer, other_user):
"""Import the test agent from the .af file and return the agent ID."""
file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", "test_agent.af")
with open(file_path, "r") as f:
agent_file_json = json.load(f)
agent_schema = AgentFileSchema.model_validate(agent_file_json)
import_result = await server.agent_serialization_manager.import_file(
schema=agent_schema,
actor=other_user,
append_copy_suffix=False,
override_existing_tools=True,
)
assert import_result.success, f"Failed to import agent: {import_result.message}"
# Get the imported agent ID
agent_id = next(db_id for file_id, db_id in import_result.id_mappings.items() if file_id.startswith("agent-"))
yield agent_id
# ============================================================================
# Token Counter Integration Test
# ============================================================================
@pytest.mark.asyncio
@pytest.mark.parametrize("llm_config", LLM_CONFIGS)
async def test_get_context_window(server: SyncServer, imported_agent_id: str, other_user, llm_config: LLMConfig):
"""Test get_context_window with different LLM providers."""
# Update the agent to use the specified LLM config
await server.agent_manager.update_agent_async(
agent_id=imported_agent_id,
agent_update=UpdateAgent(llm_config=llm_config),
actor=other_user,
)
# Call get_context_window which will use the appropriate token counting API
context_window = await server.agent_manager.get_context_window(agent_id=imported_agent_id, actor=other_user)
# Verify we got valid token counts
assert context_window.context_window_size_current > 0
assert context_window.num_tokens_system >= 0
assert context_window.num_tokens_messages >= 0
assert context_window.num_tokens_functions_definitions >= 0
print(f"{llm_config.model_endpoint_type} ({llm_config.model}) context window:")
print(f" Total tokens: {context_window.context_window_size_current}")
print(f" System tokens: {context_window.num_tokens_system}")
print(f" Message tokens: {context_window.num_tokens_messages}")
print(f" Function tokens: {context_window.num_tokens_functions_definitions}")
# ============================================================================
# Edge Case Tests
# ============================================================================
@pytest.mark.asyncio
@pytest.mark.parametrize("llm_config", LLM_CONFIGS)
async def test_count_empty_text_tokens(llm_config: LLMConfig):
"""Test that empty text returns 0 tokens for all providers."""
from letta.llm_api.anthropic_client import AnthropicClient
from letta.llm_api.google_ai_client import GoogleAIClient
from letta.llm_api.google_vertex_client import GoogleVertexClient
from letta.services.context_window_calculator.token_counter import (
AnthropicTokenCounter,
ApproxTokenCounter,
GeminiTokenCounter,
)
if llm_config.model_endpoint_type == "anthropic":
token_counter = AnthropicTokenCounter(AnthropicClient(), llm_config.model)
elif llm_config.model_endpoint_type in ("google_vertex", "google_ai"):
client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient()
token_counter = GeminiTokenCounter(client, llm_config.model)
else:
token_counter = ApproxTokenCounter()
token_count = await token_counter.count_text_tokens("")
assert token_count == 0
@pytest.mark.asyncio
@pytest.mark.parametrize("llm_config", LLM_CONFIGS)
async def test_count_empty_messages_tokens(llm_config: LLMConfig):
"""Test that empty message list returns 0 tokens for all providers."""
from letta.llm_api.anthropic_client import AnthropicClient
from letta.llm_api.google_ai_client import GoogleAIClient
from letta.llm_api.google_vertex_client import GoogleVertexClient
from letta.services.context_window_calculator.token_counter import (
AnthropicTokenCounter,
ApproxTokenCounter,
GeminiTokenCounter,
)
if llm_config.model_endpoint_type == "anthropic":
token_counter = AnthropicTokenCounter(AnthropicClient(), llm_config.model)
elif llm_config.model_endpoint_type in ("google_vertex", "google_ai"):
client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient()
token_counter = GeminiTokenCounter(client, llm_config.model)
else:
token_counter = ApproxTokenCounter()
token_count = await token_counter.count_message_tokens([])
assert token_count == 0
@pytest.mark.asyncio
@pytest.mark.parametrize("llm_config", LLM_CONFIGS)
async def test_count_empty_tools_tokens(llm_config: LLMConfig):
"""Test that empty tools list returns 0 tokens for all providers."""
from letta.llm_api.anthropic_client import AnthropicClient
from letta.llm_api.google_ai_client import GoogleAIClient
from letta.llm_api.google_vertex_client import GoogleVertexClient
from letta.services.context_window_calculator.token_counter import (
AnthropicTokenCounter,
ApproxTokenCounter,
GeminiTokenCounter,
)
if llm_config.model_endpoint_type == "anthropic":
token_counter = AnthropicTokenCounter(AnthropicClient(), llm_config.model)
elif llm_config.model_endpoint_type in ("google_vertex", "google_ai"):
client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient()
token_counter = GeminiTokenCounter(client, llm_config.model)
else:
token_counter = ApproxTokenCounter()
token_count = await token_counter.count_tool_tokens([])
assert token_count == 0