json.dumps on request data. Also remove step and actor since they are already present in the span
278 lines
13 KiB
Python
278 lines
13 KiB
Python
import json
|
|
from typing import AsyncGenerator, List
|
|
|
|
from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter
|
|
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
|
from letta.interfaces.anthropic_parallel_tool_call_streaming_interface import SimpleAnthropicStreamingInterface
|
|
from letta.interfaces.gemini_streaming_interface import SimpleGeminiStreamingInterface
|
|
from letta.interfaces.openai_streaming_interface import SimpleOpenAIResponsesStreamingInterface, SimpleOpenAIStreamingInterface
|
|
from letta.otel.tracing import log_attributes, trace_method
|
|
from letta.schemas.enums import ProviderType
|
|
from letta.schemas.letta_message import LettaMessage
|
|
from letta.schemas.letta_message_content import LettaMessageContentUnion
|
|
from letta.schemas.provider_trace import ProviderTraceCreate
|
|
from letta.schemas.usage import LettaUsageStatistics
|
|
from letta.schemas.user import User
|
|
from letta.settings import settings
|
|
from letta.utils import safe_create_task
|
|
|
|
|
|
class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
|
"""
|
|
Adapter for handling streaming LLM requests with immediate token yielding.
|
|
|
|
This adapter supports real-time streaming of tokens from the LLM, providing
|
|
minimal time-to-first-token (TTFT) latency. It uses specialized streaming
|
|
interfaces for different providers (OpenAI, Anthropic) to handle their
|
|
specific streaming formats.
|
|
"""
|
|
|
|
def _extract_tool_calls(self) -> list:
|
|
"""extract tool calls from interface, trying parallel API first then single API"""
|
|
# try multi-call api if available
|
|
if hasattr(self.interface, "get_tool_call_objects"):
|
|
try:
|
|
calls = self.interface.get_tool_call_objects()
|
|
if calls:
|
|
return calls
|
|
except Exception:
|
|
pass
|
|
|
|
# fallback to single-call api
|
|
try:
|
|
single = self.interface.get_tool_call_object()
|
|
return [single] if single else []
|
|
except Exception:
|
|
return []
|
|
|
|
async def invoke_llm(
|
|
self,
|
|
request_data: dict,
|
|
messages: list,
|
|
tools: list,
|
|
use_assistant_message: bool, # NOTE: not used
|
|
requires_approval_tools: list[str] = [],
|
|
step_id: str | None = None,
|
|
actor: User | None = None,
|
|
) -> AsyncGenerator[LettaMessage, None]:
|
|
"""
|
|
Execute a streaming LLM request and yield tokens/chunks as they arrive.
|
|
|
|
This adapter:
|
|
1. Makes a streaming request to the LLM
|
|
2. Yields chunks immediately for minimal TTFT
|
|
3. Accumulates response data through the streaming interface
|
|
4. Updates all instance variables after streaming completes
|
|
"""
|
|
# Store request data
|
|
self.request_data = request_data
|
|
|
|
# Instantiate streaming interface
|
|
if self.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]:
|
|
# NOTE: different
|
|
self.interface = SimpleAnthropicStreamingInterface(
|
|
requires_approval_tools=requires_approval_tools,
|
|
run_id=self.run_id,
|
|
step_id=step_id,
|
|
)
|
|
elif self.llm_config.model_endpoint_type == ProviderType.openai:
|
|
# Decide interface based on payload shape
|
|
use_responses = "input" in request_data and "messages" not in request_data
|
|
# No support for Responses API proxy
|
|
is_proxy = self.llm_config.provider_name == "lmstudio_openai"
|
|
|
|
if use_responses and not is_proxy:
|
|
self.interface = SimpleOpenAIResponsesStreamingInterface(
|
|
is_openai_proxy=False,
|
|
messages=messages,
|
|
tools=tools,
|
|
requires_approval_tools=requires_approval_tools,
|
|
run_id=self.run_id,
|
|
step_id=step_id,
|
|
)
|
|
else:
|
|
self.interface = SimpleOpenAIStreamingInterface(
|
|
is_openai_proxy=self.llm_config.provider_name == "lmstudio_openai",
|
|
messages=messages,
|
|
tools=tools,
|
|
requires_approval_tools=requires_approval_tools,
|
|
model=self.llm_config.model,
|
|
run_id=self.run_id,
|
|
step_id=step_id,
|
|
)
|
|
elif self.llm_config.model_endpoint_type in [ProviderType.google_ai, ProviderType.google_vertex]:
|
|
self.interface = SimpleGeminiStreamingInterface(
|
|
requires_approval_tools=requires_approval_tools,
|
|
run_id=self.run_id,
|
|
step_id=step_id,
|
|
)
|
|
else:
|
|
raise ValueError(f"Streaming not supported for provider {self.llm_config.model_endpoint_type}")
|
|
|
|
# Extract optional parameters
|
|
# ttft_span = kwargs.get('ttft_span', None)
|
|
|
|
# Start the streaming request (map provider errors to common LLMError types)
|
|
try:
|
|
# Gemini uses async generator pattern (no await) to maintain connection lifecycle
|
|
# Other providers return awaitables that resolve to iterators
|
|
if self.llm_config.model_endpoint_type in [ProviderType.google_ai, ProviderType.google_vertex]:
|
|
stream = self.llm_client.stream_async(request_data, self.llm_config)
|
|
else:
|
|
stream = await self.llm_client.stream_async(request_data, self.llm_config)
|
|
except Exception as e:
|
|
raise self.llm_client.handle_llm_error(e)
|
|
|
|
# Process the stream and yield chunks immediately for TTFT
|
|
try:
|
|
async for chunk in self.interface.process(stream): # TODO: add ttft span
|
|
# Yield each chunk immediately as it arrives
|
|
yield chunk
|
|
except Exception as e:
|
|
# Map provider-specific errors during streaming to common LLMError types
|
|
raise self.llm_client.handle_llm_error(e)
|
|
|
|
# After streaming completes, extract the accumulated data
|
|
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
|
|
|
|
# extract tool calls from interface (supports both single and parallel calls)
|
|
self.tool_calls = self._extract_tool_calls()
|
|
# preserve legacy single-call field for existing consumers
|
|
self.tool_call = self.tool_calls[-1] if self.tool_calls else None
|
|
|
|
# Extract reasoning content from the interface
|
|
# TODO this should probably just be called "content"?
|
|
# self.reasoning_content = self.interface.get_reasoning_content()
|
|
|
|
# Extract all content parts
|
|
self.content: List[LettaMessageContentUnion] = self.interface.get_content()
|
|
|
|
# Extract usage statistics
|
|
# Some providers don't provide usage in streaming, use fallback if needed
|
|
if hasattr(self.interface, "input_tokens") and hasattr(self.interface, "output_tokens"):
|
|
# Handle cases where tokens might not be set (e.g., LMStudio)
|
|
input_tokens = self.interface.input_tokens
|
|
output_tokens = self.interface.output_tokens
|
|
|
|
# Fallback to estimated values if not provided
|
|
if not input_tokens and hasattr(self.interface, "fallback_input_tokens"):
|
|
input_tokens = self.interface.fallback_input_tokens
|
|
if not output_tokens and hasattr(self.interface, "fallback_output_tokens"):
|
|
output_tokens = self.interface.fallback_output_tokens
|
|
|
|
# Extract cache token data (OpenAI/Gemini use cached_tokens)
|
|
# None means provider didn't report, 0 means provider reported 0
|
|
cached_input_tokens = None
|
|
if hasattr(self.interface, "cached_tokens") and self.interface.cached_tokens is not None:
|
|
cached_input_tokens = self.interface.cached_tokens
|
|
# Anthropic uses cache_read_tokens for cache hits
|
|
elif hasattr(self.interface, "cache_read_tokens") and self.interface.cache_read_tokens is not None:
|
|
cached_input_tokens = self.interface.cache_read_tokens
|
|
|
|
# Extract cache write tokens (Anthropic only)
|
|
# None means provider didn't report, 0 means provider reported 0
|
|
cache_write_tokens = None
|
|
if hasattr(self.interface, "cache_creation_tokens") and self.interface.cache_creation_tokens is not None:
|
|
cache_write_tokens = self.interface.cache_creation_tokens
|
|
|
|
# Extract reasoning tokens (OpenAI o1/o3 models use reasoning_tokens, Gemini uses thinking_tokens)
|
|
# None means provider didn't report, 0 means provider reported 0
|
|
reasoning_tokens = None
|
|
if hasattr(self.interface, "reasoning_tokens") and self.interface.reasoning_tokens is not None:
|
|
reasoning_tokens = self.interface.reasoning_tokens
|
|
elif hasattr(self.interface, "thinking_tokens") and self.interface.thinking_tokens is not None:
|
|
reasoning_tokens = self.interface.thinking_tokens
|
|
|
|
# Calculate actual total input tokens for context window limit checks (summarization trigger).
|
|
#
|
|
# ANTHROPIC: input_tokens is NON-cached only, must add cache tokens
|
|
# Total = input_tokens + cache_read_input_tokens + cache_creation_input_tokens
|
|
#
|
|
# OPENAI/GEMINI: input_tokens (prompt_tokens/prompt_token_count) is already TOTAL
|
|
# cached_tokens is a subset, NOT additive
|
|
# Total = input_tokens (don't add cached_tokens or it double-counts!)
|
|
is_anthropic = hasattr(self.interface, "cache_read_tokens") or hasattr(self.interface, "cache_creation_tokens")
|
|
if is_anthropic:
|
|
actual_input_tokens = (input_tokens or 0) + (cached_input_tokens or 0) + (cache_write_tokens or 0)
|
|
else:
|
|
actual_input_tokens = input_tokens or 0
|
|
|
|
self.usage = LettaUsageStatistics(
|
|
step_count=1,
|
|
completion_tokens=output_tokens or 0,
|
|
prompt_tokens=input_tokens or 0,
|
|
total_tokens=actual_input_tokens + (output_tokens or 0),
|
|
cached_input_tokens=cached_input_tokens,
|
|
cache_write_tokens=cache_write_tokens,
|
|
reasoning_tokens=reasoning_tokens,
|
|
)
|
|
else:
|
|
# Default usage statistics if not available
|
|
self.usage = LettaUsageStatistics(step_count=1, completion_tokens=0, prompt_tokens=0, total_tokens=0)
|
|
|
|
# Store any additional data from the interface
|
|
self.message_id = self.interface.letta_message_id
|
|
|
|
# Log request and response data
|
|
self.log_provider_trace(step_id=step_id, actor=actor)
|
|
|
|
@trace_method
|
|
def log_provider_trace(self, step_id: str | None, actor: User | None) -> None:
|
|
"""
|
|
Log provider trace data for telemetry purposes in a fire-and-forget manner.
|
|
|
|
Creates an async task to log the request/response data without blocking
|
|
the main execution flow. For streaming adapters, this includes the final
|
|
tool call and reasoning content collected during streaming.
|
|
|
|
Args:
|
|
step_id: The step ID associated with this request for logging purposes
|
|
actor: The user associated with this request for logging purposes
|
|
"""
|
|
if step_id is None or actor is None:
|
|
return
|
|
|
|
response_json = {
|
|
"content": {
|
|
"tool_call": self.tool_call.model_dump_json() if self.tool_call else None,
|
|
# "reasoning": [content.model_dump_json() for content in self.reasoning_content],
|
|
# NOTE: different
|
|
# TODO potentially split this into both content and reasoning?
|
|
"content": [content.model_dump_json() for content in self.content],
|
|
},
|
|
"id": self.interface.message_id,
|
|
"model": self.interface.model,
|
|
"role": "assistant",
|
|
# "stop_reason": "",
|
|
# "stop_sequence": None,
|
|
"type": "message",
|
|
# Use raw_usage if available for transparent provider trace logging, else fallback
|
|
"usage": self.interface.raw_usage
|
|
if hasattr(self.interface, "raw_usage") and self.interface.raw_usage
|
|
else {
|
|
"input_tokens": self.usage.prompt_tokens,
|
|
"output_tokens": self.usage.completion_tokens,
|
|
},
|
|
}
|
|
|
|
log_attributes(
|
|
{
|
|
"request_data": json.dumps(self.request_data),
|
|
"response_data": json.dumps(response_json),
|
|
}
|
|
)
|
|
|
|
if settings.track_provider_trace:
|
|
safe_create_task(
|
|
self.telemetry_manager.create_provider_trace_async(
|
|
actor=actor,
|
|
provider_trace_create=ProviderTraceCreate(
|
|
request_json=self.request_data,
|
|
response_json=response_json,
|
|
step_id=step_id, # Use original step_id for telemetry
|
|
organization_id=actor.organization_id,
|
|
),
|
|
),
|
|
label="create_provider_trace",
|
|
)
|