feat: Add listing llm models and embedding models for Azure endpoint (#1846)

Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
Matthew Zhou
2024-10-08 15:14:55 -07:00
committed by GitHub
parent dd51b15154
commit 446b8f2154
10 changed files with 129 additions and 61 deletions

View File

@@ -1,8 +1,13 @@
from typing import List, Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from letta.constants import LLM_MAX_TOKENS
from letta.llm_api.azure_openai import (
get_azure_chat_completions_endpoint,
get_azure_embeddings_endpoint,
)
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
@@ -274,10 +279,64 @@ class GoogleAIProvider(Provider):
class AzureProvider(Provider):
name: str = "azure"
latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
base_url: str = Field(
..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
)
api_key: str = Field(..., description="API key for the Azure API.")
api_version: str = Field(latest_api_version, description="API version for the Azure API")
@model_validator(mode="before")
def set_default_api_version(cls, values):
"""
This ensures that api_version is always set to the default if None is passed in.
"""
if values.get("api_version") is None:
values["api_version"] = cls.model_fields["latest_api_version"].default
return values
def list_llm_models(self) -> List[LLMConfig]:
from letta.llm_api.azure_openai import (
azure_openai_get_chat_completion_model_list,
)
model_options = azure_openai_get_chat_completion_model_list(self.base_url, api_key=self.api_key, api_version=self.api_version)
configs = []
for model_option in model_options:
model_name = model_option["id"]
context_window_size = self.get_model_context_window(model_name)
model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version)
configs.append(
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size)
)
return configs
def list_embedding_models(self) -> List[EmbeddingConfig]:
from letta.llm_api.azure_openai import azure_openai_get_embeddings_model_list
model_options = azure_openai_get_embeddings_model_list(
self.base_url, api_key=self.api_key, api_version=self.api_version, require_embedding_in_name=True
)
configs = []
for model_option in model_options:
model_name = model_option["id"]
model_endpoint = get_azure_embeddings_endpoint(self.base_url, model_name, self.api_version)
configs.append(
EmbeddingConfig(
embedding_model=model_name,
embedding_endpoint_type="azure",
embedding_endpoint=model_endpoint,
embedding_dim=768,
embedding_chunk_size=300, # NOTE: max is 2048
)
)
return configs
def get_model_context_window(self, model_name: str):
"""
This is hardcoded for now, since there is no API endpoints to retrieve metadata for a model.
"""
return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, 4096)
class VLLMProvider(OpenAIProvider):