diff --git a/letta/adapters/letta_llm_stream_adapter.py b/letta/adapters/letta_llm_stream_adapter.py index 23800b23..1ae25f21 100644 --- a/letta/adapters/letta_llm_stream_adapter.py +++ b/letta/adapters/letta_llm_stream_adapter.py @@ -116,6 +116,10 @@ class LettaLLMStreamAdapter(LettaLLMAdapter): if not output_tokens and hasattr(self.interface, "fallback_output_tokens"): output_tokens = self.interface.fallback_output_tokens + # NOTE: For Anthropic, input_tokens is NON-cached only, so total_tokens here + # undercounts the actual total (missing cache_read + cache_creation tokens). + # For OpenAI/Gemini, input_tokens is already the total, so this is correct. + # See simple_llm_stream_adapter.py for the proper provider-aware calculation. self.usage = LettaUsageStatistics( step_count=1, completion_tokens=output_tokens or 0, diff --git a/letta/adapters/simple_llm_stream_adapter.py b/letta/adapters/simple_llm_stream_adapter.py index 213b115c..b4155fb3 100644 --- a/letta/adapters/simple_llm_stream_adapter.py +++ b/letta/adapters/simple_llm_stream_adapter.py @@ -181,10 +181,19 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter): elif hasattr(self.interface, "thinking_tokens") and self.interface.thinking_tokens is not None: reasoning_tokens = self.interface.thinking_tokens - # Per Anthropic docs: "Total input tokens in a request is the summation of - # input_tokens, cache_creation_input_tokens, and cache_read_input_tokens." - # We need actual total for context window limit checks (summarization trigger). - actual_input_tokens = (input_tokens or 0) + (cached_input_tokens or 0) + (cache_write_tokens or 0) + # Calculate actual total input tokens for context window limit checks (summarization trigger). + # + # ANTHROPIC: input_tokens is NON-cached only, must add cache tokens + # Total = input_tokens + cache_read_input_tokens + cache_creation_input_tokens + # + # OPENAI/GEMINI: input_tokens (prompt_tokens/prompt_token_count) is already TOTAL + # cached_tokens is a subset, NOT additive + # Total = input_tokens (don't add cached_tokens or it double-counts!) + is_anthropic = hasattr(self.interface, "cache_read_tokens") or hasattr(self.interface, "cache_creation_tokens") + if is_anthropic: + actual_input_tokens = (input_tokens or 0) + (cached_input_tokens or 0) + (cache_write_tokens or 0) + else: + actual_input_tokens = input_tokens or 0 self.usage = LettaUsageStatistics( step_count=1, diff --git a/letta/agents/letta_agent_v2.py b/letta/agents/letta_agent_v2.py index aa380393..7b2723b3 100644 --- a/letta/agents/letta_agent_v2.py +++ b/letta/agents/letta_agent_v2.py @@ -629,6 +629,7 @@ class LettaAgentV2(BaseAgentV2): self.should_continue = True self.stop_reason = None self.usage = LettaUsageStatistics() + self.last_step_usage: LettaUsageStatistics | None = None # Per-step usage for Step token details self.job_update_metadata = None self.last_function_response = None self.response_messages = [] @@ -856,30 +857,34 @@ class LettaAgentV2(BaseAgentV2): # Update step with actual usage now that we have it (if step was created) if logged_step: - # Build detailed token breakdowns from LettaUsageStatistics + # Use per-step usage for Step token details (not accumulated self.usage) + # Each Step should store its own per-step values, not accumulated totals + step_usage = self.last_step_usage if self.last_step_usage else self.usage + + # Build detailed token breakdowns from per-step LettaUsageStatistics # Use `is not None` to capture 0 values (meaning "provider reported 0 cached/reasoning tokens") # Only include fields that were actually reported by the provider prompt_details = None - if self.usage.cached_input_tokens is not None or self.usage.cache_write_tokens is not None: + if step_usage.cached_input_tokens is not None or step_usage.cache_write_tokens is not None: prompt_details = UsageStatisticsPromptTokenDetails( - cached_tokens=self.usage.cached_input_tokens if self.usage.cached_input_tokens is not None else None, - cache_read_tokens=self.usage.cached_input_tokens if self.usage.cached_input_tokens is not None else None, - cache_creation_tokens=self.usage.cache_write_tokens if self.usage.cache_write_tokens is not None else None, + cached_tokens=step_usage.cached_input_tokens if step_usage.cached_input_tokens is not None else None, + cache_read_tokens=step_usage.cached_input_tokens if step_usage.cached_input_tokens is not None else None, + cache_creation_tokens=step_usage.cache_write_tokens if step_usage.cache_write_tokens is not None else None, ) completion_details = None - if self.usage.reasoning_tokens is not None: + if step_usage.reasoning_tokens is not None: completion_details = UsageStatisticsCompletionTokenDetails( - reasoning_tokens=self.usage.reasoning_tokens, + reasoning_tokens=step_usage.reasoning_tokens, ) await self.step_manager.update_step_success_async( self.actor, step_metrics.id, UsageStatistics( - completion_tokens=self.usage.completion_tokens, - prompt_tokens=self.usage.prompt_tokens, - total_tokens=self.usage.total_tokens, + completion_tokens=step_usage.completion_tokens, + prompt_tokens=step_usage.prompt_tokens, + total_tokens=step_usage.total_tokens, prompt_tokens_details=prompt_details, completion_tokens_details=completion_details, ), @@ -888,6 +893,10 @@ class LettaAgentV2(BaseAgentV2): return StepProgression.FINISHED, step_metrics def _update_global_usage_stats(self, step_usage_stats: LettaUsageStatistics): + # Save per-step usage for Step token details (before accumulating) + self.last_step_usage = step_usage_stats + + # Accumulate into global usage self.usage.step_count += step_usage_stats.step_count self.usage.completion_tokens += step_usage_stats.completion_tokens self.usage.prompt_tokens += step_usage_stats.prompt_tokens diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index dc1bf19b..3af7fad0 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -69,7 +69,6 @@ class LettaAgentV3(LettaAgentV2): def _initialize_state(self): super()._initialize_state() self._require_tool_call = False - self.last_step_usage = None self.response_messages_for_metadata = [] # Separate accumulator for streaming job metadata def _compute_tool_return_truncation_chars(self) -> int: @@ -84,11 +83,6 @@ class LettaAgentV3(LettaAgentV2): cap = 5000 return max(5000, cap) - def _update_global_usage_stats(self, step_usage_stats: LettaUsageStatistics): - """Override to track per-step usage for context limit checks""" - self.last_step_usage = step_usage_stats - super()._update_global_usage_stats(step_usage_stats) - @trace_method async def step( self, diff --git a/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py b/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py index e35d8425..58d84517 100644 --- a/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py +++ b/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py @@ -485,11 +485,19 @@ class SimpleAnthropicStreamingInterface: self.raw_usage = None elif isinstance(event, BetaRawMessageDeltaEvent): - self.output_tokens += event.usage.output_tokens + # Per Anthropic docs: "The token counts shown in the usage field of the + # message_delta event are *cumulative*." So we assign, not accumulate. + self.output_tokens = event.usage.output_tokens elif isinstance(event, BetaRawMessageStopEvent): - # Don't do anything here! We don't want to stop the stream. - pass + # Update raw_usage with final accumulated values for accurate provider trace logging + if self.raw_usage: + self.raw_usage["input_tokens"] = self.input_tokens + self.raw_usage["output_tokens"] = self.output_tokens + if self.cache_read_tokens: + self.raw_usage["cache_read_input_tokens"] = self.cache_read_tokens + if self.cache_creation_tokens: + self.raw_usage["cache_creation_input_tokens"] = self.cache_creation_tokens elif isinstance(event, BetaRawContentBlockStopEvent): # Finalize the tool_use block at this index using accumulated deltas diff --git a/letta/interfaces/anthropic_streaming_interface.py b/letta/interfaces/anthropic_streaming_interface.py index 970f9e7a..a02d0edc 100644 --- a/letta/interfaces/anthropic_streaming_interface.py +++ b/letta/interfaces/anthropic_streaming_interface.py @@ -546,7 +546,9 @@ class AnthropicStreamingInterface: self.output_tokens += event.message.usage.output_tokens self.model = event.message.model elif isinstance(event, BetaRawMessageDeltaEvent): - self.output_tokens += event.usage.output_tokens + # Per Anthropic docs: "The token counts shown in the usage field of the + # message_delta event are *cumulative*." So we assign, not accumulate. + self.output_tokens = event.usage.output_tokens elif isinstance(event, BetaRawMessageStopEvent): # Don't do anything here! We don't want to stop the stream. pass @@ -941,7 +943,9 @@ class SimpleAnthropicStreamingInterface: self.model = event.message.model elif isinstance(event, BetaRawMessageDeltaEvent): - self.output_tokens += event.usage.output_tokens + # Per Anthropic docs: "The token counts shown in the usage field of the + # message_delta event are *cumulative*." So we assign, not accumulate. + self.output_tokens = event.usage.output_tokens elif isinstance(event, BetaRawMessageStopEvent): # Don't do anything here! We don't want to stop the stream. diff --git a/tests/test_prompt_caching.py b/tests/test_prompt_caching.py index 0ac587e5..9f5ac8ab 100644 --- a/tests/test_prompt_caching.py +++ b/tests/test_prompt_caching.py @@ -23,7 +23,7 @@ import uuid import pytest from letta_client import AsyncLetta -from letta_client.types import MessageCreate +from letta_client.types import MessageCreateParam logger = logging.getLogger(__name__) @@ -251,8 +251,7 @@ def base_url() -> str: @pytest.fixture async def async_client(base_url: str) -> AsyncLetta: """Create an async Letta client.""" - token = os.getenv("LETTA_SERVER_TOKEN") - return AsyncLetta(base_url=base_url, token=token) + return AsyncLetta(base_url=base_url) # ------------------------------ @@ -269,7 +268,7 @@ async def create_agent_with_large_memory(client: AsyncLetta, model: str, model_s If tests fail, that reveals actual caching issues with production configurations. """ - from letta_client.types import CreateBlock + from letta_client.types import CreateBlockParam # Clean suffix to avoid invalid characters (e.g., dots in model names) clean_suffix = suffix.replace(".", "-").replace("/", "-") @@ -278,7 +277,7 @@ async def create_agent_with_large_memory(client: AsyncLetta, model: str, model_s model=model, embedding="openai/text-embedding-3-small", memory_blocks=[ - CreateBlock( + CreateBlockParam( label="persona", value=LARGE_MEMORY_BLOCK, ) @@ -299,6 +298,52 @@ async def cleanup_agent(client: AsyncLetta, agent_id: str): logger.warning(f"Failed to cleanup agent {agent_id}: {e}") +def assert_usage_sanity(usage, context: str = ""): + """ + Sanity checks for usage data to catch obviously wrong values. + + These catch bugs like: + - output_tokens=1 (impossible for real responses) + - Cumulative values being accumulated instead of assigned + - Token counts exceeding model limits + """ + prefix = f"[{context}] " if context else "" + + # Basic existence checks + assert usage is not None, f"{prefix}Usage should not be None" + + # Prompt tokens sanity + if usage.prompt_tokens is not None: + assert usage.prompt_tokens > 0, f"{prefix}prompt_tokens should be > 0, got {usage.prompt_tokens}" + assert usage.prompt_tokens < 500000, f"{prefix}prompt_tokens unreasonably high: {usage.prompt_tokens}" + + # Completion tokens sanity - a real response should have more than 1 token + if usage.completion_tokens is not None: + assert usage.completion_tokens > 1, ( + f"{prefix}completion_tokens={usage.completion_tokens} is suspiciously low. " + "A real response should have > 1 output token. This may indicate a usage tracking bug." + ) + assert usage.completion_tokens < 50000, ( + f"{prefix}completion_tokens={usage.completion_tokens} unreasonably high. " + "This may indicate cumulative values being accumulated instead of assigned." + ) + + # Cache tokens sanity (if present) + if usage.cache_write_tokens is not None and usage.cache_write_tokens > 0: + # Cache write shouldn't exceed total input + total_input = (usage.prompt_tokens or 0) + (usage.cache_write_tokens or 0) + (usage.cached_input_tokens or 0) + assert usage.cache_write_tokens <= total_input, ( + f"{prefix}cache_write_tokens ({usage.cache_write_tokens}) > total input ({total_input})" + ) + + if usage.cached_input_tokens is not None and usage.cached_input_tokens > 0: + # Cached input shouldn't exceed prompt tokens + cached + total_input = (usage.prompt_tokens or 0) + (usage.cached_input_tokens or 0) + assert usage.cached_input_tokens <= total_input, ( + f"{prefix}cached_input_tokens ({usage.cached_input_tokens}) exceeds reasonable bounds" + ) + + # ------------------------------ # Prompt Caching Validation Tests # ------------------------------ @@ -340,10 +385,11 @@ async def test_prompt_caching_cache_write_then_read( # Message 1: First interaction should trigger cache WRITE response1 = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content="Hello! Please introduce yourself.")], + messages=[MessageCreateParam(role="user", content="Hello! Please introduce yourself.")], ) assert response1.usage is not None, "First message should have usage data" + assert_usage_sanity(response1.usage, f"{model} msg1") logger.info( f"[{model}] Message 1 usage: " f"prompt={response1.usage.prompt_tokens}, " @@ -371,10 +417,11 @@ async def test_prompt_caching_cache_write_then_read( # Message 2: Follow-up with same agent/context should trigger cache HIT response2 = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content="What are your main areas of expertise?")], + messages=[MessageCreateParam(role="user", content="What are your main areas of expertise?")], ) assert response2.usage is not None, "Second message should have usage data" + assert_usage_sanity(response2.usage, f"{model} msg2") logger.info( f"[{model}] Message 2 usage: " f"prompt={response2.usage.prompt_tokens}, " @@ -444,7 +491,7 @@ async def test_prompt_caching_multiple_messages( for i, message in enumerate(messages_to_send): response = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content=message)], + messages=[MessageCreateParam(role="user", content=message)], ) responses.append(response) @@ -497,13 +544,13 @@ async def test_prompt_caching_cache_invalidation_on_memory_update( # Message 1: Establish cache response1 = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content="Hello!")], + messages=[MessageCreateParam(role="user", content="Hello!")], ) # Message 2: Should hit cache response2 = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content="How are you?")], + messages=[MessageCreateParam(role="user", content="How are you?")], ) read_tokens_before_update = response2.usage.cached_input_tokens if response2.usage else None @@ -526,7 +573,7 @@ async def test_prompt_caching_cache_invalidation_on_memory_update( # Message 3: After memory update, cache should MISS (then create new cache) response3 = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content="What changed?")], + messages=[MessageCreateParam(role="user", content="What changed?")], ) # After memory update, we expect cache miss (low or zero cache hits) @@ -572,13 +619,13 @@ async def test_anthropic_system_prompt_stability(async_client: AsyncLetta): # Send message 1 response1 = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content="Hello!")], + messages=[MessageCreateParam(role="user", content="Hello!")], ) # Send message 2 response2 = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content="Follow up!")], + messages=[MessageCreateParam(role="user", content="Follow up!")], ) # Get provider traces from ACTUAL requests sent to Anthropic @@ -660,7 +707,7 @@ async def test_anthropic_inspect_raw_request(async_client: AsyncLetta): # Message 1 response1 = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content="Hello!")], + messages=[MessageCreateParam(role="user", content="Hello!")], ) # Get step_id from message 1 @@ -686,7 +733,7 @@ async def test_anthropic_inspect_raw_request(async_client: AsyncLetta): # Message 2 - this should hit the cache response2 = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content="Follow up!")], + messages=[MessageCreateParam(role="user", content="Follow up!")], ) # Get step_id from message 2 @@ -744,7 +791,7 @@ async def test_anthropic_cache_control_breakpoints(async_client: AsyncLetta): # First message response1 = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content="Hello!")], + messages=[MessageCreateParam(role="user", content="Hello!")], ) assert response1.usage is not None, "Should have usage data" @@ -768,7 +815,7 @@ async def test_anthropic_cache_control_breakpoints(async_client: AsyncLetta): for i, msg in enumerate(follow_up_messages): response = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content=msg)], + messages=[MessageCreateParam(role="user", content=msg)], ) cache_read = response.usage.cached_input_tokens if response.usage else 0 cached_token_counts.append(cache_read) @@ -805,7 +852,7 @@ async def test_openai_automatic_caching(async_client: AsyncLetta): # First message response1 = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content="Hello!")], + messages=[MessageCreateParam(role="user", content="Hello!")], ) # OpenAI doesn't charge for cache writes, so cached_input_tokens should be 0 or None on first message @@ -815,7 +862,7 @@ async def test_openai_automatic_caching(async_client: AsyncLetta): # Second message should show cached_input_tokens > 0 response2 = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content="What can you help with?")], + messages=[MessageCreateParam(role="user", content="What can you help with?")], ) cached_tokens_2 = response2.usage.cached_input_tokens if response2.usage else None @@ -846,7 +893,7 @@ async def test_gemini_2_5_flash_implicit_caching(async_client: AsyncLetta): # First message response1 = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content="Hello!")], + messages=[MessageCreateParam(role="user", content="Hello!")], ) logger.info(f"[Gemini 2.5 Flash] First message prompt_tokens: {response1.usage.prompt_tokens if response1.usage else 'N/A'}") @@ -854,7 +901,7 @@ async def test_gemini_2_5_flash_implicit_caching(async_client: AsyncLetta): # Second message should show implicit cache hit response2 = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content="What are your capabilities?")], + messages=[MessageCreateParam(role="user", content="What are your capabilities?")], ) # For Gemini, cached_input_tokens comes from cached_content_token_count (normalized in backend) @@ -886,7 +933,7 @@ async def test_gemini_3_pro_preview_implicit_caching(async_client: AsyncLetta): # First message establishes the prompt response1 = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content="Hello!")], + messages=[MessageCreateParam(role="user", content="Hello!")], ) logger.info(f"[Gemini 3 Pro] First message prompt_tokens: {response1.usage.prompt_tokens if response1.usage else 'N/A'}") @@ -902,7 +949,7 @@ async def test_gemini_3_pro_preview_implicit_caching(async_client: AsyncLetta): for i, msg in enumerate(follow_up_messages): response = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content=msg)], + messages=[MessageCreateParam(role="user", content=msg)], ) cached_tokens = response.usage.cached_input_tokens if response.usage else 0 cached_token_counts.append(cached_tokens) @@ -948,13 +995,13 @@ async def test_gemini_request_prefix_stability(async_client: AsyncLetta): # Send message 1 response1 = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content="Hello!")], + messages=[MessageCreateParam(role="user", content="Hello!")], ) # Send message 2 response2 = await async_client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreate(role="user", content="Follow up!")], + messages=[MessageCreateParam(role="user", content="Follow up!")], ) # Get provider traces from ACTUAL requests sent to Gemini