fix: patch counting of tokens for anthropic (#6458)

* fix: patch counting of tokens for anthropic

* fix: patch ui to be simpler

* fix: patch undercounting bug in anthropic when caching is on
This commit is contained in:
Charles Packer
2025-11-29 21:08:13 -08:00
committed by Caren Thomas
parent c0b422c4c6
commit 1f7165afc4
7 changed files with 125 additions and 50 deletions

View File

@@ -23,7 +23,7 @@ import uuid
import pytest
from letta_client import AsyncLetta
from letta_client.types import MessageCreate
from letta_client.types import MessageCreateParam
logger = logging.getLogger(__name__)
@@ -251,8 +251,7 @@ def base_url() -> str:
@pytest.fixture
async def async_client(base_url: str) -> AsyncLetta:
"""Create an async Letta client."""
token = os.getenv("LETTA_SERVER_TOKEN")
return AsyncLetta(base_url=base_url, token=token)
return AsyncLetta(base_url=base_url)
# ------------------------------
@@ -269,7 +268,7 @@ async def create_agent_with_large_memory(client: AsyncLetta, model: str, model_s
If tests fail, that reveals actual caching issues with production configurations.
"""
from letta_client.types import CreateBlock
from letta_client.types import CreateBlockParam
# Clean suffix to avoid invalid characters (e.g., dots in model names)
clean_suffix = suffix.replace(".", "-").replace("/", "-")
@@ -278,7 +277,7 @@ async def create_agent_with_large_memory(client: AsyncLetta, model: str, model_s
model=model,
embedding="openai/text-embedding-3-small",
memory_blocks=[
CreateBlock(
CreateBlockParam(
label="persona",
value=LARGE_MEMORY_BLOCK,
)
@@ -299,6 +298,52 @@ async def cleanup_agent(client: AsyncLetta, agent_id: str):
logger.warning(f"Failed to cleanup agent {agent_id}: {e}")
def assert_usage_sanity(usage, context: str = ""):
"""
Sanity checks for usage data to catch obviously wrong values.
These catch bugs like:
- output_tokens=1 (impossible for real responses)
- Cumulative values being accumulated instead of assigned
- Token counts exceeding model limits
"""
prefix = f"[{context}] " if context else ""
# Basic existence checks
assert usage is not None, f"{prefix}Usage should not be None"
# Prompt tokens sanity
if usage.prompt_tokens is not None:
assert usage.prompt_tokens > 0, f"{prefix}prompt_tokens should be > 0, got {usage.prompt_tokens}"
assert usage.prompt_tokens < 500000, f"{prefix}prompt_tokens unreasonably high: {usage.prompt_tokens}"
# Completion tokens sanity - a real response should have more than 1 token
if usage.completion_tokens is not None:
assert usage.completion_tokens > 1, (
f"{prefix}completion_tokens={usage.completion_tokens} is suspiciously low. "
"A real response should have > 1 output token. This may indicate a usage tracking bug."
)
assert usage.completion_tokens < 50000, (
f"{prefix}completion_tokens={usage.completion_tokens} unreasonably high. "
"This may indicate cumulative values being accumulated instead of assigned."
)
# Cache tokens sanity (if present)
if usage.cache_write_tokens is not None and usage.cache_write_tokens > 0:
# Cache write shouldn't exceed total input
total_input = (usage.prompt_tokens or 0) + (usage.cache_write_tokens or 0) + (usage.cached_input_tokens or 0)
assert usage.cache_write_tokens <= total_input, (
f"{prefix}cache_write_tokens ({usage.cache_write_tokens}) > total input ({total_input})"
)
if usage.cached_input_tokens is not None and usage.cached_input_tokens > 0:
# Cached input shouldn't exceed prompt tokens + cached
total_input = (usage.prompt_tokens or 0) + (usage.cached_input_tokens or 0)
assert usage.cached_input_tokens <= total_input, (
f"{prefix}cached_input_tokens ({usage.cached_input_tokens}) exceeds reasonable bounds"
)
# ------------------------------
# Prompt Caching Validation Tests
# ------------------------------
@@ -340,10 +385,11 @@ async def test_prompt_caching_cache_write_then_read(
# Message 1: First interaction should trigger cache WRITE
response1 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Hello! Please introduce yourself.")],
messages=[MessageCreateParam(role="user", content="Hello! Please introduce yourself.")],
)
assert response1.usage is not None, "First message should have usage data"
assert_usage_sanity(response1.usage, f"{model} msg1")
logger.info(
f"[{model}] Message 1 usage: "
f"prompt={response1.usage.prompt_tokens}, "
@@ -371,10 +417,11 @@ async def test_prompt_caching_cache_write_then_read(
# Message 2: Follow-up with same agent/context should trigger cache HIT
response2 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="What are your main areas of expertise?")],
messages=[MessageCreateParam(role="user", content="What are your main areas of expertise?")],
)
assert response2.usage is not None, "Second message should have usage data"
assert_usage_sanity(response2.usage, f"{model} msg2")
logger.info(
f"[{model}] Message 2 usage: "
f"prompt={response2.usage.prompt_tokens}, "
@@ -444,7 +491,7 @@ async def test_prompt_caching_multiple_messages(
for i, message in enumerate(messages_to_send):
response = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content=message)],
messages=[MessageCreateParam(role="user", content=message)],
)
responses.append(response)
@@ -497,13 +544,13 @@ async def test_prompt_caching_cache_invalidation_on_memory_update(
# Message 1: Establish cache
response1 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Hello!")],
messages=[MessageCreateParam(role="user", content="Hello!")],
)
# Message 2: Should hit cache
response2 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="How are you?")],
messages=[MessageCreateParam(role="user", content="How are you?")],
)
read_tokens_before_update = response2.usage.cached_input_tokens if response2.usage else None
@@ -526,7 +573,7 @@ async def test_prompt_caching_cache_invalidation_on_memory_update(
# Message 3: After memory update, cache should MISS (then create new cache)
response3 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="What changed?")],
messages=[MessageCreateParam(role="user", content="What changed?")],
)
# After memory update, we expect cache miss (low or zero cache hits)
@@ -572,13 +619,13 @@ async def test_anthropic_system_prompt_stability(async_client: AsyncLetta):
# Send message 1
response1 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Hello!")],
messages=[MessageCreateParam(role="user", content="Hello!")],
)
# Send message 2
response2 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Follow up!")],
messages=[MessageCreateParam(role="user", content="Follow up!")],
)
# Get provider traces from ACTUAL requests sent to Anthropic
@@ -660,7 +707,7 @@ async def test_anthropic_inspect_raw_request(async_client: AsyncLetta):
# Message 1
response1 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Hello!")],
messages=[MessageCreateParam(role="user", content="Hello!")],
)
# Get step_id from message 1
@@ -686,7 +733,7 @@ async def test_anthropic_inspect_raw_request(async_client: AsyncLetta):
# Message 2 - this should hit the cache
response2 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Follow up!")],
messages=[MessageCreateParam(role="user", content="Follow up!")],
)
# Get step_id from message 2
@@ -744,7 +791,7 @@ async def test_anthropic_cache_control_breakpoints(async_client: AsyncLetta):
# First message
response1 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Hello!")],
messages=[MessageCreateParam(role="user", content="Hello!")],
)
assert response1.usage is not None, "Should have usage data"
@@ -768,7 +815,7 @@ async def test_anthropic_cache_control_breakpoints(async_client: AsyncLetta):
for i, msg in enumerate(follow_up_messages):
response = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content=msg)],
messages=[MessageCreateParam(role="user", content=msg)],
)
cache_read = response.usage.cached_input_tokens if response.usage else 0
cached_token_counts.append(cache_read)
@@ -805,7 +852,7 @@ async def test_openai_automatic_caching(async_client: AsyncLetta):
# First message
response1 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Hello!")],
messages=[MessageCreateParam(role="user", content="Hello!")],
)
# OpenAI doesn't charge for cache writes, so cached_input_tokens should be 0 or None on first message
@@ -815,7 +862,7 @@ async def test_openai_automatic_caching(async_client: AsyncLetta):
# Second message should show cached_input_tokens > 0
response2 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="What can you help with?")],
messages=[MessageCreateParam(role="user", content="What can you help with?")],
)
cached_tokens_2 = response2.usage.cached_input_tokens if response2.usage else None
@@ -846,7 +893,7 @@ async def test_gemini_2_5_flash_implicit_caching(async_client: AsyncLetta):
# First message
response1 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Hello!")],
messages=[MessageCreateParam(role="user", content="Hello!")],
)
logger.info(f"[Gemini 2.5 Flash] First message prompt_tokens: {response1.usage.prompt_tokens if response1.usage else 'N/A'}")
@@ -854,7 +901,7 @@ async def test_gemini_2_5_flash_implicit_caching(async_client: AsyncLetta):
# Second message should show implicit cache hit
response2 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="What are your capabilities?")],
messages=[MessageCreateParam(role="user", content="What are your capabilities?")],
)
# For Gemini, cached_input_tokens comes from cached_content_token_count (normalized in backend)
@@ -886,7 +933,7 @@ async def test_gemini_3_pro_preview_implicit_caching(async_client: AsyncLetta):
# First message establishes the prompt
response1 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Hello!")],
messages=[MessageCreateParam(role="user", content="Hello!")],
)
logger.info(f"[Gemini 3 Pro] First message prompt_tokens: {response1.usage.prompt_tokens if response1.usage else 'N/A'}")
@@ -902,7 +949,7 @@ async def test_gemini_3_pro_preview_implicit_caching(async_client: AsyncLetta):
for i, msg in enumerate(follow_up_messages):
response = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content=msg)],
messages=[MessageCreateParam(role="user", content=msg)],
)
cached_tokens = response.usage.cached_input_tokens if response.usage else 0
cached_token_counts.append(cached_tokens)
@@ -948,13 +995,13 @@ async def test_gemini_request_prefix_stability(async_client: AsyncLetta):
# Send message 1
response1 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Hello!")],
messages=[MessageCreateParam(role="user", content="Hello!")],
)
# Send message 2
response2 = await async_client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="Follow up!")],
messages=[MessageCreateParam(role="user", content="Follow up!")],
)
# Get provider traces from ACTUAL requests sent to Gemini