Files
letta-server/letta/llm_api/azure_client.py
Ari Webb 5fd5a6dd07 feat: add new azure api maintaining backward compat (#9387)
* feat: add new azure provider type

* fix context window
2026-02-24 10:52:07 -08:00

155 lines
7.2 KiB
Python

import json
import os
from typing import List, Optional, Tuple
from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream, AzureOpenAI, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.responses.response_stream_event import ResponseStreamEvent
from letta.helpers.json_helpers import sanitize_unicode_surrogates
from letta.llm_api.openai_client import OpenAIClient
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.settings import model_settings
logger = get_logger(__name__)
class AzureClient(OpenAIClient):
@staticmethod
def _is_v1_endpoint(base_url: str) -> bool:
if not base_url:
return False
return base_url.rstrip("/").endswith("/openai/v1")
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
return ProviderManager().get_azure_credentials(llm_config.provider_name, actor=self.actor)
return None, None, None
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
return await ProviderManager().get_azure_credentials_async(llm_config.provider_name, actor=self.actor)
return None, None, None
def _resolve_credentials(self, api_key, base_url, api_version):
"""Resolve credentials, falling back to env vars. For v1 endpoints, api_version is not required."""
if not api_key:
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
if not base_url:
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
if not api_version and not self._is_v1_endpoint(base_url):
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
return api_key, base_url, api_version
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying synchronous request to OpenAI API and returns raw response dict.
"""
api_key, base_url, api_version = self.get_byok_overrides(llm_config)
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
if self._is_v1_endpoint(base_url):
client = OpenAI(api_key=api_key, base_url=base_url)
else:
client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
if "input" in request_data and "messages" not in request_data:
resp = client.responses.create(**request_data)
return resp.model_dump()
else:
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
"""
request_data = sanitize_unicode_surrogates(request_data)
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
try:
if self._is_v1_endpoint(base_url):
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
else:
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
if "input" in request_data and "messages" not in request_data:
resp = await client.responses.create(**request_data)
return resp.model_dump()
else:
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
except Exception as e:
raise self.handle_llm_error(e)
@trace_method
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk | ResponseStreamEvent]:
"""
Performs underlying asynchronous streaming request to Azure/OpenAI and returns the async stream iterator.
"""
request_data = sanitize_unicode_surrogates(request_data)
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
if self._is_v1_endpoint(base_url):
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
else:
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
if "input" in request_data and "messages" not in request_data:
try:
response_stream: AsyncStream[ResponseStreamEvent] = await client.responses.create(
**request_data,
stream=True,
)
except Exception as e:
logger.error(f"Error streaming Azure Responses request: {e} with request data: {json.dumps(request_data)}")
raise e
else:
try:
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
**request_data,
stream=True,
stream_options={"include_usage": True},
)
except Exception as e:
logger.error(f"Error streaming Azure Chat Completions request: {e} with request data: {json.dumps(request_data)}")
raise e
return response_stream
@trace_method
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
"""Request embeddings given texts and embedding config"""
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
if self._is_v1_endpoint(base_url):
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
else:
client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=base_url)
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
# TODO: add total usage
return [r.embedding for r in response.data]