* feat: add provider trace backend abstraction for multi-backend telemetry Introduces a pluggable backend system for provider traces: - Base class with async/sync create and read interfaces - PostgreSQL backend (existing behavior) - ClickHouse backend (via OTEL instrumentation) - Socket backend (writes to Unix socket for crouton sidecar) - Factory for instantiating backends from config Refactors TelemetryManager to use backends with support for: - Multi-backend writes (concurrent via asyncio.gather) - Primary backend for reads (first in config list) - Graceful error handling per backend Config: LETTA_TELEMETRY_PROVIDER_TRACE_BACKEND (comma-separated) Example: "postgres,socket" for dual-write to Postgres and crouton 🐙 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * feat: add protocol version to socket backend records Adds PROTOCOL_VERSION constant to socket backend: - Included in every telemetry record sent to crouton - Must match ProtocolVersion in apps/crouton/main.go - Enables crouton to detect and reject incompatible messages 🐙 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: remove organization_id from ProviderTraceCreate calls The organization_id is now handled via the actor parameter in the telemetry manager, not through ProviderTraceCreate schema. This fixes validation errors after changing ProviderTraceCreate to inherit from BaseProviderTrace which forbids extra fields. 🐙 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * consolidate provider trace * add clickhouse-connect to fix bug on main lmao * auto generated sdk changes, and deployment details, and clikchouse prefix bug and added fields to runs trace return api * auto generated sdk changes, and deployment details, and clikchouse prefix bug and added fields to runs trace return api * consolidate provider trace * consolidate provider trace bug fix --------- Co-authored-by: Letta <noreply@letta.com>
301 lines
12 KiB
Python
301 lines
12 KiB
Python
import json
|
|
from abc import abstractmethod
|
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
|
|
|
import httpx
|
|
from anthropic.types.beta.messages import BetaMessageBatch
|
|
from openai import AsyncStream, Stream
|
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|
|
|
from letta.errors import ErrorCode, LLMConnectionError, LLMError
|
|
from letta.otel.tracing import log_event, trace_method
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.enums import AgentType, ProviderCategory
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.message import Message
|
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
|
from letta.schemas.provider_trace import ProviderTrace
|
|
from letta.services.telemetry_manager import TelemetryManager
|
|
from letta.settings import settings
|
|
|
|
if TYPE_CHECKING:
|
|
from letta.orm import User
|
|
|
|
|
|
class LLMClientBase:
|
|
"""
|
|
Abstract base class for LLM clients, formatting the request objects,
|
|
handling the downstream request and parsing into chat completions response format
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
put_inner_thoughts_first: Optional[bool] = True,
|
|
use_tool_naming: bool = True,
|
|
actor: Optional["User"] = None,
|
|
):
|
|
self.actor = actor
|
|
self.put_inner_thoughts_first = put_inner_thoughts_first
|
|
self.use_tool_naming = use_tool_naming
|
|
|
|
@trace_method
|
|
async def send_llm_request(
|
|
self,
|
|
agent_type: AgentType,
|
|
messages: List[Message],
|
|
llm_config: LLMConfig,
|
|
tools: Optional[List[dict]] = None, # TODO: change to Tool object
|
|
force_tool_call: Optional[str] = None,
|
|
telemetry_manager: Optional["TelemetryManager"] = None,
|
|
step_id: Optional[str] = None,
|
|
tool_return_truncation_chars: Optional[int] = None,
|
|
) -> Union[ChatCompletionResponse, Stream[ChatCompletionChunk]]:
|
|
"""
|
|
Issues a request to the downstream model endpoint and parses response.
|
|
If stream=True, returns a Stream[ChatCompletionChunk] that can be iterated over.
|
|
Otherwise returns a ChatCompletionResponse.
|
|
"""
|
|
request_data = self.build_request_data(
|
|
agent_type,
|
|
messages,
|
|
llm_config,
|
|
tools,
|
|
force_tool_call,
|
|
requires_subsequent_tool_call=False,
|
|
tool_return_truncation_chars=tool_return_truncation_chars,
|
|
)
|
|
|
|
try:
|
|
log_event(name="llm_request_sent", attributes=request_data)
|
|
response_data = await self.request_async(request_data, llm_config)
|
|
if step_id and telemetry_manager:
|
|
telemetry_manager.create_provider_trace(
|
|
actor=self.actor,
|
|
provider_trace=ProviderTrace(
|
|
request_json=request_data,
|
|
response_json=response_data,
|
|
step_id=step_id,
|
|
),
|
|
)
|
|
log_event(name="llm_response_received", attributes=response_data)
|
|
except Exception as e:
|
|
raise self.handle_llm_error(e)
|
|
|
|
return await self.convert_response_to_chat_completion(response_data, messages, llm_config)
|
|
|
|
@trace_method
|
|
async def send_llm_request_async(
|
|
self,
|
|
request_data: dict,
|
|
messages: List[Message],
|
|
llm_config: LLMConfig,
|
|
telemetry_manager: "TelemetryManager | None" = None,
|
|
step_id: str | None = None,
|
|
) -> Union[ChatCompletionResponse, AsyncStream[ChatCompletionChunk]]:
|
|
"""
|
|
Issues a request to the downstream model endpoint.
|
|
If stream=True, returns an AsyncStream[ChatCompletionChunk] that can be async iterated over.
|
|
Otherwise returns a ChatCompletionResponse.
|
|
"""
|
|
|
|
try:
|
|
log_event(name="llm_request_sent", attributes=request_data)
|
|
response_data = await self.request_async(request_data, llm_config)
|
|
if settings.track_provider_trace and telemetry_manager:
|
|
await telemetry_manager.create_provider_trace_async(
|
|
actor=self.actor,
|
|
provider_trace=ProviderTrace(
|
|
request_json=request_data,
|
|
response_json=response_data,
|
|
step_id=step_id,
|
|
),
|
|
)
|
|
|
|
log_event(name="llm_response_received", attributes=response_data)
|
|
except Exception as e:
|
|
raise self.handle_llm_error(e)
|
|
|
|
return await self.convert_response_to_chat_completion(response_data, messages, llm_config)
|
|
|
|
async def send_llm_batch_request_async(
|
|
self,
|
|
agent_type: AgentType,
|
|
agent_messages_mapping: Dict[str, List[Message]],
|
|
agent_tools_mapping: Dict[str, List[dict]],
|
|
agent_llm_config_mapping: Dict[str, LLMConfig],
|
|
) -> Union[BetaMessageBatch]:
|
|
"""
|
|
Issues a batch request to the downstream model endpoint and parses response.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def build_request_data(
|
|
self,
|
|
agent_type: AgentType,
|
|
messages: List[Message],
|
|
llm_config: LLMConfig,
|
|
tools: List[dict],
|
|
force_tool_call: Optional[str] = None,
|
|
requires_subsequent_tool_call: bool = False,
|
|
tool_return_truncation_chars: Optional[int] = None,
|
|
) -> dict:
|
|
"""
|
|
Constructs a request object in the expected data format for this client.
|
|
|
|
Args:
|
|
tool_return_truncation_chars: If set, truncates tool return content to this many characters.
|
|
Used during summarization to avoid context window issues.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
|
"""
|
|
Performs underlying request to llm and returns raw response.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
|
"""
|
|
Performs underlying request to llm and returns raw response.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def request_embeddings(self, texts: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
|
|
"""
|
|
Generate embeddings for a batch of texts.
|
|
|
|
Args:
|
|
texts (List[str]): List of texts to generate embeddings for.
|
|
embedding_config (EmbeddingConfig): Configuration for the embedding model.
|
|
|
|
Returns:
|
|
embeddings (List[List[float]]): List of embeddings for the input texts.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def convert_response_to_chat_completion(
|
|
self,
|
|
response_data: dict,
|
|
input_messages: List[Message],
|
|
llm_config: LLMConfig,
|
|
) -> ChatCompletionResponse:
|
|
"""
|
|
Converts custom response format from llm client into an OpenAI
|
|
ChatCompletionsResponse object.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
|
|
"""
|
|
Performs underlying streaming request to llm and returns raw response.
|
|
"""
|
|
raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}")
|
|
|
|
@abstractmethod
|
|
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
|
"""
|
|
Returns True if the model is a native reasoning model.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def handle_llm_error(self, e: Exception) -> Exception:
|
|
"""
|
|
Maps provider-specific errors to common LLMError types.
|
|
Each LLM provider should implement this to translate their specific errors.
|
|
|
|
Args:
|
|
e: The original provider-specific exception
|
|
|
|
Returns:
|
|
An LLMError subclass that represents the error in a provider-agnostic way
|
|
"""
|
|
# Handle httpx.RemoteProtocolError which can occur during streaming
|
|
# when the remote server closes the connection unexpectedly
|
|
# (e.g., "peer closed connection without sending complete message body")
|
|
if isinstance(e, httpx.RemoteProtocolError):
|
|
from letta.log import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
logger.warning(f"[LLM] Remote protocol error during streaming: {e}")
|
|
return LLMConnectionError(
|
|
message=f"Connection error during streaming: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
|
)
|
|
|
|
return LLMError(f"Unhandled LLM error: {str(e)}")
|
|
|
|
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
|
"""
|
|
Returns the override key for the given llm config.
|
|
Only fetches API key from database for BYOK providers.
|
|
Base providers use environment variables directly.
|
|
"""
|
|
api_key = None
|
|
# Only fetch API key from database for BYOK providers
|
|
# Base providers should always use environment variables
|
|
if llm_config.provider_category == ProviderCategory.byok:
|
|
from letta.services.provider_manager import ProviderManager
|
|
|
|
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
|
|
# If we got an empty string from the database, treat it as None
|
|
# so the client can fall back to environment variables or default behavior
|
|
if api_key == "":
|
|
api_key = None
|
|
|
|
return api_key, None, None
|
|
|
|
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
|
"""
|
|
Returns the override key for the given llm config.
|
|
Only fetches API key from database for BYOK providers.
|
|
Base providers use environment variables directly.
|
|
"""
|
|
api_key = None
|
|
# Only fetch API key from database for BYOK providers
|
|
# Base providers should always use environment variables
|
|
if llm_config.provider_category == ProviderCategory.byok:
|
|
from letta.services.provider_manager import ProviderManager
|
|
|
|
api_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor)
|
|
# If we got an empty string from the database, treat it as None
|
|
# so the client can fall back to environment variables or default behavior
|
|
if api_key == "":
|
|
api_key = None
|
|
|
|
return api_key, None, None
|
|
|
|
def _fix_truncated_json_response(self, response: ChatCompletionResponse) -> ChatCompletionResponse:
|
|
"""
|
|
Fixes truncated JSON responses by ensuring the content is properly formatted.
|
|
This is a workaround for some providers that may return incomplete JSON.
|
|
"""
|
|
if response.choices and response.choices[0].message and response.choices[0].message.tool_calls:
|
|
tool_call_args_str = response.choices[0].message.tool_calls[0].function.arguments
|
|
try:
|
|
json.loads(tool_call_args_str)
|
|
except json.JSONDecodeError:
|
|
try:
|
|
json_str_end = ""
|
|
quote_count = tool_call_args_str.count('"')
|
|
if quote_count % 2 != 0:
|
|
json_str_end = json_str_end + '"'
|
|
|
|
open_braces = tool_call_args_str.count("{")
|
|
close_braces = tool_call_args_str.count("}")
|
|
missing_braces = open_braces - close_braces
|
|
json_str_end += "}" * missing_braces
|
|
fixed_tool_call_args_str = tool_call_args_str[: -len(json_str_end)] + json_str_end
|
|
json.loads(fixed_tool_call_args_str)
|
|
response.choices[0].message.tool_calls[0].function.arguments = fixed_tool_call_args_str
|
|
except json.JSONDecodeError:
|
|
pass
|
|
return response
|