diff --git a/letta/llm_api/azure_client.py b/letta/llm_api/azure_client.py index 3700c2a0..80926aec 100644 --- a/letta/llm_api/azure_client.py +++ b/letta/llm_api/azure_client.py @@ -1,19 +1,31 @@ +import json import os from typing import List, Optional, Tuple -from openai import AsyncAzureOpenAI, AzureOpenAI +from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream, AzureOpenAI, OpenAI from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.types.responses.response_stream_event import ResponseStreamEvent from letta.helpers.json_helpers import sanitize_unicode_surrogates from letta.llm_api.openai_client import OpenAIClient +from letta.log import get_logger from letta.otel.tracing import trace_method from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import ProviderCategory from letta.schemas.llm_config import LLMConfig from letta.settings import model_settings +logger = get_logger(__name__) + class AzureClient(OpenAIClient): + @staticmethod + def _is_v1_endpoint(base_url: str) -> bool: + if not base_url: + return False + return base_url.rstrip("/").endswith("/openai/v1") + def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]: if llm_config.provider_category == ProviderCategory.byok: from letta.services.provider_manager import ProviderManager @@ -30,20 +42,36 @@ class AzureClient(OpenAIClient): return None, None, None + def _resolve_credentials(self, api_key, base_url, api_version): + """Resolve credentials, falling back to env vars. For v1 endpoints, api_version is not required.""" + if not api_key: + api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY") + if not base_url: + base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL") + if not api_version and not self._is_v1_endpoint(base_url): + api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION") + return api_key, base_url, api_version + @trace_method def request(self, request_data: dict, llm_config: LLMConfig) -> dict: """ Performs underlying synchronous request to OpenAI API and returns raw response dict. """ api_key, base_url, api_version = self.get_byok_overrides(llm_config) - if not api_key or not base_url or not api_version: - api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY") - base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL") - api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION") + api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version) - client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version) - response: ChatCompletion = client.chat.completions.create(**request_data) - return response.model_dump() + if self._is_v1_endpoint(base_url): + client = OpenAI(api_key=api_key, base_url=base_url) + else: + client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version) + + # Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages' + if "input" in request_data and "messages" not in request_data: + resp = client.responses.create(**request_data) + return resp.model_dump() + else: + response: ChatCompletion = client.chat.completions.create(**request_data) + return response.model_dump() @trace_method async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict: @@ -53,17 +81,60 @@ class AzureClient(OpenAIClient): request_data = sanitize_unicode_surrogates(request_data) api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config) - if not api_key or not base_url or not api_version: - api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY") - base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL") - api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION") + api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version) + try: - client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version) - response: ChatCompletion = await client.chat.completions.create(**request_data) + if self._is_v1_endpoint(base_url): + client = AsyncOpenAI(api_key=api_key, base_url=base_url) + else: + client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version) + + # Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages' + if "input" in request_data and "messages" not in request_data: + resp = await client.responses.create(**request_data) + return resp.model_dump() + else: + response: ChatCompletion = await client.chat.completions.create(**request_data) + return response.model_dump() except Exception as e: raise self.handle_llm_error(e) - return response.model_dump() + @trace_method + async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk | ResponseStreamEvent]: + """ + Performs underlying asynchronous streaming request to Azure/OpenAI and returns the async stream iterator. + """ + request_data = sanitize_unicode_surrogates(request_data) + + api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config) + api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version) + + if self._is_v1_endpoint(base_url): + client = AsyncOpenAI(api_key=api_key, base_url=base_url) + else: + client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version) + + # Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages' + if "input" in request_data and "messages" not in request_data: + try: + response_stream: AsyncStream[ResponseStreamEvent] = await client.responses.create( + **request_data, + stream=True, + ) + except Exception as e: + logger.error(f"Error streaming Azure Responses request: {e} with request data: {json.dumps(request_data)}") + raise e + else: + try: + response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create( + **request_data, + stream=True, + stream_options={"include_usage": True}, + ) + except Exception as e: + logger.error(f"Error streaming Azure Chat Completions request: {e} with request data: {json.dumps(request_data)}") + raise e + return response_stream @trace_method async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]: @@ -71,7 +142,12 @@ class AzureClient(OpenAIClient): api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY") base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL") api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION") - client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=base_url) + + if self._is_v1_endpoint(base_url): + client = AsyncOpenAI(api_key=api_key, base_url=base_url) + else: + client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=base_url) + response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs) # TODO: add total usage diff --git a/letta/schemas/providers/azure.py b/letta/schemas/providers/azure.py index cc552dea..11fa2452 100644 --- a/letta/schemas/providers/azure.py +++ b/letta/schemas/providers/azure.py @@ -46,6 +46,12 @@ class AzureProvider(Provider): def replace_none_with_default(cls, v): return v if v is not None else cls.LATEST_API_VERSION + @staticmethod + def _is_v1_endpoint(base_url: str) -> bool: + if not base_url: + return False + return base_url.rstrip("/").endswith("/openai/v1") + def get_azure_chat_completions_endpoint(self, model: str): return f"{self.base_url}/openai/deployments/{model}/chat/completions?api-version={self.api_version}" @@ -60,10 +66,50 @@ class AzureProvider(Provider): # That's the only api version that works with this deployments endpoint return f"{self.base_url}/openai/deployments?api-version=2023-03-15-preview" + def _get_resource_base_url(self) -> str: + """Derive the Azure resource base URL (e.g. https://project.openai.azure.com) from any endpoint format.""" + url = self.base_url.rstrip("/") + if url.endswith("/openai/v1"): + return url[: -len("/openai/v1")] + return url + + async def _get_deployments(self, api_key: str | None) -> list[dict]: + """Fetch deployments using the legacy 2023-03-15-preview endpoint. + + Works for both v1 and legacy endpoints since it hits the resource base URL. + Returns the raw deployment dicts (each has 'id' = deployment name). + """ + resource_base = self._get_resource_base_url() + url = f"{resource_base}/openai/deployments?api-version=2023-03-15-preview" + + headers = {"Content-Type": "application/json"} + if api_key is not None: + headers["api-key"] = f"{api_key}" + + try: + timeout = httpx.Timeout(15.0, connect=10.0) + async with httpx.AsyncClient(timeout=timeout) as http_client: + response = await http_client.get(url, headers=headers) + response.raise_for_status() + except httpx.TimeoutException as e: + raise RuntimeError(f"Azure API timeout after 15s: {e}") + except httpx.HTTPStatusError as e: + raise RuntimeError(f"Failed to retrieve deployment list: {e}") + + return response.json().get("data", []) + async def azure_openai_get_deployed_model_list(self) -> list: """https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP""" api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None + + if self._is_v1_endpoint(self.base_url): + # The v1 /models endpoint returns base model names (e.g. "gpt-5.2-chat-2025-12-11") + # but inference calls require deployment names (e.g. "gpt-5.2-chat"). + # Query the legacy deployments endpoint to get actual deployment names. + return await self._get_deployments(api_key) + + # Legacy path: use Azure SDK + deployments endpoint client = AsyncAzureOpenAI(api_key=api_key, api_version=self.api_version, azure_endpoint=self.base_url) try: @@ -122,6 +168,37 @@ class AzureProvider(Provider): async def list_llm_models_async(self) -> list[LLMConfig]: model_list = await self.azure_openai_get_deployed_model_list() + + if self._is_v1_endpoint(self.base_url): + # v1 path: follow OpenAIProvider pattern with litellm context window lookup + configs = [] + for model in model_list: + model_name = model.get("id") + if not model_name: + continue + + # Use capabilities if present, otherwise accept all (Azure deployments are user-curated) + capabilities = model.get("capabilities") + if capabilities and capabilities.get("chat_completion") is not None: + if not capabilities.get("chat_completion"): + continue + + context_window_size = await self.get_model_context_window_async(model_name) + configs.append( + LLMConfig( + model=model_name, + model_endpoint_type="azure", + model_endpoint=self.base_url, + context_window=context_window_size, + handle=self.get_handle(model_name), + max_tokens=self.get_default_max_output_tokens(model_name), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + return configs + + # Legacy path # Extract models that support text generation model_options = [m for m in model_list if m.get("capabilities").get("chat_completion") == True] @@ -145,6 +222,38 @@ class AzureProvider(Provider): return configs async def list_embedding_models_async(self) -> list[EmbeddingConfig]: + model_list = await self.azure_openai_get_deployed_model_list() + + if self._is_v1_endpoint(self.base_url): + # v1 path: use base URL as endpoint, filter by capabilities or name + configs = [] + for model in model_list: + model_name = model.get("id") + if not model_name: + continue + + # Use capabilities if present, otherwise filter by name + capabilities = model.get("capabilities") + if capabilities and capabilities.get("embeddings") is not None: + if not capabilities.get("embeddings"): + continue + elif "embedding" not in model_name: + continue + + configs.append( + EmbeddingConfig( + embedding_model=model_name, + embedding_endpoint_type="azure", + embedding_endpoint=self.base_url, + embedding_dim=768, + embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, + handle=self.get_handle(model_name, is_embedding=True), + batch_size=1024, + ) + ) + return configs + + # Legacy path def valid_embedding_model(m: dict, require_embedding_in_name: bool = True): valid_name = True if require_embedding_in_name: @@ -152,9 +261,7 @@ class AzureProvider(Provider): return m.get("capabilities").get("embeddings") == True and valid_name - model_list = await self.azure_openai_get_deployed_model_list() # Extract models that support embeddings - model_options = [m for m in model_list if valid_embedding_model(m)] configs = [] @@ -179,6 +286,23 @@ class AzureProvider(Provider): llm_default = LLM_MAX_CONTEXT_WINDOW.get(model_name, 4096) return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, llm_default) + async def get_model_context_window_async(self, model_name: str) -> int | None: + """Get context window size, using litellm specs for v1 endpoints or hardcoded map for legacy.""" + if self._is_v1_endpoint(self.base_url): + from letta.model_specs.litellm_model_specs import get_context_window + + # Litellm keys Azure models with an "azure/" prefix + context_window = await get_context_window(f"azure/{model_name}") + if context_window is not None: + return context_window + # Try without prefix as fallback + context_window = await get_context_window(model_name) + if context_window is not None: + return context_window + # Fall back to hardcoded map, then default + return self.get_model_context_window(model_name) + return self.get_model_context_window(model_name) + async def check_api_key(self): api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None if not api_key: