feat: add new azure api maintaining backward compat (#9387)
* feat: add new azure provider type * fix context window
This commit is contained in:
@@ -1,19 +1,31 @@
|
||||
import json
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream, AzureOpenAI, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.responses.response_stream_event import ResponseStreamEvent
|
||||
|
||||
from letta.helpers.json_helpers import sanitize_unicode_surrogates
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.settings import model_settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AzureClient(OpenAIClient):
|
||||
@staticmethod
|
||||
def _is_v1_endpoint(base_url: str) -> bool:
|
||||
if not base_url:
|
||||
return False
|
||||
return base_url.rstrip("/").endswith("/openai/v1")
|
||||
|
||||
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
@@ -30,20 +42,36 @@ class AzureClient(OpenAIClient):
|
||||
|
||||
return None, None, None
|
||||
|
||||
def _resolve_credentials(self, api_key, base_url, api_version):
|
||||
"""Resolve credentials, falling back to env vars. For v1 endpoints, api_version is not required."""
|
||||
if not api_key:
|
||||
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
|
||||
if not base_url:
|
||||
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
|
||||
if not api_version and not self._is_v1_endpoint(base_url):
|
||||
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
|
||||
return api_key, base_url, api_version
|
||||
|
||||
@trace_method
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying synchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
api_key, base_url, api_version = self.get_byok_overrides(llm_config)
|
||||
if not api_key or not base_url or not api_version:
|
||||
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
|
||||
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
|
||||
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
|
||||
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
|
||||
|
||||
client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
response: ChatCompletion = client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
if self._is_v1_endpoint(base_url):
|
||||
client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
else:
|
||||
client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
|
||||
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
|
||||
if "input" in request_data and "messages" not in request_data:
|
||||
resp = client.responses.create(**request_data)
|
||||
return resp.model_dump()
|
||||
else:
|
||||
response: ChatCompletion = client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
@@ -53,17 +81,60 @@ class AzureClient(OpenAIClient):
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
|
||||
if not api_key or not base_url or not api_version:
|
||||
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
|
||||
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
|
||||
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
|
||||
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
|
||||
|
||||
try:
|
||||
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
if self._is_v1_endpoint(base_url):
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
else:
|
||||
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
|
||||
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
|
||||
if "input" in request_data and "messages" not in request_data:
|
||||
resp = await client.responses.create(**request_data)
|
||||
return resp.model_dump()
|
||||
else:
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
except Exception as e:
|
||||
raise self.handle_llm_error(e)
|
||||
|
||||
return response.model_dump()
|
||||
@trace_method
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk | ResponseStreamEvent]:
|
||||
"""
|
||||
Performs underlying asynchronous streaming request to Azure/OpenAI and returns the async stream iterator.
|
||||
"""
|
||||
request_data = sanitize_unicode_surrogates(request_data)
|
||||
|
||||
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
|
||||
api_key, base_url, api_version = self._resolve_credentials(api_key, base_url, api_version)
|
||||
|
||||
if self._is_v1_endpoint(base_url):
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
else:
|
||||
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
|
||||
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
|
||||
if "input" in request_data and "messages" not in request_data:
|
||||
try:
|
||||
response_stream: AsyncStream[ResponseStreamEvent] = await client.responses.create(
|
||||
**request_data,
|
||||
stream=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming Azure Responses request: {e} with request data: {json.dumps(request_data)}")
|
||||
raise e
|
||||
else:
|
||||
try:
|
||||
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
**request_data,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming Azure Chat Completions request: {e} with request data: {json.dumps(request_data)}")
|
||||
raise e
|
||||
return response_stream
|
||||
|
||||
@trace_method
|
||||
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
|
||||
@@ -71,7 +142,12 @@ class AzureClient(OpenAIClient):
|
||||
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
|
||||
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
|
||||
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
|
||||
client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=base_url)
|
||||
|
||||
if self._is_v1_endpoint(base_url):
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
else:
|
||||
client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=base_url)
|
||||
|
||||
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
|
||||
|
||||
# TODO: add total usage
|
||||
|
||||
Reference in New Issue
Block a user