feat: add tracking of advanced usage data (eg caching) [LET-6372] (#6449)

* feat: init refactor

* feat: add helper code

* fix: missing file + test

* fix: just state/publish api
This commit is contained in:
Charles Packer
2025-11-28 21:21:20 -08:00
committed by Caren Thomas
parent 807c5c18d9
commit 131891e05f
19 changed files with 895 additions and 9 deletions

View File

@@ -0,0 +1,29 @@
"""Add prompt_tokens_details to steps table
Revision ID: 175dd10fb916
Revises: b1c2d3e4f5a6
Create Date: 2025-11-28 12:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "175dd10fb916"
down_revision: Union[str, None] = "b1c2d3e4f5a6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add prompt_tokens_details JSON column to steps table
# This stores detailed prompt token breakdown (cached_tokens, cache_read_tokens, cache_creation_tokens)
op.add_column("steps", sa.Column("prompt_tokens_details", sa.JSON(), nullable=True))
def downgrade() -> None:
op.drop_column("steps", "prompt_tokens_details")

View File

@@ -30590,11 +30590,29 @@
],
"title": "Run Ids",
"description": "The background task run IDs associated with the agent interaction"
},
"cached_input_tokens": {
"type": "integer",
"title": "Cached Input Tokens",
"description": "The number of input tokens served from cache.",
"default": 0
},
"cache_write_tokens": {
"type": "integer",
"title": "Cache Write Tokens",
"description": "The number of input tokens written to cache (Anthropic only).",
"default": 0
},
"reasoning_tokens": {
"type": "integer",
"title": "Reasoning Tokens",
"description": "The number of reasoning/thinking tokens generated.",
"default": 0
}
},
"type": "object",
"title": "LettaUsageStatistics",
"description": "Usage statistics for the agent interaction.\n\nAttributes:\n completion_tokens (int): The number of tokens generated by the agent.\n prompt_tokens (int): The number of tokens in the prompt.\n total_tokens (int): The total number of tokens processed by the agent.\n step_count (int): The number of steps taken by the agent."
"description": "Usage statistics for the agent interaction.\n\nAttributes:\n completion_tokens (int): The number of tokens generated by the agent.\n prompt_tokens (int): The number of tokens in the prompt.\n total_tokens (int): The total number of tokens processed by the agent.\n step_count (int): The number of steps taken by the agent.\n cached_input_tokens (int): The number of input tokens served from cache.\n cache_write_tokens (int): The number of input tokens written to cache (Anthropic only).\n reasoning_tokens (int): The number of reasoning/thinking tokens generated."
},
"ListDeploymentEntitiesResponse": {
"properties": {
@@ -35071,7 +35089,20 @@
}
],
"title": "Completion Tokens Details",
"description": "Metadata for the agent."
"description": "Detailed completion token breakdown (e.g., reasoning_tokens)."
},
"prompt_tokens_details": {
"anyOf": [
{
"additionalProperties": true,
"type": "object"
},
{
"type": "null"
}
],
"title": "Prompt Tokens Details",
"description": "Detailed prompt token breakdown (e.g., cached_tokens, cache_read_tokens, cache_creation_tokens)."
},
"stop_reason": {
"anyOf": [
@@ -38178,6 +38209,16 @@
"type": "integer",
"title": "Cached Tokens",
"default": 0
},
"cache_read_tokens": {
"type": "integer",
"title": "Cache Read Tokens",
"default": 0
},
"cache_creation_tokens": {
"type": "integer",
"title": "Cache Creation Tokens",
"default": 0
}
},
"type": "object",

View File

@@ -4,6 +4,7 @@ from letta.adapters.letta_llm_request_adapter import LettaLLMRequestAdapter
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
from letta.schemas.letta_message import LettaMessage
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, TextContent
from letta.schemas.usage import normalize_cache_tokens, normalize_reasoning_tokens
class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
@@ -85,6 +86,11 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
self.usage.prompt_tokens = self.chat_completions_response.usage.prompt_tokens
self.usage.total_tokens = self.chat_completions_response.usage.total_tokens
# Extract cache and reasoning token details using normalized helpers
usage = self.chat_completions_response.usage
self.usage.cached_input_tokens, self.usage.cache_write_tokens = normalize_cache_tokens(usage.prompt_tokens_details)
self.usage.reasoning_tokens = normalize_reasoning_tokens(usage.completion_tokens_details)
self.log_provider_trace(step_id=step_id, actor=actor)
yield None

View File

@@ -158,11 +158,34 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
if not output_tokens and hasattr(self.interface, "fallback_output_tokens"):
output_tokens = self.interface.fallback_output_tokens
# Extract cache token data (OpenAI/Gemini use cached_tokens)
cached_input_tokens = 0
if hasattr(self.interface, "cached_tokens") and self.interface.cached_tokens:
cached_input_tokens = self.interface.cached_tokens
# Anthropic uses cache_read_tokens for cache hits
elif hasattr(self.interface, "cache_read_tokens") and self.interface.cache_read_tokens:
cached_input_tokens = self.interface.cache_read_tokens
# Extract cache write tokens (Anthropic only)
cache_write_tokens = 0
if hasattr(self.interface, "cache_creation_tokens") and self.interface.cache_creation_tokens:
cache_write_tokens = self.interface.cache_creation_tokens
# Extract reasoning tokens (OpenAI o1/o3 models use reasoning_tokens, Gemini uses thinking_tokens)
reasoning_tokens = 0
if hasattr(self.interface, "reasoning_tokens") and self.interface.reasoning_tokens:
reasoning_tokens = self.interface.reasoning_tokens
elif hasattr(self.interface, "thinking_tokens") and self.interface.thinking_tokens:
reasoning_tokens = self.interface.thinking_tokens
self.usage = LettaUsageStatistics(
step_count=1,
completion_tokens=output_tokens or 0,
prompt_tokens=input_tokens or 0,
total_tokens=(input_tokens or 0) + (output_tokens or 0),
cached_input_tokens=cached_input_tokens,
cache_write_tokens=cache_write_tokens,
reasoning_tokens=reasoning_tokens,
)
else:
# Default usage statistics if not available

View File

@@ -42,7 +42,13 @@ from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message, MessageCreateBase
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall, UsageStatistics
from letta.schemas.openai.chat_completion_response import (
FunctionCall,
ToolCall,
UsageStatistics,
UsageStatisticsCompletionTokenDetails,
UsageStatisticsPromptTokenDetails,
)
from letta.schemas.provider_trace import ProviderTraceCreate
from letta.schemas.step import StepProgression
from letta.schemas.step_metrics import StepMetrics
@@ -1077,6 +1083,15 @@ class LettaAgent(BaseAgent):
usage.completion_tokens += interface.output_tokens
usage.prompt_tokens += interface.input_tokens
usage.total_tokens += interface.input_tokens + interface.output_tokens
# Aggregate cache and reasoning tokens if available from streaming interface
if hasattr(interface, "cached_tokens") and interface.cached_tokens:
usage.cached_input_tokens += interface.cached_tokens
if hasattr(interface, "cache_read_tokens") and interface.cache_read_tokens:
usage.cached_input_tokens += interface.cache_read_tokens
if hasattr(interface, "cache_creation_tokens") and interface.cache_creation_tokens:
usage.cache_write_tokens += interface.cache_creation_tokens
if hasattr(interface, "reasoning_tokens") and interface.reasoning_tokens:
usage.reasoning_tokens += interface.reasoning_tokens
MetricRegistry().message_output_tokens.record(
usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
)
@@ -1124,6 +1139,21 @@ class LettaAgent(BaseAgent):
# Update step with actual usage now that we have it (if step was created)
if logged_step:
# Build detailed token breakdowns from LettaUsageStatistics
prompt_details = None
if usage.cached_input_tokens or usage.cache_write_tokens:
prompt_details = UsageStatisticsPromptTokenDetails(
cached_tokens=usage.cached_input_tokens,
cache_read_tokens=usage.cached_input_tokens,
cache_creation_tokens=usage.cache_write_tokens,
)
completion_details = None
if usage.reasoning_tokens:
completion_details = UsageStatisticsCompletionTokenDetails(
reasoning_tokens=usage.reasoning_tokens,
)
await self.step_manager.update_step_success_async(
self.actor,
step_id,
@@ -1131,6 +1161,8 @@ class LettaAgent(BaseAgent):
completion_tokens=usage.completion_tokens,
prompt_tokens=usage.prompt_tokens,
total_tokens=usage.total_tokens,
prompt_tokens_details=prompt_details,
completion_tokens_details=completion_details,
),
stop_reason,
)

View File

@@ -37,7 +37,13 @@ from letta.schemas.letta_message_content import OmittedReasoningContent, Reasoni
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall, UsageStatistics
from letta.schemas.openai.chat_completion_response import (
FunctionCall,
ToolCall,
UsageStatistics,
UsageStatisticsCompletionTokenDetails,
UsageStatisticsPromptTokenDetails,
)
from letta.schemas.step import Step, StepProgression
from letta.schemas.step_metrics import StepMetrics
from letta.schemas.tool import Tool
@@ -850,6 +856,21 @@ class LettaAgentV2(BaseAgentV2):
# Update step with actual usage now that we have it (if step was created)
if logged_step:
# Build detailed token breakdowns from LettaUsageStatistics
prompt_details = None
if self.usage.cached_input_tokens or self.usage.cache_write_tokens:
prompt_details = UsageStatisticsPromptTokenDetails(
cached_tokens=self.usage.cached_input_tokens,
cache_read_tokens=self.usage.cached_input_tokens, # Normalized from various providers
cache_creation_tokens=self.usage.cache_write_tokens,
)
completion_details = None
if self.usage.reasoning_tokens:
completion_details = UsageStatisticsCompletionTokenDetails(
reasoning_tokens=self.usage.reasoning_tokens,
)
await self.step_manager.update_step_success_async(
self.actor,
step_metrics.id,
@@ -857,6 +878,8 @@ class LettaAgentV2(BaseAgentV2):
completion_tokens=self.usage.completion_tokens,
prompt_tokens=self.usage.prompt_tokens,
total_tokens=self.usage.total_tokens,
prompt_tokens_details=prompt_details,
completion_tokens_details=completion_details,
),
self.stop_reason,
)
@@ -867,6 +890,10 @@ class LettaAgentV2(BaseAgentV2):
self.usage.completion_tokens += step_usage_stats.completion_tokens
self.usage.prompt_tokens += step_usage_stats.prompt_tokens
self.usage.total_tokens += step_usage_stats.total_tokens
# Aggregate cache and reasoning token fields
self.usage.cached_input_tokens += step_usage_stats.cached_input_tokens
self.usage.cache_write_tokens += step_usage_stats.cache_write_tokens
self.usage.reasoning_tokens += step_usage_stats.reasoning_tokens
@trace_method
async def _handle_ai_response(

View File

@@ -94,6 +94,10 @@ class SimpleAnthropicStreamingInterface:
self.output_tokens = 0
self.model = None
# cache tracking (Anthropic-specific)
self.cache_read_tokens = 0
self.cache_creation_tokens = 0
# reasoning object trackers
self.reasoning_messages = []
@@ -463,6 +467,13 @@ class SimpleAnthropicStreamingInterface:
self.output_tokens += event.message.usage.output_tokens
self.model = event.message.model
# Capture cache data if available
usage = event.message.usage
if hasattr(usage, "cache_read_input_tokens") and usage.cache_read_input_tokens:
self.cache_read_tokens += usage.cache_read_input_tokens
if hasattr(usage, "cache_creation_input_tokens") and usage.cache_creation_input_tokens:
self.cache_creation_tokens += usage.cache_creation_input_tokens
elif isinstance(event, BetaRawMessageDeltaEvent):
self.output_tokens += event.usage.output_tokens

View File

@@ -74,6 +74,16 @@ class SimpleGeminiStreamingInterface:
# Sadly, Gemini's encrypted reasoning logic forces us to store stream parts in state
self.content_parts: List[ReasoningContent | TextContent | ToolCallContent] = []
# Token counters
self.input_tokens = 0
self.output_tokens = 0
# Cache token tracking (Gemini uses cached_content_token_count)
self.cached_tokens = 0
# Thinking/reasoning token tracking (Gemini uses thoughts_token_count)
self.thinking_tokens = 0
def get_content(self) -> List[ReasoningContent | TextContent | ToolCallContent]:
"""This is (unusually) in chunked format, instead of merged"""
for content in self.content_parts:
@@ -171,6 +181,12 @@ class SimpleGeminiStreamingInterface:
# includes thinking/reasoning tokens which can be 10-100x the actual output.
if usage_metadata.candidates_token_count:
self.output_tokens = usage_metadata.candidates_token_count
# Capture cache token data (Gemini uses cached_content_token_count)
if hasattr(usage_metadata, "cached_content_token_count") and usage_metadata.cached_content_token_count:
self.cached_tokens = usage_metadata.cached_content_token_count
# Capture thinking/reasoning token data (Gemini uses thoughts_token_count)
if hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count:
self.thinking_tokens = usage_metadata.thoughts_token_count
if not event.candidates or len(event.candidates) == 0:
return

View File

@@ -537,6 +537,10 @@ class SimpleOpenAIStreamingInterface:
self.input_tokens = 0
self.output_tokens = 0
# Cache and reasoning token tracking
self.cached_tokens = 0
self.reasoning_tokens = 0
# Fallback token counters (using tiktoken cl200k-base)
self.fallback_input_tokens = 0
self.fallback_output_tokens = 0
@@ -702,6 +706,16 @@ class SimpleOpenAIStreamingInterface:
if chunk.usage:
self.input_tokens += chunk.usage.prompt_tokens
self.output_tokens += chunk.usage.completion_tokens
# Capture cache token details (OpenAI)
if hasattr(chunk.usage, "prompt_tokens_details") and chunk.usage.prompt_tokens_details:
details = chunk.usage.prompt_tokens_details
if hasattr(details, "cached_tokens") and details.cached_tokens:
self.cached_tokens += details.cached_tokens
# Capture reasoning token details (OpenAI o1/o3)
if hasattr(chunk.usage, "completion_tokens_details") and chunk.usage.completion_tokens_details:
details = chunk.usage.completion_tokens_details
if hasattr(details, "reasoning_tokens") and details.reasoning_tokens:
self.reasoning_tokens += details.reasoning_tokens
if chunk.choices:
choice = chunk.choices[0]
@@ -846,6 +860,14 @@ class SimpleOpenAIResponsesStreamingInterface:
self.model = model
self.final_response: Optional[ParsedResponse] = None
# Token counters
self.input_tokens = 0
self.output_tokens = 0
# Cache and reasoning token tracking
self.cached_tokens = 0
self.reasoning_tokens = 0
# -------- Mapping helpers (no broad try/except) --------
def _record_tool_mapping(self, event: object, item: object) -> tuple[str | None, str | None, int | None, str | None]:
"""Record call_id/name mapping for this tool-call using output_index and item.id if present.
@@ -1270,6 +1292,16 @@ class SimpleOpenAIResponsesStreamingInterface:
self.input_tokens = event.response.usage.input_tokens
self.output_tokens = event.response.usage.output_tokens
self.message_id = event.response.id
# Capture cache token details (Responses API uses input_tokens_details)
if hasattr(event.response.usage, "input_tokens_details") and event.response.usage.input_tokens_details:
details = event.response.usage.input_tokens_details
if hasattr(details, "cached_tokens") and details.cached_tokens:
self.cached_tokens = details.cached_tokens
# Capture reasoning token details (Responses API uses output_tokens_details)
if hasattr(event.response.usage, "output_tokens_details") and event.response.usage.output_tokens_details:
details = event.response.usage.output_tokens_details
if hasattr(details, "reasoning_tokens") and details.reasoning_tokens:
self.reasoning_tokens = details.reasoning_tokens
return
else:

View File

@@ -848,6 +848,16 @@ class AnthropicClient(LLMClientBase):
),
)
# Build prompt tokens details with cache data if available
prompt_tokens_details = None
if hasattr(response.usage, "cache_read_input_tokens") or hasattr(response.usage, "cache_creation_input_tokens"):
from letta.schemas.openai.chat_completion_response import UsageStatisticsPromptTokenDetails
prompt_tokens_details = UsageStatisticsPromptTokenDetails(
cache_read_tokens=getattr(response.usage, "cache_read_input_tokens", 0) or 0,
cache_creation_tokens=getattr(response.usage, "cache_creation_input_tokens", 0) or 0,
)
chat_completion_response = ChatCompletionResponse(
id=response.id,
choices=[choice],
@@ -857,6 +867,7 @@ class AnthropicClient(LLMClientBase):
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens_details=prompt_tokens_details,
),
)
if llm_config.put_inner_thoughts_in_kwargs:

