refactor: add extract_usage_statistics returning LettaUsageStatistics (#9065)
👾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> --------- Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
committed by
Caren Thomas
parent
2bccd36382
commit
221b4e6279
@@ -146,6 +146,26 @@ class SimpleAnthropicStreamingInterface:
|
||||
return tool_calls[0]
|
||||
return None
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
LettaUsageStatistics with token counts from the stream.
|
||||
"""
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
# Anthropic: input_tokens is NON-cached only, must add cache tokens for total
|
||||
actual_input_tokens = (self.input_tokens or 0) + (self.cache_read_tokens or 0) + (self.cache_creation_tokens or 0)
|
||||
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=actual_input_tokens,
|
||||
completion_tokens=self.output_tokens or 0,
|
||||
total_tokens=actual_input_tokens + (self.output_tokens or 0),
|
||||
cached_input_tokens=self.cache_read_tokens if self.cache_read_tokens else None,
|
||||
cache_write_tokens=self.cache_creation_tokens if self.cache_creation_tokens else None,
|
||||
reasoning_tokens=None, # Anthropic doesn't report reasoning tokens separately
|
||||
)
|
||||
|
||||
def get_reasoning_content(self) -> list[TextContent | ReasoningContent | RedactedReasoningContent]:
|
||||
def _process_group(
|
||||
group: list[ReasoningMessage | HiddenReasoningMessage | AssistantMessage],
|
||||
|
||||
@@ -128,6 +128,25 @@ class AnthropicStreamingInterface:
|
||||
arguments = str(json.dumps(tool_input, indent=2))
|
||||
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=arguments, name=self.tool_call_name))
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
LettaUsageStatistics with token counts from the stream.
|
||||
"""
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
# Anthropic: input_tokens is NON-cached only in streaming
|
||||
# This interface doesn't track cache tokens, so we just use the raw values
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=self.input_tokens or 0,
|
||||
completion_tokens=self.output_tokens or 0,
|
||||
total_tokens=(self.input_tokens or 0) + (self.output_tokens or 0),
|
||||
cached_input_tokens=None, # This interface doesn't track cache tokens
|
||||
cache_write_tokens=None,
|
||||
reasoning_tokens=None,
|
||||
)
|
||||
|
||||
def _check_inner_thoughts_complete(self, combined_args: str) -> bool:
|
||||
"""
|
||||
Check if inner thoughts are complete in the current tool call arguments
|
||||
@@ -637,6 +656,25 @@ class SimpleAnthropicStreamingInterface:
|
||||
arguments = str(json.dumps(tool_input, indent=2))
|
||||
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=arguments, name=self.tool_call_name))
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
LettaUsageStatistics with token counts from the stream.
|
||||
"""
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
# Anthropic: input_tokens is NON-cached only in streaming
|
||||
# This interface doesn't track cache tokens, so we just use the raw values
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=self.input_tokens or 0,
|
||||
completion_tokens=self.output_tokens or 0,
|
||||
total_tokens=(self.input_tokens or 0) + (self.output_tokens or 0),
|
||||
cached_input_tokens=None, # This interface doesn't track cache tokens
|
||||
cache_write_tokens=None,
|
||||
reasoning_tokens=None,
|
||||
)
|
||||
|
||||
def get_reasoning_content(self) -> list[TextContent | ReasoningContent | RedactedReasoningContent]:
|
||||
def _process_group(
|
||||
group: list[ReasoningMessage | HiddenReasoningMessage | AssistantMessage],
|
||||
|
||||
@@ -122,6 +122,27 @@ class SimpleGeminiStreamingInterface:
|
||||
"""Return all finalized tool calls collected during this message (parallel supported)."""
|
||||
return list(self.collected_tool_calls)
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
LettaUsageStatistics with token counts from the stream.
|
||||
|
||||
Note:
|
||||
Gemini uses `thinking_tokens` instead of `reasoning_tokens` (OpenAI o1/o3).
|
||||
"""
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=self.input_tokens or 0,
|
||||
completion_tokens=self.output_tokens or 0,
|
||||
total_tokens=(self.input_tokens or 0) + (self.output_tokens or 0),
|
||||
# Gemini: input_tokens is already total, cached_tokens is a subset (not additive)
|
||||
cached_input_tokens=self.cached_tokens,
|
||||
cache_write_tokens=None, # Gemini doesn't report cache write tokens
|
||||
reasoning_tokens=self.thinking_tokens, # Gemini uses thinking_tokens
|
||||
)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
stream: AsyncIterator[GenerateContentResponse],
|
||||
|
||||
@@ -194,6 +194,28 @@ class OpenAIStreamingInterface:
|
||||
function=FunctionCall(arguments=self._get_current_function_arguments(), name=function_name),
|
||||
)
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
LettaUsageStatistics with token counts from the stream.
|
||||
"""
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
# Use actual tokens if available, otherwise fall back to estimated
|
||||
input_tokens = self.input_tokens if self.input_tokens else self.fallback_input_tokens
|
||||
output_tokens = self.output_tokens if self.output_tokens else self.fallback_output_tokens
|
||||
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=input_tokens or 0,
|
||||
completion_tokens=output_tokens or 0,
|
||||
total_tokens=(input_tokens or 0) + (output_tokens or 0),
|
||||
# OpenAI: input_tokens is already total, cached_tokens is a subset (not additive)
|
||||
cached_input_tokens=None, # This interface doesn't track cache tokens
|
||||
cache_write_tokens=None,
|
||||
reasoning_tokens=None, # This interface doesn't track reasoning tokens
|
||||
)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
stream: AsyncStream[ChatCompletionChunk],
|
||||
@@ -672,6 +694,28 @@ class SimpleOpenAIStreamingInterface:
|
||||
raise ValueError("No tool calls available")
|
||||
return calls[0]
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
LettaUsageStatistics with token counts from the stream.
|
||||
"""
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
# Use actual tokens if available, otherwise fall back to estimated
|
||||
input_tokens = self.input_tokens if self.input_tokens else self.fallback_input_tokens
|
||||
output_tokens = self.output_tokens if self.output_tokens else self.fallback_output_tokens
|
||||
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=input_tokens or 0,
|
||||
completion_tokens=output_tokens or 0,
|
||||
total_tokens=(input_tokens or 0) + (output_tokens or 0),
|
||||
# OpenAI: input_tokens is already total, cached_tokens is a subset (not additive)
|
||||
cached_input_tokens=self.cached_tokens,
|
||||
cache_write_tokens=None, # OpenAI doesn't have cache write tokens
|
||||
reasoning_tokens=self.reasoning_tokens,
|
||||
)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
stream: AsyncStream[ChatCompletionChunk],
|
||||
@@ -1080,6 +1124,24 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
raise ValueError("No tool calls available")
|
||||
return calls[0]
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
LettaUsageStatistics with token counts from the stream.
|
||||
"""
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=self.input_tokens or 0,
|
||||
completion_tokens=self.output_tokens or 0,
|
||||
total_tokens=(self.input_tokens or 0) + (self.output_tokens or 0),
|
||||
# OpenAI Responses API: input_tokens is already total
|
||||
cached_input_tokens=self.cached_tokens,
|
||||
cache_write_tokens=None, # OpenAI doesn't have cache write tokens
|
||||
reasoning_tokens=self.reasoning_tokens,
|
||||
)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
stream: AsyncStream[ResponseStreamEvent],
|
||||
|
||||
Reference in New Issue
Block a user