feat: add new azure api maintaining backward compat (#9387)

* feat: add new azure provider type

* fix context window
This commit is contained in:
Ari Webb
2026-02-09 16:04:27 -08:00
committed by Caren Thomas
parent 226df8baef
commit 5fd5a6dd07
2 changed files with 218 additions and 18 deletions

View File

@@ -1,19 +1,31 @@
import json
import os
from typing import List, Optional, Tuple
from openai import AsyncAzureOpenAI, AzureOpenAI
from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream, AzureOpenAI, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.responses.response_stream_event import ResponseStreamEvent
from letta.helpers.json_helpers import sanitize_unicode_surrogates
from letta.llm_api.openai_client import OpenAIClient
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.settings import model_settings
logger = get_logger(__name__)
class AzureClient(OpenAIClient):
@staticmethod
def _is_v1_endpoint(base_url: str) -> bool:
if not base_url:
return False
return base_url.rstrip("/").endswith("/openai/v1")
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
@@ -30,20 +42,36 @@ class AzureClient(OpenAIClient):
return None, None, None
def _resolve_credentials(self, api_key, base_url, api_version):
"""Resolve credentials, falling back to env vars. For v1 endpoints, api_version is not required."""
if not api_key:
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
if not base_url:
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
if not api_version and not self._is_v1_endpoint(base_url):
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
return api_key, base_url, api_version
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying synchronous request to OpenAI API and returns raw response dict.
"""
api_key, base_url, api_version = self.get_byok_overrides(llm_config)
if not api_key or not base_url or not api_version:
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
if self._is_v1_endpoint(base_url):
client = OpenAI(api_key=api_key, base_url=base_url)
else:
client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
if "input" in request_data and "messages" not in request_data:
resp = client.responses.create(**request_data)
return resp.model_dump()
else:
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
@@ -53,17 +81,60 @@ class AzureClient(OpenAIClient):
request_data = sanitize_unicode_surrogates(request_data)
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
if not api_key or not base_url or not api_version:
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
try:
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
response: ChatCompletion = await client.chat.completions.create(**request_data)
if self._is_v1_endpoint(base_url):
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
else:
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
if "input" in request_data and "messages" not in request_data:
resp = await client.responses.create(**request_data)
return resp.model_dump()
else:
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
except Exception as e:
raise self.handle_llm_error(e)
return response.model_dump()
@trace_method
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk | ResponseStreamEvent]:
"""
Performs underlying asynchronous streaming request to Azure/OpenAI and returns the async stream iterator.
"""
request_data = sanitize_unicode_surrogates(request_data)
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
if self._is_v1_endpoint(base_url):
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
else:
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
if "input" in request_data and "messages" not in request_data:
try:
response_stream: AsyncStream[ResponseStreamEvent] = await client.responses.create(
**request_data,
stream=True,
)
except Exception as e:
logger.error(f"Error streaming Azure Responses request: {e} with request data: {json.dumps(request_data)}")
raise e
else:
try:
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
**request_data,
stream=True,
stream_options={"include_usage": True},
)
except Exception as e:
logger.error(f"Error streaming Azure Chat Completions request: {e} with request data: {json.dumps(request_data)}")
raise e
return response_stream
@trace_method
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
@@ -71,7 +142,12 @@ class AzureClient(OpenAIClient):
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=base_url)
if self._is_v1_endpoint(base_url):
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
else:
client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=base_url)
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
# TODO: add total usage

View File

@@ -46,6 +46,12 @@ class AzureProvider(Provider):
def replace_none_with_default(cls, v):
return v if v is not None else cls.LATEST_API_VERSION
@staticmethod
def _is_v1_endpoint(base_url: str) -> bool:
if not base_url:
return False
return base_url.rstrip("/").endswith("/openai/v1")
def get_azure_chat_completions_endpoint(self, model: str):
return f"{self.base_url}/openai/deployments/{model}/chat/completions?api-version={self.api_version}"
@@ -60,10 +66,50 @@ class AzureProvider(Provider):
# That's the only api version that works with this deployments endpoint
return f"{self.base_url}/openai/deployments?api-version=2023-03-15-preview"
def _get_resource_base_url(self) -> str:
"""Derive the Azure resource base URL (e.g. https://project.openai.azure.com) from any endpoint format."""
url = self.base_url.rstrip("/")
if url.endswith("/openai/v1"):
return url[: -len("/openai/v1")]
return url
async def _get_deployments(self, api_key: str | None) -> list[dict]:
"""Fetch deployments using the legacy 2023-03-15-preview endpoint.
Works for both v1 and legacy endpoints since it hits the resource base URL.
Returns the raw deployment dicts (each has 'id' = deployment name).
"""
resource_base = self._get_resource_base_url()
url = f"{resource_base}/openai/deployments?api-version=2023-03-15-preview"
headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["api-key"] = f"{api_key}"
try:
timeout = httpx.Timeout(15.0, connect=10.0)
async with httpx.AsyncClient(timeout=timeout) as http_client:
response = await http_client.get(url, headers=headers)
response.raise_for_status()
except httpx.TimeoutException as e:
raise RuntimeError(f"Azure API timeout after 15s: {e}")
except httpx.HTTPStatusError as e:
raise RuntimeError(f"Failed to retrieve deployment list: {e}")
return response.json().get("data", [])
async def azure_openai_get_deployed_model_list(self) -> list:
"""https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP"""
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
if self._is_v1_endpoint(self.base_url):
# The v1 /models endpoint returns base model names (e.g. "gpt-5.2-chat-2025-12-11")
# but inference calls require deployment names (e.g. "gpt-5.2-chat").
# Query the legacy deployments endpoint to get actual deployment names.
return await self._get_deployments(api_key)
# Legacy path: use Azure SDK + deployments endpoint
client = AsyncAzureOpenAI(api_key=api_key, api_version=self.api_version, azure_endpoint=self.base_url)
try:
@@ -122,6 +168,37 @@ class AzureProvider(Provider):
async def list_llm_models_async(self) -> list[LLMConfig]:
model_list = await self.azure_openai_get_deployed_model_list()
if self._is_v1_endpoint(self.base_url):
# v1 path: follow OpenAIProvider pattern with litellm context window lookup
configs = []
for model in model_list:
model_name = model.get("id")
if not model_name:
continue
# Use capabilities if present, otherwise accept all (Azure deployments are user-curated)
capabilities = model.get("capabilities")
if capabilities and capabilities.get("chat_completion") is not None:
if not capabilities.get("chat_completion"):
continue
context_window_size = await self.get_model_context_window_async(model_name)
configs.append(
LLMConfig(
model=model_name,
model_endpoint_type="azure",
model_endpoint=self.base_url,
context_window=context_window_size,
handle=self.get_handle(model_name),
max_tokens=self.get_default_max_output_tokens(model_name),
provider_name=self.name,
provider_category=self.provider_category,
)
)
return configs
# Legacy path
# Extract models that support text generation
model_options = [m for m in model_list if m.get("capabilities").get("chat_completion") == True]
@@ -145,6 +222,38 @@ class AzureProvider(Provider):
return configs
async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
model_list = await self.azure_openai_get_deployed_model_list()
if self._is_v1_endpoint(self.base_url):
# v1 path: use base URL as endpoint, filter by capabilities or name
configs = []
for model in model_list:
model_name = model.get("id")
if not model_name:
continue
# Use capabilities if present, otherwise filter by name
capabilities = model.get("capabilities")
if capabilities and capabilities.get("embeddings") is not None:
if not capabilities.get("embeddings"):
continue
elif "embedding" not in model_name:
continue
configs.append(
EmbeddingConfig(
embedding_model=model_name,
embedding_endpoint_type="azure",
embedding_endpoint=self.base_url,
embedding_dim=768,
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
handle=self.get_handle(model_name, is_embedding=True),
batch_size=1024,
)
)
return configs
# Legacy path
def valid_embedding_model(m: dict, require_embedding_in_name: bool = True):
valid_name = True
if require_embedding_in_name:
@@ -152,9 +261,7 @@ class AzureProvider(Provider):
return m.get("capabilities").get("embeddings") == True and valid_name
model_list = await self.azure_openai_get_deployed_model_list()
# Extract models that support embeddings
model_options = [m for m in model_list if valid_embedding_model(m)]
configs = []
@@ -179,6 +286,23 @@ class AzureProvider(Provider):
llm_default = LLM_MAX_CONTEXT_WINDOW.get(model_name, 4096)
return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, llm_default)
async def get_model_context_window_async(self, model_name: str) -> int | None:
"""Get context window size, using litellm specs for v1 endpoints or hardcoded map for legacy."""
if self._is_v1_endpoint(self.base_url):
from letta.model_specs.litellm_model_specs import get_context_window
# Litellm keys Azure models with an "azure/" prefix
context_window = await get_context_window(f"azure/{model_name}")
if context_window is not None:
return context_window
# Try without prefix as fallback
context_window = await get_context_window(model_name)
if context_window is not None:
return context_window
# Fall back to hardcoded map, then default
return self.get_model_context_window(model_name)
return self.get_model_context_window(model_name)
async def check_api_key(self):
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
if not api_key: