Files
letta-server/letta/schemas/usage.py
Charles Packer 88a3743cc8 fix(core): distinguish between null and 0 for prompt caching (#6451)
* fix(core): distinguish between null and 0 for prompt caching

* fix: runtime errors

* fix: just publish just sgate
2025-12-15 12:02:19 -08:00

129 lines
5.2 KiB
Python

from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, Field
from letta.schemas.message import Message
if TYPE_CHECKING:
from letta.schemas.openai.chat_completion_response import (
UsageStatisticsCompletionTokenDetails,
UsageStatisticsPromptTokenDetails,
)
def normalize_cache_tokens(
prompt_details: Union["UsageStatisticsPromptTokenDetails", Dict[str, Any], None],
) -> Tuple[int, int]:
"""
Extract normalized cache token counts from provider-specific prompt details.
Handles both Pydantic model objects (from adapters) and dict objects (from database).
Provider mappings:
- OpenAI/Gemini: cached_tokens -> cached_input_tokens
- Anthropic: cache_read_tokens -> cached_input_tokens, cache_creation_tokens -> cache_write_tokens
Args:
prompt_details: Provider-specific prompt token details (model or dict)
Returns:
Tuple of (cached_input_tokens, cache_write_tokens)
"""
if prompt_details is None:
return 0, 0
# Handle dict (from database storage)
if isinstance(prompt_details, dict):
cached_input = 0
if prompt_details.get("cached_tokens"):
cached_input = prompt_details.get("cached_tokens", 0)
elif prompt_details.get("cache_read_tokens"):
cached_input = prompt_details.get("cache_read_tokens", 0)
cache_write = prompt_details.get("cache_creation_tokens", 0) or 0
return cached_input, cache_write
# Handle Pydantic model (from adapters)
cached_input = 0
if hasattr(prompt_details, "cached_tokens") and prompt_details.cached_tokens:
cached_input = prompt_details.cached_tokens
elif hasattr(prompt_details, "cache_read_tokens") and prompt_details.cache_read_tokens:
cached_input = prompt_details.cache_read_tokens
cache_write = 0
if hasattr(prompt_details, "cache_creation_tokens") and prompt_details.cache_creation_tokens:
cache_write = prompt_details.cache_creation_tokens
return cached_input, cache_write
def normalize_reasoning_tokens(
completion_details: Union["UsageStatisticsCompletionTokenDetails", Dict[str, Any], None],
) -> int:
"""
Extract normalized reasoning token count from provider-specific completion details.
Handles both Pydantic model objects (from adapters) and dict objects (from database).
Provider mappings:
- OpenAI: completion_tokens_details.reasoning_tokens
- Gemini: thoughts_token_count (mapped to reasoning_tokens in UsageStatistics)
- Anthropic: thinking tokens are included in completion_tokens, not separately tracked
Args:
completion_details: Provider-specific completion token details (model or dict)
Returns:
The reasoning token count
"""
if completion_details is None:
return 0
# Handle dict (from database storage)
if isinstance(completion_details, dict):
return completion_details.get("reasoning_tokens", 0) or 0
# Handle Pydantic model (from adapters)
if hasattr(completion_details, "reasoning_tokens") and completion_details.reasoning_tokens:
return completion_details.reasoning_tokens
return 0
class LettaUsageStatistics(BaseModel):
"""
Usage statistics for the agent interaction.
Attributes:
completion_tokens (int): The number of tokens generated by the agent.
prompt_tokens (int): The number of tokens in the prompt.
total_tokens (int): The total number of tokens processed by the agent.
step_count (int): The number of steps taken by the agent.
cached_input_tokens (Optional[int]): The number of input tokens served from cache. None if not reported.
cache_write_tokens (Optional[int]): The number of input tokens written to cache. None if not reported.
reasoning_tokens (Optional[int]): The number of reasoning/thinking tokens generated. None if not reported.
"""
message_type: Literal["usage_statistics"] = "usage_statistics"
completion_tokens: int = Field(0, description="The number of tokens generated by the agent.")
prompt_tokens: int = Field(0, description="The number of tokens in the prompt.")
total_tokens: int = Field(0, description="The total number of tokens processed by the agent.")
step_count: int = Field(0, description="The number of steps taken by the agent.")
# TODO: Optional for now. This field makes everyone's lives easier
run_ids: Optional[List[str]] = Field(None, description="The background task run IDs associated with the agent interaction")
# Cache tracking (common across providers)
# None means provider didn't report this data, 0 means provider reported 0
cached_input_tokens: Optional[int] = Field(
None, description="The number of input tokens served from cache. None if not reported by provider."
)
cache_write_tokens: Optional[int] = Field(
None, description="The number of input tokens written to cache (Anthropic only). None if not reported by provider."
)
# Reasoning token tracking
# None means provider didn't report this data, 0 means provider reported 0
reasoning_tokens: Optional[int] = Field(
None, description="The number of reasoning/thinking tokens generated. None if not reported by provider."
)