View File

@@ -630,10 +630,30 @@ class GoogleVertexClient(LLMClientBase):
# "totalTokenCount": 36
# }
if response.usage_metadata:
# Extract cache token data if available (Gemini uses cached_content_token_count)
prompt_tokens_details = None
if hasattr(response.usage_metadata, "cached_content_token_count") and response.usage_metadata.cached_content_token_count:
from letta.schemas.openai.chat_completion_response import UsageStatisticsPromptTokenDetails
prompt_tokens_details = UsageStatisticsPromptTokenDetails(
cached_tokens=response.usage_metadata.cached_content_token_count,
)
# Extract thinking/reasoning token data if available (Gemini uses thoughts_token_count)
completion_tokens_details = None
if hasattr(response.usage_metadata, "thoughts_token_count") and response.usage_metadata.thoughts_token_count:
from letta.schemas.openai.chat_completion_response import UsageStatisticsCompletionTokenDetails
completion_tokens_details = UsageStatisticsCompletionTokenDetails(
reasoning_tokens=response.usage_metadata.thoughts_token_count,
)
usage = UsageStatistics(
prompt_tokens=response.usage_metadata.prompt_token_count,
completion_tokens=response.usage_metadata.candidates_token_count,
total_tokens=response.usage_metadata.total_token_count,
prompt_tokens_details=prompt_tokens_details,
completion_tokens_details=completion_tokens_details,
)
else:
# Count it ourselves

