feat: add gemini token counting [LET-6371] (#6444)

This commit is contained in:
Sarah Wooders
2025-11-28 19:49:02 -08:00
committed by Caren Thomas
parent d3f5307789
commit 807c5c18d9
8 changed files with 379 additions and 19 deletions

View File

@@ -0,0 +1,235 @@
"""
Integration tests for token counting APIs.
These tests verify that the token counting implementations actually hit the real APIs
for Anthropic, Google Gemini, and OpenAI (tiktoken) by calling get_context_window
on an imported agent.
"""
import json
import os
import pytest
from letta.config import LettaConfig
from letta.orm import Base
from letta.schemas.agent import UpdateAgent
from letta.schemas.agent_file import AgentFileSchema
from letta.schemas.llm_config import LLMConfig
from letta.schemas.organization import Organization
from letta.schemas.user import User
from letta.server.server import SyncServer
# ============================================================================
# LLM Configs to test
# ============================================================================
def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig:
"""Load LLM configuration from JSON file."""
filename = os.path.join(llm_config_dir, filename)
with open(filename, "r") as f:
config_data = json.load(f)
return LLMConfig(**config_data)
LLM_CONFIG_FILES = [
"openai-gpt-4o-mini.json",
"claude-4-5-sonnet.json",
"gemini-2.5-pro.json",
]
LLM_CONFIGS = [pytest.param(get_llm_config(f), id=f.replace(".json", "")) for f in LLM_CONFIG_FILES]
# ============================================================================
# Fixtures
# ============================================================================
async def _clear_tables():
from letta.server.db import db_registry
async with db_registry.async_session() as session:
for table in reversed(Base.metadata.sorted_tables):
await session.execute(table.delete())
await session.commit()
@pytest.fixture(autouse=True)
async def clear_tables():
await _clear_tables()
@pytest.fixture
async def server():
config = LettaConfig.load()
config.save()
server = SyncServer(init_with_default_org_and_user=True)
await server.init_async()
await server.tool_manager.upsert_base_tools_async(actor=server.default_user)
yield server
@pytest.fixture
async def default_organization(server: SyncServer):
"""Fixture to create and return the default organization."""
org = await server.organization_manager.create_default_organization_async()
yield org
@pytest.fixture
async def default_user(server: SyncServer, default_organization):
"""Fixture to create and return the default user within the default organization."""
user = await server.user_manager.create_default_actor_async(org_id=default_organization.id)
yield user
@pytest.fixture
async def other_organization(server: SyncServer):
"""Fixture to create and return another organization."""
org = await server.organization_manager.create_organization_async(pydantic_org=Organization(name="test_org"))
yield org
@pytest.fixture
async def other_user(server: SyncServer, other_organization):
"""Fixture to create and return another user within the other organization."""
user = await server.user_manager.create_actor_async(pydantic_user=User(organization_id=other_organization.id, name="test_user"))
yield user
@pytest.fixture
async def imported_agent_id(server: SyncServer, other_user):
"""Import the test agent from the .af file and return the agent ID."""
file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", "test_agent.af")
with open(file_path, "r") as f:
agent_file_json = json.load(f)
agent_schema = AgentFileSchema.model_validate(agent_file_json)
import_result = await server.agent_serialization_manager.import_file(
schema=agent_schema,
actor=other_user,
append_copy_suffix=False,
override_existing_tools=True,
)
assert import_result.success, f"Failed to import agent: {import_result.message}"
# Get the imported agent ID
agent_id = next(db_id for file_id, db_id in import_result.id_mappings.items() if file_id.startswith("agent-"))
yield agent_id
# ============================================================================
# Token Counter Integration Test
# ============================================================================
@pytest.mark.asyncio
@pytest.mark.parametrize("llm_config", LLM_CONFIGS)
async def test_get_context_window(server: SyncServer, imported_agent_id: str, other_user, llm_config: LLMConfig):
"""Test get_context_window with different LLM providers."""
# Update the agent to use the specified LLM config
await server.agent_manager.update_agent_async(
agent_id=imported_agent_id,
agent_update=UpdateAgent(llm_config=llm_config),
actor=other_user,
)
# Call get_context_window which will use the appropriate token counting API
context_window = await server.agent_manager.get_context_window(agent_id=imported_agent_id, actor=other_user)
# Verify we got valid token counts
assert context_window.context_window_size_current > 0
assert context_window.num_tokens_system >= 0
assert context_window.num_tokens_messages >= 0
assert context_window.num_tokens_functions_definitions >= 0
print(f"{llm_config.model_endpoint_type} ({llm_config.model}) context window:")
print(f" Total tokens: {context_window.context_window_size_current}")
print(f" System tokens: {context_window.num_tokens_system}")
print(f" Message tokens: {context_window.num_tokens_messages}")
print(f" Function tokens: {context_window.num_tokens_functions_definitions}")
# ============================================================================
# Edge Case Tests
# ============================================================================
@pytest.mark.asyncio
@pytest.mark.parametrize("llm_config", LLM_CONFIGS)
async def test_count_empty_text_tokens(llm_config: LLMConfig):
"""Test that empty text returns 0 tokens for all providers."""
from letta.llm_api.anthropic_client import AnthropicClient
from letta.llm_api.google_ai_client import GoogleAIClient
from letta.llm_api.google_vertex_client import GoogleVertexClient
from letta.services.context_window_calculator.token_counter import (
AnthropicTokenCounter,
GeminiTokenCounter,
TiktokenCounter,
)
if llm_config.model_endpoint_type == "anthropic":
token_counter = AnthropicTokenCounter(AnthropicClient(), llm_config.model)
elif llm_config.model_endpoint_type in ("google_vertex", "google_ai"):
client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient()
token_counter = GeminiTokenCounter(client, llm_config.model)
else:
token_counter = TiktokenCounter(llm_config.model)
token_count = await token_counter.count_text_tokens("")
assert token_count == 0
@pytest.mark.asyncio
@pytest.mark.parametrize("llm_config", LLM_CONFIGS)
async def test_count_empty_messages_tokens(llm_config: LLMConfig):
"""Test that empty message list returns 0 tokens for all providers."""
from letta.llm_api.anthropic_client import AnthropicClient
from letta.llm_api.google_ai_client import GoogleAIClient
from letta.llm_api.google_vertex_client import GoogleVertexClient
from letta.services.context_window_calculator.token_counter import (
AnthropicTokenCounter,
GeminiTokenCounter,
TiktokenCounter,
)
if llm_config.model_endpoint_type == "anthropic":
token_counter = AnthropicTokenCounter(AnthropicClient(), llm_config.model)
elif llm_config.model_endpoint_type in ("google_vertex", "google_ai"):
client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient()
token_counter = GeminiTokenCounter(client, llm_config.model)
else:
token_counter = TiktokenCounter(llm_config.model)
token_count = await token_counter.count_message_tokens([])
assert token_count == 0
@pytest.mark.asyncio
@pytest.mark.parametrize("llm_config", LLM_CONFIGS)
async def test_count_empty_tools_tokens(llm_config: LLMConfig):
"""Test that empty tools list returns 0 tokens for all providers."""
from letta.llm_api.anthropic_client import AnthropicClient
from letta.llm_api.google_ai_client import GoogleAIClient
from letta.llm_api.google_vertex_client import GoogleVertexClient
from letta.services.context_window_calculator.token_counter import (
AnthropicTokenCounter,
GeminiTokenCounter,
TiktokenCounter,
)
if llm_config.model_endpoint_type == "anthropic":
token_counter = AnthropicTokenCounter(AnthropicClient(), llm_config.model)
elif llm_config.model_endpoint_type in ("google_vertex", "google_ai"):
client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient()
token_counter = GeminiTokenCounter(client, llm_config.model)
else:
token_counter = TiktokenCounter(llm_config.model)
token_count = await token_counter.count_tool_tokens([])
assert token_count == 0