diff --git a/alembic/versions/175dd10fb916_add_prompt_tokens_details_to_steps.py b/alembic/versions/175dd10fb916_add_prompt_tokens_details_to_steps.py new file mode 100644 index 00000000..a796657a --- /dev/null +++ b/alembic/versions/175dd10fb916_add_prompt_tokens_details_to_steps.py @@ -0,0 +1,29 @@ +"""Add prompt_tokens_details to steps table + +Revision ID: 175dd10fb916 +Revises: b1c2d3e4f5a6 +Create Date: 2025-11-28 12:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "175dd10fb916" +down_revision: Union[str, None] = "b1c2d3e4f5a6" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add prompt_tokens_details JSON column to steps table + # This stores detailed prompt token breakdown (cached_tokens, cache_read_tokens, cache_creation_tokens) + op.add_column("steps", sa.Column("prompt_tokens_details", sa.JSON(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("steps", "prompt_tokens_details") diff --git a/fern/openapi.json b/fern/openapi.json index 2b834140..58df7710 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -30590,11 +30590,29 @@ ], "title": "Run Ids", "description": "The background task run IDs associated with the agent interaction" + }, + "cached_input_tokens": { + "type": "integer", + "title": "Cached Input Tokens", + "description": "The number of input tokens served from cache.", + "default": 0 + }, + "cache_write_tokens": { + "type": "integer", + "title": "Cache Write Tokens", + "description": "The number of input tokens written to cache (Anthropic only).", + "default": 0 + }, + "reasoning_tokens": { + "type": "integer", + "title": "Reasoning Tokens", + "description": "The number of reasoning/thinking tokens generated.", + "default": 0 } }, "type": "object", "title": "LettaUsageStatistics", - "description": "Usage statistics for the agent interaction.\n\nAttributes:\n completion_tokens (int): The number of tokens generated by the agent.\n prompt_tokens (int): The number of tokens in the prompt.\n total_tokens (int): The total number of tokens processed by the agent.\n step_count (int): The number of steps taken by the agent." + "description": "Usage statistics for the agent interaction.\n\nAttributes:\n completion_tokens (int): The number of tokens generated by the agent.\n prompt_tokens (int): The number of tokens in the prompt.\n total_tokens (int): The total number of tokens processed by the agent.\n step_count (int): The number of steps taken by the agent.\n cached_input_tokens (int): The number of input tokens served from cache.\n cache_write_tokens (int): The number of input tokens written to cache (Anthropic only).\n reasoning_tokens (int): The number of reasoning/thinking tokens generated." }, "ListDeploymentEntitiesResponse": { "properties": { @@ -35071,7 +35089,20 @@ } ], "title": "Completion Tokens Details", - "description": "Metadata for the agent." + "description": "Detailed completion token breakdown (e.g., reasoning_tokens)." + }, + "prompt_tokens_details": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Prompt Tokens Details", + "description": "Detailed prompt token breakdown (e.g., cached_tokens, cache_read_tokens, cache_creation_tokens)." }, "stop_reason": { "anyOf": [ @@ -38178,6 +38209,16 @@ "type": "integer", "title": "Cached Tokens", "default": 0 + }, + "cache_read_tokens": { + "type": "integer", + "title": "Cache Read Tokens", + "default": 0 + }, + "cache_creation_tokens": { + "type": "integer", + "title": "Cache Creation Tokens", + "default": 0 } }, "type": "object", diff --git a/letta/adapters/simple_llm_request_adapter.py b/letta/adapters/simple_llm_request_adapter.py index 4caa8d00..58ca1dff 100644 --- a/letta/adapters/simple_llm_request_adapter.py +++ b/letta/adapters/simple_llm_request_adapter.py @@ -4,6 +4,7 @@ from letta.adapters.letta_llm_request_adapter import LettaLLMRequestAdapter from letta.helpers.datetime_helpers import get_utc_timestamp_ns from letta.schemas.letta_message import LettaMessage from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, TextContent +from letta.schemas.usage import normalize_cache_tokens, normalize_reasoning_tokens class SimpleLLMRequestAdapter(LettaLLMRequestAdapter): @@ -85,6 +86,11 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter): self.usage.prompt_tokens = self.chat_completions_response.usage.prompt_tokens self.usage.total_tokens = self.chat_completions_response.usage.total_tokens + # Extract cache and reasoning token details using normalized helpers + usage = self.chat_completions_response.usage + self.usage.cached_input_tokens, self.usage.cache_write_tokens = normalize_cache_tokens(usage.prompt_tokens_details) + self.usage.reasoning_tokens = normalize_reasoning_tokens(usage.completion_tokens_details) + self.log_provider_trace(step_id=step_id, actor=actor) yield None diff --git a/letta/adapters/simple_llm_stream_adapter.py b/letta/adapters/simple_llm_stream_adapter.py index 1cd2ee23..7b38c5f6 100644 --- a/letta/adapters/simple_llm_stream_adapter.py +++ b/letta/adapters/simple_llm_stream_adapter.py @@ -158,11 +158,34 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter): if not output_tokens and hasattr(self.interface, "fallback_output_tokens"): output_tokens = self.interface.fallback_output_tokens + # Extract cache token data (OpenAI/Gemini use cached_tokens) + cached_input_tokens = 0 + if hasattr(self.interface, "cached_tokens") and self.interface.cached_tokens: + cached_input_tokens = self.interface.cached_tokens + # Anthropic uses cache_read_tokens for cache hits + elif hasattr(self.interface, "cache_read_tokens") and self.interface.cache_read_tokens: + cached_input_tokens = self.interface.cache_read_tokens + + # Extract cache write tokens (Anthropic only) + cache_write_tokens = 0 + if hasattr(self.interface, "cache_creation_tokens") and self.interface.cache_creation_tokens: + cache_write_tokens = self.interface.cache_creation_tokens + + # Extract reasoning tokens (OpenAI o1/o3 models use reasoning_tokens, Gemini uses thinking_tokens) + reasoning_tokens = 0 + if hasattr(self.interface, "reasoning_tokens") and self.interface.reasoning_tokens: + reasoning_tokens = self.interface.reasoning_tokens + elif hasattr(self.interface, "thinking_tokens") and self.interface.thinking_tokens: + reasoning_tokens = self.interface.thinking_tokens + self.usage = LettaUsageStatistics( step_count=1, completion_tokens=output_tokens or 0, prompt_tokens=input_tokens or 0, total_tokens=(input_tokens or 0) + (output_tokens or 0), + cached_input_tokens=cached_input_tokens, + cache_write_tokens=cache_write_tokens, + reasoning_tokens=reasoning_tokens, ) else: # Default usage statistics if not available diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 0860a9dc..4ea2442d 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -42,7 +42,13 @@ from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message, MessageCreateBase -from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall, UsageStatistics +from letta.schemas.openai.chat_completion_response import ( + FunctionCall, + ToolCall, + UsageStatistics, + UsageStatisticsCompletionTokenDetails, + UsageStatisticsPromptTokenDetails, +) from letta.schemas.provider_trace import ProviderTraceCreate from letta.schemas.step import StepProgression from letta.schemas.step_metrics import StepMetrics @@ -1077,6 +1083,15 @@ class LettaAgent(BaseAgent): usage.completion_tokens += interface.output_tokens usage.prompt_tokens += interface.input_tokens usage.total_tokens += interface.input_tokens + interface.output_tokens + # Aggregate cache and reasoning tokens if available from streaming interface + if hasattr(interface, "cached_tokens") and interface.cached_tokens: + usage.cached_input_tokens += interface.cached_tokens + if hasattr(interface, "cache_read_tokens") and interface.cache_read_tokens: + usage.cached_input_tokens += interface.cache_read_tokens + if hasattr(interface, "cache_creation_tokens") and interface.cache_creation_tokens: + usage.cache_write_tokens += interface.cache_creation_tokens + if hasattr(interface, "reasoning_tokens") and interface.reasoning_tokens: + usage.reasoning_tokens += interface.reasoning_tokens MetricRegistry().message_output_tokens.record( usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}) ) @@ -1124,6 +1139,21 @@ class LettaAgent(BaseAgent): # Update step with actual usage now that we have it (if step was created) if logged_step: + # Build detailed token breakdowns from LettaUsageStatistics + prompt_details = None + if usage.cached_input_tokens or usage.cache_write_tokens: + prompt_details = UsageStatisticsPromptTokenDetails( + cached_tokens=usage.cached_input_tokens, + cache_read_tokens=usage.cached_input_tokens, + cache_creation_tokens=usage.cache_write_tokens, + ) + + completion_details = None + if usage.reasoning_tokens: + completion_details = UsageStatisticsCompletionTokenDetails( + reasoning_tokens=usage.reasoning_tokens, + ) + await self.step_manager.update_step_success_async( self.actor, step_id, @@ -1131,6 +1161,8 @@ class LettaAgent(BaseAgent): completion_tokens=usage.completion_tokens, prompt_tokens=usage.prompt_tokens, total_tokens=usage.total_tokens, + prompt_tokens_details=prompt_details, + completion_tokens_details=completion_details, ), stop_reason, ) diff --git a/letta/agents/letta_agent_v2.py b/letta/agents/letta_agent_v2.py index 0c9f309e..0dfaa060 100644 --- a/letta/agents/letta_agent_v2.py +++ b/letta/agents/letta_agent_v2.py @@ -37,7 +37,13 @@ from letta.schemas.letta_message_content import OmittedReasoningContent, Reasoni from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message, MessageCreate, MessageUpdate -from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall, UsageStatistics +from letta.schemas.openai.chat_completion_response import ( + FunctionCall, + ToolCall, + UsageStatistics, + UsageStatisticsCompletionTokenDetails, + UsageStatisticsPromptTokenDetails, +) from letta.schemas.step import Step, StepProgression from letta.schemas.step_metrics import StepMetrics from letta.schemas.tool import Tool @@ -850,6 +856,21 @@ class LettaAgentV2(BaseAgentV2): # Update step with actual usage now that we have it (if step was created) if logged_step: + # Build detailed token breakdowns from LettaUsageStatistics + prompt_details = None + if self.usage.cached_input_tokens or self.usage.cache_write_tokens: + prompt_details = UsageStatisticsPromptTokenDetails( + cached_tokens=self.usage.cached_input_tokens, + cache_read_tokens=self.usage.cached_input_tokens, # Normalized from various providers + cache_creation_tokens=self.usage.cache_write_tokens, + ) + + completion_details = None + if self.usage.reasoning_tokens: + completion_details = UsageStatisticsCompletionTokenDetails( + reasoning_tokens=self.usage.reasoning_tokens, + ) + await self.step_manager.update_step_success_async( self.actor, step_metrics.id, @@ -857,6 +878,8 @@ class LettaAgentV2(BaseAgentV2): completion_tokens=self.usage.completion_tokens, prompt_tokens=self.usage.prompt_tokens, total_tokens=self.usage.total_tokens, + prompt_tokens_details=prompt_details, + completion_tokens_details=completion_details, ), self.stop_reason, ) @@ -867,6 +890,10 @@ class LettaAgentV2(BaseAgentV2): self.usage.completion_tokens += step_usage_stats.completion_tokens self.usage.prompt_tokens += step_usage_stats.prompt_tokens self.usage.total_tokens += step_usage_stats.total_tokens + # Aggregate cache and reasoning token fields + self.usage.cached_input_tokens += step_usage_stats.cached_input_tokens + self.usage.cache_write_tokens += step_usage_stats.cache_write_tokens + self.usage.reasoning_tokens += step_usage_stats.reasoning_tokens @trace_method async def _handle_ai_response( diff --git a/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py b/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py index 6a819e42..37346cd1 100644 --- a/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py +++ b/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py @@ -94,6 +94,10 @@ class SimpleAnthropicStreamingInterface: self.output_tokens = 0 self.model = None + # cache tracking (Anthropic-specific) + self.cache_read_tokens = 0 + self.cache_creation_tokens = 0 + # reasoning object trackers self.reasoning_messages = [] @@ -463,6 +467,13 @@ class SimpleAnthropicStreamingInterface: self.output_tokens += event.message.usage.output_tokens self.model = event.message.model + # Capture cache data if available + usage = event.message.usage + if hasattr(usage, "cache_read_input_tokens") and usage.cache_read_input_tokens: + self.cache_read_tokens += usage.cache_read_input_tokens + if hasattr(usage, "cache_creation_input_tokens") and usage.cache_creation_input_tokens: + self.cache_creation_tokens += usage.cache_creation_input_tokens + elif isinstance(event, BetaRawMessageDeltaEvent): self.output_tokens += event.usage.output_tokens diff --git a/letta/interfaces/gemini_streaming_interface.py b/letta/interfaces/gemini_streaming_interface.py index 020dbad1..6edff990 100644 --- a/letta/interfaces/gemini_streaming_interface.py +++ b/letta/interfaces/gemini_streaming_interface.py @@ -74,6 +74,16 @@ class SimpleGeminiStreamingInterface: # Sadly, Gemini's encrypted reasoning logic forces us to store stream parts in state self.content_parts: List[ReasoningContent | TextContent | ToolCallContent] = [] + # Token counters + self.input_tokens = 0 + self.output_tokens = 0 + + # Cache token tracking (Gemini uses cached_content_token_count) + self.cached_tokens = 0 + + # Thinking/reasoning token tracking (Gemini uses thoughts_token_count) + self.thinking_tokens = 0 + def get_content(self) -> List[ReasoningContent | TextContent | ToolCallContent]: """This is (unusually) in chunked format, instead of merged""" for content in self.content_parts: @@ -171,6 +181,12 @@ class SimpleGeminiStreamingInterface: # includes thinking/reasoning tokens which can be 10-100x the actual output. if usage_metadata.candidates_token_count: self.output_tokens = usage_metadata.candidates_token_count + # Capture cache token data (Gemini uses cached_content_token_count) + if hasattr(usage_metadata, "cached_content_token_count") and usage_metadata.cached_content_token_count: + self.cached_tokens = usage_metadata.cached_content_token_count + # Capture thinking/reasoning token data (Gemini uses thoughts_token_count) + if hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count: + self.thinking_tokens = usage_metadata.thoughts_token_count if not event.candidates or len(event.candidates) == 0: return diff --git a/letta/interfaces/openai_streaming_interface.py b/letta/interfaces/openai_streaming_interface.py index 51754c8b..cb16d835 100644 --- a/letta/interfaces/openai_streaming_interface.py +++ b/letta/interfaces/openai_streaming_interface.py @@ -537,6 +537,10 @@ class SimpleOpenAIStreamingInterface: self.input_tokens = 0 self.output_tokens = 0 + # Cache and reasoning token tracking + self.cached_tokens = 0 + self.reasoning_tokens = 0 + # Fallback token counters (using tiktoken cl200k-base) self.fallback_input_tokens = 0 self.fallback_output_tokens = 0 @@ -702,6 +706,16 @@ class SimpleOpenAIStreamingInterface: if chunk.usage: self.input_tokens += chunk.usage.prompt_tokens self.output_tokens += chunk.usage.completion_tokens + # Capture cache token details (OpenAI) + if hasattr(chunk.usage, "prompt_tokens_details") and chunk.usage.prompt_tokens_details: + details = chunk.usage.prompt_tokens_details + if hasattr(details, "cached_tokens") and details.cached_tokens: + self.cached_tokens += details.cached_tokens + # Capture reasoning token details (OpenAI o1/o3) + if hasattr(chunk.usage, "completion_tokens_details") and chunk.usage.completion_tokens_details: + details = chunk.usage.completion_tokens_details + if hasattr(details, "reasoning_tokens") and details.reasoning_tokens: + self.reasoning_tokens += details.reasoning_tokens if chunk.choices: choice = chunk.choices[0] @@ -846,6 +860,14 @@ class SimpleOpenAIResponsesStreamingInterface: self.model = model self.final_response: Optional[ParsedResponse] = None + # Token counters + self.input_tokens = 0 + self.output_tokens = 0 + + # Cache and reasoning token tracking + self.cached_tokens = 0 + self.reasoning_tokens = 0 + # -------- Mapping helpers (no broad try/except) -------- def _record_tool_mapping(self, event: object, item: object) -> tuple[str | None, str | None, int | None, str | None]: """Record call_id/name mapping for this tool-call using output_index and item.id if present. @@ -1270,6 +1292,16 @@ class SimpleOpenAIResponsesStreamingInterface: self.input_tokens = event.response.usage.input_tokens self.output_tokens = event.response.usage.output_tokens self.message_id = event.response.id + # Capture cache token details (Responses API uses input_tokens_details) + if hasattr(event.response.usage, "input_tokens_details") and event.response.usage.input_tokens_details: + details = event.response.usage.input_tokens_details + if hasattr(details, "cached_tokens") and details.cached_tokens: + self.cached_tokens = details.cached_tokens + # Capture reasoning token details (Responses API uses output_tokens_details) + if hasattr(event.response.usage, "output_tokens_details") and event.response.usage.output_tokens_details: + details = event.response.usage.output_tokens_details + if hasattr(details, "reasoning_tokens") and details.reasoning_tokens: + self.reasoning_tokens = details.reasoning_tokens return else: diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index 00cb7309..d7b744a3 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -848,6 +848,16 @@ class AnthropicClient(LLMClientBase): ), ) + # Build prompt tokens details with cache data if available + prompt_tokens_details = None + if hasattr(response.usage, "cache_read_input_tokens") or hasattr(response.usage, "cache_creation_input_tokens"): + from letta.schemas.openai.chat_completion_response import UsageStatisticsPromptTokenDetails + + prompt_tokens_details = UsageStatisticsPromptTokenDetails( + cache_read_tokens=getattr(response.usage, "cache_read_input_tokens", 0) or 0, + cache_creation_tokens=getattr(response.usage, "cache_creation_input_tokens", 0) or 0, + ) + chat_completion_response = ChatCompletionResponse( id=response.id, choices=[choice], @@ -857,6 +867,7 @@ class AnthropicClient(LLMClientBase): prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, + prompt_tokens_details=prompt_tokens_details, ), ) if llm_config.put_inner_thoughts_in_kwargs: diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index c5c7b60b..91a9cdcd 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -630,10 +630,30 @@ class GoogleVertexClient(LLMClientBase): # "totalTokenCount": 36 # } if response.usage_metadata: + # Extract cache token data if available (Gemini uses cached_content_token_count) + prompt_tokens_details = None + if hasattr(response.usage_metadata, "cached_content_token_count") and response.usage_metadata.cached_content_token_count: + from letta.schemas.openai.chat_completion_response import UsageStatisticsPromptTokenDetails + + prompt_tokens_details = UsageStatisticsPromptTokenDetails( + cached_tokens=response.usage_metadata.cached_content_token_count, + ) + + # Extract thinking/reasoning token data if available (Gemini uses thoughts_token_count) + completion_tokens_details = None + if hasattr(response.usage_metadata, "thoughts_token_count") and response.usage_metadata.thoughts_token_count: + from letta.schemas.openai.chat_completion_response import UsageStatisticsCompletionTokenDetails + + completion_tokens_details = UsageStatisticsCompletionTokenDetails( + reasoning_tokens=response.usage_metadata.thoughts_token_count, + ) + usage = UsageStatistics( prompt_tokens=response.usage_metadata.prompt_token_count, completion_tokens=response.usage_metadata.candidates_token_count, total_tokens=response.usage_metadata.total_token_count, + prompt_tokens_details=prompt_tokens_details, + completion_tokens_details=completion_tokens_details, ) else: # Count it ourselves diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index c178f993..fa9d6e5d 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -630,6 +630,25 @@ class OpenAIClient(LLMClientBase): completion_tokens = usage.get("output_tokens") or 0 total_tokens = usage.get("total_tokens") or (prompt_tokens + completion_tokens) + # Extract detailed token breakdowns (Responses API uses input_tokens_details/output_tokens_details) + prompt_tokens_details = None + input_details = usage.get("input_tokens_details", {}) or {} + if input_details.get("cached_tokens"): + from letta.schemas.openai.chat_completion_response import UsageStatisticsPromptTokenDetails + + prompt_tokens_details = UsageStatisticsPromptTokenDetails( + cached_tokens=input_details.get("cached_tokens") or 0, + ) + + completion_tokens_details = None + output_details = usage.get("output_tokens_details", {}) or {} + if output_details.get("reasoning_tokens"): + from letta.schemas.openai.chat_completion_response import UsageStatisticsCompletionTokenDetails + + completion_tokens_details = UsageStatisticsCompletionTokenDetails( + reasoning_tokens=output_details.get("reasoning_tokens") or 0, + ) + # Extract assistant message text from the outputs list outputs = response_data.get("output") or [] assistant_text_parts = [] @@ -692,6 +711,8 @@ class OpenAIClient(LLMClientBase): prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, + prompt_tokens_details=prompt_tokens_details, + completion_tokens_details=completion_tokens_details, ), ) diff --git a/letta/orm/step.py b/letta/orm/step.py index 7a76649b..eca32ed5 100644 --- a/letta/orm/step.py +++ b/letta/orm/step.py @@ -50,7 +50,12 @@ class Step(SqlalchemyBase, ProjectMixin): completion_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens generated by the agent") prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt") total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent") - completion_tokens_details: Mapped[Optional[Dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.") + completion_tokens_details: Mapped[Optional[Dict]] = mapped_column( + JSON, nullable=True, doc="Detailed completion token breakdown (e.g., reasoning_tokens)." + ) + prompt_tokens_details: Mapped[Optional[Dict]] = mapped_column( + JSON, nullable=True, doc="Detailed prompt token breakdown (e.g., cached_tokens, cache_read_tokens, cache_creation_tokens)." + ) stop_reason: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The stop reason associated with this step.") tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.") tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.") diff --git a/letta/schemas/openai/chat_completion_response.py b/letta/schemas/openai/chat_completion_response.py index 547f4b6c..81db7dad 100644 --- a/letta/schemas/openai/chat_completion_response.py +++ b/letta/schemas/openai/chat_completion_response.py @@ -78,13 +78,17 @@ class Choice(BaseModel): class UsageStatisticsPromptTokenDetails(BaseModel): - cached_tokens: int = 0 + cached_tokens: int = 0 # OpenAI/Gemini: tokens served from cache + cache_read_tokens: int = 0 # Anthropic: tokens read from cache + cache_creation_tokens: int = 0 # Anthropic: tokens written to cache # NOTE: OAI specific # audio_tokens: int = 0 def __add__(self, other: "UsageStatisticsPromptTokenDetails") -> "UsageStatisticsPromptTokenDetails": return UsageStatisticsPromptTokenDetails( cached_tokens=self.cached_tokens + other.cached_tokens, + cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens, + cache_creation_tokens=self.cache_creation_tokens + other.cache_creation_tokens, ) diff --git a/letta/schemas/step.py b/letta/schemas/step.py index 38eb8cde..a1609b5e 100644 --- a/letta/schemas/step.py +++ b/letta/schemas/step.py @@ -30,7 +30,10 @@ class Step(StepBase): completion_tokens: Optional[int] = Field(None, description="The number of tokens generated by the agent during this step.") prompt_tokens: Optional[int] = Field(None, description="The number of tokens in the prompt during this step.") total_tokens: Optional[int] = Field(None, description="The total number of tokens processed by the agent during this step.") - completion_tokens_details: Optional[Dict] = Field(None, description="Metadata for the agent.") + completion_tokens_details: Optional[Dict] = Field(None, description="Detailed completion token breakdown (e.g., reasoning_tokens).") + prompt_tokens_details: Optional[Dict] = Field( + None, description="Detailed prompt token breakdown (e.g., cached_tokens, cache_read_tokens, cache_creation_tokens)." + ) stop_reason: Optional[StopReasonType] = Field(None, description="The stop reason associated with the step.") tags: List[str] = Field([], description="Metadata tags.") tid: Optional[str] = Field(None, description="The unique identifier of the transaction that processed this step.") diff --git a/letta/schemas/usage.py b/letta/schemas/usage.py index 547c61da..414b4163 100644 --- a/letta/schemas/usage.py +++ b/letta/schemas/usage.py @@ -1,9 +1,94 @@ -from typing import List, Literal, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union from pydantic import BaseModel, Field from letta.schemas.message import Message +if TYPE_CHECKING: + from letta.schemas.openai.chat_completion_response import ( + UsageStatisticsCompletionTokenDetails, + UsageStatisticsPromptTokenDetails, + ) + + +def normalize_cache_tokens( + prompt_details: Union["UsageStatisticsPromptTokenDetails", Dict[str, Any], None], +) -> Tuple[int, int]: + """ + Extract normalized cache token counts from provider-specific prompt details. + + Handles both Pydantic model objects (from adapters) and dict objects (from database). + + Provider mappings: + - OpenAI/Gemini: cached_tokens -> cached_input_tokens + - Anthropic: cache_read_tokens -> cached_input_tokens, cache_creation_tokens -> cache_write_tokens + + Args: + prompt_details: Provider-specific prompt token details (model or dict) + + Returns: + Tuple of (cached_input_tokens, cache_write_tokens) + """ + if prompt_details is None: + return 0, 0 + + # Handle dict (from database storage) + if isinstance(prompt_details, dict): + cached_input = 0 + if prompt_details.get("cached_tokens"): + cached_input = prompt_details.get("cached_tokens", 0) + elif prompt_details.get("cache_read_tokens"): + cached_input = prompt_details.get("cache_read_tokens", 0) + + cache_write = prompt_details.get("cache_creation_tokens", 0) or 0 + return cached_input, cache_write + + # Handle Pydantic model (from adapters) + cached_input = 0 + if hasattr(prompt_details, "cached_tokens") and prompt_details.cached_tokens: + cached_input = prompt_details.cached_tokens + elif hasattr(prompt_details, "cache_read_tokens") and prompt_details.cache_read_tokens: + cached_input = prompt_details.cache_read_tokens + + cache_write = 0 + if hasattr(prompt_details, "cache_creation_tokens") and prompt_details.cache_creation_tokens: + cache_write = prompt_details.cache_creation_tokens + + return cached_input, cache_write + + +def normalize_reasoning_tokens( + completion_details: Union["UsageStatisticsCompletionTokenDetails", Dict[str, Any], None], +) -> int: + """ + Extract normalized reasoning token count from provider-specific completion details. + + Handles both Pydantic model objects (from adapters) and dict objects (from database). + + Provider mappings: + - OpenAI: completion_tokens_details.reasoning_tokens + - Gemini: thoughts_token_count (mapped to reasoning_tokens in UsageStatistics) + - Anthropic: thinking tokens are included in completion_tokens, not separately tracked + + Args: + completion_details: Provider-specific completion token details (model or dict) + + Returns: + The reasoning token count + """ + if completion_details is None: + return 0 + + # Handle dict (from database storage) + if isinstance(completion_details, dict): + return completion_details.get("reasoning_tokens", 0) or 0 + + # Handle Pydantic model (from adapters) + if hasattr(completion_details, "reasoning_tokens") and completion_details.reasoning_tokens: + return completion_details.reasoning_tokens + + return 0 + class LettaUsageStatistics(BaseModel): """ @@ -14,6 +99,9 @@ class LettaUsageStatistics(BaseModel): prompt_tokens (int): The number of tokens in the prompt. total_tokens (int): The total number of tokens processed by the agent. step_count (int): The number of steps taken by the agent. + cached_input_tokens (int): The number of input tokens served from cache. + cache_write_tokens (int): The number of input tokens written to cache (Anthropic only). + reasoning_tokens (int): The number of reasoning/thinking tokens generated. """ message_type: Literal["usage_statistics"] = "usage_statistics" @@ -23,3 +111,10 @@ class LettaUsageStatistics(BaseModel): step_count: int = Field(0, description="The number of steps taken by the agent.") # TODO: Optional for now. This field makes everyone's lives easier run_ids: Optional[List[str]] = Field(None, description="The background task run IDs associated with the agent interaction") + + # Cache tracking (common across providers) + cached_input_tokens: int = Field(0, description="The number of input tokens served from cache.") + cache_write_tokens: int = Field(0, description="The number of input tokens written to cache (Anthropic only).") + + # Reasoning token tracking + reasoning_tokens: int = Field(0, description="The number of reasoning/thinking tokens generated.") diff --git a/letta/services/run_manager.py b/letta/services/run_manager.py index da2e67f7..aef42076 100644 --- a/letta/services/run_manager.py +++ b/letta/services/run_manager.py @@ -23,7 +23,7 @@ from letta.schemas.message import Message as PydanticMessage from letta.schemas.run import Run as PydanticRun, RunUpdate from letta.schemas.run_metrics import RunMetrics as PydanticRunMetrics from letta.schemas.step import Step as PydanticStep -from letta.schemas.usage import LettaUsageStatistics +from letta.schemas.usage import LettaUsageStatistics, normalize_cache_tokens, normalize_reasoning_tokens from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.services.agent_manager import AgentManager @@ -469,6 +469,13 @@ class RunManager: total_usage.completion_tokens += step.completion_tokens total_usage.total_tokens += step.total_tokens total_usage.step_count += 1 + + # Aggregate cache and reasoning tokens from detailed breakdowns using normalized helpers + cached_input, cache_write = normalize_cache_tokens(step.prompt_tokens_details) + total_usage.cached_input_tokens += cached_input + total_usage.cache_write_tokens += cache_write + total_usage.reasoning_tokens += normalize_reasoning_tokens(step.completion_tokens_details) + return total_usage @enforce_types diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index 2b9a182c..eb6a37db 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -408,6 +408,12 @@ class StepManager: if stop_reason: step.stop_reason = stop_reason.stop_reason + # Persist detailed token breakdowns if available + if usage.prompt_tokens_details: + step.prompt_tokens_details = usage.prompt_tokens_details.model_dump() + if usage.completion_tokens_details: + step.completion_tokens_details = usage.completion_tokens_details.model_dump() + await session.commit() pydantic_step = step.to_pydantic() diff --git a/tests/integration_test_usage_tracking.py b/tests/integration_test_usage_tracking.py new file mode 100644 index 00000000..018fb312 --- /dev/null +++ b/tests/integration_test_usage_tracking.py @@ -0,0 +1,497 @@ +""" +Integration tests for advanced usage tracking (cache tokens, reasoning tokens). + +These tests verify that: +1. Cache token data (cached_input_tokens, cache_write_tokens) is captured from providers +2. Reasoning token data is captured from reasoning models +3. The data flows correctly through streaming and non-streaming paths +4. Step-level and run-level aggregation works correctly + +Provider-specific cache field mappings: +- Anthropic: cache_read_input_tokens, cache_creation_input_tokens +- OpenAI: prompt_tokens_details.cached_tokens, completion_tokens_details.reasoning_tokens +- Gemini: cached_content_token_count +""" + +import json +import logging +import os +import uuid +from typing import Any, Dict, List, Optional, Tuple + +import pytest +from dotenv import load_dotenv +from letta_client import AsyncLetta +from letta_client.types import ( + AgentState, + MessageCreateParam, +) +from letta_client.types.agents import Run +from letta_client.types.agents.letta_streaming_response import LettaUsageStatistics + +logger = logging.getLogger(__name__) + +# Load environment variables +load_dotenv() + + +# ------------------------------ +# Test Configuration +# ------------------------------ + +# Model configs for testing - these models should support caching or reasoning +CACHE_TEST_CONFIGS = [ + # Anthropic Sonnet 4.5 with prompt caching + ("anthropic/claude-sonnet-4-5-20250514", {"provider_type": "anthropic"}), + # OpenAI gpt-4o with prompt caching (Chat Completions API) + ("openai/gpt-4o", {"provider_type": "openai"}), + # Gemini 3 Pro Preview with context caching + ("google_ai/gemini-3-pro-preview", {"provider_type": "google_ai"}), +] + +REASONING_TEST_CONFIGS = [ + # Anthropic Sonnet 4.5 with thinking enabled + ( + "anthropic/claude-sonnet-4-5-20250514", + {"provider_type": "anthropic", "thinking": {"type": "enabled", "budget_tokens": 1024}}, + ), + # OpenAI gpt-5.1 reasoning model (Responses API) + ("openai/gpt-5.1", {"provider_type": "openai", "reasoning": {"reasoning_effort": "low"}}), + # Gemini 3 Pro Preview with thinking enabled + ( + "google_ai/gemini-3-pro-preview", + {"provider_type": "google_ai", "thinking_config": {"include_thoughts": True, "thinking_budget": 1024}}, + ), +] + +# Filter based on environment variable if set +requested = os.getenv("USAGE_TEST_CONFIG") +if requested: + # Filter configs to only include the requested one + CACHE_TEST_CONFIGS = [(h, s) for h, s in CACHE_TEST_CONFIGS if requested in h] + REASONING_TEST_CONFIGS = [(h, s) for h, s in REASONING_TEST_CONFIGS if requested in h] + + +def get_model_config(filename: str, model_settings_dir: str = "tests/model_settings") -> Tuple[str, dict]: + """Load a model_settings file and return the handle and settings dict.""" + filepath = os.path.join(model_settings_dir, filename) + with open(filepath, "r") as f: + config_data = json.load(f) + return config_data["handle"], config_data.get("model_settings", {}) + + +# ------------------------------ +# Fixtures +# ------------------------------ + + +@pytest.fixture +def base_url() -> str: + """Get the Letta server URL from environment or use default.""" + return os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + + +@pytest.fixture +async def async_client(base_url: str) -> AsyncLetta: + """Create an async Letta client.""" + token = os.getenv("LETTA_SERVER_TOKEN") + return AsyncLetta(base_url=base_url, token=token) + + +# ------------------------------ +# Helper Functions +# ------------------------------ + + +async def create_test_agent( + client: AsyncLetta, + model_handle: str, + model_settings: dict, + name_suffix: str = "", +) -> AgentState: + """Create a test agent with the specified model configuration.""" + agent = await client.agents.create( + name=f"usage-test-agent-{name_suffix}-{uuid.uuid4().hex[:8]}", + model=model_handle, + model_settings=model_settings, + include_base_tools=False, # Keep it simple for usage testing + ) + return agent + + +async def cleanup_agent(client: AsyncLetta, agent_id: str) -> None: + """Delete a test agent.""" + try: + await client.agents.delete(agent_id) + except Exception as e: + logger.warning(f"Failed to cleanup agent {agent_id}: {e}") + + +def extract_usage_from_stream(messages: List[Any]) -> Optional[LettaUsageStatistics]: + """Extract LettaUsageStatistics from a stream response.""" + for msg in reversed(messages): + if isinstance(msg, LettaUsageStatistics): + return msg + return None + + +# ------------------------------ +# Cache Token Tests +# ------------------------------ + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_handle,model_settings", CACHE_TEST_CONFIGS) +async def test_cache_tokens_streaming( + async_client: AsyncLetta, + model_handle: str, + model_settings: dict, +) -> None: + """ + Test that cache token data is captured in streaming mode. + + Cache hits typically occur on the second+ request with the same context, + so we send multiple messages to trigger caching. + """ + agent = await create_test_agent(async_client, model_handle, model_settings, "cache-stream") + + try: + # First message - likely cache write (cache_creation_tokens for Anthropic) + messages1: List[Any] = [] + async for chunk in async_client.agents.messages.send_message_streaming( + agent_id=agent.id, + messages=[MessageCreateParam(role="user", content="Hello, this is a test message for caching.")], + ): + messages1.append(chunk) + + usage1 = extract_usage_from_stream(messages1) + assert usage1 is not None, "Should receive usage statistics in stream" + assert usage1.prompt_tokens > 0, "Should have prompt tokens" + + # Log first call usage for debugging + logger.info( + f"First call usage ({model_handle}): prompt={usage1.prompt_tokens}, " + f"cached_input={usage1.cached_input_tokens}, cache_write={usage1.cache_write_tokens}" + ) + + # Second message - same agent/context should trigger cache hits + messages2: List[Any] = [] + async for chunk in async_client.agents.messages.send_message_streaming( + agent_id=agent.id, + messages=[MessageCreateParam(role="user", content="This is a follow-up message.")], + ): + messages2.append(chunk) + + usage2 = extract_usage_from_stream(messages2) + assert usage2 is not None, "Should receive usage statistics in stream" + + # Log second call usage + logger.info( + f"Second call usage ({model_handle}): prompt={usage2.prompt_tokens}, " + f"cached_input={usage2.cached_input_tokens}, cache_write={usage2.cache_write_tokens}" + ) + + # Verify cache fields exist (values may be 0 if caching not available for this model/config) + assert hasattr(usage2, "cached_input_tokens"), "Should have cached_input_tokens field" + assert hasattr(usage2, "cache_write_tokens"), "Should have cache_write_tokens field" + + # For providers with caching enabled, we expect either: + # - cache_write_tokens > 0 on first call (writing to cache) + # - cached_input_tokens > 0 on second call (reading from cache) + # Note: Not all providers always return cache data, so we just verify the fields exist + + finally: + await cleanup_agent(async_client, agent.id) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_handle,model_settings", CACHE_TEST_CONFIGS) +async def test_cache_tokens_non_streaming( + async_client: AsyncLetta, + model_handle: str, + model_settings: dict, +) -> None: + """ + Test that cache token data is captured in non-streaming (blocking) mode. + """ + agent = await create_test_agent(async_client, model_handle, model_settings, "cache-blocking") + + try: + # First message + response1: Run = await async_client.agents.messages.send_message( + agent_id=agent.id, + messages=[MessageCreateParam(role="user", content="Hello, this is a test message for caching.")], + ) + + assert response1.usage is not None, "Should have usage in response" + logger.info( + f"First call usage ({model_handle}): prompt={response1.usage.prompt_tokens}, " + f"cached_input={response1.usage.cached_input_tokens}, cache_write={response1.usage.cache_write_tokens}" + ) + + # Second message - should trigger cache hit + response2: Run = await async_client.agents.messages.send_message( + agent_id=agent.id, + messages=[MessageCreateParam(role="user", content="This is a follow-up message.")], + ) + + assert response2.usage is not None, "Should have usage in response" + logger.info( + f"Second call usage ({model_handle}): prompt={response2.usage.prompt_tokens}, " + f"cached_input={response2.usage.cached_input_tokens}, cache_write={response2.usage.cache_write_tokens}" + ) + + # Verify cache fields exist + assert hasattr(response2.usage, "cached_input_tokens"), "Should have cached_input_tokens field" + assert hasattr(response2.usage, "cache_write_tokens"), "Should have cache_write_tokens field" + + finally: + await cleanup_agent(async_client, agent.id) + + +# ------------------------------ +# Reasoning Token Tests +# ------------------------------ + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_handle,model_settings", REASONING_TEST_CONFIGS) +async def test_reasoning_tokens_streaming( + async_client: AsyncLetta, + model_handle: str, + model_settings: dict, +) -> None: + """ + Test that reasoning token data is captured from reasoning models in streaming mode. + """ + agent = await create_test_agent(async_client, model_handle, model_settings, "reasoning-stream") + + try: + messages: List[Any] = [] + async for chunk in async_client.agents.messages.send_message_streaming( + agent_id=agent.id, + messages=[MessageCreateParam(role="user", content="Think step by step: what is 2 + 2? Explain your reasoning.")], + ): + messages.append(chunk) + + usage = extract_usage_from_stream(messages) + assert usage is not None, "Should receive usage statistics in stream" + + logger.info( + f"Reasoning usage ({model_handle}): prompt={usage.prompt_tokens}, " + f"completion={usage.completion_tokens}, reasoning={usage.reasoning_tokens}" + ) + + # Verify reasoning_tokens field exists + assert hasattr(usage, "reasoning_tokens"), "Should have reasoning_tokens field" + + # For reasoning models, we expect reasoning_tokens > 0 + # Note: Some providers may not always return reasoning token counts + if "gpt-5" in model_handle or "o3" in model_handle or "o1" in model_handle: + # OpenAI reasoning models should always have reasoning tokens + assert usage.reasoning_tokens > 0, f"OpenAI reasoning model {model_handle} should have reasoning_tokens > 0" + + finally: + await cleanup_agent(async_client, agent.id) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_handle,model_settings", REASONING_TEST_CONFIGS) +async def test_reasoning_tokens_non_streaming( + async_client: AsyncLetta, + model_handle: str, + model_settings: dict, +) -> None: + """ + Test that reasoning token data is captured from reasoning models in non-streaming mode. + """ + agent = await create_test_agent(async_client, model_handle, model_settings, "reasoning-blocking") + + try: + response: Run = await async_client.agents.messages.send_message( + agent_id=agent.id, + messages=[MessageCreateParam(role="user", content="Think step by step: what is 2 + 2? Explain your reasoning.")], + ) + + assert response.usage is not None, "Should have usage in response" + + logger.info( + f"Reasoning usage ({model_handle}): prompt={response.usage.prompt_tokens}, " + f"completion={response.usage.completion_tokens}, reasoning={response.usage.reasoning_tokens}" + ) + + # Verify reasoning_tokens field exists + assert hasattr(response.usage, "reasoning_tokens"), "Should have reasoning_tokens field" + + # For OpenAI reasoning models, we expect reasoning_tokens > 0 + if "gpt-5" in model_handle or "o3" in model_handle or "o1" in model_handle: + assert response.usage.reasoning_tokens > 0, f"OpenAI reasoning model {model_handle} should have reasoning_tokens > 0" + + finally: + await cleanup_agent(async_client, agent.id) + + +# ------------------------------ +# Step-Level Usage Tests +# ------------------------------ + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_handle,model_settings", CACHE_TEST_CONFIGS[:1]) # Test with one config +async def test_step_level_usage_details( + async_client: AsyncLetta, + model_handle: str, + model_settings: dict, +) -> None: + """ + Test that step-level usage details (prompt_tokens_details, completion_tokens_details) + are properly persisted and retrievable. + """ + agent = await create_test_agent(async_client, model_handle, model_settings, "step-details") + + try: + # Send a message to create a step + response: Run = await async_client.agents.messages.send_message( + agent_id=agent.id, + messages=[MessageCreateParam(role="user", content="Hello!")], + ) + + # Get the run's steps + steps = await async_client.runs.list_steps(run_id=response.id) + + assert len(steps) > 0, "Should have at least one step" + + step = steps[0] + logger.info( + f"Step usage ({model_handle}): prompt_tokens={step.prompt_tokens}, " + f"prompt_tokens_details={step.prompt_tokens_details}, " + f"completion_tokens_details={step.completion_tokens_details}" + ) + + # Verify the step has the usage fields + assert step.prompt_tokens > 0, "Step should have prompt_tokens" + assert step.completion_tokens >= 0, "Step should have completion_tokens" + assert step.total_tokens > 0, "Step should have total_tokens" + + # The details fields may be None if no cache/reasoning was involved, + # but they should be present in the schema + # Note: This test mainly verifies the field exists and can hold data + + finally: + await cleanup_agent(async_client, agent.id) + + +# ------------------------------ +# Run-Level Aggregation Tests +# ------------------------------ + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_handle,model_settings", CACHE_TEST_CONFIGS[:1]) # Test with one config +async def test_run_level_usage_aggregation( + async_client: AsyncLetta, + model_handle: str, + model_settings: dict, +) -> None: + """ + Test that run-level usage correctly aggregates cache/reasoning tokens from steps. + """ + agent = await create_test_agent(async_client, model_handle, model_settings, "run-aggregation") + + try: + # Send multiple messages to create multiple steps + response1: Run = await async_client.agents.messages.send_message( + agent_id=agent.id, + messages=[MessageCreateParam(role="user", content="Message 1")], + ) + + response2: Run = await async_client.agents.messages.send_message( + agent_id=agent.id, + messages=[MessageCreateParam(role="user", content="Message 2")], + ) + + # Get run usage for the second run (which should have accumulated context) + run_usage = await async_client.runs.get_run_usage(run_id=response2.id) + + logger.info( + f"Run usage ({model_handle}): prompt={run_usage.prompt_tokens}, " + f"completion={run_usage.completion_tokens}, total={run_usage.total_tokens}, " + f"cached_input={run_usage.cached_input_tokens}, cache_write={run_usage.cache_write_tokens}, " + f"reasoning={run_usage.reasoning_tokens}" + ) + + # Verify the run usage has all the expected fields + assert run_usage.prompt_tokens >= 0, "Run should have prompt_tokens" + assert run_usage.completion_tokens >= 0, "Run should have completion_tokens" + assert run_usage.total_tokens >= 0, "Run should have total_tokens" + assert hasattr(run_usage, "cached_input_tokens"), "Run should have cached_input_tokens" + assert hasattr(run_usage, "cache_write_tokens"), "Run should have cache_write_tokens" + assert hasattr(run_usage, "reasoning_tokens"), "Run should have reasoning_tokens" + + finally: + await cleanup_agent(async_client, agent.id) + + +# ------------------------------ +# Comprehensive End-to-End Test +# ------------------------------ + + +@pytest.mark.asyncio +async def test_usage_tracking_end_to_end(async_client: AsyncLetta) -> None: + """ + End-to-end test that verifies the complete usage tracking flow: + 1. Create agent with a model that supports caching + 2. Send messages to trigger cache writes and reads + 3. Verify step-level details are persisted + 4. Verify run-level aggregation is correct + """ + # Use Anthropic Sonnet 4.5 for this test as it has the most comprehensive caching + model_handle = "anthropic/claude-sonnet-4-5-20250514" + model_settings = {"provider_type": "anthropic"} + + agent = await create_test_agent(async_client, model_handle, model_settings, "e2e") + + try: + # Send first message (should trigger cache write) + response1: Run = await async_client.agents.messages.send_message( + agent_id=agent.id, + messages=[MessageCreateParam(role="user", content="This is a longer message to ensure there's enough content to cache. " * 5)], + ) + + logger.info(f"E2E Test - First message usage: {response1.usage}") + + # Send second message (should trigger cache read) + response2: Run = await async_client.agents.messages.send_message( + agent_id=agent.id, + messages=[MessageCreateParam(role="user", content="Short follow-up")], + ) + + logger.info(f"E2E Test - Second message usage: {response2.usage}") + + # Verify basic usage is tracked + assert response1.usage is not None + assert response2.usage is not None + assert response1.usage.prompt_tokens > 0 + assert response2.usage.prompt_tokens > 0 + + # Get steps for the second run + steps = await async_client.runs.list_steps(run_id=response2.id) + assert len(steps) > 0, "Should have steps for the run" + + # Get run-level usage + run_usage = await async_client.runs.get_run_usage(run_id=response2.id) + assert run_usage.total_tokens > 0, "Run should have total tokens" + + logger.info( + f"E2E Test - Run usage: prompt={run_usage.prompt_tokens}, " + f"completion={run_usage.completion_tokens}, " + f"cached_input={run_usage.cached_input_tokens}, " + f"cache_write={run_usage.cache_write_tokens}" + ) + + # The test passes if we get here without errors - cache data may or may not be present + # depending on whether the provider actually cached the content + + finally: + await cleanup_agent(async_client, agent.id)