Files
letta-server/letta/adapters/simple_llm_stream_adapter.py
cthomas 416ffc7cd7 Add billing context to LLM telemetry traces (#9745)
* feat: add billing context to LLM telemetry traces

Add billing metadata (plan type, cost source, customer ID) to LLM traces in ClickHouse for cost analytics and attribution.

**Data Flow:**
- Cloud-API: Extract billing info from subscription in rate limiting, set x-billing-* headers
- Core: Parse headers into BillingContext object via dependencies
- Adapters: Flow billing_context through all LLM adapters (blocking & streaming)
- Agent: Pass billing_context to step() and stream() methods
- ClickHouse: Store in billing_plan_type, billing_cost_source, billing_customer_id columns

**Changes:**
- Add BillingContext schema to provider_trace.py
- Add billing columns to llm_traces ClickHouse table DDL
- Update getCustomerSubscription to fetch stripeCustomerId from organization_billing_details
- Propagate billing_context through agent step flow, adapters, and streaming service
- Update ProviderTrace and LLMTrace to include billing metadata
- Regenerate SDK with autogen

**Production Deployment:**
Requires env vars: LETTA_PROVIDER_TRACE_BACKEND=clickhouse, LETTA_STORE_LLM_TRACES=true, CLICKHOUSE_*

🐾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* fix: add billing_context parameter to agent step methods

- Add billing_context to BaseAgent and BaseAgentV2 abstract methods
- Update LettaAgent, LettaAgentV2, LettaAgentV3 step methods
- Update multi-agent groups: SleeptimeMultiAgentV2, V3, V4
- Fix test_utils.py to include billing header parameters
- Import BillingContext in all affected files

* fix: add billing_context to stream methods

- Add billing_context parameter to BaseAgentV2.stream()
- Add billing_context parameter to LettaAgentV2.stream()
- LettaAgentV3.stream() already has it from previous commit

* fix: exclude billing headers from OpenAPI spec

Mark billing headers as internal (include_in_schema=False) so they don't appear in the public API.
These are internal headers between cloud-api and core, not part of the public SDK.

Regenerated SDK with stage-api - removes 10,650 lines of bloat that was causing OOM during Next.js build.

* refactor: return billing context from handleUnifiedRateLimiting instead of mutating req

Instead of passing req into handleUnifiedRateLimiting and mutating headers inside it:
- Return billing context fields (billingPlanType, billingCostSource, billingCustomerId) from handleUnifiedRateLimiting
- Set headers in handleMessageRateLimiting (middleware layer) after getting the result
- This fixes step-orchestrator compatibility since it doesn't have a real Express req object

* chore: remove extra gencode

* p

---------

Co-authored-by: Letta <noreply@letta.com>
2026-03-03 18:34:13 -08:00

286 lines
13 KiB
Python

from typing import AsyncGenerator, List
from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter
from letta.errors import LLMError
from letta.log import get_logger
logger = get_logger(__name__)
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
from letta.interfaces.anthropic_parallel_tool_call_streaming_interface import SimpleAnthropicStreamingInterface
from letta.interfaces.gemini_streaming_interface import SimpleGeminiStreamingInterface
from letta.interfaces.openai_streaming_interface import SimpleOpenAIResponsesStreamingInterface, SimpleOpenAIStreamingInterface
from letta.otel.tracing import log_attributes, safe_json_dumps, trace_method
from letta.schemas.enums import ProviderType
from letta.schemas.letta_message import LettaMessage
from letta.schemas.letta_message_content import LettaMessageContentUnion
from letta.schemas.provider_trace import ProviderTrace
from letta.schemas.user import User
from letta.server.rest_api.streaming_response import get_cancellation_event_for_run
from letta.settings import settings
from letta.utils import safe_create_task
class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
"""
Adapter for handling streaming LLM requests with immediate token yielding.
This adapter supports real-time streaming of tokens from the LLM, providing
minimal time-to-first-token (TTFT) latency. It uses specialized streaming
interfaces for different providers (OpenAI, Anthropic) to handle their
specific streaming formats.
"""
def _extract_tool_calls(self) -> list:
"""extract tool calls from interface, trying parallel API first then single API"""
# try multi-call api if available
if hasattr(self.interface, "get_tool_call_objects"):
try:
calls = self.interface.get_tool_call_objects()
if calls:
return calls
except Exception:
pass
# fallback to single-call api
try:
single = self.interface.get_tool_call_object()
return [single] if single else []
except Exception:
return []
async def invoke_llm(
self,
request_data: dict,
messages: list,
tools: list,
use_assistant_message: bool, # NOTE: not used
requires_approval_tools: list[str] = [],
step_id: str | None = None,
actor: User | None = None,
) -> AsyncGenerator[LettaMessage, None]:
"""
Execute a streaming LLM request and yield tokens/chunks as they arrive.
This adapter:
1. Makes a streaming request to the LLM
2. Yields chunks immediately for minimal TTFT
3. Accumulates response data through the streaming interface
4. Updates all instance variables after streaming completes
"""
# Store request data
self.request_data = request_data
# Track request start time for latency calculation
request_start_ns = get_utc_timestamp_ns()
# Get cancellation event for this run to enable graceful cancellation (before branching)
cancellation_event = get_cancellation_event_for_run(self.run_id) if self.run_id else None
# Instantiate streaming interface
if self.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock, ProviderType.minimax]:
# NOTE: different
self.interface = SimpleAnthropicStreamingInterface(
requires_approval_tools=requires_approval_tools,
run_id=self.run_id,
step_id=step_id,
)
elif self.llm_config.model_endpoint_type in [
ProviderType.openai,
ProviderType.deepseek,
ProviderType.openrouter,
ProviderType.zai,
ProviderType.chatgpt_oauth,
]:
# Decide interface based on payload shape
use_responses = "input" in request_data and "messages" not in request_data
# No support for Responses API proxy
is_proxy = self.llm_config.provider_name == "lmstudio_openai"
# ChatGPT OAuth always uses Responses API format
if self.llm_config.model_endpoint_type == ProviderType.chatgpt_oauth:
use_responses = True
is_proxy = False
if use_responses and not is_proxy:
self.interface = SimpleOpenAIResponsesStreamingInterface(
is_openai_proxy=False,
messages=messages,
tools=tools,
requires_approval_tools=requires_approval_tools,
run_id=self.run_id,
step_id=step_id,
cancellation_event=cancellation_event,
)
else:
self.interface = SimpleOpenAIStreamingInterface(
is_openai_proxy=self.llm_config.provider_name == "lmstudio_openai",
messages=messages,
tools=tools,
requires_approval_tools=requires_approval_tools,
model=self.llm_config.model,
run_id=self.run_id,
step_id=step_id,
cancellation_event=cancellation_event,
)
elif self.llm_config.model_endpoint_type in [ProviderType.google_ai, ProviderType.google_vertex]:
self.interface = SimpleGeminiStreamingInterface(
requires_approval_tools=requires_approval_tools,
run_id=self.run_id,
step_id=step_id,
cancellation_event=cancellation_event,
)
else:
raise ValueError(f"Streaming not supported for provider {self.llm_config.model_endpoint_type}")
# Start the streaming request (map provider errors to common LLMError types)
try:
# Gemini uses async generator pattern (no await) to maintain connection lifecycle
# Other providers return awaitables that resolve to iterators
if self.llm_config.model_endpoint_type in [ProviderType.google_ai, ProviderType.google_vertex]:
stream = self.llm_client.stream_async(request_data, self.llm_config)
else:
stream = await self.llm_client.stream_async(request_data, self.llm_config)
except Exception as e:
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
latency_ms = int((self.llm_request_finish_timestamp_ns - request_start_ns) / 1_000_000)
await self.llm_client.log_provider_trace_async(
request_data=request_data,
response_json=None,
llm_config=self.llm_config,
latency_ms=latency_ms,
error_msg=str(e),
error_type=type(e).__name__,
)
if isinstance(e, LLMError):
raise
raise self.llm_client.handle_llm_error(e, llm_config=self.llm_config)
# Process the stream and yield chunks immediately for TTFT
try:
async for chunk in self.interface.process(stream): # TODO: add ttft span
# Yield each chunk immediately as it arrives
yield chunk
except Exception as e:
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
latency_ms = int((self.llm_request_finish_timestamp_ns - request_start_ns) / 1_000_000)
await self.llm_client.log_provider_trace_async(
request_data=request_data,
response_json=None,
llm_config=self.llm_config,
latency_ms=latency_ms,
error_msg=str(e),
error_type=type(e).__name__,
)
if isinstance(e, LLMError):
raise
raise self.llm_client.handle_llm_error(e, llm_config=self.llm_config)
# After streaming completes, extract the accumulated data
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
# extract tool calls from interface (supports both single and parallel calls)
self.tool_calls = self._extract_tool_calls()
# preserve legacy single-call field for existing consumers
self.tool_call = self.tool_calls[-1] if self.tool_calls else None
# Extract reasoning content from the interface
# TODO this should probably just be called "content"?
# self.reasoning_content = self.interface.get_reasoning_content()
# Extract all content parts
self.content: List[LettaMessageContentUnion] = self.interface.get_content()
# Extract usage statistics from the interface
# Each interface implements get_usage_statistics() with provider-specific logic
self.usage = self.interface.get_usage_statistics()
self.usage.step_count = 1
# Store any additional data from the interface
self.message_id = self.interface.letta_message_id
# Populate finish_reason for downstream continuation logic.
# In Responses streaming, max_output_tokens is expressed via incomplete_details.reason.
if hasattr(self.interface, "final_response") and self.interface.final_response is not None:
resp = self.interface.final_response
incomplete_details = getattr(resp, "incomplete_details", None)
incomplete_reason = getattr(incomplete_details, "reason", None) if incomplete_details else None
if incomplete_reason == "max_output_tokens":
self._finish_reason = "length"
elif incomplete_reason == "content_filter":
self._finish_reason = "content_filter"
elif incomplete_reason is not None:
# Unknown incomplete reason — preserve it as-is for diagnostics
self._finish_reason = incomplete_reason
elif getattr(resp, "status", None) == "completed":
self._finish_reason = "stop"
# Log request and response data
self.log_provider_trace(step_id=step_id, actor=actor)
@trace_method
def log_provider_trace(self, step_id: str | None, actor: User | None) -> None:
"""
Log provider trace data for telemetry purposes in a fire-and-forget manner.
Creates an async task to log the request/response data without blocking
the main execution flow. For streaming adapters, this includes the final
tool call and reasoning content collected during streaming.
Args:
step_id: The step ID associated with this request for logging purposes
actor: The user associated with this request for logging purposes
"""
if step_id is None or actor is None:
return
response_json = {
"content": {
"tool_call": self.tool_call.model_dump_json() if self.tool_call else None,
# "reasoning": [content.model_dump_json() for content in self.reasoning_content],
# NOTE: different
# TODO potentially split this into both content and reasoning?
"content": [content.model_dump_json() for content in self.content],
},
"id": self.interface.message_id,
"model": self.interface.model,
"role": "assistant",
# "stop_reason": "",
# "stop_sequence": None,
"type": "message",
# Use raw_usage if available for transparent provider trace logging, else fallback
"usage": self.interface.raw_usage
if hasattr(self.interface, "raw_usage") and self.interface.raw_usage
else {
"input_tokens": self.usage.prompt_tokens,
"output_tokens": self.usage.completion_tokens,
},
}
log_attributes(
{
"request_data": safe_json_dumps(self.request_data),
"response_data": safe_json_dumps(response_json),
}
)
if settings.track_provider_trace:
safe_create_task(
self.telemetry_manager.create_provider_trace_async(
actor=actor,
provider_trace=ProviderTrace(
request_json=self.request_data,
response_json=response_json,
step_id=step_id,
agent_id=self.agent_id,
agent_tags=self.agent_tags,
run_id=self.run_id,
call_type=self.call_type,
org_id=self.org_id,
user_id=self.user_id,
llm_config=self.llm_config.model_dump() if self.llm_config else None,
billing_context=self.billing_context,
),
),
label="create_provider_trace",
)