fix(core): distinguish between null and 0 for prompt caching (#6451)

* fix(core): distinguish between null and 0 for prompt caching

* fix: runtime errors

* fix: just publish just sgate
This commit is contained in:
Charles Packer
2025-11-29 00:09:43 -08:00
committed by Caren Thomas
parent 131891e05f
commit 88a3743cc8
10 changed files with 182 additions and 84 deletions

View File

@@ -30592,27 +30592,45 @@
"description": "The background task run IDs associated with the agent interaction"
},
"cached_input_tokens": {
"type": "integer",
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "Cached Input Tokens",
"description": "The number of input tokens served from cache.",
"default": 0
"description": "The number of input tokens served from cache. None if not reported by provider."
},
"cache_write_tokens": {
"type": "integer",
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "Cache Write Tokens",
"description": "The number of input tokens written to cache (Anthropic only).",
"default": 0
"description": "The number of input tokens written to cache (Anthropic only). None if not reported by provider."
},
"reasoning_tokens": {
"type": "integer",
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "Reasoning Tokens",
"description": "The number of reasoning/thinking tokens generated.",
"default": 0
"description": "The number of reasoning/thinking tokens generated. None if not reported by provider."
}
},
"type": "object",
"title": "LettaUsageStatistics",
"description": "Usage statistics for the agent interaction.\n\nAttributes:\n completion_tokens (int): The number of tokens generated by the agent.\n prompt_tokens (int): The number of tokens in the prompt.\n total_tokens (int): The total number of tokens processed by the agent.\n step_count (int): The number of steps taken by the agent.\n cached_input_tokens (int): The number of input tokens served from cache.\n cache_write_tokens (int): The number of input tokens written to cache (Anthropic only).\n reasoning_tokens (int): The number of reasoning/thinking tokens generated."
"description": "Usage statistics for the agent interaction.\n\nAttributes:\n completion_tokens (int): The number of tokens generated by the agent.\n prompt_tokens (int): The number of tokens in the prompt.\n total_tokens (int): The total number of tokens processed by the agent.\n step_count (int): The number of steps taken by the agent.\n cached_input_tokens (Optional[int]): The number of input tokens served from cache. None if not reported.\n cache_write_tokens (Optional[int]): The number of input tokens written to cache. None if not reported.\n reasoning_tokens (Optional[int]): The number of reasoning/thinking tokens generated. None if not reported."
},
"ListDeploymentEntitiesResponse": {
"properties": {
@@ -38195,9 +38213,15 @@
"UsageStatisticsCompletionTokenDetails": {
"properties": {
"reasoning_tokens": {
"type": "integer",
"title": "Reasoning Tokens",
"default": 0
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "Reasoning Tokens"
}
},
"type": "object",
@@ -38206,19 +38230,37 @@
"UsageStatisticsPromptTokenDetails": {
"properties": {
"cached_tokens": {
"type": "integer",
"title": "Cached Tokens",
"default": 0
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "Cached Tokens"
},
"cache_read_tokens": {
"type": "integer",
"title": "Cache Read Tokens",
"default": 0
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "Cache Read Tokens"
},
"cache_creation_tokens": {
"type": "integer",
"title": "Cache Creation Tokens",
"default": 0
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "Cache Creation Tokens"
}
},
"type": "object",

View File

@@ -159,23 +159,26 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
output_tokens = self.interface.fallback_output_tokens
# Extract cache token data (OpenAI/Gemini use cached_tokens)
cached_input_tokens = 0
if hasattr(self.interface, "cached_tokens") and self.interface.cached_tokens:
# None means provider didn't report, 0 means provider reported 0
cached_input_tokens = None
if hasattr(self.interface, "cached_tokens") and self.interface.cached_tokens is not None:
cached_input_tokens = self.interface.cached_tokens
# Anthropic uses cache_read_tokens for cache hits
elif hasattr(self.interface, "cache_read_tokens") and self.interface.cache_read_tokens:
elif hasattr(self.interface, "cache_read_tokens") and self.interface.cache_read_tokens is not None:
cached_input_tokens = self.interface.cache_read_tokens
# Extract cache write tokens (Anthropic only)
cache_write_tokens = 0
if hasattr(self.interface, "cache_creation_tokens") and self.interface.cache_creation_tokens:
# None means provider didn't report, 0 means provider reported 0
cache_write_tokens = None
if hasattr(self.interface, "cache_creation_tokens") and self.interface.cache_creation_tokens is not None:
cache_write_tokens = self.interface.cache_creation_tokens
# Extract reasoning tokens (OpenAI o1/o3 models use reasoning_tokens, Gemini uses thinking_tokens)
reasoning_tokens = 0
if hasattr(self.interface, "reasoning_tokens") and self.interface.reasoning_tokens:
# None means provider didn't report, 0 means provider reported 0
reasoning_tokens = None
if hasattr(self.interface, "reasoning_tokens") and self.interface.reasoning_tokens is not None:
reasoning_tokens = self.interface.reasoning_tokens
elif hasattr(self.interface, "thinking_tokens") and self.interface.thinking_tokens:
elif hasattr(self.interface, "thinking_tokens") and self.interface.thinking_tokens is not None:
reasoning_tokens = self.interface.thinking_tokens
self.usage = LettaUsageStatistics(

View File

@@ -1083,15 +1083,15 @@ class LettaAgent(BaseAgent):
usage.completion_tokens += interface.output_tokens
usage.prompt_tokens += interface.input_tokens
usage.total_tokens += interface.input_tokens + interface.output_tokens
# Aggregate cache and reasoning tokens if available from streaming interface
if hasattr(interface, "cached_tokens") and interface.cached_tokens:
usage.cached_input_tokens += interface.cached_tokens
if hasattr(interface, "cache_read_tokens") and interface.cache_read_tokens:
usage.cached_input_tokens += interface.cache_read_tokens
if hasattr(interface, "cache_creation_tokens") and interface.cache_creation_tokens:
usage.cache_write_tokens += interface.cache_creation_tokens
if hasattr(interface, "reasoning_tokens") and interface.reasoning_tokens:
usage.reasoning_tokens += interface.reasoning_tokens
# Aggregate cache and reasoning tokens if available from streaming interface (handle None defaults)
if hasattr(interface, "cached_tokens") and interface.cached_tokens is not None:
usage.cached_input_tokens = (usage.cached_input_tokens or 0) + interface.cached_tokens
if hasattr(interface, "cache_read_tokens") and interface.cache_read_tokens is not None:
usage.cached_input_tokens = (usage.cached_input_tokens or 0) + interface.cache_read_tokens
if hasattr(interface, "cache_creation_tokens") and interface.cache_creation_tokens is not None:
usage.cache_write_tokens = (usage.cache_write_tokens or 0) + interface.cache_creation_tokens
if hasattr(interface, "reasoning_tokens") and interface.reasoning_tokens is not None:
usage.reasoning_tokens = (usage.reasoning_tokens or 0) + interface.reasoning_tokens
MetricRegistry().message_output_tokens.record(
usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
)
@@ -1140,16 +1140,18 @@ class LettaAgent(BaseAgent):
# Update step with actual usage now that we have it (if step was created)
if logged_step:
# Build detailed token breakdowns from LettaUsageStatistics
# Use `is not None` to capture 0 values (meaning "provider reported 0 cached/reasoning tokens")
# Only include fields that were actually reported by the provider
prompt_details = None
if usage.cached_input_tokens or usage.cache_write_tokens:
if usage.cached_input_tokens is not None or usage.cache_write_tokens is not None:
prompt_details = UsageStatisticsPromptTokenDetails(
cached_tokens=usage.cached_input_tokens,
cache_read_tokens=usage.cached_input_tokens,
cache_creation_tokens=usage.cache_write_tokens,
cached_tokens=usage.cached_input_tokens if usage.cached_input_tokens is not None else None,
cache_read_tokens=usage.cached_input_tokens if usage.cached_input_tokens is not None else None,
cache_creation_tokens=usage.cache_write_tokens if usage.cache_write_tokens is not None else None,
)
completion_details = None
if usage.reasoning_tokens:
if usage.reasoning_tokens is not None:
completion_details = UsageStatisticsCompletionTokenDetails(
reasoning_tokens=usage.reasoning_tokens,
)

View File

@@ -857,16 +857,18 @@ class LettaAgentV2(BaseAgentV2):
# Update step with actual usage now that we have it (if step was created)
if logged_step:
# Build detailed token breakdowns from LettaUsageStatistics
# Use `is not None` to capture 0 values (meaning "provider reported 0 cached/reasoning tokens")
# Only include fields that were actually reported by the provider
prompt_details = None
if self.usage.cached_input_tokens or self.usage.cache_write_tokens:
if self.usage.cached_input_tokens is not None or self.usage.cache_write_tokens is not None:
prompt_details = UsageStatisticsPromptTokenDetails(
cached_tokens=self.usage.cached_input_tokens,
cache_read_tokens=self.usage.cached_input_tokens, # Normalized from various providers
cache_creation_tokens=self.usage.cache_write_tokens,
cached_tokens=self.usage.cached_input_tokens if self.usage.cached_input_tokens is not None else None,
cache_read_tokens=self.usage.cached_input_tokens if self.usage.cached_input_tokens is not None else None,
cache_creation_tokens=self.usage.cache_write_tokens if self.usage.cache_write_tokens is not None else None,
)
completion_details = None
if self.usage.reasoning_tokens:
if self.usage.reasoning_tokens is not None:
completion_details = UsageStatisticsCompletionTokenDetails(
reasoning_tokens=self.usage.reasoning_tokens,
)
@@ -890,10 +892,13 @@ class LettaAgentV2(BaseAgentV2):
self.usage.completion_tokens += step_usage_stats.completion_tokens
self.usage.prompt_tokens += step_usage_stats.prompt_tokens
self.usage.total_tokens += step_usage_stats.total_tokens
# Aggregate cache and reasoning token fields
self.usage.cached_input_tokens += step_usage_stats.cached_input_tokens
self.usage.cache_write_tokens += step_usage_stats.cache_write_tokens
self.usage.reasoning_tokens += step_usage_stats.reasoning_tokens
# Aggregate cache and reasoning token fields (handle None values)
if step_usage_stats.cached_input_tokens is not None:
self.usage.cached_input_tokens = (self.usage.cached_input_tokens or 0) + step_usage_stats.cached_input_tokens
if step_usage_stats.cache_write_tokens is not None:
self.usage.cache_write_tokens = (self.usage.cache_write_tokens or 0) + step_usage_stats.cache_write_tokens
if step_usage_stats.reasoning_tokens is not None:
self.usage.reasoning_tokens = (self.usage.reasoning_tokens or 0) + step_usage_stats.reasoning_tokens
@trace_method
async def _handle_ai_response(

View File

@@ -79,10 +79,12 @@ class SimpleGeminiStreamingInterface:
self.output_tokens = 0
# Cache token tracking (Gemini uses cached_content_token_count)
self.cached_tokens = 0
# None means "not reported by provider", 0 means "provider reported 0"
self.cached_tokens: int | None = None
# Thinking/reasoning token tracking (Gemini uses thoughts_token_count)
self.thinking_tokens = 0
# None means "not reported by provider", 0 means "provider reported 0"
self.thinking_tokens: int | None = None
def get_content(self) -> List[ReasoningContent | TextContent | ToolCallContent]:
"""This is (unusually) in chunked format, instead of merged"""
@@ -182,10 +184,12 @@ class SimpleGeminiStreamingInterface:
if usage_metadata.candidates_token_count:
self.output_tokens = usage_metadata.candidates_token_count
# Capture cache token data (Gemini uses cached_content_token_count)
if hasattr(usage_metadata, "cached_content_token_count") and usage_metadata.cached_content_token_count:
# Use `is not None` to capture 0 values (meaning "provider reported 0 cached tokens")
if hasattr(usage_metadata, "cached_content_token_count") and usage_metadata.cached_content_token_count is not None:
self.cached_tokens = usage_metadata.cached_content_token_count
# Capture thinking/reasoning token data (Gemini uses thoughts_token_count)
if hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count:
# Use `is not None` to capture 0 values (meaning "provider reported 0 reasoning tokens")
if hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count is not None:
self.thinking_tokens = usage_metadata.thoughts_token_count
if not event.candidates or len(event.candidates) == 0:

View File

@@ -538,8 +538,9 @@ class SimpleOpenAIStreamingInterface:
self.output_tokens = 0
# Cache and reasoning token tracking
self.cached_tokens = 0
self.reasoning_tokens = 0
# None means "not reported by provider", 0 means "provider reported 0"
self.cached_tokens: int | None = None
self.reasoning_tokens: int | None = None
# Fallback token counters (using tiktoken cl200k-base)
self.fallback_input_tokens = 0
@@ -707,14 +708,20 @@ class SimpleOpenAIStreamingInterface:
self.input_tokens += chunk.usage.prompt_tokens
self.output_tokens += chunk.usage.completion_tokens
# Capture cache token details (OpenAI)
# Use `is not None` to capture 0 values (meaning "provider reported 0 cached tokens")
if hasattr(chunk.usage, "prompt_tokens_details") and chunk.usage.prompt_tokens_details:
details = chunk.usage.prompt_tokens_details
if hasattr(details, "cached_tokens") and details.cached_tokens:
if hasattr(details, "cached_tokens") and details.cached_tokens is not None:
if self.cached_tokens is None:
self.cached_tokens = 0
self.cached_tokens += details.cached_tokens
# Capture reasoning token details (OpenAI o1/o3)
# Use `is not None` to capture 0 values (meaning "provider reported 0 reasoning tokens")
if hasattr(chunk.usage, "completion_tokens_details") and chunk.usage.completion_tokens_details:
details = chunk.usage.completion_tokens_details
if hasattr(details, "reasoning_tokens") and details.reasoning_tokens:
if hasattr(details, "reasoning_tokens") and details.reasoning_tokens is not None:
if self.reasoning_tokens is None:
self.reasoning_tokens = 0
self.reasoning_tokens += details.reasoning_tokens
if chunk.choices:
@@ -865,8 +872,9 @@ class SimpleOpenAIResponsesStreamingInterface:
self.output_tokens = 0
# Cache and reasoning token tracking
self.cached_tokens = 0
self.reasoning_tokens = 0
# None means "not reported by provider", 0 means "provider reported 0"
self.cached_tokens: int | None = None
self.reasoning_tokens: int | None = None
# -------- Mapping helpers (no broad try/except) --------
def _record_tool_mapping(self, event: object, item: object) -> tuple[str | None, str | None, int | None, str | None]:
@@ -1293,14 +1301,16 @@ class SimpleOpenAIResponsesStreamingInterface:
self.output_tokens = event.response.usage.output_tokens
self.message_id = event.response.id
# Capture cache token details (Responses API uses input_tokens_details)
# Use `is not None` to capture 0 values (meaning "provider reported 0 cached tokens")
if hasattr(event.response.usage, "input_tokens_details") and event.response.usage.input_tokens_details:
details = event.response.usage.input_tokens_details
if hasattr(details, "cached_tokens") and details.cached_tokens:
if hasattr(details, "cached_tokens") and details.cached_tokens is not None:
self.cached_tokens = details.cached_tokens
# Capture reasoning token details (Responses API uses output_tokens_details)
# Use `is not None` to capture 0 values (meaning "provider reported 0 reasoning tokens")
if hasattr(event.response.usage, "output_tokens_details") and event.response.usage.output_tokens_details:
details = event.response.usage.output_tokens_details
if hasattr(details, "reasoning_tokens") and details.reasoning_tokens:
if hasattr(details, "reasoning_tokens") and details.reasoning_tokens is not None:
self.reasoning_tokens = details.reasoning_tokens
return

View File

@@ -631,8 +631,12 @@ class GoogleVertexClient(LLMClientBase):
# }
if response.usage_metadata:
# Extract cache token data if available (Gemini uses cached_content_token_count)
# Use `is not None` to capture 0 values (meaning "provider reported 0 cached tokens")
prompt_tokens_details = None
if hasattr(response.usage_metadata, "cached_content_token_count") and response.usage_metadata.cached_content_token_count:
if (
hasattr(response.usage_metadata, "cached_content_token_count")
and response.usage_metadata.cached_content_token_count is not None
):
from letta.schemas.openai.chat_completion_response import UsageStatisticsPromptTokenDetails
prompt_tokens_details = UsageStatisticsPromptTokenDetails(
@@ -640,8 +644,9 @@ class GoogleVertexClient(LLMClientBase):
)
# Extract thinking/reasoning token data if available (Gemini uses thoughts_token_count)
# Use `is not None` to capture 0 values (meaning "provider reported 0 reasoning tokens")
completion_tokens_details = None
if hasattr(response.usage_metadata, "thoughts_token_count") and response.usage_metadata.thoughts_token_count:
if hasattr(response.usage_metadata, "thoughts_token_count") and response.usage_metadata.thoughts_token_count is not None:
from letta.schemas.openai.chat_completion_response import UsageStatisticsCompletionTokenDetails
completion_tokens_details = UsageStatisticsCompletionTokenDetails(

View File

@@ -78,30 +78,44 @@ class Choice(BaseModel):
class UsageStatisticsPromptTokenDetails(BaseModel):
cached_tokens: int = 0 # OpenAI/Gemini: tokens served from cache
cache_read_tokens: int = 0 # Anthropic: tokens read from cache
cache_creation_tokens: int = 0 # Anthropic: tokens written to cache
# None means provider didn't report this field, 0 means provider reported 0
cached_tokens: Optional[int] = None # OpenAI/Gemini: tokens served from cache
cache_read_tokens: Optional[int] = None # Anthropic: tokens read from cache
cache_creation_tokens: Optional[int] = None # Anthropic: tokens written to cache
# NOTE: OAI specific
# audio_tokens: int = 0
def __add__(self, other: "UsageStatisticsPromptTokenDetails") -> "UsageStatisticsPromptTokenDetails":
# Helper to add optional ints (None + None = None, None + N = N, N + M = N+M)
def add_optional(a: Optional[int], b: Optional[int]) -> Optional[int]:
if a is None and b is None:
return None
return (a or 0) + (b or 0)
return UsageStatisticsPromptTokenDetails(
cached_tokens=self.cached_tokens + other.cached_tokens,
cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens,
cache_creation_tokens=self.cache_creation_tokens + other.cache_creation_tokens,
cached_tokens=add_optional(self.cached_tokens, other.cached_tokens),
cache_read_tokens=add_optional(self.cache_read_tokens, other.cache_read_tokens),
cache_creation_tokens=add_optional(self.cache_creation_tokens, other.cache_creation_tokens),
)
class UsageStatisticsCompletionTokenDetails(BaseModel):
reasoning_tokens: int = 0
# None means provider didn't report this field, 0 means provider reported 0
reasoning_tokens: Optional[int] = None
# NOTE: OAI specific
# audio_tokens: int = 0
# accepted_prediction_tokens: int = 0
# rejected_prediction_tokens: int = 0
def __add__(self, other: "UsageStatisticsCompletionTokenDetails") -> "UsageStatisticsCompletionTokenDetails":
# Helper to add optional ints (None + None = None, None + N = N, N + M = N+M)
def add_optional(a: Optional[int], b: Optional[int]) -> Optional[int]:
if a is None and b is None:
return None
return (a or 0) + (b or 0)
return UsageStatisticsCompletionTokenDetails(
reasoning_tokens=self.reasoning_tokens + other.reasoning_tokens,
reasoning_tokens=add_optional(self.reasoning_tokens, other.reasoning_tokens),
)

View File

@@ -99,9 +99,9 @@ class LettaUsageStatistics(BaseModel):
prompt_tokens (int): The number of tokens in the prompt.
total_tokens (int): The total number of tokens processed by the agent.
step_count (int): The number of steps taken by the agent.
cached_input_tokens (int): The number of input tokens served from cache.
cache_write_tokens (int): The number of input tokens written to cache (Anthropic only).
reasoning_tokens (int): The number of reasoning/thinking tokens generated.
cached_input_tokens (Optional[int]): The number of input tokens served from cache. None if not reported.
cache_write_tokens (Optional[int]): The number of input tokens written to cache. None if not reported.
reasoning_tokens (Optional[int]): The number of reasoning/thinking tokens generated. None if not reported.
"""
message_type: Literal["usage_statistics"] = "usage_statistics"
@@ -113,8 +113,16 @@ class LettaUsageStatistics(BaseModel):
run_ids: Optional[List[str]] = Field(None, description="The background task run IDs associated with the agent interaction")
# Cache tracking (common across providers)
cached_input_tokens: int = Field(0, description="The number of input tokens served from cache.")
cache_write_tokens: int = Field(0, description="The number of input tokens written to cache (Anthropic only).")
# None means provider didn't report this data, 0 means provider reported 0
cached_input_tokens: Optional[int] = Field(
None, description="The number of input tokens served from cache. None if not reported by provider."
)
cache_write_tokens: Optional[int] = Field(
None, description="The number of input tokens written to cache (Anthropic only). None if not reported by provider."
)
# Reasoning token tracking
reasoning_tokens: int = Field(0, description="The number of reasoning/thinking tokens generated.")
# None means provider didn't report this data, 0 means provider reported 0
reasoning_tokens: Optional[int] = Field(
None, description="The number of reasoning/thinking tokens generated. None if not reported by provider."
)

View File

@@ -471,10 +471,15 @@ class RunManager:
total_usage.step_count += 1
# Aggregate cache and reasoning tokens from detailed breakdowns using normalized helpers
# Handle None defaults: only set if we have data, accumulate if already set
cached_input, cache_write = normalize_cache_tokens(step.prompt_tokens_details)
total_usage.cached_input_tokens += cached_input
total_usage.cache_write_tokens += cache_write
total_usage.reasoning_tokens += normalize_reasoning_tokens(step.completion_tokens_details)
if cached_input > 0 or total_usage.cached_input_tokens is not None:
total_usage.cached_input_tokens = (total_usage.cached_input_tokens or 0) + cached_input
if cache_write > 0 or total_usage.cache_write_tokens is not None:
total_usage.cache_write_tokens = (total_usage.cache_write_tokens or 0) + cache_write
reasoning = normalize_reasoning_tokens(step.completion_tokens_details)
if reasoning > 0 or total_usage.reasoning_tokens is not None:
total_usage.reasoning_tokens = (total_usage.reasoning_tokens or 0) + reasoning
return total_usage