feat: Add listing llm models and embedding models for Azure endpoint (#1846)
Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
@@ -1,8 +1,13 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from letta.constants import LLM_MAX_TOKENS
|
||||
from letta.llm_api.azure_openai import (
|
||||
get_azure_chat_completions_endpoint,
|
||||
get_azure_embeddings_endpoint,
|
||||
)
|
||||
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
@@ -274,10 +279,64 @@ class GoogleAIProvider(Provider):
|
||||
|
||||
class AzureProvider(Provider):
|
||||
name: str = "azure"
|
||||
latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
|
||||
base_url: str = Field(
|
||||
..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
|
||||
)
|
||||
api_key: str = Field(..., description="API key for the Azure API.")
|
||||
api_version: str = Field(latest_api_version, description="API version for the Azure API")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def set_default_api_version(cls, values):
|
||||
"""
|
||||
This ensures that api_version is always set to the default if None is passed in.
|
||||
"""
|
||||
if values.get("api_version") is None:
|
||||
values["api_version"] = cls.model_fields["latest_api_version"].default
|
||||
return values
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
from letta.llm_api.azure_openai import (
|
||||
azure_openai_get_chat_completion_model_list,
|
||||
)
|
||||
|
||||
model_options = azure_openai_get_chat_completion_model_list(self.base_url, api_key=self.api_key, api_version=self.api_version)
|
||||
configs = []
|
||||
for model_option in model_options:
|
||||
model_name = model_option["id"]
|
||||
context_window_size = self.get_model_context_window(model_name)
|
||||
model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version)
|
||||
configs.append(
|
||||
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size)
|
||||
)
|
||||
return configs
|
||||
|
||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||
from letta.llm_api.azure_openai import azure_openai_get_embeddings_model_list
|
||||
|
||||
model_options = azure_openai_get_embeddings_model_list(
|
||||
self.base_url, api_key=self.api_key, api_version=self.api_version, require_embedding_in_name=True
|
||||
)
|
||||
configs = []
|
||||
for model_option in model_options:
|
||||
model_name = model_option["id"]
|
||||
model_endpoint = get_azure_embeddings_endpoint(self.base_url, model_name, self.api_version)
|
||||
configs.append(
|
||||
EmbeddingConfig(
|
||||
embedding_model=model_name,
|
||||
embedding_endpoint_type="azure",
|
||||
embedding_endpoint=model_endpoint,
|
||||
embedding_dim=768,
|
||||
embedding_chunk_size=300, # NOTE: max is 2048
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
||||
def get_model_context_window(self, model_name: str):
|
||||
"""
|
||||
This is hardcoded for now, since there is no API endpoints to retrieve metadata for a model.
|
||||
"""
|
||||
return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, 4096)
|
||||
|
||||
|
||||
class VLLMProvider(OpenAIProvider):
|
||||
|
||||
Reference in New Issue
Block a user