feat: add tests for prompt caching + fix anthropic prompt caching [LET-6373] (#6454)
* feat: add tests for prompt caching * fix: add cache control breakpoints for anthropic + fix tests * fix: silence logging * fix: patch token counting error * fix: same patch on non-streaming path
This commit is contained in:
committed by
Caren Thomas
parent
e862bae524
commit
e67c98eedb
@@ -181,11 +181,16 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
elif hasattr(self.interface, "thinking_tokens") and self.interface.thinking_tokens is not None:
|
||||
reasoning_tokens = self.interface.thinking_tokens
|
||||
|
||||
# Per Anthropic docs: "Total input tokens in a request is the summation of
|
||||
# input_tokens, cache_creation_input_tokens, and cache_read_input_tokens."
|
||||
# We need actual total for context window limit checks (summarization trigger).
|
||||
actual_input_tokens = (input_tokens or 0) + (cached_input_tokens or 0) + (cache_write_tokens or 0)
|
||||
|
||||
self.usage = LettaUsageStatistics(
|
||||
step_count=1,
|
||||
completion_tokens=output_tokens or 0,
|
||||
prompt_tokens=input_tokens or 0,
|
||||
total_tokens=(input_tokens or 0) + (output_tokens or 0),
|
||||
total_tokens=actual_input_tokens + (output_tokens or 0),
|
||||
cached_input_tokens=cached_input_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
|
||||
@@ -381,6 +381,9 @@ class AnthropicClient(LLMClientBase):
|
||||
if tools_for_request and len(tools_for_request) > 0:
|
||||
# TODO eventually enable parallel tool use
|
||||
data["tools"] = convert_tools_to_anthropic_format(tools_for_request)
|
||||
# Add cache control to the last tool for caching tool definitions
|
||||
if len(data["tools"]) > 0:
|
||||
data["tools"][-1]["cache_control"] = {"type": "ephemeral"}
|
||||
|
||||
# Messages
|
||||
inner_thoughts_xml_tag = "thinking"
|
||||
@@ -429,6 +432,22 @@ class AnthropicClient(LLMClientBase):
|
||||
# produce multiple tool_result blocks with the same id; consolidate them here.
|
||||
data["messages"] = dedupe_tool_results_in_user_messages(data["messages"])
|
||||
|
||||
# Add cache control to final message for incremental conversation caching
|
||||
# Per Anthropic docs: "During each turn, we mark the final block of the final message with
|
||||
# cache_control so the conversation can be incrementally cached."
|
||||
data["messages"] = self._add_cache_control_to_messages(data["messages"])
|
||||
|
||||
# Debug: Log cache control placement
|
||||
logger.debug(f"Anthropic request has {len(data.get('messages', []))} messages")
|
||||
if data.get("messages") and len(data["messages"]) > 0:
|
||||
last_msg = data["messages"][-1]
|
||||
logger.debug(f"Last message role: {last_msg.get('role')}, content type: {type(last_msg.get('content'))}")
|
||||
if isinstance(last_msg.get("content"), list) and len(last_msg["content"]) > 0:
|
||||
last_block = last_msg["content"][-1]
|
||||
logger.debug(f"Last content block type: {last_block.get('type')}, has cache_control: {'cache_control' in last_block}")
|
||||
if "cache_control" in last_block:
|
||||
logger.debug(f"Cache control value: {last_block['cache_control']}")
|
||||
|
||||
# Prefix fill
|
||||
# https://docs.anthropic.com/en/api/messages#body-messages
|
||||
# NOTE: cannot prefill with tools for opus:
|
||||
@@ -850,14 +869,22 @@ class AnthropicClient(LLMClientBase):
|
||||
|
||||
# Build prompt tokens details with cache data if available
|
||||
prompt_tokens_details = None
|
||||
cache_read_tokens = 0
|
||||
cache_creation_tokens = 0
|
||||
if hasattr(response.usage, "cache_read_input_tokens") or hasattr(response.usage, "cache_creation_input_tokens"):
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatisticsPromptTokenDetails
|
||||
|
||||
cache_read_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
|
||||
cache_creation_tokens = getattr(response.usage, "cache_creation_input_tokens", 0) or 0
|
||||
prompt_tokens_details = UsageStatisticsPromptTokenDetails(
|
||||
cache_read_tokens=getattr(response.usage, "cache_read_input_tokens", 0) or 0,
|
||||
cache_creation_tokens=getattr(response.usage, "cache_creation_input_tokens", 0) or 0,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
)
|
||||
|
||||
# Per Anthropic docs: "Total input tokens in a request is the summation of
|
||||
# input_tokens, cache_creation_input_tokens, and cache_read_input_tokens."
|
||||
actual_input_tokens = prompt_tokens + cache_read_tokens + cache_creation_tokens
|
||||
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
id=response.id,
|
||||
choices=[choice],
|
||||
@@ -866,7 +893,7 @@ class AnthropicClient(LLMClientBase):
|
||||
usage=UsageStatistics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_tokens=actual_input_tokens + completion_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
),
|
||||
)
|
||||
@@ -878,7 +905,7 @@ class AnthropicClient(LLMClientBase):
|
||||
return chat_completion_response
|
||||
|
||||
def _add_cache_control_to_system_message(self, system_content):
|
||||
"""Add cache control to system message content"""
|
||||
"""Add cache control to system message content."""
|
||||
if isinstance(system_content, str):
|
||||
# For string content, convert to list format with cache control
|
||||
return [{"type": "text", "text": system_content, "cache_control": {"type": "ephemeral"}}]
|
||||
@@ -893,6 +920,44 @@ class AnthropicClient(LLMClientBase):
|
||||
|
||||
return system_content
|
||||
|
||||
def _add_cache_control_to_messages(self, messages):
|
||||
"""
|
||||
Add cache control to the final content block of the final message.
|
||||
|
||||
This enables incremental conversation caching per Anthropic docs:
|
||||
"During each turn, we mark the final block of the final message with cache_control
|
||||
so the conversation can be incrementally cached."
|
||||
|
||||
Args:
|
||||
messages: List of Anthropic-formatted message dicts
|
||||
|
||||
Returns:
|
||||
Modified messages list with cache_control on final block
|
||||
"""
|
||||
if not messages or len(messages) == 0:
|
||||
return messages
|
||||
|
||||
# Work backwards to find the last message with content
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
message = messages[i]
|
||||
content = message.get("content")
|
||||
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# Handle string content
|
||||
if isinstance(content, str):
|
||||
messages[i]["content"] = [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
|
||||
return messages
|
||||
|
||||
# Handle list content - add cache_control to the last block
|
||||
if isinstance(content, list) and len(content) > 0:
|
||||
# Add cache_control to the last content block
|
||||
messages[i]["content"][-1]["cache_control"] = {"type": "ephemeral"}
|
||||
return messages
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def convert_tools_to_anthropic_format(tools: List[OpenAITool]) -> List[dict]:
|
||||
"""See: https://docs.anthropic.com/claude/docs/tool-use
|
||||
|
||||
1093
tests/test_prompt_caching.py
Normal file
1093
tests/test_prompt_caching.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user