fix: patch counting of tokens for anthropic (#6458)
* fix: patch counting of tokens for anthropic * fix: patch ui to be simpler * fix: patch undercounting bug in anthropic when caching is on
This commit is contained in:
committed by
Caren Thomas
parent
c0b422c4c6
commit
1f7165afc4
@@ -23,7 +23,7 @@ import uuid
|
||||
|
||||
import pytest
|
||||
from letta_client import AsyncLetta
|
||||
from letta_client.types import MessageCreate
|
||||
from letta_client.types import MessageCreateParam
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -251,8 +251,7 @@ def base_url() -> str:
|
||||
@pytest.fixture
|
||||
async def async_client(base_url: str) -> AsyncLetta:
|
||||
"""Create an async Letta client."""
|
||||
token = os.getenv("LETTA_SERVER_TOKEN")
|
||||
return AsyncLetta(base_url=base_url, token=token)
|
||||
return AsyncLetta(base_url=base_url)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
@@ -269,7 +268,7 @@ async def create_agent_with_large_memory(client: AsyncLetta, model: str, model_s
|
||||
|
||||
If tests fail, that reveals actual caching issues with production configurations.
|
||||
"""
|
||||
from letta_client.types import CreateBlock
|
||||
from letta_client.types import CreateBlockParam
|
||||
|
||||
# Clean suffix to avoid invalid characters (e.g., dots in model names)
|
||||
clean_suffix = suffix.replace(".", "-").replace("/", "-")
|
||||
@@ -278,7 +277,7 @@ async def create_agent_with_large_memory(client: AsyncLetta, model: str, model_s
|
||||
model=model,
|
||||
embedding="openai/text-embedding-3-small",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
CreateBlockParam(
|
||||
label="persona",
|
||||
value=LARGE_MEMORY_BLOCK,
|
||||
)
|
||||
@@ -299,6 +298,52 @@ async def cleanup_agent(client: AsyncLetta, agent_id: str):
|
||||
logger.warning(f"Failed to cleanup agent {agent_id}: {e}")
|
||||
|
||||
|
||||
def assert_usage_sanity(usage, context: str = ""):
|
||||
"""
|
||||
Sanity checks for usage data to catch obviously wrong values.
|
||||
|
||||
These catch bugs like:
|
||||
- output_tokens=1 (impossible for real responses)
|
||||
- Cumulative values being accumulated instead of assigned
|
||||
- Token counts exceeding model limits
|
||||
"""
|
||||
prefix = f"[{context}] " if context else ""
|
||||
|
||||
# Basic existence checks
|
||||
assert usage is not None, f"{prefix}Usage should not be None"
|
||||
|
||||
# Prompt tokens sanity
|
||||
if usage.prompt_tokens is not None:
|
||||
assert usage.prompt_tokens > 0, f"{prefix}prompt_tokens should be > 0, got {usage.prompt_tokens}"
|
||||
assert usage.prompt_tokens < 500000, f"{prefix}prompt_tokens unreasonably high: {usage.prompt_tokens}"
|
||||
|
||||
# Completion tokens sanity - a real response should have more than 1 token
|
||||
if usage.completion_tokens is not None:
|
||||
assert usage.completion_tokens > 1, (
|
||||
f"{prefix}completion_tokens={usage.completion_tokens} is suspiciously low. "
|
||||
"A real response should have > 1 output token. This may indicate a usage tracking bug."
|
||||
)
|
||||
assert usage.completion_tokens < 50000, (
|
||||
f"{prefix}completion_tokens={usage.completion_tokens} unreasonably high. "
|
||||
"This may indicate cumulative values being accumulated instead of assigned."
|
||||
)
|
||||
|
||||
# Cache tokens sanity (if present)
|
||||
if usage.cache_write_tokens is not None and usage.cache_write_tokens > 0:
|
||||
# Cache write shouldn't exceed total input
|
||||
total_input = (usage.prompt_tokens or 0) + (usage.cache_write_tokens or 0) + (usage.cached_input_tokens or 0)
|
||||
assert usage.cache_write_tokens <= total_input, (
|
||||
f"{prefix}cache_write_tokens ({usage.cache_write_tokens}) > total input ({total_input})"
|
||||
)
|
||||
|
||||
if usage.cached_input_tokens is not None and usage.cached_input_tokens > 0:
|
||||
# Cached input shouldn't exceed prompt tokens + cached
|
||||
total_input = (usage.prompt_tokens or 0) + (usage.cached_input_tokens or 0)
|
||||
assert usage.cached_input_tokens <= total_input, (
|
||||
f"{prefix}cached_input_tokens ({usage.cached_input_tokens}) exceeds reasonable bounds"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Prompt Caching Validation Tests
|
||||
# ------------------------------
|
||||
@@ -340,10 +385,11 @@ async def test_prompt_caching_cache_write_then_read(
|
||||
# Message 1: First interaction should trigger cache WRITE
|
||||
response1 = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="Hello! Please introduce yourself.")],
|
||||
messages=[MessageCreateParam(role="user", content="Hello! Please introduce yourself.")],
|
||||
)
|
||||
|
||||
assert response1.usage is not None, "First message should have usage data"
|
||||
assert_usage_sanity(response1.usage, f"{model} msg1")
|
||||
logger.info(
|
||||
f"[{model}] Message 1 usage: "
|
||||
f"prompt={response1.usage.prompt_tokens}, "
|
||||
@@ -371,10 +417,11 @@ async def test_prompt_caching_cache_write_then_read(
|
||||
# Message 2: Follow-up with same agent/context should trigger cache HIT
|
||||
response2 = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="What are your main areas of expertise?")],
|
||||
messages=[MessageCreateParam(role="user", content="What are your main areas of expertise?")],
|
||||
)
|
||||
|
||||
assert response2.usage is not None, "Second message should have usage data"
|
||||
assert_usage_sanity(response2.usage, f"{model} msg2")
|
||||
logger.info(
|
||||
f"[{model}] Message 2 usage: "
|
||||
f"prompt={response2.usage.prompt_tokens}, "
|
||||
@@ -444,7 +491,7 @@ async def test_prompt_caching_multiple_messages(
|
||||
for i, message in enumerate(messages_to_send):
|
||||
response = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content=message)],
|
||||
messages=[MessageCreateParam(role="user", content=message)],
|
||||
)
|
||||
responses.append(response)
|
||||
|
||||
@@ -497,13 +544,13 @@ async def test_prompt_caching_cache_invalidation_on_memory_update(
|
||||
# Message 1: Establish cache
|
||||
response1 = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="Hello!")],
|
||||
messages=[MessageCreateParam(role="user", content="Hello!")],
|
||||
)
|
||||
|
||||
# Message 2: Should hit cache
|
||||
response2 = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="How are you?")],
|
||||
messages=[MessageCreateParam(role="user", content="How are you?")],
|
||||
)
|
||||
|
||||
read_tokens_before_update = response2.usage.cached_input_tokens if response2.usage else None
|
||||
@@ -526,7 +573,7 @@ async def test_prompt_caching_cache_invalidation_on_memory_update(
|
||||
# Message 3: After memory update, cache should MISS (then create new cache)
|
||||
response3 = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="What changed?")],
|
||||
messages=[MessageCreateParam(role="user", content="What changed?")],
|
||||
)
|
||||
|
||||
# After memory update, we expect cache miss (low or zero cache hits)
|
||||
@@ -572,13 +619,13 @@ async def test_anthropic_system_prompt_stability(async_client: AsyncLetta):
|
||||
# Send message 1
|
||||
response1 = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="Hello!")],
|
||||
messages=[MessageCreateParam(role="user", content="Hello!")],
|
||||
)
|
||||
|
||||
# Send message 2
|
||||
response2 = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="Follow up!")],
|
||||
messages=[MessageCreateParam(role="user", content="Follow up!")],
|
||||
)
|
||||
|
||||
# Get provider traces from ACTUAL requests sent to Anthropic
|
||||
@@ -660,7 +707,7 @@ async def test_anthropic_inspect_raw_request(async_client: AsyncLetta):
|
||||
# Message 1
|
||||
response1 = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="Hello!")],
|
||||
messages=[MessageCreateParam(role="user", content="Hello!")],
|
||||
)
|
||||
|
||||
# Get step_id from message 1
|
||||
@@ -686,7 +733,7 @@ async def test_anthropic_inspect_raw_request(async_client: AsyncLetta):
|
||||
# Message 2 - this should hit the cache
|
||||
response2 = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="Follow up!")],
|
||||
messages=[MessageCreateParam(role="user", content="Follow up!")],
|
||||
)
|
||||
|
||||
# Get step_id from message 2
|
||||
@@ -744,7 +791,7 @@ async def test_anthropic_cache_control_breakpoints(async_client: AsyncLetta):
|
||||
# First message
|
||||
response1 = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="Hello!")],
|
||||
messages=[MessageCreateParam(role="user", content="Hello!")],
|
||||
)
|
||||
|
||||
assert response1.usage is not None, "Should have usage data"
|
||||
@@ -768,7 +815,7 @@ async def test_anthropic_cache_control_breakpoints(async_client: AsyncLetta):
|
||||
for i, msg in enumerate(follow_up_messages):
|
||||
response = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content=msg)],
|
||||
messages=[MessageCreateParam(role="user", content=msg)],
|
||||
)
|
||||
cache_read = response.usage.cached_input_tokens if response.usage else 0
|
||||
cached_token_counts.append(cache_read)
|
||||
@@ -805,7 +852,7 @@ async def test_openai_automatic_caching(async_client: AsyncLetta):
|
||||
# First message
|
||||
response1 = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="Hello!")],
|
||||
messages=[MessageCreateParam(role="user", content="Hello!")],
|
||||
)
|
||||
|
||||
# OpenAI doesn't charge for cache writes, so cached_input_tokens should be 0 or None on first message
|
||||
@@ -815,7 +862,7 @@ async def test_openai_automatic_caching(async_client: AsyncLetta):
|
||||
# Second message should show cached_input_tokens > 0
|
||||
response2 = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="What can you help with?")],
|
||||
messages=[MessageCreateParam(role="user", content="What can you help with?")],
|
||||
)
|
||||
|
||||
cached_tokens_2 = response2.usage.cached_input_tokens if response2.usage else None
|
||||
@@ -846,7 +893,7 @@ async def test_gemini_2_5_flash_implicit_caching(async_client: AsyncLetta):
|
||||
# First message
|
||||
response1 = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="Hello!")],
|
||||
messages=[MessageCreateParam(role="user", content="Hello!")],
|
||||
)
|
||||
|
||||
logger.info(f"[Gemini 2.5 Flash] First message prompt_tokens: {response1.usage.prompt_tokens if response1.usage else 'N/A'}")
|
||||
@@ -854,7 +901,7 @@ async def test_gemini_2_5_flash_implicit_caching(async_client: AsyncLetta):
|
||||
# Second message should show implicit cache hit
|
||||
response2 = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="What are your capabilities?")],
|
||||
messages=[MessageCreateParam(role="user", content="What are your capabilities?")],
|
||||
)
|
||||
|
||||
# For Gemini, cached_input_tokens comes from cached_content_token_count (normalized in backend)
|
||||
@@ -886,7 +933,7 @@ async def test_gemini_3_pro_preview_implicit_caching(async_client: AsyncLetta):
|
||||
# First message establishes the prompt
|
||||
response1 = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="Hello!")],
|
||||
messages=[MessageCreateParam(role="user", content="Hello!")],
|
||||
)
|
||||
|
||||
logger.info(f"[Gemini 3 Pro] First message prompt_tokens: {response1.usage.prompt_tokens if response1.usage else 'N/A'}")
|
||||
@@ -902,7 +949,7 @@ async def test_gemini_3_pro_preview_implicit_caching(async_client: AsyncLetta):
|
||||
for i, msg in enumerate(follow_up_messages):
|
||||
response = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content=msg)],
|
||||
messages=[MessageCreateParam(role="user", content=msg)],
|
||||
)
|
||||
cached_tokens = response.usage.cached_input_tokens if response.usage else 0
|
||||
cached_token_counts.append(cached_tokens)
|
||||
@@ -948,13 +995,13 @@ async def test_gemini_request_prefix_stability(async_client: AsyncLetta):
|
||||
# Send message 1
|
||||
response1 = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="Hello!")],
|
||||
messages=[MessageCreateParam(role="user", content="Hello!")],
|
||||
)
|
||||
|
||||
# Send message 2
|
||||
response2 = await async_client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="Follow up!")],
|
||||
messages=[MessageCreateParam(role="user", content="Follow up!")],
|
||||
)
|
||||
|
||||
# Get provider traces from ACTUAL requests sent to Gemini
|
||||
|
||||
Reference in New Issue
Block a user