Add is_byok flag to every LLMError's details dict returned from
handle_llm_error across all providers (OpenAI, Anthropic, Google,
ChatGPT OAuth). This enables observability into whether errors
originate from Letta's production keys or user-provided BYOK keys.
The rate limit handler in app.py now returns a more helpful message
for BYOK users ("check your provider's rate limits and billing")
versus the generic message for base provider rate limits.
Datadog issues:
- https://us5.datadoghq.com/error-tracking/issue/b711c824-f490-11f0-96e4-da7ad0900000
- https://us5.datadoghq.com/error-tracking/issue/76623036-f4de-11f0-8697-da7ad0900000
- https://us5.datadoghq.com/error-tracking/issue/43e9888a-dfcf-11f0-a645-da7ad0900000
🤖 Generated with [Letta Code](https://letta.com)
Co-authored-by: Letta <noreply@letta.com>
155 lines
7.2 KiB
Python
155 lines
7.2 KiB
Python
import json
|
|
import os
|
|
from typing import List, Optional, Tuple
|
|
|
|
from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream, AzureOpenAI, OpenAI
|
|
from openai.types.chat.chat_completion import ChatCompletion
|
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|
from openai.types.responses.response_stream_event import ResponseStreamEvent
|
|
|
|
from letta.helpers.json_helpers import sanitize_unicode_surrogates
|
|
from letta.llm_api.openai_client import OpenAIClient
|
|
from letta.log import get_logger
|
|
from letta.otel.tracing import trace_method
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.enums import ProviderCategory
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.settings import model_settings
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class AzureClient(OpenAIClient):
|
|
@staticmethod
|
|
def _is_v1_endpoint(base_url: str) -> bool:
|
|
if not base_url:
|
|
return False
|
|
return base_url.rstrip("/").endswith("/openai/v1")
|
|
|
|
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
|
if llm_config.provider_category == ProviderCategory.byok:
|
|
from letta.services.provider_manager import ProviderManager
|
|
|
|
return ProviderManager().get_azure_credentials(llm_config.provider_name, actor=self.actor)
|
|
|
|
return None, None, None
|
|
|
|
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
|
if llm_config.provider_category == ProviderCategory.byok:
|
|
from letta.services.provider_manager import ProviderManager
|
|
|
|
return await ProviderManager().get_azure_credentials_async(llm_config.provider_name, actor=self.actor)
|
|
|
|
return None, None, None
|
|
|
|
def _resolve_credentials(self, api_key, base_url, api_version):
|
|
"""Resolve credentials, falling back to env vars. For v1 endpoints, api_version is not required."""
|
|
if not api_key:
|
|
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
|
|
if not base_url:
|
|
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
|
|
if not api_version and not self._is_v1_endpoint(base_url):
|
|
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
|
|
return api_key, base_url, api_version
|
|
|
|
@trace_method
|
|
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
|
"""
|
|
Performs underlying synchronous request to OpenAI API and returns raw response dict.
|
|
"""
|
|
api_key, base_url, api_version = self.get_byok_overrides(llm_config)
|
|
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
|
|
|
|
if self._is_v1_endpoint(base_url):
|
|
client = OpenAI(api_key=api_key, base_url=base_url)
|
|
else:
|
|
client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
|
|
|
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
|
|
if "input" in request_data and "messages" not in request_data:
|
|
resp = client.responses.create(**request_data)
|
|
return resp.model_dump()
|
|
else:
|
|
response: ChatCompletion = client.chat.completions.create(**request_data)
|
|
return response.model_dump()
|
|
|
|
@trace_method
|
|
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
|
"""
|
|
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
|
|
"""
|
|
request_data = sanitize_unicode_surrogates(request_data)
|
|
|
|
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
|
|
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
|
|
|
|
try:
|
|
if self._is_v1_endpoint(base_url):
|
|
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
|
else:
|
|
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
|
|
|
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
|
|
if "input" in request_data and "messages" not in request_data:
|
|
resp = await client.responses.create(**request_data)
|
|
return resp.model_dump()
|
|
else:
|
|
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
|
return response.model_dump()
|
|
except Exception as e:
|
|
raise self.handle_llm_error(e, llm_config=llm_config)
|
|
|
|
@trace_method
|
|
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk | ResponseStreamEvent]:
|
|
"""
|
|
Performs underlying asynchronous streaming request to Azure/OpenAI and returns the async stream iterator.
|
|
"""
|
|
request_data = sanitize_unicode_surrogates(request_data)
|
|
|
|
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
|
|
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
|
|
|
|
if self._is_v1_endpoint(base_url):
|
|
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
|
else:
|
|
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
|
|
|
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
|
|
if "input" in request_data and "messages" not in request_data:
|
|
try:
|
|
response_stream: AsyncStream[ResponseStreamEvent] = await client.responses.create(
|
|
**request_data,
|
|
stream=True,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error streaming Azure Responses request: {e} with request data: {json.dumps(request_data)}")
|
|
raise e
|
|
else:
|
|
try:
|
|
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
|
**request_data,
|
|
stream=True,
|
|
stream_options={"include_usage": True},
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error streaming Azure Chat Completions request: {e} with request data: {json.dumps(request_data)}")
|
|
raise e
|
|
return response_stream
|
|
|
|
@trace_method
|
|
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
|
|
"""Request embeddings given texts and embedding config"""
|
|
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
|
|
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
|
|
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
|
|
|
|
if self._is_v1_endpoint(base_url):
|
|
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
|
else:
|
|
client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=base_url)
|
|
|
|
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
|
|
|
|
# TODO: add total usage
|
|
return [r.embedding for r in response.data]
|