Files
letta-server/letta/llm_api/azure_client.py
Kian Jones fecf6decfb chore: migrate to ruff (#4305)
* base requirements

* autofix

* Configure ruff for Python linting and formatting

- Set up minimal ruff configuration with basic checks (E, W, F, I)
- Add temporary ignores for common issues during migration
- Configure pre-commit hooks to use ruff with pass_filenames
- This enables gradual migration from black to ruff

* Delete sdj

* autofixed only

* migrate lint action

* more autofixed

* more fixes

* change precommit

* try changing the hook

* try this stuff
2025-08-29 11:11:19 -07:00

73 lines
3.7 KiB
Python

import os
from typing import List, Optional, Tuple
from openai import AsyncAzureOpenAI, AzureOpenAI
from openai.types.chat.chat_completion import ChatCompletion
from letta.llm_api.openai_client import OpenAIClient
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.settings import model_settings
class AzureClient(OpenAIClient):
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
return ProviderManager().get_azure_credentials(llm_config.provider_name, actor=self.actor)
return None, None, None
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
return await ProviderManager().get_azure_credentials_async(llm_config.provider_name, actor=self.actor)
return None, None, None
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying synchronous request to OpenAI API and returns raw response dict.
"""
api_key, base_url, api_version = self.get_byok_overrides(llm_config)
if not api_key or not base_url or not api_version:
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
"""
api_key, base_url, api_version = await self.get_byok_overrides_async(llm_config)
if not api_key or not base_url or not api_version:
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
"""Request embeddings given texts and embedding config"""
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=base_url)
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
# TODO: add total usage
return [r.embedding for r in response.data]