From 1eb692f62a9dac95c573423f9553d42716d12917 Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 12 Aug 2025 14:43:03 -0700 Subject: [PATCH] feat: add azure llm client (#3882) --- letta/llm_api/azure_client.py | 52 +++++++++++++++++++++++++++++++++++ letta/llm_api/llm_client.py | 7 +++++ 2 files changed, 59 insertions(+) create mode 100644 letta/llm_api/azure_client.py diff --git a/letta/llm_api/azure_client.py b/letta/llm_api/azure_client.py new file mode 100644 index 00000000..95468896 --- /dev/null +++ b/letta/llm_api/azure_client.py @@ -0,0 +1,52 @@ +import os +from typing import List + +from openai import AsyncAzureOpenAI, AzureOpenAI +from openai.types.chat.chat_completion import ChatCompletion + +from letta.llm_api.openai_client import OpenAIClient +from letta.otel.tracing import trace_method +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.llm_config import LLMConfig +from letta.settings import model_settings + + +class AzureClient(OpenAIClient): + + @trace_method + def request(self, request_data: dict, llm_config: LLMConfig) -> dict: + """ + Performs underlying synchronous request to OpenAI API and returns raw response dict. + """ + api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY") + base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL") + api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION") + client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version) + + response: ChatCompletion = client.chat.completions.create(**request_data) + return response.model_dump() + + @trace_method + async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict: + """ + Performs underlying asynchronous request to OpenAI API and returns raw response dict. + """ + api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY") + base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL") + api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION") + client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version) + + response: ChatCompletion = await client.chat.completions.create(**request_data) + return response.model_dump() + + @trace_method + async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]: + """Request embeddings given texts and embedding config""" + api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY") + base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL") + api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION") + client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=base_url) + response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs) + + # TODO: add total usage + return [r.embedding for r in response.data] diff --git a/letta/llm_api/llm_client.py b/letta/llm_api/llm_client.py index b047ec85..d5686eed 100644 --- a/letta/llm_api/llm_client.py +++ b/letta/llm_api/llm_client.py @@ -72,5 +72,12 @@ class LLMClient: put_inner_thoughts_first=put_inner_thoughts_first, actor=actor, ) + case ProviderType.azure: + from letta.llm_api.azure_client import AzureClient + + return AzureClient( + put_inner_thoughts_first=put_inner_thoughts_first, + actor=actor, + ) case _: return None