diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index 26b63f44..c5c7b60b 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -37,7 +37,7 @@ from letta.otel.tracing import trace_method from letta.schemas.agent import AgentType from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage -from letta.schemas.openai.chat_completion_request import Tool +from letta.schemas.openai.chat_completion_request import Tool, Tool as OpenAITool from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics from letta.settings import model_settings, settings from letta.utils import get_tool_call_id @@ -832,3 +832,54 @@ class GoogleVertexClient(LLMClientBase): # Fallback to base implementation for other errors return super().handle_llm_error(e) + + async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int: + """ + Count tokens for the given messages and tools using the Gemini token counting API. + + Args: + messages: List of message dicts in Google AI format (with 'role' and 'parts' keys) + model: The model to use for token counting (defaults to gemini-2.0-flash-lite) + tools: List of OpenAI-style Tool objects to include in the count + + Returns: + The total token count for the input + """ + from letta.llm_api.google_constants import GOOGLE_MODEL_FOR_API_KEY_CHECK + + client = self._get_client() + + # Default model for token counting if not specified + count_model = model or GOOGLE_MODEL_FOR_API_KEY_CHECK + + # Build the contents parameter + # If no messages provided, use empty string (like the API key check) + if messages is None or len(messages) == 0: + contents = "" + else: + # Messages should already be in Google format (role + parts) + contents = messages + + try: + # Count message tokens + result = await client.aio.models.count_tokens( + model=count_model, + contents=contents, + ) + total_tokens = result.total_tokens + + # Count tool tokens separately by serializing to text + # The Gemini count_tokens API doesn't support a tools parameter directly + if tools and len(tools) > 0: + # Serialize tools to JSON text and count those tokens + tools_text = json.dumps([t.model_dump() for t in tools]) + tools_result = await client.aio.models.count_tokens( + model=count_model, + contents=tools_text, + ) + total_tokens += tools_result.total_tokens + + except Exception as e: + raise self.handle_llm_error(e) + + return total_tokens diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 3ee289d9..1f9b64bf 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -1795,7 +1795,9 @@ class Message(BaseMessage): parts.append(function_call_part) else: - if not native_content: + # Only add single text_content if we don't have multiple content items + # (multi-content case is handled below at the len(self.content) > 1 block) + if not native_content and not (self.content and len(self.content) > 1): assert text_content is not None parts.append({"text": text_content}) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 3f2d0e2d..194afff5 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -80,7 +80,7 @@ from letta.server.db import db_registry from letta.services.archive_manager import ArchiveManager from letta.services.block_manager import BlockManager, validate_block_limit_constraint from letta.services.context_window_calculator.context_window_calculator import ContextWindowCalculator -from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, TiktokenCounter +from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, GeminiTokenCounter, TiktokenCounter from letta.services.file_processor.chunker.line_chunker import LineChunker from letta.services.files_agents_manager import FileAgentManager from letta.services.helpers.agent_manager_helper import ( @@ -3286,37 +3286,60 @@ class AgentManager: ) calculator = ContextWindowCalculator() + # Determine which token counter to use based on provider + model_endpoint_type = agent_state.llm_config.model_endpoint_type + + # Use Gemini token counter for Google Vertex and Google AI + use_gemini = model_endpoint_type in ("google_vertex", "google_ai") + # Use Anthropic token counter if: # 1. The model endpoint type is anthropic, OR - # 2. We're in PRODUCTION and anthropic_api_key is available - use_anthropic = agent_state.llm_config.model_endpoint_type == "anthropic" or ( - settings.environment == "PRODUCTION" and model_settings.anthropic_api_key is not None + # 2. We're in PRODUCTION and anthropic_api_key is available (and not using Gemini) + use_anthropic = model_endpoint_type == "anthropic" or ( + not use_gemini and settings.environment == "PRODUCTION" and model_settings.anthropic_api_key is not None ) - if use_anthropic: + if use_gemini: + # Use native Gemini token counting API + + client = LLMClient.create(provider_type=agent_state.llm_config.model_endpoint_type, actor=actor) + model = agent_state.llm_config.model + + token_counter = GeminiTokenCounter(client, model) + logger.info( + f"Using GeminiTokenCounter for agent_id={agent_id}, model={model}, " + f"model_endpoint_type={model_endpoint_type}, " + f"environment={settings.environment}" + ) + elif use_anthropic: anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor) - model = agent_state.llm_config.model if agent_state.llm_config.model_endpoint_type == "anthropic" else None + model = agent_state.llm_config.model if model_endpoint_type == "anthropic" else None token_counter = AnthropicTokenCounter(anthropic_client, model) # noqa logger.info( f"Using AnthropicTokenCounter for agent_id={agent_id}, model={model}, " - f"model_endpoint_type={agent_state.llm_config.model_endpoint_type}, " + f"model_endpoint_type={model_endpoint_type}, " f"environment={settings.environment}" ) else: token_counter = TiktokenCounter(agent_state.llm_config.model) logger.info( f"Using TiktokenCounter for agent_id={agent_id}, model={agent_state.llm_config.model}, " - f"model_endpoint_type={agent_state.llm_config.model_endpoint_type}, " + f"model_endpoint_type={model_endpoint_type}, " f"environment={settings.environment}" ) - return await calculator.calculate_context_window( - agent_state=agent_state, - actor=actor, - token_counter=token_counter, - message_manager=self.message_manager, - system_message_compiled=system_message, - num_archival_memories=num_archival_memories, - num_messages=num_messages, - ) + try: + result = await calculator.calculate_context_window( + agent_state=agent_state, + actor=actor, + token_counter=token_counter, + message_manager=self.message_manager, + system_message_compiled=system_message, + num_archival_memories=num_archival_memories, + num_messages=num_messages, + ) + except Exception as e: + raise e + + return result diff --git a/letta/services/context_window_calculator/token_counter.py b/letta/services/context_window_calculator/token_counter.py index 12833fda..33c9a70f 100644 --- a/letta/services/context_window_calculator/token_counter.py +++ b/letta/services/context_window_calculator/token_counter.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List from letta.helpers.decorators import async_redis_cache from letta.llm_api.anthropic_client import AnthropicClient +from letta.llm_api.google_vertex_client import GoogleVertexClient from letta.otel.tracing import trace_method from letta.schemas.message import Message from letta.schemas.openai.chat_completion_request import Tool as OpenAITool @@ -77,6 +78,54 @@ class AnthropicTokenCounter(TokenCounter): return Message.to_anthropic_dicts_from_list(messages, current_model=self.model) +class GeminiTokenCounter(TokenCounter): + """Token counter using Google's Gemini token counting API""" + + def __init__(self, gemini_client: GoogleVertexClient, model: str): + self.client = gemini_client + self.model = model + + @trace_method + @async_redis_cache( + key_func=lambda self, text: f"gemini_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}", + prefix="token_counter", + ttl_s=3600, # cache for 1 hour + ) + async def count_text_tokens(self, text: str) -> int: + if not text: + return 0 + # For text counting, wrap in a simple user message format for Google + return await self.client.count_tokens(model=self.model, messages=[{"role": "user", "parts": [{"text": text}]}]) + + @trace_method + @async_redis_cache( + key_func=lambda self, + messages: f"gemini_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}", + prefix="token_counter", + ttl_s=3600, # cache for 1 hour + ) + async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int: + if not messages: + return 0 + return await self.client.count_tokens(model=self.model, messages=messages) + + @trace_method + @async_redis_cache( + key_func=lambda self, + tools: f"gemini_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}", + prefix="token_counter", + ttl_s=3600, # cache for 1 hour + ) + async def count_tool_tokens(self, tools: List[OpenAITool]) -> int: + if not tools: + return 0 + return await self.client.count_tokens(model=self.model, tools=tools) + + def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]: + google_messages = Message.to_google_dicts_from_list(messages, current_model=self.model) + return google_messages + + class TiktokenCounter(TokenCounter): """Token counter using tiktoken""" diff --git a/tests/data/__pycache__/1_to_100.cpython-310.pyc b/tests/data/__pycache__/1_to_100.cpython-310.pyc deleted file mode 100644 index 431649a9..00000000 Binary files a/tests/data/__pycache__/1_to_100.cpython-310.pyc and /dev/null differ diff --git a/tests/data/__pycache__/data_analysis.cpython-310.pyc b/tests/data/__pycache__/data_analysis.cpython-310.pyc deleted file mode 100644 index 13c79ce1..00000000 Binary files a/tests/data/__pycache__/data_analysis.cpython-310.pyc and /dev/null differ diff --git a/tests/data/functions/__pycache__/dump_json.cpython-310.pyc b/tests/data/functions/__pycache__/dump_json.cpython-310.pyc deleted file mode 100644 index 1a3446d4..00000000 Binary files a/tests/data/functions/__pycache__/dump_json.cpython-310.pyc and /dev/null differ diff --git a/tests/integration_test_token_counters.py b/tests/integration_test_token_counters.py new file mode 100644 index 00000000..f105ad02 --- /dev/null +++ b/tests/integration_test_token_counters.py @@ -0,0 +1,235 @@ +""" +Integration tests for token counting APIs. + +These tests verify that the token counting implementations actually hit the real APIs +for Anthropic, Google Gemini, and OpenAI (tiktoken) by calling get_context_window +on an imported agent. +""" + +import json +import os + +import pytest + +from letta.config import LettaConfig +from letta.orm import Base +from letta.schemas.agent import UpdateAgent +from letta.schemas.agent_file import AgentFileSchema +from letta.schemas.llm_config import LLMConfig +from letta.schemas.organization import Organization +from letta.schemas.user import User +from letta.server.server import SyncServer + +# ============================================================================ +# LLM Configs to test +# ============================================================================ + + +def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig: + """Load LLM configuration from JSON file.""" + filename = os.path.join(llm_config_dir, filename) + with open(filename, "r") as f: + config_data = json.load(f) + return LLMConfig(**config_data) + + +LLM_CONFIG_FILES = [ + "openai-gpt-4o-mini.json", + "claude-4-5-sonnet.json", + "gemini-2.5-pro.json", +] + +LLM_CONFIGS = [pytest.param(get_llm_config(f), id=f.replace(".json", "")) for f in LLM_CONFIG_FILES] + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +async def _clear_tables(): + from letta.server.db import db_registry + + async with db_registry.async_session() as session: + for table in reversed(Base.metadata.sorted_tables): + await session.execute(table.delete()) + await session.commit() + + +@pytest.fixture(autouse=True) +async def clear_tables(): + await _clear_tables() + + +@pytest.fixture +async def server(): + config = LettaConfig.load() + config.save() + server = SyncServer(init_with_default_org_and_user=True) + await server.init_async() + await server.tool_manager.upsert_base_tools_async(actor=server.default_user) + yield server + + +@pytest.fixture +async def default_organization(server: SyncServer): + """Fixture to create and return the default organization.""" + org = await server.organization_manager.create_default_organization_async() + yield org + + +@pytest.fixture +async def default_user(server: SyncServer, default_organization): + """Fixture to create and return the default user within the default organization.""" + user = await server.user_manager.create_default_actor_async(org_id=default_organization.id) + yield user + + +@pytest.fixture +async def other_organization(server: SyncServer): + """Fixture to create and return another organization.""" + org = await server.organization_manager.create_organization_async(pydantic_org=Organization(name="test_org")) + yield org + + +@pytest.fixture +async def other_user(server: SyncServer, other_organization): + """Fixture to create and return another user within the other organization.""" + user = await server.user_manager.create_actor_async(pydantic_user=User(organization_id=other_organization.id, name="test_user")) + yield user + + +@pytest.fixture +async def imported_agent_id(server: SyncServer, other_user): + """Import the test agent from the .af file and return the agent ID.""" + file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", "test_agent.af") + + with open(file_path, "r") as f: + agent_file_json = json.load(f) + + agent_schema = AgentFileSchema.model_validate(agent_file_json) + + import_result = await server.agent_serialization_manager.import_file( + schema=agent_schema, + actor=other_user, + append_copy_suffix=False, + override_existing_tools=True, + ) + + assert import_result.success, f"Failed to import agent: {import_result.message}" + + # Get the imported agent ID + agent_id = next(db_id for file_id, db_id in import_result.id_mappings.items() if file_id.startswith("agent-")) + yield agent_id + + +# ============================================================================ +# Token Counter Integration Test +# ============================================================================ + + +@pytest.mark.asyncio +@pytest.mark.parametrize("llm_config", LLM_CONFIGS) +async def test_get_context_window(server: SyncServer, imported_agent_id: str, other_user, llm_config: LLMConfig): + """Test get_context_window with different LLM providers.""" + # Update the agent to use the specified LLM config + await server.agent_manager.update_agent_async( + agent_id=imported_agent_id, + agent_update=UpdateAgent(llm_config=llm_config), + actor=other_user, + ) + + # Call get_context_window which will use the appropriate token counting API + context_window = await server.agent_manager.get_context_window(agent_id=imported_agent_id, actor=other_user) + + # Verify we got valid token counts + assert context_window.context_window_size_current > 0 + assert context_window.num_tokens_system >= 0 + assert context_window.num_tokens_messages >= 0 + assert context_window.num_tokens_functions_definitions >= 0 + + print(f"{llm_config.model_endpoint_type} ({llm_config.model}) context window:") + print(f" Total tokens: {context_window.context_window_size_current}") + print(f" System tokens: {context_window.num_tokens_system}") + print(f" Message tokens: {context_window.num_tokens_messages}") + print(f" Function tokens: {context_window.num_tokens_functions_definitions}") + + +# ============================================================================ +# Edge Case Tests +# ============================================================================ + + +@pytest.mark.asyncio +@pytest.mark.parametrize("llm_config", LLM_CONFIGS) +async def test_count_empty_text_tokens(llm_config: LLMConfig): + """Test that empty text returns 0 tokens for all providers.""" + from letta.llm_api.anthropic_client import AnthropicClient + from letta.llm_api.google_ai_client import GoogleAIClient + from letta.llm_api.google_vertex_client import GoogleVertexClient + from letta.services.context_window_calculator.token_counter import ( + AnthropicTokenCounter, + GeminiTokenCounter, + TiktokenCounter, + ) + + if llm_config.model_endpoint_type == "anthropic": + token_counter = AnthropicTokenCounter(AnthropicClient(), llm_config.model) + elif llm_config.model_endpoint_type in ("google_vertex", "google_ai"): + client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient() + token_counter = GeminiTokenCounter(client, llm_config.model) + else: + token_counter = TiktokenCounter(llm_config.model) + + token_count = await token_counter.count_text_tokens("") + assert token_count == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("llm_config", LLM_CONFIGS) +async def test_count_empty_messages_tokens(llm_config: LLMConfig): + """Test that empty message list returns 0 tokens for all providers.""" + from letta.llm_api.anthropic_client import AnthropicClient + from letta.llm_api.google_ai_client import GoogleAIClient + from letta.llm_api.google_vertex_client import GoogleVertexClient + from letta.services.context_window_calculator.token_counter import ( + AnthropicTokenCounter, + GeminiTokenCounter, + TiktokenCounter, + ) + + if llm_config.model_endpoint_type == "anthropic": + token_counter = AnthropicTokenCounter(AnthropicClient(), llm_config.model) + elif llm_config.model_endpoint_type in ("google_vertex", "google_ai"): + client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient() + token_counter = GeminiTokenCounter(client, llm_config.model) + else: + token_counter = TiktokenCounter(llm_config.model) + + token_count = await token_counter.count_message_tokens([]) + assert token_count == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("llm_config", LLM_CONFIGS) +async def test_count_empty_tools_tokens(llm_config: LLMConfig): + """Test that empty tools list returns 0 tokens for all providers.""" + from letta.llm_api.anthropic_client import AnthropicClient + from letta.llm_api.google_ai_client import GoogleAIClient + from letta.llm_api.google_vertex_client import GoogleVertexClient + from letta.services.context_window_calculator.token_counter import ( + AnthropicTokenCounter, + GeminiTokenCounter, + TiktokenCounter, + ) + + if llm_config.model_endpoint_type == "anthropic": + token_counter = AnthropicTokenCounter(AnthropicClient(), llm_config.model) + elif llm_config.model_endpoint_type in ("google_vertex", "google_ai"): + client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient() + token_counter = GeminiTokenCounter(client, llm_config.model) + else: + token_counter = TiktokenCounter(llm_config.model) + + token_count = await token_counter.count_tool_tokens([]) + assert token_count == 0