fix: patch counting of tokens for anthropic (#6458)

* fix: patch counting of tokens for anthropic

* fix: patch ui to be simpler

* fix: patch undercounting bug in anthropic when caching is on
This commit is contained in:
Charles Packer
2025-11-29 21:08:13 -08:00
committed by Caren Thomas
parent c0b422c4c6
commit 1f7165afc4
7 changed files with 125 additions and 50 deletions

View File

@@ -116,6 +116,10 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
if not output_tokens and hasattr(self.interface, "fallback_output_tokens"):
output_tokens = self.interface.fallback_output_tokens
# NOTE: For Anthropic, input_tokens is NON-cached only, so total_tokens here
# undercounts the actual total (missing cache_read + cache_creation tokens).
# For OpenAI/Gemini, input_tokens is already the total, so this is correct.
# See simple_llm_stream_adapter.py for the proper provider-aware calculation.
self.usage = LettaUsageStatistics(
step_count=1,
completion_tokens=output_tokens or 0,

View File

@@ -181,10 +181,19 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
elif hasattr(self.interface, "thinking_tokens") and self.interface.thinking_tokens is not None:
reasoning_tokens = self.interface.thinking_tokens
# Per Anthropic docs: "Total input tokens in a request is the summation of
# input_tokens, cache_creation_input_tokens, and cache_read_input_tokens."
# We need actual total for context window limit checks (summarization trigger).
actual_input_tokens = (input_tokens or 0) + (cached_input_tokens or 0) + (cache_write_tokens or 0)
# Calculate actual total input tokens for context window limit checks (summarization trigger).
#
# ANTHROPIC: input_tokens is NON-cached only, must add cache tokens
# Total = input_tokens + cache_read_input_tokens + cache_creation_input_tokens
#
# OPENAI/GEMINI: input_tokens (prompt_tokens/prompt_token_count) is already TOTAL
# cached_tokens is a subset, NOT additive
# Total = input_tokens (don't add cached_tokens or it double-counts!)
is_anthropic = hasattr(self.interface, "cache_read_tokens") or hasattr(self.interface, "cache_creation_tokens")
if is_anthropic:
actual_input_tokens = (input_tokens or 0) + (cached_input_tokens or 0) + (cache_write_tokens or 0)
else:
actual_input_tokens = input_tokens or 0
self.usage = LettaUsageStatistics(
step_count=1,

View File

@@ -629,6 +629,7 @@ class LettaAgentV2(BaseAgentV2):
self.should_continue = True
self.stop_reason = None
self.usage = LettaUsageStatistics()
self.last_step_usage: LettaUsageStatistics | None = None # Per-step usage for Step token details
self.job_update_metadata = None
self.last_function_response = None
self.response_messages = []
@@ -856,30 +857,34 @@ class LettaAgentV2(BaseAgentV2):
# Update step with actual usage now that we have it (if step was created)
if logged_step:
# Build detailed token breakdowns from LettaUsageStatistics
# Use per-step usage for Step token details (not accumulated self.usage)
# Each Step should store its own per-step values, not accumulated totals
step_usage = self.last_step_usage if self.last_step_usage else self.usage
# Build detailed token breakdowns from per-step LettaUsageStatistics
# Use `is not None` to capture 0 values (meaning "provider reported 0 cached/reasoning tokens")
# Only include fields that were actually reported by the provider
prompt_details = None
if self.usage.cached_input_tokens is not None or self.usage.cache_write_tokens is not None:
if step_usage.cached_input_tokens is not None or step_usage.cache_write_tokens is not None:
prompt_details = UsageStatisticsPromptTokenDetails(
cached_tokens=self.usage.cached_input_tokens if self.usage.cached_input_tokens is not None else None,
cache_read_tokens=self.usage.cached_input_tokens if self.usage.cached_input_tokens is not None else None,
cache_creation_tokens=self.usage.cache_write_tokens if self.usage.cache_write_tokens is not None else None,
cached_tokens=step_usage.cached_input_tokens if step_usage.cached_input_tokens is not None else None,
cache_read_tokens=step_usage.cached_input_tokens if step_usage.cached_input_tokens is not None else None,
cache_creation_tokens=step_usage.cache_write_tokens if step_usage.cache_write_tokens is not None else None,
)
completion_details = None
if self.usage.reasoning_tokens is not None:
if step_usage.reasoning_tokens is not None:
completion_details = UsageStatisticsCompletionTokenDetails(
reasoning_tokens=self.usage.reasoning_tokens,
reasoning_tokens=step_usage.reasoning_tokens,
)
await self.step_manager.update_step_success_async(
self.actor,
step_metrics.id,
UsageStatistics(
completion_tokens=self.usage.completion_tokens,
prompt_tokens=self.usage.prompt_tokens,
total_tokens=self.usage.total_tokens,
completion_tokens=step_usage.completion_tokens,
prompt_tokens=step_usage.prompt_tokens,
total_tokens=step_usage.total_tokens,
prompt_tokens_details=prompt_details,
completion_tokens_details=completion_details,
),
@@ -888,6 +893,10 @@ class LettaAgentV2(BaseAgentV2):
return StepProgression.FINISHED, step_metrics
def _update_global_usage_stats(self, step_usage_stats: LettaUsageStatistics):
# Save per-step usage for Step token details (before accumulating)
self.last_step_usage = step_usage_stats
# Accumulate into global usage
self.usage.step_count += step_usage_stats.step_count
self.usage.completion_tokens += step_usage_stats.completion_tokens
self.usage.prompt_tokens += step_usage_stats.prompt_tokens

View File

@@ -69,7 +69,6 @@ class LettaAgentV3(LettaAgentV2):
def _initialize_state(self):
super()._initialize_state()
self._require_tool_call = False
self.last_step_usage = None
self.response_messages_for_metadata = [] # Separate accumulator for streaming job metadata
def _compute_tool_return_truncation_chars(self) -> int:
@@ -84,11 +83,6 @@ class LettaAgentV3(LettaAgentV2):
cap = 5000
return max(5000, cap)
def _update_global_usage_stats(self, step_usage_stats: LettaUsageStatistics):
"""Override to track per-step usage for context limit checks"""
self.last_step_usage = step_usage_stats
super()._update_global_usage_stats(step_usage_stats)
@trace_method
async def step(
self,

View File

@@ -485,11 +485,19 @@ class SimpleAnthropicStreamingInterface:
self.raw_usage = None
elif isinstance(event, BetaRawMessageDeltaEvent):
self.output_tokens += event.usage.output_tokens
# Per Anthropic docs: "The token counts shown in the usage field of the
# message_delta event are *cumulative*." So we assign, not accumulate.
self.output_tokens = event.usage.output_tokens
elif isinstance(event, BetaRawMessageStopEvent):
# Don't do anything here! We don't want to stop the stream.
pass
# Update raw_usage with final accumulated values for accurate provider trace logging
if self.raw_usage:
self.raw_usage["input_tokens"] = self.input_tokens
self.raw_usage["output_tokens"] = self.output_tokens
if self.cache_read_tokens:
self.raw_usage["cache_read_input_tokens"] = self.cache_read_tokens
if self.cache_creation_tokens:
self.raw_usage["cache_creation_input_tokens"] = self.cache_creation_tokens
elif isinstance(event, BetaRawContentBlockStopEvent):
# Finalize the tool_use block at this index using accumulated deltas

View File

@@ -546,7 +546,9 @@ class AnthropicStreamingInterface:
self.output_tokens += event.message.usage.output_tokens
self.model = event.message.model
elif isinstance(event, BetaRawMessageDeltaEvent):
self.output_tokens += event.usage.output_tokens
# Per Anthropic docs: "The token counts shown in the usage field of the
# message_delta event are *cumulative*." So we assign, not accumulate.
self.output_tokens = event.usage.output_tokens
elif isinstance(event, BetaRawMessageStopEvent):
# Don't do anything here! We don't want to stop the stream.
pass
@@ -941,7 +943,9 @@ class SimpleAnthropicStreamingInterface:
self.model = event.message.model
elif isinstance(event, BetaRawMessageDeltaEvent):
self.output_tokens += event.usage.output_tokens
# Per Anthropic docs: "The token counts shown in the usage field of the
# message_delta event are *cumulative*." So we assign, not accumulate.
self.output_tokens = event.usage.output_tokens
elif isinstance(event, BetaRawMessageStopEvent):
# Don't do anything here! We don't want to stop the stream.

View File

@@ -23,7 +23,7 @@ import uuid
import pytest
from letta_client import AsyncLetta
from letta_client.types import MessageCreate
from letta_client.types import MessageCreateParam
logger = logging.getLogger(__name__)
@@ -251,8 +251,7 @@ def base_url() -> str:
@pytest.fixture
async def async_client(base_url: str) -> AsyncLetta:
"""Create an async Letta client."""
token = os.getenv("LETTA_SERVER_TOKEN")
return AsyncLetta(base_url=base_url, token=token)
return AsyncLetta(base_url=base_url)
# ------------------------------
@@ -269,7 +268,7 @@ async def create_agent_with_large_memory(client: AsyncLetta, model: str, model_s
If tests fail, that reveals actual caching issues with production configurations.
"""
from letta_client.types import CreateBlock
from letta_client.types import CreateBlockParam
# Clean suffix to avoid invalid characters (e.g., dots in model names)
clean_suffix = suffix.replace(".", "-").replace("/", "-")
@@ -278,7 +277,7 @@ async def create_agent_with_large_memory(client: AsyncLetta, model: str, model_s
model=model,
embedding="openai/text-embedding-3-small",
memory_blocks=[
CreateBlock(
CreateBlockParam(
label="persona",
value=LARGE_MEMORY_BLOCK,
)
@@ -299,6 +298,52 @@ async def cleanup_agent(client: AsyncLetta, agent_id: str):
logger.warning(f"Failed to cleanup agent {agent_id}: {e}")
def assert_usage_sanity(usage, context: str = ""):
"""
Sanity checks for usage data to catch obviously wrong values.
These catch bugs like:
- output_tokens=1 (impossible for real responses)
- Cumulative values being accumulated instead of assigned
- Token counts exceeding model limits
"""
prefix = f"[{context}] " if context else ""
# Basic existence checks
assert usage is not None, f"{prefix}Usage should not be None"
# Prompt tokens sanity
if usage.prompt_tokens is not None:
assert usage.prompt_tokens > 0, f"{prefix}prompt_tokens should be > 0, got {usage.prompt_tokens}"
assert usage.prompt_tokens < 500000, f"{prefix}prompt_tokens unreasonably high: {usage.prompt_tokens}"
# Completion tokens sanity - a real response should have more than 1 token
if usage.completion_tokens is not None:
assert usage.completion_tokens > 1, (
f"{prefix}completion_tokens={usage.completion_tokens} is suspiciously low. "
"A real response should have > 1 output token. This may indicate a usage tracking bug."
)
assert usage.completion_tokens < 50000, (
f"{prefix}completion_tokens={usage.completion_tokens} unreasonably high. "
"This may indicate cumulative values being accumulated instead of assigned."
)
# Cache tokens sanity (if present)
if usage.cache_write_tokens is not None and usage.cache_write_tokens > 0:
# Cache write shouldn't exceed total input
total_input = (usage.prompt_tokens or 0) + (usage.cache_write_tokens or 0) + (usage.cached_input_tokens or 0)
assert usage.cache_write_tokens <= total_input, (
f"{prefix}cache_write_tokens ({usage.cache_write_tokens}) > total input ({total_input})"
)
if usage.cached_input_tokens is not None and usage.cached_input_tokens > 0:
# Cached input shouldn't exceed prompt tokens + cached
total_input = (usage.prompt_tokens or 0) + (usage.cached_input_tokens or 0)
assert usage.cached_input_tokens <= total_input, (
f"{prefix}cached_input_tokens ({usage.cached_input_tokens}) exceeds reasonable bounds"
)
# ------------------------------
# Prompt Caching Validation Tests
# ------------------------------
@@ -340,10 +385,11 @@ async def test_prompt_caching_cache_write_then_read(
# Message 1: First interaction should trigger cache WRITE
response1 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Hello! Please introduce yourself.")],
messages=[MessageCreateParam(role="user", content="Hello! Please introduce yourself.")],
)
assert response1.usage is not None, "First message should have usage data"
assert_usage_sanity(response1.usage, f"{model} msg1")
logger.info(
f"[{model}] Message 1 usage: "
f"prompt={response1.usage.prompt_tokens}, "
@@ -371,10 +417,11 @@ async def test_prompt_caching_cache_write_then_read(
# Message 2: Follow-up with same agent/context should trigger cache HIT
response2 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="What are your main areas of expertise?")],
messages=[MessageCreateParam(role="user", content="What are your main areas of expertise?")],
)
assert response2.usage is not None, "Second message should have usage data"
assert_usage_sanity(response2.usage, f"{model} msg2")
logger.info(
f"[{model}] Message 2 usage: "
f"prompt={response2.usage.prompt_tokens}, "
@@ -444,7 +491,7 @@ async def test_prompt_caching_multiple_messages(
for i, message in enumerate(messages_to_send):
response = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content=message)],
messages=[MessageCreateParam(role="user", content=message)],
)
responses.append(response)
@@ -497,13 +544,13 @@ async def test_prompt_caching_cache_invalidation_on_memory_update(
# Message 1: Establish cache
response1 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Hello!")],
messages=[MessageCreateParam(role="user", content="Hello!")],
)
# Message 2: Should hit cache
response2 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="How are you?")],
messages=[MessageCreateParam(role="user", content="How are you?")],
)
read_tokens_before_update = response2.usage.cached_input_tokens if response2.usage else None
@@ -526,7 +573,7 @@ async def test_prompt_caching_cache_invalidation_on_memory_update(
# Message 3: After memory update, cache should MISS (then create new cache)
response3 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="What changed?")],
messages=[MessageCreateParam(role="user", content="What changed?")],
)
# After memory update, we expect cache miss (low or zero cache hits)
@@ -572,13 +619,13 @@ async def test_anthropic_system_prompt_stability(async_client: AsyncLetta):
# Send message 1
response1 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Hello!")],
messages=[MessageCreateParam(role="user", content="Hello!")],
)
# Send message 2
response2 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Follow up!")],
messages=[MessageCreateParam(role="user", content="Follow up!")],
)
# Get provider traces from ACTUAL requests sent to Anthropic
@@ -660,7 +707,7 @@ async def test_anthropic_inspect_raw_request(async_client: AsyncLetta):
# Message 1
response1 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Hello!")],
messages=[MessageCreateParam(role="user", content="Hello!")],
)
# Get step_id from message 1
@@ -686,7 +733,7 @@ async def test_anthropic_inspect_raw_request(async_client: AsyncLetta):
# Message 2 - this should hit the cache
response2 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Follow up!")],
messages=[MessageCreateParam(role="user", content="Follow up!")],
)
# Get step_id from message 2
@@ -744,7 +791,7 @@ async def test_anthropic_cache_control_breakpoints(async_client: AsyncLetta):
# First message
response1 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Hello!")],
messages=[MessageCreateParam(role="user", content="Hello!")],
)
assert response1.usage is not None, "Should have usage data"
@@ -768,7 +815,7 @@ async def test_anthropic_cache_control_breakpoints(async_client: AsyncLetta):
for i, msg in enumerate(follow_up_messages):
response = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content=msg)],
messages=[MessageCreateParam(role="user", content=msg)],
)
cache_read = response.usage.cached_input_tokens if response.usage else 0
cached_token_counts.append(cache_read)
@@ -805,7 +852,7 @@ async def test_openai_automatic_caching(async_client: AsyncLetta):
# First message
response1 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Hello!")],
messages=[MessageCreateParam(role="user", content="Hello!")],
)
# OpenAI doesn't charge for cache writes, so cached_input_tokens should be 0 or None on first message
@@ -815,7 +862,7 @@ async def test_openai_automatic_caching(async_client: AsyncLetta):
# Second message should show cached_input_tokens > 0
response2 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="What can you help with?")],
messages=[MessageCreateParam(role="user", content="What can you help with?")],
)
cached_tokens_2 = response2.usage.cached_input_tokens if response2.usage else None
@@ -846,7 +893,7 @@ async def test_gemini_2_5_flash_implicit_caching(async_client: AsyncLetta):
# First message
response1 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Hello!")],
messages=[MessageCreateParam(role="user", content="Hello!")],
)
logger.info(f"[Gemini 2.5 Flash] First message prompt_tokens: {response1.usage.prompt_tokens if response1.usage else 'N/A'}")
@@ -854,7 +901,7 @@ async def test_gemini_2_5_flash_implicit_caching(async_client: AsyncLetta):
# Second message should show implicit cache hit
response2 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="What are your capabilities?")],
messages=[MessageCreateParam(role="user", content="What are your capabilities?")],
)
# For Gemini, cached_input_tokens comes from cached_content_token_count (normalized in backend)
@@ -886,7 +933,7 @@ async def test_gemini_3_pro_preview_implicit_caching(async_client: AsyncLetta):
# First message establishes the prompt
response1 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Hello!")],
messages=[MessageCreateParam(role="user", content="Hello!")],
)
logger.info(f"[Gemini 3 Pro] First message prompt_tokens: {response1.usage.prompt_tokens if response1.usage else 'N/A'}")
@@ -902,7 +949,7 @@ async def test_gemini_3_pro_preview_implicit_caching(async_client: AsyncLetta):
for i, msg in enumerate(follow_up_messages):
response = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content=msg)],
messages=[MessageCreateParam(role="user", content=msg)],
)
cached_tokens = response.usage.cached_input_tokens if response.usage else 0
cached_token_counts.append(cached_tokens)
@@ -948,13 +995,13 @@ async def test_gemini_request_prefix_stability(async_client: AsyncLetta):
# Send message 1
response1 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Hello!")],
messages=[MessageCreateParam(role="user", content="Hello!")],
)
# Send message 2
response2 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Follow up!")],
messages=[MessageCreateParam(role="user", content="Follow up!")],
)
# Get provider traces from ACTUAL requests sent to Gemini