feat: add gemini token counting [LET-6371] (#6444)
This commit is contained in:
committed by
Caren Thomas
parent
d3f5307789
commit
807c5c18d9
Binary file not shown.
Binary file not shown.
Binary file not shown.
235
tests/integration_test_token_counters.py
Normal file
235
tests/integration_test_token_counters.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
Integration tests for token counting APIs.
|
||||
|
||||
These tests verify that the token counting implementations actually hit the real APIs
|
||||
for Anthropic, Google Gemini, and OpenAI (tiktoken) by calling get_context_window
|
||||
on an imported agent.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.orm import Base
|
||||
from letta.schemas.agent import UpdateAgent
|
||||
from letta.schemas.agent_file import AgentFileSchema
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.user import User
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
# ============================================================================
|
||||
# LLM Configs to test
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig:
|
||||
"""Load LLM configuration from JSON file."""
|
||||
filename = os.path.join(llm_config_dir, filename)
|
||||
with open(filename, "r") as f:
|
||||
config_data = json.load(f)
|
||||
return LLMConfig(**config_data)
|
||||
|
||||
|
||||
LLM_CONFIG_FILES = [
|
||||
"openai-gpt-4o-mini.json",
|
||||
"claude-4-5-sonnet.json",
|
||||
"gemini-2.5-pro.json",
|
||||
]
|
||||
|
||||
LLM_CONFIGS = [pytest.param(get_llm_config(f), id=f.replace(".json", "")) for f in LLM_CONFIG_FILES]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def _clear_tables():
|
||||
from letta.server.db import db_registry
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
for table in reversed(Base.metadata.sorted_tables):
|
||||
await session.execute(table.delete())
|
||||
await session.commit()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def clear_tables():
|
||||
await _clear_tables()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def server():
|
||||
config = LettaConfig.load()
|
||||
config.save()
|
||||
server = SyncServer(init_with_default_org_and_user=True)
|
||||
await server.init_async()
|
||||
await server.tool_manager.upsert_base_tools_async(actor=server.default_user)
|
||||
yield server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def default_organization(server: SyncServer):
|
||||
"""Fixture to create and return the default organization."""
|
||||
org = await server.organization_manager.create_default_organization_async()
|
||||
yield org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def default_user(server: SyncServer, default_organization):
|
||||
"""Fixture to create and return the default user within the default organization."""
|
||||
user = await server.user_manager.create_default_actor_async(org_id=default_organization.id)
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def other_organization(server: SyncServer):
|
||||
"""Fixture to create and return another organization."""
|
||||
org = await server.organization_manager.create_organization_async(pydantic_org=Organization(name="test_org"))
|
||||
yield org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def other_user(server: SyncServer, other_organization):
|
||||
"""Fixture to create and return another user within the other organization."""
|
||||
user = await server.user_manager.create_actor_async(pydantic_user=User(organization_id=other_organization.id, name="test_user"))
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def imported_agent_id(server: SyncServer, other_user):
|
||||
"""Import the test agent from the .af file and return the agent ID."""
|
||||
file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", "test_agent.af")
|
||||
|
||||
with open(file_path, "r") as f:
|
||||
agent_file_json = json.load(f)
|
||||
|
||||
agent_schema = AgentFileSchema.model_validate(agent_file_json)
|
||||
|
||||
import_result = await server.agent_serialization_manager.import_file(
|
||||
schema=agent_schema,
|
||||
actor=other_user,
|
||||
append_copy_suffix=False,
|
||||
override_existing_tools=True,
|
||||
)
|
||||
|
||||
assert import_result.success, f"Failed to import agent: {import_result.message}"
|
||||
|
||||
# Get the imported agent ID
|
||||
agent_id = next(db_id for file_id, db_id in import_result.id_mappings.items() if file_id.startswith("agent-"))
|
||||
yield agent_id
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Counter Integration Test
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("llm_config", LLM_CONFIGS)
|
||||
async def test_get_context_window(server: SyncServer, imported_agent_id: str, other_user, llm_config: LLMConfig):
|
||||
"""Test get_context_window with different LLM providers."""
|
||||
# Update the agent to use the specified LLM config
|
||||
await server.agent_manager.update_agent_async(
|
||||
agent_id=imported_agent_id,
|
||||
agent_update=UpdateAgent(llm_config=llm_config),
|
||||
actor=other_user,
|
||||
)
|
||||
|
||||
# Call get_context_window which will use the appropriate token counting API
|
||||
context_window = await server.agent_manager.get_context_window(agent_id=imported_agent_id, actor=other_user)
|
||||
|
||||
# Verify we got valid token counts
|
||||
assert context_window.context_window_size_current > 0
|
||||
assert context_window.num_tokens_system >= 0
|
||||
assert context_window.num_tokens_messages >= 0
|
||||
assert context_window.num_tokens_functions_definitions >= 0
|
||||
|
||||
print(f"{llm_config.model_endpoint_type} ({llm_config.model}) context window:")
|
||||
print(f" Total tokens: {context_window.context_window_size_current}")
|
||||
print(f" System tokens: {context_window.num_tokens_system}")
|
||||
print(f" Message tokens: {context_window.num_tokens_messages}")
|
||||
print(f" Function tokens: {context_window.num_tokens_functions_definitions}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Edge Case Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("llm_config", LLM_CONFIGS)
|
||||
async def test_count_empty_text_tokens(llm_config: LLMConfig):
|
||||
"""Test that empty text returns 0 tokens for all providers."""
|
||||
from letta.llm_api.anthropic_client import AnthropicClient
|
||||
from letta.llm_api.google_ai_client import GoogleAIClient
|
||||
from letta.llm_api.google_vertex_client import GoogleVertexClient
|
||||
from letta.services.context_window_calculator.token_counter import (
|
||||
AnthropicTokenCounter,
|
||||
GeminiTokenCounter,
|
||||
TiktokenCounter,
|
||||
)
|
||||
|
||||
if llm_config.model_endpoint_type == "anthropic":
|
||||
token_counter = AnthropicTokenCounter(AnthropicClient(), llm_config.model)
|
||||
elif llm_config.model_endpoint_type in ("google_vertex", "google_ai"):
|
||||
client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient()
|
||||
token_counter = GeminiTokenCounter(client, llm_config.model)
|
||||
else:
|
||||
token_counter = TiktokenCounter(llm_config.model)
|
||||
|
||||
token_count = await token_counter.count_text_tokens("")
|
||||
assert token_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("llm_config", LLM_CONFIGS)
|
||||
async def test_count_empty_messages_tokens(llm_config: LLMConfig):
|
||||
"""Test that empty message list returns 0 tokens for all providers."""
|
||||
from letta.llm_api.anthropic_client import AnthropicClient
|
||||
from letta.llm_api.google_ai_client import GoogleAIClient
|
||||
from letta.llm_api.google_vertex_client import GoogleVertexClient
|
||||
from letta.services.context_window_calculator.token_counter import (
|
||||
AnthropicTokenCounter,
|
||||
GeminiTokenCounter,
|
||||
TiktokenCounter,
|
||||
)
|
||||
|
||||
if llm_config.model_endpoint_type == "anthropic":
|
||||
token_counter = AnthropicTokenCounter(AnthropicClient(), llm_config.model)
|
||||
elif llm_config.model_endpoint_type in ("google_vertex", "google_ai"):
|
||||
client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient()
|
||||
token_counter = GeminiTokenCounter(client, llm_config.model)
|
||||
else:
|
||||
token_counter = TiktokenCounter(llm_config.model)
|
||||
|
||||
token_count = await token_counter.count_message_tokens([])
|
||||
assert token_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("llm_config", LLM_CONFIGS)
|
||||
async def test_count_empty_tools_tokens(llm_config: LLMConfig):
|
||||
"""Test that empty tools list returns 0 tokens for all providers."""
|
||||
from letta.llm_api.anthropic_client import AnthropicClient
|
||||
from letta.llm_api.google_ai_client import GoogleAIClient
|
||||
from letta.llm_api.google_vertex_client import GoogleVertexClient
|
||||
from letta.services.context_window_calculator.token_counter import (
|
||||
AnthropicTokenCounter,
|
||||
GeminiTokenCounter,
|
||||
TiktokenCounter,
|
||||
)
|
||||
|
||||
if llm_config.model_endpoint_type == "anthropic":
|
||||
token_counter = AnthropicTokenCounter(AnthropicClient(), llm_config.model)
|
||||
elif llm_config.model_endpoint_type in ("google_vertex", "google_ai"):
|
||||
client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient()
|
||||
token_counter = GeminiTokenCounter(client, llm_config.model)
|
||||
else:
|
||||
token_counter = TiktokenCounter(llm_config.model)
|
||||
|
||||
token_count = await token_counter.count_tool_tokens([])
|
||||
assert token_count == 0
|
||||
Reference in New Issue
Block a user