feat: add tests for prompt caching + fix anthropic prompt caching [LET-6373] (#6454)

* feat: add tests for prompt caching

* fix: add cache control breakpoints for anthropic + fix tests

* fix: silence logging

* fix: patch token counting error

* fix: same patch on non-streaming path
This commit is contained in:
Charles Packer
2025-11-29 17:44:18 -08:00
committed by Caren Thomas
parent e862bae524
commit e67c98eedb
3 changed files with 1168 additions and 5 deletions

View File

@@ -181,11 +181,16 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
elif hasattr(self.interface, "thinking_tokens") and self.interface.thinking_tokens is not None:
reasoning_tokens = self.interface.thinking_tokens
# Per Anthropic docs: "Total input tokens in a request is the summation of
# input_tokens, cache_creation_input_tokens, and cache_read_input_tokens."
# We need actual total for context window limit checks (summarization trigger).
actual_input_tokens = (input_tokens or 0) + (cached_input_tokens or 0) + (cache_write_tokens or 0)
self.usage = LettaUsageStatistics(
step_count=1,
completion_tokens=output_tokens or 0,
prompt_tokens=input_tokens or 0,
total_tokens=(input_tokens or 0) + (output_tokens or 0),
total_tokens=actual_input_tokens + (output_tokens or 0),
cached_input_tokens=cached_input_tokens,
cache_write_tokens=cache_write_tokens,
reasoning_tokens=reasoning_tokens,

View File

@@ -381,6 +381,9 @@ class AnthropicClient(LLMClientBase):
if tools_for_request and len(tools_for_request) > 0:
# TODO eventually enable parallel tool use
data["tools"] = convert_tools_to_anthropic_format(tools_for_request)
# Add cache control to the last tool for caching tool definitions
if len(data["tools"]) > 0:
data["tools"][-1]["cache_control"] = {"type": "ephemeral"}
# Messages
inner_thoughts_xml_tag = "thinking"
@@ -429,6 +432,22 @@ class AnthropicClient(LLMClientBase):
# produce multiple tool_result blocks with the same id; consolidate them here.
data["messages"] = dedupe_tool_results_in_user_messages(data["messages"])
# Add cache control to final message for incremental conversation caching
# Per Anthropic docs: "During each turn, we mark the final block of the final message with
# cache_control so the conversation can be incrementally cached."
data["messages"] = self._add_cache_control_to_messages(data["messages"])
# Debug: Log cache control placement
logger.debug(f"Anthropic request has {len(data.get('messages', []))} messages")
if data.get("messages") and len(data["messages"]) > 0:
last_msg = data["messages"][-1]
logger.debug(f"Last message role: {last_msg.get('role')}, content type: {type(last_msg.get('content'))}")
if isinstance(last_msg.get("content"), list) and len(last_msg["content"]) > 0:
last_block = last_msg["content"][-1]
logger.debug(f"Last content block type: {last_block.get('type')}, has cache_control: {'cache_control' in last_block}")
if "cache_control" in last_block:
logger.debug(f"Cache control value: {last_block['cache_control']}")
# Prefix fill
# https://docs.anthropic.com/en/api/messages#body-messages
# NOTE: cannot prefill with tools for opus:
@@ -850,14 +869,22 @@ class AnthropicClient(LLMClientBase):
# Build prompt tokens details with cache data if available
prompt_tokens_details = None
cache_read_tokens = 0
cache_creation_tokens = 0
if hasattr(response.usage, "cache_read_input_tokens") or hasattr(response.usage, "cache_creation_input_tokens"):
from letta.schemas.openai.chat_completion_response import UsageStatisticsPromptTokenDetails
cache_read_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
cache_creation_tokens = getattr(response.usage, "cache_creation_input_tokens", 0) or 0
prompt_tokens_details = UsageStatisticsPromptTokenDetails(
cache_read_tokens=getattr(response.usage, "cache_read_input_tokens", 0) or 0,
cache_creation_tokens=getattr(response.usage, "cache_creation_input_tokens", 0) or 0,
cache_read_tokens=cache_read_tokens,
cache_creation_tokens=cache_creation_tokens,
)
# Per Anthropic docs: "Total input tokens in a request is the summation of
# input_tokens, cache_creation_input_tokens, and cache_read_input_tokens."
actual_input_tokens = prompt_tokens + cache_read_tokens + cache_creation_tokens
chat_completion_response = ChatCompletionResponse(
id=response.id,
choices=[choice],
@@ -866,7 +893,7 @@ class AnthropicClient(LLMClientBase):
usage=UsageStatistics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
total_tokens=actual_input_tokens + completion_tokens,
prompt_tokens_details=prompt_tokens_details,
),
)
@@ -878,7 +905,7 @@ class AnthropicClient(LLMClientBase):
return chat_completion_response
def _add_cache_control_to_system_message(self, system_content):
"""Add cache control to system message content"""
"""Add cache control to system message content."""
if isinstance(system_content, str):
# For string content, convert to list format with cache control
return [{"type": "text", "text": system_content, "cache_control": {"type": "ephemeral"}}]
@@ -893,6 +920,44 @@ class AnthropicClient(LLMClientBase):
return system_content
def _add_cache_control_to_messages(self, messages):
"""
Add cache control to the final content block of the final message.
This enables incremental conversation caching per Anthropic docs:
"During each turn, we mark the final block of the final message with cache_control
so the conversation can be incrementally cached."
Args:
messages: List of Anthropic-formatted message dicts
Returns:
Modified messages list with cache_control on final block
"""
if not messages or len(messages) == 0:
return messages
# Work backwards to find the last message with content
for i in range(len(messages) - 1, -1, -1):
message = messages[i]
content = message.get("content")
if not content:
continue
# Handle string content
if isinstance(content, str):
messages[i]["content"] = [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
return messages
# Handle list content - add cache_control to the last block
if isinstance(content, list) and len(content) > 0:
# Add cache_control to the last content block
messages[i]["content"][-1]["cache_control"] = {"type": "ephemeral"}
return messages
return messages
def convert_tools_to_anthropic_format(tools: List[OpenAITool]) -> List[dict]:
"""See: https://docs.anthropic.com/claude/docs/tool-use

1093
tests/test_prompt_caching.py Normal file

File diff suppressed because it is too large Load Diff