feat: add new azure api maintaining backward compat (#9387)
* feat: add new azure provider type * fix context window
This commit is contained in:
@@ -1,19 +1,31 @@
|
||||
import json
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream, AzureOpenAI, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.responses.response_stream_event import ResponseStreamEvent
|
||||
|
||||
from letta.helpers.json_helpers import sanitize_unicode_surrogates
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.settings import model_settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AzureClient(OpenAIClient):
|
||||
@staticmethod
|
||||
def _is_v1_endpoint(base_url: str) -> bool:
|
||||
if not base_url:
|
||||
return False
|
||||
return base_url.rstrip("/").endswith("/openai/v1")
|
||||
|
||||
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
@@ -30,20 +42,36 @@ class AzureClient(OpenAIClient):
|
||||
|
||||
return None, None, None
|
||||
|
||||
def _resolve_credentials(self, api_key, base_url, api_version):
|
||||
"""Resolve credentials, falling back to env vars. For v1 endpoints, api_version is not required."""
|
||||
if not api_key:
|
||||
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
|
||||
if not base_url:
|
||||
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
|
||||
if not api_version and not self._is_v1_endpoint(base_url):
|
||||
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
|
||||
return api_key, base_url, api_version
|
||||
|
||||
@trace_method
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying synchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
api_key, base_url, api_version = self.get_byok_overrides(llm_config)
|
||||
if not api_key or not base_url or not api_version:
|
||||
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
|
||||
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
|
||||
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
|
||||
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
|
||||
|
||||
client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
response: ChatCompletion = client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
if self._is_v1_endpoint(base_url):
|
||||
client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
else:
|
||||
client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
|
||||
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
|
||||
if "input" in request_data and "messages" not in request_data:
|
||||
resp = client.responses.create(**request_data)
|
||||
return resp.model_dump()
|
||||
else:
|
||||
response: ChatCompletion = client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
@@ -53,17 +81,60 @@ class AzureClient(OpenAIClient):
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
|
||||
if not api_key or not base_url or not api_version:
|
||||
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
|
||||
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
|
||||
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
|
||||
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
|
||||
|
||||
try:
|
||||
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
if self._is_v1_endpoint(base_url):
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
else:
|
||||
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
|
||||
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
|
||||
if "input" in request_data and "messages" not in request_data:
|
||||
resp = await client.responses.create(**request_data)
|
||||
return resp.model_dump()
|
||||
else:
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
except Exception as e:
|
||||
raise self.handle_llm_error(e)
|
||||
|
||||
return response.model_dump()
|
||||
@trace_method
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk | ResponseStreamEvent]:
|
||||
"""
|
||||
Performs underlying asynchronous streaming request to Azure/OpenAI and returns the async stream iterator.
|
||||
"""
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
|
||||
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
|
||||
|
||||
if self._is_v1_endpoint(base_url):
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
else:
|
||||
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
|
||||
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
|
||||
if "input" in request_data and "messages" not in request_data:
|
||||
try:
|
||||
response_stream: AsyncStream[ResponseStreamEvent] = await client.responses.create(
|
||||
**request_data,
|
||||
stream=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming Azure Responses request: {e} with request data: {json.dumps(request_data)}")
|
||||
raise e
|
||||
else:
|
||||
try:
|
||||
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
**request_data,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming Azure Chat Completions request: {e} with request data: {json.dumps(request_data)}")
|
||||
raise e
|
||||
return response_stream
|
||||
|
||||
@trace_method
|
||||
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
|
||||
@@ -71,7 +142,12 @@ class AzureClient(OpenAIClient):
|
||||
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
|
||||
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
|
||||
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
|
||||
client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=base_url)
|
||||
|
||||
if self._is_v1_endpoint(base_url):
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
else:
|
||||
client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=base_url)
|
||||
|
||||
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
|
||||
|
||||
# TODO: add total usage
|
||||
|
||||
@@ -46,6 +46,12 @@ class AzureProvider(Provider):
|
||||
def replace_none_with_default(cls, v):
|
||||
return v if v is not None else cls.LATEST_API_VERSION
|
||||
|
||||
@staticmethod
|
||||
def _is_v1_endpoint(base_url: str) -> bool:
|
||||
if not base_url:
|
||||
return False
|
||||
return base_url.rstrip("/").endswith("/openai/v1")
|
||||
|
||||
def get_azure_chat_completions_endpoint(self, model: str):
|
||||
return f"{self.base_url}/openai/deployments/{model}/chat/completions?api-version={self.api_version}"
|
||||
|
||||
@@ -60,10 +66,50 @@ class AzureProvider(Provider):
|
||||
# That's the only api version that works with this deployments endpoint
|
||||
return f"{self.base_url}/openai/deployments?api-version=2023-03-15-preview"
|
||||
|
||||
def _get_resource_base_url(self) -> str:
|
||||
"""Derive the Azure resource base URL (e.g. https://project.openai.azure.com) from any endpoint format."""
|
||||
url = self.base_url.rstrip("/")
|
||||
if url.endswith("/openai/v1"):
|
||||
return url[: -len("/openai/v1")]
|
||||
return url
|
||||
|
||||
async def _get_deployments(self, api_key: str | None) -> list[dict]:
|
||||
"""Fetch deployments using the legacy 2023-03-15-preview endpoint.
|
||||
|
||||
Works for both v1 and legacy endpoints since it hits the resource base URL.
|
||||
Returns the raw deployment dicts (each has 'id' = deployment name).
|
||||
"""
|
||||
resource_base = self._get_resource_base_url()
|
||||
url = f"{resource_base}/openai/deployments?api-version=2023-03-15-preview"
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key is not None:
|
||||
headers["api-key"] = f"{api_key}"
|
||||
|
||||
try:
|
||||
timeout = httpx.Timeout(15.0, connect=10.0)
|
||||
async with httpx.AsyncClient(timeout=timeout) as http_client:
|
||||
response = await http_client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
except httpx.TimeoutException as e:
|
||||
raise RuntimeError(f"Azure API timeout after 15s: {e}")
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise RuntimeError(f"Failed to retrieve deployment list: {e}")
|
||||
|
||||
return response.json().get("data", [])
|
||||
|
||||
async def azure_openai_get_deployed_model_list(self) -> list:
|
||||
"""https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP"""
|
||||
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
|
||||
if self._is_v1_endpoint(self.base_url):
|
||||
# The v1 /models endpoint returns base model names (e.g. "gpt-5.2-chat-2025-12-11")
|
||||
# but inference calls require deployment names (e.g. "gpt-5.2-chat").
|
||||
# Query the legacy deployments endpoint to get actual deployment names.
|
||||
return await self._get_deployments(api_key)
|
||||
|
||||
# Legacy path: use Azure SDK + deployments endpoint
|
||||
client = AsyncAzureOpenAI(api_key=api_key, api_version=self.api_version, azure_endpoint=self.base_url)
|
||||
|
||||
try:
|
||||
@@ -122,6 +168,37 @@ class AzureProvider(Provider):
|
||||
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
model_list = await self.azure_openai_get_deployed_model_list()
|
||||
|
||||
if self._is_v1_endpoint(self.base_url):
|
||||
# v1 path: follow OpenAIProvider pattern with litellm context window lookup
|
||||
configs = []
|
||||
for model in model_list:
|
||||
model_name = model.get("id")
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
# Use capabilities if present, otherwise accept all (Azure deployments are user-curated)
|
||||
capabilities = model.get("capabilities")
|
||||
if capabilities and capabilities.get("chat_completion") is not None:
|
||||
if not capabilities.get("chat_completion"):
|
||||
continue
|
||||
|
||||
context_window_size = await self.get_model_context_window_async(model_name)
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model_name,
|
||||
model_endpoint_type="azure",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
max_tokens=self.get_default_max_output_tokens(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
||||
# Legacy path
|
||||
# Extract models that support text generation
|
||||
model_options = [m for m in model_list if m.get("capabilities").get("chat_completion") == True]
|
||||
|
||||
@@ -145,6 +222,38 @@ class AzureProvider(Provider):
|
||||
return configs
|
||||
|
||||
async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
|
||||
model_list = await self.azure_openai_get_deployed_model_list()
|
||||
|
||||
if self._is_v1_endpoint(self.base_url):
|
||||
# v1 path: use base URL as endpoint, filter by capabilities or name
|
||||
configs = []
|
||||
for model in model_list:
|
||||
model_name = model.get("id")
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
# Use capabilities if present, otherwise filter by name
|
||||
capabilities = model.get("capabilities")
|
||||
if capabilities and capabilities.get("embeddings") is not None:
|
||||
if not capabilities.get("embeddings"):
|
||||
continue
|
||||
elif "embedding" not in model_name:
|
||||
continue
|
||||
|
||||
configs.append(
|
||||
EmbeddingConfig(
|
||||
embedding_model=model_name,
|
||||
embedding_endpoint_type="azure",
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=768,
|
||||
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
handle=self.get_handle(model_name, is_embedding=True),
|
||||
batch_size=1024,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
||||
# Legacy path
|
||||
def valid_embedding_model(m: dict, require_embedding_in_name: bool = True):
|
||||
valid_name = True
|
||||
if require_embedding_in_name:
|
||||
@@ -152,9 +261,7 @@ class AzureProvider(Provider):
|
||||
|
||||
return m.get("capabilities").get("embeddings") == True and valid_name
|
||||
|
||||
model_list = await self.azure_openai_get_deployed_model_list()
|
||||
# Extract models that support embeddings
|
||||
|
||||
model_options = [m for m in model_list if valid_embedding_model(m)]
|
||||
|
||||
configs = []
|
||||
@@ -179,6 +286,23 @@ class AzureProvider(Provider):
|
||||
llm_default = LLM_MAX_CONTEXT_WINDOW.get(model_name, 4096)
|
||||
return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, llm_default)
|
||||
|
||||
async def get_model_context_window_async(self, model_name: str) -> int | None:
|
||||
"""Get context window size, using litellm specs for v1 endpoints or hardcoded map for legacy."""
|
||||
if self._is_v1_endpoint(self.base_url):
|
||||
from letta.model_specs.litellm_model_specs import get_context_window
|
||||
|
||||
# Litellm keys Azure models with an "azure/" prefix
|
||||
context_window = await get_context_window(f"azure/{model_name}")
|
||||
if context_window is not None:
|
||||
return context_window
|
||||
# Try without prefix as fallback
|
||||
context_window = await get_context_window(model_name)
|
||||
if context_window is not None:
|
||||
return context_window
|
||||
# Fall back to hardcoded map, then default
|
||||
return self.get_model_context_window(model_name)
|
||||
return self.get_model_context_window(model_name)
|
||||
|
||||
async def check_api_key(self):
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
if not api_key:
|
||||
|
||||
Reference in New Issue
Block a user