View File

@@ -630,6 +630,25 @@ class OpenAIClient(LLMClientBase):
completion_tokens = usage.get("output_tokens") or 0
total_tokens = usage.get("total_tokens") or (prompt_tokens + completion_tokens)
# Extract detailed token breakdowns (Responses API uses input_tokens_details/output_tokens_details)
prompt_tokens_details = None
input_details = usage.get("input_tokens_details", {}) or {}
if input_details.get("cached_tokens"):
from letta.schemas.openai.chat_completion_response import UsageStatisticsPromptTokenDetails
prompt_tokens_details = UsageStatisticsPromptTokenDetails(
cached_tokens=input_details.get("cached_tokens") or 0,
)
completion_tokens_details = None
output_details = usage.get("output_tokens_details", {}) or {}
if output_details.get("reasoning_tokens"):
from letta.schemas.openai.chat_completion_response import UsageStatisticsCompletionTokenDetails
completion_tokens_details = UsageStatisticsCompletionTokenDetails(
reasoning_tokens=output_details.get("reasoning_tokens") or 0,
)
# Extract assistant message text from the outputs list
outputs = response_data.get("output") or []
assistant_text_parts = []
@@ -692,6 +711,8 @@ class OpenAIClient(LLMClientBase):
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens_details=prompt_tokens_details,
completion_tokens_details=completion_tokens_details,
),
)

View File

@@ -50,7 +50,12 @@ class Step(SqlalchemyBase, ProjectMixin):
completion_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens generated by the agent")
prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt")
total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent")
completion_tokens_details: Mapped[Optional[Dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
completion_tokens_details: Mapped[Optional[Dict]] = mapped_column(
JSON, nullable=True, doc="Detailed completion token breakdown (e.g., reasoning_tokens)."
)
prompt_tokens_details: Mapped[Optional[Dict]] = mapped_column(
JSON, nullable=True, doc="Detailed prompt token breakdown (e.g., cached_tokens, cache_read_tokens, cache_creation_tokens)."
)
stop_reason: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The stop reason associated with this step.")
tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.")
tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.")

View File

@@ -78,13 +78,17 @@ class Choice(BaseModel):
class UsageStatisticsPromptTokenDetails(BaseModel):
cached_tokens: int = 0
cached_tokens: int = 0 # OpenAI/Gemini: tokens served from cache
cache_read_tokens: int = 0 # Anthropic: tokens read from cache
cache_creation_tokens: int = 0 # Anthropic: tokens written to cache
# NOTE: OAI specific
# audio_tokens: int = 0
def __add__(self, other: "UsageStatisticsPromptTokenDetails") -> "UsageStatisticsPromptTokenDetails":
return UsageStatisticsPromptTokenDetails(
cached_tokens=self.cached_tokens + other.cached_tokens,
cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens,
cache_creation_tokens=self.cache_creation_tokens + other.cache_creation_tokens,
)

View File

@@ -30,7 +30,10 @@ class Step(StepBase):
completion_tokens: Optional[int] = Field(None, description="The number of tokens generated by the agent during this step.")
prompt_tokens: Optional[int] = Field(None, description="The number of tokens in the prompt during this step.")
total_tokens: Optional[int] = Field(None, description="The total number of tokens processed by the agent during this step.")
completion_tokens_details: Optional[Dict] = Field(None, description="Metadata for the agent.")
completion_tokens_details: Optional[Dict] = Field(None, description="Detailed completion token breakdown (e.g., reasoning_tokens).")
prompt_tokens_details: Optional[Dict] = Field(
None, description="Detailed prompt token breakdown (e.g., cached_tokens, cache_read_tokens, cache_creation_tokens)."
)
stop_reason: Optional[StopReasonType] = Field(None, description="The stop reason associated with the step.")
tags: List[str] = Field([], description="Metadata tags.")
tid: Optional[str] = Field(None, description="The unique identifier of the transaction that processed this step.")

View File

@@ -1,9 +1,94 @@
from typing import List, Literal, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, Field
from letta.schemas.message import Message
if TYPE_CHECKING:
from letta.schemas.openai.chat_completion_response import (
UsageStatisticsCompletionTokenDetails,
UsageStatisticsPromptTokenDetails,
)
def normalize_cache_tokens(
prompt_details: Union["UsageStatisticsPromptTokenDetails", Dict[str, Any], None],
) -> Tuple[int, int]:
"""
Extract normalized cache token counts from provider-specific prompt details.
Handles both Pydantic model objects (from adapters) and dict objects (from database).
Provider mappings:
- OpenAI/Gemini: cached_tokens -> cached_input_tokens
- Anthropic: cache_read_tokens -> cached_input_tokens, cache_creation_tokens -> cache_write_tokens
Args:
prompt_details: Provider-specific prompt token details (model or dict)
Returns:
Tuple of (cached_input_tokens, cache_write_tokens)
"""
if prompt_details is None:
return 0, 0
# Handle dict (from database storage)
if isinstance(prompt_details, dict):
cached_input = 0
if prompt_details.get("cached_tokens"):
cached_input = prompt_details.get("cached_tokens", 0)
elif prompt_details.get("cache_read_tokens"):
cached_input = prompt_details.get("cache_read_tokens", 0)
cache_write = prompt_details.get("cache_creation_tokens", 0) or 0
return cached_input, cache_write
# Handle Pydantic model (from adapters)
cached_input = 0
if hasattr(prompt_details, "cached_tokens") and prompt_details.cached_tokens:
cached_input = prompt_details.cached_tokens
elif hasattr(prompt_details, "cache_read_tokens") and prompt_details.cache_read_tokens:
cached_input = prompt_details.cache_read_tokens
cache_write = 0
if hasattr(prompt_details, "cache_creation_tokens") and prompt_details.cache_creation_tokens:
cache_write = prompt_details.cache_creation_tokens
return cached_input, cache_write
def normalize_reasoning_tokens(
completion_details: Union["UsageStatisticsCompletionTokenDetails", Dict[str, Any], None],
) -> int:
"""
Extract normalized reasoning token count from provider-specific completion details.
Handles both Pydantic model objects (from adapters) and dict objects (from database).
Provider mappings:
- OpenAI: completion_tokens_details.reasoning_tokens
- Gemini: thoughts_token_count (mapped to reasoning_tokens in UsageStatistics)
- Anthropic: thinking tokens are included in completion_tokens, not separately tracked
Args:
completion_details: Provider-specific completion token details (model or dict)
Returns:
The reasoning token count
"""
if completion_details is None:
return 0
# Handle dict (from database storage)
if isinstance(completion_details, dict):
return completion_details.get("reasoning_tokens", 0) or 0
# Handle Pydantic model (from adapters)
if hasattr(completion_details, "reasoning_tokens") and completion_details.reasoning_tokens:
return completion_details.reasoning_tokens
return 0
class LettaUsageStatistics(BaseModel):
"""
@@ -14,6 +99,9 @@ class LettaUsageStatistics(BaseModel):
prompt_tokens (int): The number of tokens in the prompt.
total_tokens (int): The total number of tokens processed by the agent.
step_count (int): The number of steps taken by the agent.
cached_input_tokens (int): The number of input tokens served from cache.
cache_write_tokens (int): The number of input tokens written to cache (Anthropic only).
reasoning_tokens (int): The number of reasoning/thinking tokens generated.
"""
message_type: Literal["usage_statistics"] = "usage_statistics"
@@ -23,3 +111,10 @@ class LettaUsageStatistics(BaseModel):
step_count: int = Field(0, description="The number of steps taken by the agent.")
# TODO: Optional for now. This field makes everyone's lives easier
run_ids: Optional[List[str]] = Field(None, description="The background task run IDs associated with the agent interaction")
# Cache tracking (common across providers)
cached_input_tokens: int = Field(0, description="The number of input tokens served from cache.")
cache_write_tokens: int = Field(0, description="The number of input tokens written to cache (Anthropic only).")
# Reasoning token tracking
reasoning_tokens: int = Field(0, description="The number of reasoning/thinking tokens generated.")

View File

@@ -23,7 +23,7 @@ from letta.schemas.message import Message as PydanticMessage
from letta.schemas.run import Run as PydanticRun, RunUpdate
from letta.schemas.run_metrics import RunMetrics as PydanticRunMetrics
from letta.schemas.step import Step as PydanticStep
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.usage import LettaUsageStatistics, normalize_cache_tokens, normalize_reasoning_tokens
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.services.agent_manager import AgentManager
@@ -469,6 +469,13 @@ class RunManager:
total_usage.completion_tokens += step.completion_tokens
total_usage.total_tokens += step.total_tokens
total_usage.step_count += 1
# Aggregate cache and reasoning tokens from detailed breakdowns using normalized helpers
cached_input, cache_write = normalize_cache_tokens(step.prompt_tokens_details)
total_usage.cached_input_tokens += cached_input
total_usage.cache_write_tokens += cache_write
total_usage.reasoning_tokens += normalize_reasoning_tokens(step.completion_tokens_details)
return total_usage
@enforce_types

View File

@@ -408,6 +408,12 @@ class StepManager:
if stop_reason:
step.stop_reason = stop_reason.stop_reason
# Persist detailed token breakdowns if available
if usage.prompt_tokens_details:
step.prompt_tokens_details = usage.prompt_tokens_details.model_dump()
if usage.completion_tokens_details:
step.completion_tokens_details = usage.completion_tokens_details.model_dump()
await session.commit()
pydantic_step = step.to_pydantic()

View File

@@ -0,0 +1,497 @@
"""
Integration tests for advanced usage tracking (cache tokens, reasoning tokens).
These tests verify that:
1. Cache token data (cached_input_tokens, cache_write_tokens) is captured from providers
2. Reasoning token data is captured from reasoning models
3. The data flows correctly through streaming and non-streaming paths
4. Step-level and run-level aggregation works correctly
Provider-specific cache field mappings:
- Anthropic: cache_read_input_tokens, cache_creation_input_tokens
- OpenAI: prompt_tokens_details.cached_tokens, completion_tokens_details.reasoning_tokens
- Gemini: cached_content_token_count
"""
import json
import logging
import os
import uuid
from typing import Any, Dict, List, Optional, Tuple
import pytest
from dotenv import load_dotenv
from letta_client import AsyncLetta
from letta_client.types import (
AgentState,
MessageCreateParam,
)
from letta_client.types.agents import Run
from letta_client.types.agents.letta_streaming_response import LettaUsageStatistics
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
# ------------------------------
# Test Configuration
# ------------------------------
# Model configs for testing - these models should support caching or reasoning
CACHE_TEST_CONFIGS = [
# Anthropic Sonnet 4.5 with prompt caching
("anthropic/claude-sonnet-4-5-20250514", {"provider_type": "anthropic"}),
# OpenAI gpt-4o with prompt caching (Chat Completions API)
("openai/gpt-4o", {"provider_type": "openai"}),
# Gemini 3 Pro Preview with context caching
("google_ai/gemini-3-pro-preview", {"provider_type": "google_ai"}),
]
REASONING_TEST_CONFIGS = [
# Anthropic Sonnet 4.5 with thinking enabled
(
"anthropic/claude-sonnet-4-5-20250514",
{"provider_type": "anthropic", "thinking": {"type": "enabled", "budget_tokens": 1024}},
),
# OpenAI gpt-5.1 reasoning model (Responses API)
("openai/gpt-5.1", {"provider_type": "openai", "reasoning": {"reasoning_effort": "low"}}),
# Gemini 3 Pro Preview with thinking enabled
(
"google_ai/gemini-3-pro-preview",
{"provider_type": "google_ai", "thinking_config": {"include_thoughts": True, "thinking_budget": 1024}},
),
]
# Filter based on environment variable if set
requested = os.getenv("USAGE_TEST_CONFIG")
if requested:
# Filter configs to only include the requested one
CACHE_TEST_CONFIGS = [(h, s) for h, s in CACHE_TEST_CONFIGS if requested in h]
REASONING_TEST_CONFIGS = [(h, s) for h, s in REASONING_TEST_CONFIGS if requested in h]
def get_model_config(filename: str, model_settings_dir: str = "tests/model_settings") -> Tuple[str, dict]:
"""Load a model_settings file and return the handle and settings dict."""
filepath = os.path.join(model_settings_dir, filename)
with open(filepath, "r") as f:
config_data = json.load(f)
return config_data["handle"], config_data.get("model_settings", {})
# ------------------------------
# Fixtures
# ------------------------------
@pytest.fixture
def base_url() -> str:
"""Get the Letta server URL from environment or use default."""
return os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
@pytest.fixture
async def async_client(base_url: str) -> AsyncLetta:
"""Create an async Letta client."""
token = os.getenv("LETTA_SERVER_TOKEN")
return AsyncLetta(base_url=base_url, token=token)
# ------------------------------
# Helper Functions
# ------------------------------
async def create_test_agent(
client: AsyncLetta,
model_handle: str,
model_settings: dict,
name_suffix: str = "",
) -> AgentState:
"""Create a test agent with the specified model configuration."""
agent = await client.agents.create(
name=f"usage-test-agent-{name_suffix}-{uuid.uuid4().hex[:8]}",
model=model_handle,
model_settings=model_settings,
include_base_tools=False, # Keep it simple for usage testing
)
return agent
async def cleanup_agent(client: AsyncLetta, agent_id: str) -> None:
"""Delete a test agent."""
try:
await client.agents.delete(agent_id)
except Exception as e:
logger.warning(f"Failed to cleanup agent {agent_id}: {e}")
def extract_usage_from_stream(messages: List[Any]) -> Optional[LettaUsageStatistics]:
"""Extract LettaUsageStatistics from a stream response."""
for msg in reversed(messages):
if isinstance(msg, LettaUsageStatistics):
return msg
return None
# ------------------------------
# Cache Token Tests
# ------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize("model_handle,model_settings", CACHE_TEST_CONFIGS)
async def test_cache_tokens_streaming(
async_client: AsyncLetta,
model_handle: str,
model_settings: dict,
) -> None:
"""
Test that cache token data is captured in streaming mode.
Cache hits typically occur on the second+ request with the same context,
so we send multiple messages to trigger caching.
"""
agent = await create_test_agent(async_client, model_handle, model_settings, "cache-stream")
try:
# First message - likely cache write (cache_creation_tokens for Anthropic)
messages1: List[Any] = []
async for chunk in async_client.agents.messages.send_message_streaming(
agent_id=agent.id,
messages=[MessageCreateParam(role="user", content="Hello, this is a test message for caching.")],
):
messages1.append(chunk)
usage1 = extract_usage_from_stream(messages1)
assert usage1 is not None, "Should receive usage statistics in stream"
assert usage1.prompt_tokens > 0, "Should have prompt tokens"
# Log first call usage for debugging
logger.info(
f"First call usage ({model_handle}): prompt={usage1.prompt_tokens}, "
f"cached_input={usage1.cached_input_tokens}, cache_write={usage1.cache_write_tokens}"
)
# Second message - same agent/context should trigger cache hits
messages2: List[Any] = []
async for chunk in async_client.agents.messages.send_message_streaming(
agent_id=agent.id,
messages=[MessageCreateParam(role="user", content="This is a follow-up message.")],
):
messages2.append(chunk)
usage2 = extract_usage_from_stream(messages2)
assert usage2 is not None, "Should receive usage statistics in stream"
# Log second call usage
logger.info(
f"Second call usage ({model_handle}): prompt={usage2.prompt_tokens}, "
f"cached_input={usage2.cached_input_tokens}, cache_write={usage2.cache_write_tokens}"
)
# Verify cache fields exist (values may be 0 if caching not available for this model/config)
assert hasattr(usage2, "cached_input_tokens"), "Should have cached_input_tokens field"
assert hasattr(usage2, "cache_write_tokens"), "Should have cache_write_tokens field"
# For providers with caching enabled, we expect either:
# - cache_write_tokens > 0 on first call (writing to cache)
# - cached_input_tokens > 0 on second call (reading from cache)
# Note: Not all providers always return cache data, so we just verify the fields exist
finally:
await cleanup_agent(async_client, agent.id)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_handle,model_settings", CACHE_TEST_CONFIGS)
async def test_cache_tokens_non_streaming(
async_client: AsyncLetta,
model_handle: str,
model_settings: dict,
) -> None:
"""
Test that cache token data is captured in non-streaming (blocking) mode.
"""
agent = await create_test_agent(async_client, model_handle, model_settings, "cache-blocking")
try:
# First message
response1: Run = await async_client.agents.messages.send_message(
agent_id=agent.id,
messages=[MessageCreateParam(role="user", content="Hello, this is a test message for caching.")],
)
assert response1.usage is not None, "Should have usage in response"
logger.info(
f"First call usage ({model_handle}): prompt={response1.usage.prompt_tokens}, "
f"cached_input={response1.usage.cached_input_tokens}, cache_write={response1.usage.cache_write_tokens}"
)
# Second message - should trigger cache hit
response2: Run = await async_client.agents.messages.send_message(
agent_id=agent.id,
messages=[MessageCreateParam(role="user", content="This is a follow-up message.")],
)
assert response2.usage is not None, "Should have usage in response"
logger.info(
f"Second call usage ({model_handle}): prompt={response2.usage.prompt_tokens}, "
f"cached_input={response2.usage.cached_input_tokens}, cache_write={response2.usage.cache_write_tokens}"
)
# Verify cache fields exist
assert hasattr(response2.usage, "cached_input_tokens"), "Should have cached_input_tokens field"
assert hasattr(response2.usage, "cache_write_tokens"), "Should have cache_write_tokens field"
finally:
await cleanup_agent(async_client, agent.id)
# ------------------------------
# Reasoning Token Tests
# ------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize("model_handle,model_settings", REASONING_TEST_CONFIGS)
async def test_reasoning_tokens_streaming(
async_client: AsyncLetta,
model_handle: str,
model_settings: dict,
) -> None:
"""
Test that reasoning token data is captured from reasoning models in streaming mode.
"""
agent = await create_test_agent(async_client, model_handle, model_settings, "reasoning-stream")
try:
messages: List[Any] = []
async for chunk in async_client.agents.messages.send_message_streaming(
agent_id=agent.id,
messages=[MessageCreateParam(role="user", content="Think step by step: what is 2 + 2? Explain your reasoning.")],
):
messages.append(chunk)
usage = extract_usage_from_stream(messages)
assert usage is not None, "Should receive usage statistics in stream"
logger.info(
f"Reasoning usage ({model_handle}): prompt={usage.prompt_tokens}, "
f"completion={usage.completion_tokens}, reasoning={usage.reasoning_tokens}"
)
# Verify reasoning_tokens field exists
assert hasattr(usage, "reasoning_tokens"), "Should have reasoning_tokens field"
# For reasoning models, we expect reasoning_tokens > 0
# Note: Some providers may not always return reasoning token counts
if "gpt-5" in model_handle or "o3" in model_handle or "o1" in model_handle:
# OpenAI reasoning models should always have reasoning tokens
assert usage.reasoning_tokens > 0, f"OpenAI reasoning model {model_handle} should have reasoning_tokens > 0"
finally:
await cleanup_agent(async_client, agent.id)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_handle,model_settings", REASONING_TEST_CONFIGS)
async def test_reasoning_tokens_non_streaming(
async_client: AsyncLetta,
model_handle: str,
model_settings: dict,
) -> None:
"""
Test that reasoning token data is captured from reasoning models in non-streaming mode.
"""
agent = await create_test_agent(async_client, model_handle, model_settings, "reasoning-blocking")
try:
response: Run = await async_client.agents.messages.send_message(
agent_id=agent.id,
messages=[MessageCreateParam(role="user", content="Think step by step: what is 2 + 2? Explain your reasoning.")],
)
assert response.usage is not None, "Should have usage in response"
logger.info(
f"Reasoning usage ({model_handle}): prompt={response.usage.prompt_tokens}, "
f"completion={response.usage.completion_tokens}, reasoning={response.usage.reasoning_tokens}"
)
# Verify reasoning_tokens field exists
assert hasattr(response.usage, "reasoning_tokens"), "Should have reasoning_tokens field"
# For OpenAI reasoning models, we expect reasoning_tokens > 0
if "gpt-5" in model_handle or "o3" in model_handle or "o1" in model_handle:
assert response.usage.reasoning_tokens > 0, f"OpenAI reasoning model {model_handle} should have reasoning_tokens > 0"
finally:
await cleanup_agent(async_client, agent.id)
# ------------------------------
# Step-Level Usage Tests
# ------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize("model_handle,model_settings", CACHE_TEST_CONFIGS[:1]) # Test with one config
async def test_step_level_usage_details(
async_client: AsyncLetta,
model_handle: str,
model_settings: dict,
) -> None:
"""
Test that step-level usage details (prompt_tokens_details, completion_tokens_details)
are properly persisted and retrievable.
"""
agent = await create_test_agent(async_client, model_handle, model_settings, "step-details")
try:
# Send a message to create a step
response: Run = await async_client.agents.messages.send_message(
agent_id=agent.id,
messages=[MessageCreateParam(role="user", content="Hello!")],
)
# Get the run's steps
steps = await async_client.runs.list_steps(run_id=response.id)
assert len(steps) > 0, "Should have at least one step"
step = steps[0]
logger.info(
f"Step usage ({model_handle}): prompt_tokens={step.prompt_tokens}, "
f"prompt_tokens_details={step.prompt_tokens_details}, "
f"completion_tokens_details={step.completion_tokens_details}"
)
# Verify the step has the usage fields
assert step.prompt_tokens > 0, "Step should have prompt_tokens"
assert step.completion_tokens >= 0, "Step should have completion_tokens"
assert step.total_tokens > 0, "Step should have total_tokens"
# The details fields may be None if no cache/reasoning was involved,
# but they should be present in the schema
# Note: This test mainly verifies the field exists and can hold data
finally:
await cleanup_agent(async_client, agent.id)
# ------------------------------
# Run-Level Aggregation Tests
# ------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize("model_handle,model_settings", CACHE_TEST_CONFIGS[:1]) # Test with one config
async def test_run_level_usage_aggregation(
async_client: AsyncLetta,
model_handle: str,
model_settings: dict,
) -> None:
"""
Test that run-level usage correctly aggregates cache/reasoning tokens from steps.
"""
agent = await create_test_agent(async_client, model_handle, model_settings, "run-aggregation")
try:
# Send multiple messages to create multiple steps
response1: Run = await async_client.agents.messages.send_message(
agent_id=agent.id,
messages=[MessageCreateParam(role="user", content="Message 1")],
)
response2: Run = await async_client.agents.messages.send_message(
agent_id=agent.id,
messages=[MessageCreateParam(role="user", content="Message 2")],
)
# Get run usage for the second run (which should have accumulated context)
run_usage = await async_client.runs.get_run_usage(run_id=response2.id)
logger.info(
f"Run usage ({model_handle}): prompt={run_usage.prompt_tokens}, "
f"completion={run_usage.completion_tokens}, total={run_usage.total_tokens}, "
f"cached_input={run_usage.cached_input_tokens}, cache_write={run_usage.cache_write_tokens}, "
f"reasoning={run_usage.reasoning_tokens}"
)
# Verify the run usage has all the expected fields
assert run_usage.prompt_tokens >= 0, "Run should have prompt_tokens"
assert run_usage.completion_tokens >= 0, "Run should have completion_tokens"
assert run_usage.total_tokens >= 0, "Run should have total_tokens"
assert hasattr(run_usage, "cached_input_tokens"), "Run should have cached_input_tokens"
assert hasattr(run_usage, "cache_write_tokens"), "Run should have cache_write_tokens"
assert hasattr(run_usage, "reasoning_tokens"), "Run should have reasoning_tokens"
finally:
await cleanup_agent(async_client, agent.id)
# ------------------------------
# Comprehensive End-to-End Test
# ------------------------------
@pytest.mark.asyncio
async def test_usage_tracking_end_to_end(async_client: AsyncLetta) -> None:
"""
End-to-end test that verifies the complete usage tracking flow:
1. Create agent with a model that supports caching
2. Send messages to trigger cache writes and reads
3. Verify step-level details are persisted
4. Verify run-level aggregation is correct
"""
# Use Anthropic Sonnet 4.5 for this test as it has the most comprehensive caching
model_handle = "anthropic/claude-sonnet-4-5-20250514"
model_settings = {"provider_type": "anthropic"}
agent = await create_test_agent(async_client, model_handle, model_settings, "e2e")
try:
# Send first message (should trigger cache write)
response1: Run = await async_client.agents.messages.send_message(
agent_id=agent.id,
messages=[MessageCreateParam(role="user", content="This is a longer message to ensure there's enough content to cache. " * 5)],
)
logger.info(f"E2E Test - First message usage: {response1.usage}")
# Send second message (should trigger cache read)
response2: Run = await async_client.agents.messages.send_message(
agent_id=agent.id,
messages=[MessageCreateParam(role="user", content="Short follow-up")],
)
logger.info(f"E2E Test - Second message usage: {response2.usage}")
# Verify basic usage is tracked
assert response1.usage is not None
assert response2.usage is not None
assert response1.usage.prompt_tokens > 0
assert response2.usage.prompt_tokens > 0
# Get steps for the second run
steps = await async_client.runs.list_steps(run_id=response2.id)
assert len(steps) > 0, "Should have steps for the run"
# Get run-level usage
run_usage = await async_client.runs.get_run_usage(run_id=response2.id)
assert run_usage.total_tokens > 0, "Run should have total tokens"
logger.info(
f"E2E Test - Run usage: prompt={run_usage.prompt_tokens}, "
f"completion={run_usage.completion_tokens}, "
f"cached_input={run_usage.cached_input_tokens}, "
f"cache_write={run_usage.cache_write_tokens}"
)
# The test passes if we get here without errors - cache data may or may not be present
# depending on whether the provider actually cached the content
finally:
await cleanup_agent(async_client, agent.id)