refactor: add extract_usage_statistics returning LettaUsageStatistics (#9065)
👾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> --------- Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
committed by
Caren Thomas
parent
2bccd36382
commit
221b4e6279
@@ -60,6 +60,7 @@ from letta.schemas.openai.chat_completion_response import (
|
||||
)
|
||||
from letta.schemas.openai.responses_request import ResponsesRequest
|
||||
from letta.schemas.response_format import JsonSchemaResponseFormat
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.settings import model_settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -591,6 +592,66 @@ class OpenAIClient(LLMClientBase):
|
||||
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
||||
return is_openai_reasoning_model(llm_config.model)
|
||||
|
||||
def extract_usage_statistics(self, response_data: dict | None, llm_config: LLMConfig) -> LettaUsageStatistics:
|
||||
"""Extract usage statistics from OpenAI response and return as LettaUsageStatistics."""
|
||||
if not response_data:
|
||||
return LettaUsageStatistics()
|
||||
|
||||
# Handle Responses API format (used by reasoning models like o1/o3)
|
||||
if response_data.get("object") == "response":
|
||||
usage = response_data.get("usage", {}) or {}
|
||||
prompt_tokens = usage.get("input_tokens") or 0
|
||||
completion_tokens = usage.get("output_tokens") or 0
|
||||
total_tokens = usage.get("total_tokens") or (prompt_tokens + completion_tokens)
|
||||
|
||||
input_details = usage.get("input_tokens_details", {}) or {}
|
||||
cached_tokens = input_details.get("cached_tokens")
|
||||
|
||||
output_details = usage.get("output_tokens_details", {}) or {}
|
||||
reasoning_tokens = output_details.get("reasoning_tokens")
|
||||
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cached_input_tokens=cached_tokens,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
)
|
||||
|
||||
# Handle standard Chat Completions API format using pydantic models
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
try:
|
||||
completion = ChatCompletion.model_validate(response_data)
|
||||
except Exception:
|
||||
return LettaUsageStatistics()
|
||||
|
||||
if not completion.usage:
|
||||
return LettaUsageStatistics()
|
||||
|
||||
usage = completion.usage
|
||||
prompt_tokens = usage.prompt_tokens or 0
|
||||
completion_tokens = usage.completion_tokens or 0
|
||||
total_tokens = usage.total_tokens or (prompt_tokens + completion_tokens)
|
||||
|
||||
# Extract cached tokens from prompt_tokens_details
|
||||
cached_tokens = None
|
||||
if usage.prompt_tokens_details:
|
||||
cached_tokens = usage.prompt_tokens_details.cached_tokens
|
||||
|
||||
# Extract reasoning tokens from completion_tokens_details
|
||||
reasoning_tokens = None
|
||||
if usage.completion_tokens_details:
|
||||
reasoning_tokens = usage.completion_tokens_details.reasoning_tokens
|
||||
|
||||
return LettaUsageStatistics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cached_input_tokens=cached_tokens,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def convert_response_to_chat_completion(
|
||||
self,
|
||||
@@ -607,30 +668,10 @@ class OpenAIClient(LLMClientBase):
|
||||
# See example payload in tests/integration_test_send_message_v2.py
|
||||
model = response_data.get("model")
|
||||
|
||||
# Extract usage
|
||||
usage = response_data.get("usage", {}) or {}
|
||||
prompt_tokens = usage.get("input_tokens") or 0
|
||||
completion_tokens = usage.get("output_tokens") or 0
|
||||
total_tokens = usage.get("total_tokens") or (prompt_tokens + completion_tokens)
|
||||
# Extract usage via centralized method
|
||||
from letta.schemas.enums import ProviderType
|
||||
|
||||
# Extract detailed token breakdowns (Responses API uses input_tokens_details/output_tokens_details)
|
||||
prompt_tokens_details = None
|
||||
input_details = usage.get("input_tokens_details", {}) or {}
|
||||
if input_details.get("cached_tokens"):
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatisticsPromptTokenDetails
|
||||
|
||||
prompt_tokens_details = UsageStatisticsPromptTokenDetails(
|
||||
cached_tokens=input_details.get("cached_tokens") or 0,
|
||||
)
|
||||
|
||||
completion_tokens_details = None
|
||||
output_details = usage.get("output_tokens_details", {}) or {}
|
||||
if output_details.get("reasoning_tokens"):
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatisticsCompletionTokenDetails
|
||||
|
||||
completion_tokens_details = UsageStatisticsCompletionTokenDetails(
|
||||
reasoning_tokens=output_details.get("reasoning_tokens") or 0,
|
||||
)
|
||||
usage_stats = self.extract_usage_statistics(response_data, llm_config).to_usage(ProviderType.openai)
|
||||
|
||||
# Extract assistant message text from the outputs list
|
||||
outputs = response_data.get("output") or []
|
||||
@@ -698,13 +739,7 @@ class OpenAIClient(LLMClientBase):
|
||||
choices=[choice],
|
||||
created=int(response_data.get("created_at") or 0),
|
||||
model=model or (llm_config.model if hasattr(llm_config, "model") else None),
|
||||
usage=UsageStatistics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
completion_tokens_details=completion_tokens_details,
|
||||
),
|
||||
usage=usage_stats,
|
||||
)
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
Reference in New Issue
Block a user