feat: add new azure api maintaining backward compat (#9387)

* feat: add new azure provider type

* fix context window
This commit is contained in:
Ari Webb
2026-02-09 16:04:27 -08:00
committed by Caren Thomas
parent 226df8baef
commit 5fd5a6dd07
2 changed files with 218 additions and 18 deletions

View File

@@ -1,19 +1,31 @@
import json
import os
from typing import List, Optional, Tuple
from openai import AsyncAzureOpenAI, AzureOpenAI
from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream, AzureOpenAI, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.responses.response_stream_event import ResponseStreamEvent
from letta.helpers.json_helpers import sanitize_unicode_surrogates
from letta.llm_api.openai_client import OpenAIClient
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.settings import model_settings
logger = get_logger(__name__)
class AzureClient(OpenAIClient):
@staticmethod
def _is_v1_endpoint(base_url: str) -> bool:
if not base_url:
return False
return base_url.rstrip("/").endswith("/openai/v1")
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
@@ -30,20 +42,36 @@ class AzureClient(OpenAIClient):
return None, None, None
def _resolve_credentials(self, api_key, base_url, api_version):
"""Resolve credentials, falling back to env vars. For v1 endpoints, api_version is not required."""
if not api_key:
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
if not base_url:
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
if not api_version and not self._is_v1_endpoint(base_url):
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
return api_key, base_url, api_version
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying synchronous request to OpenAI API and returns raw response dict.
"""
api_key, base_url, api_version = self.get_byok_overrides(llm_config)
if not api_key or not base_url or not api_version:
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
if self._is_v1_endpoint(base_url):
client = OpenAI(api_key=api_key, base_url=base_url)
else:
client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
if "input" in request_data and "messages" not in request_data:
resp = client.responses.create(**request_data)
return resp.model_dump()
else:
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
@@ -53,17 +81,60 @@ class AzureClient(OpenAIClient):
request_data = sanitize_unicode_surrogates(request_data)
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
if not api_key or not base_url or not api_version:
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
try:
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
response: ChatCompletion = await client.chat.completions.create(**request_data)
if self._is_v1_endpoint(base_url):
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
else:
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
if "input" in request_data and "messages" not in request_data:
resp = await client.responses.create(**request_data)
return resp.model_dump()
else:
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
except Exception as e:
raise self.handle_llm_error(e)
return response.model_dump()
@trace_method
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk | ResponseStreamEvent]:
"""
Performs underlying asynchronous streaming request to Azure/OpenAI and returns the async stream iterator.
"""
request_data = sanitize_unicode_surrogates(request_data)
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
if self._is_v1_endpoint(base_url):
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
else:
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
if "input" in request_data and "messages" not in request_data:
try:
response_stream: AsyncStream[ResponseStreamEvent] = await client.responses.create(
**request_data,
stream=True,
)
except Exception as e:
logger.error(f"Error streaming Azure Responses request: {e} with request data: {json.dumps(request_data)}")
raise e
else:
try:
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
**request_data,
stream=True,
stream_options={"include_usage": True},
)
except Exception as e:
logger.error(f"Error streaming Azure Chat Completions request: {e} with request data: {json.dumps(request_data)}")
raise e
return response_stream
@trace_method
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
@@ -71,7 +142,12 @@ class AzureClient(OpenAIClient):
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=base_url)
if self._is_v1_endpoint(base_url):
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
else:
client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=base_url)
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
# TODO: add total usage