feat: Asyncify openai model listing (#2281)

This commit is contained in:
Matthew Zhou
2025-05-20 15:24:00 -07:00
committed by GitHub
parent 50b12a24a8
commit 87dae5d6e7
3 changed files with 243 additions and 107 deletions

View File

@@ -1,6 +1,7 @@
import warnings
from typing import Generator, List, Optional, Union
import httpx
import requests
from openai import OpenAI
@@ -111,6 +112,62 @@ def openai_get_model_list(url: str, api_key: Optional[str] = None, fix_url: bool
raise e
async def openai_get_model_list_async(
url: str,
api_key: Optional[str] = None,
fix_url: bool = False,
extra_params: Optional[dict] = None,
client: Optional["httpx.AsyncClient"] = None,
) -> dict:
"""https://platform.openai.com/docs/api-reference/models/list"""
from letta.utils import printd
# In some cases we may want to double-check the URL and do basic correction
if fix_url and not url.endswith("/v1"):
url = smart_urljoin(url, "v1")
url = smart_urljoin(url, "models")
headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
printd(f"Sending request to {url}")
# Use provided client or create a new one
close_client = False
if client is None:
client = httpx.AsyncClient()
close_client = True
try:
response = await client.get(url, headers=headers, params=extra_params)
response.raise_for_status()
result = response.json()
printd(f"response = {result}")
return result
except httpx.HTTPStatusError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
error_response = None
try:
error_response = http_err.response.json()
except:
error_response = {"status_code": http_err.response.status_code, "text": http_err.response.text}
printd(f"Got HTTPError, exception={http_err}, response={error_response}")
raise http_err
except httpx.RequestError as req_err:
# Handle other httpx-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e
finally:
if close_client:
await client.aclose()
def build_openai_chat_completions_request(
llm_config: LLMConfig,
messages: List[_Message],

View File

@@ -47,9 +47,15 @@ class Provider(ProviderBase):
def list_llm_models(self) -> List[LLMConfig]:
return []
async def list_llm_models_async(self) -> List[LLMConfig]:
return []
def list_embedding_models(self) -> List[EmbeddingConfig]:
return []
async def list_embedding_models_async(self) -> List[EmbeddingConfig]:
return []
def get_model_context_window(self, model_name: str) -> Optional[int]:
raise NotImplementedError
@@ -140,6 +146,19 @@ class LettaProvider(Provider):
)
]
async def list_llm_models_async(self) -> List[LLMConfig]:
return [
LLMConfig(
model="letta-free", # NOTE: renamed
model_endpoint_type="openai",
model_endpoint=LETTA_MODEL_ENDPOINT,
context_window=8192,
handle=self.get_handle("letta-free"),
provider_name=self.name,
provider_category=self.provider_category,
)
]
def list_embedding_models(self):
return [
EmbeddingConfig(
@@ -189,9 +208,40 @@ class OpenAIProvider(Provider):
return data
async def _get_models_async(self) -> List[dict]:
from letta.llm_api.openai import openai_get_model_list_async
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
# See: https://openrouter.ai/docs/requests
extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None
# Similar to Nebius
extra_params = {"verbose": True} if "nebius.com" in self.base_url else None
response = await openai_get_model_list_async(
self.base_url,
api_key=self.api_key,
extra_params=extra_params,
# fix_url=True, # NOTE: make sure together ends with /v1
)
if "data" in response:
data = response["data"]
else:
# TogetherAI's response is missing the 'data' field
data = response
return data
def list_llm_models(self) -> List[LLMConfig]:
data = self._get_models()
return self._list_llm_models(data)
async def list_llm_models_async(self) -> List[LLMConfig]:
data = await self._get_models_async()
return self._list_llm_models(data)
def _list_llm_models(self, data) -> List[LLMConfig]:
configs = []
for model in data:
assert "id" in model, f"OpenAI model missing 'id' field: {model}"
@@ -279,7 +329,6 @@ class OpenAIProvider(Provider):
return configs
def list_embedding_models(self) -> List[EmbeddingConfig]:
if self.base_url == "https://api.openai.com/v1":
# TODO: actually automatically list models for OpenAI
return [
@@ -312,55 +361,92 @@ class OpenAIProvider(Provider):
else:
# Actually attempt to list
data = self._get_models()
return self._list_embedding_models(data)
configs = []
for model in data:
assert "id" in model, f"Model missing 'id' field: {model}"
model_name = model["id"]
async def list_embedding_models_async(self) -> List[EmbeddingConfig]:
if self.base_url == "https://api.openai.com/v1":
# TODO: actually automatically list models for OpenAI
return [
EmbeddingConfig(
embedding_model="text-embedding-ada-002",
embedding_endpoint_type="openai",
embedding_endpoint=self.base_url,
embedding_dim=1536,
embedding_chunk_size=300,
handle=self.get_handle("text-embedding-ada-002", is_embedding=True),
),
EmbeddingConfig(
embedding_model="text-embedding-3-small",
embedding_endpoint_type="openai",
embedding_endpoint=self.base_url,
embedding_dim=2000,
embedding_chunk_size=300,
handle=self.get_handle("text-embedding-3-small", is_embedding=True),
),
EmbeddingConfig(
embedding_model="text-embedding-3-large",
embedding_endpoint_type="openai",
embedding_endpoint=self.base_url,
embedding_dim=2000,
embedding_chunk_size=300,
handle=self.get_handle("text-embedding-3-large", is_embedding=True),
),
]
if "context_length" in model:
# Context length is returned in Nebius as "context_length"
context_window_size = model["context_length"]
else:
context_window_size = self.get_model_context_window_size(model_name)
else:
# Actually attempt to list
data = await self._get_models_async()
return self._list_embedding_models(data)
# We need the context length for embeddings too
if not context_window_size:
continue
def _list_embedding_models(self, data) -> List[EmbeddingConfig]:
configs = []
for model in data:
assert "id" in model, f"Model missing 'id' field: {model}"
model_name = model["id"]
if "nebius.com" in self.base_url:
# Nebius includes the type, which we can use to filter for embedidng models
try:
model_type = model["architecture"]["modality"]
if model_type not in ["text->embedding"]:
# print(f"Skipping model w/ modality {model_type}:\n{model}")
continue
except KeyError:
print(f"Couldn't access architecture type field, skipping model:\n{model}")
continue
if "context_length" in model:
# Context length is returned in Nebius as "context_length"
context_window_size = model["context_length"]
else:
context_window_size = self.get_model_context_window_size(model_name)
elif "together.ai" in self.base_url or "together.xyz" in self.base_url:
# TogetherAI includes the type, which we can use to filter for embedding models
if "type" in model and model["type"] not in ["embedding"]:
# We need the context length for embeddings too
if not context_window_size:
continue
if "nebius.com" in self.base_url:
# Nebius includes the type, which we can use to filter for embedidng models
try:
model_type = model["architecture"]["modality"]
if model_type not in ["text->embedding"]:
# print(f"Skipping model w/ modality {model_type}:\n{model}")
continue
else:
# For other providers we should skip by default, since we don't want to assume embeddings are supported
except KeyError:
print(f"Couldn't access architecture type field, skipping model:\n{model}")
continue
configs.append(
EmbeddingConfig(
embedding_model=model_name,
embedding_endpoint_type=self.provider_type,
embedding_endpoint=self.base_url,
embedding_dim=context_window_size,
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
handle=self.get_handle(model, is_embedding=True),
)
)
elif "together.ai" in self.base_url or "together.xyz" in self.base_url:
# TogetherAI includes the type, which we can use to filter for embedding models
if "type" in model and model["type"] not in ["embedding"]:
# print(f"Skipping model w/ modality {model_type}:\n{model}")
continue
return configs
else:
# For other providers we should skip by default, since we don't want to assume embeddings are supported
continue
configs.append(
EmbeddingConfig(
embedding_model=model_name,
embedding_endpoint_type=self.provider_type,
embedding_endpoint=self.base_url,
embedding_dim=context_window_size,
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
handle=self.get_handle(model, is_embedding=True),
)
)
return configs
def get_model_context_window_size(self, model_name: str):
if model_name in LLM_MAX_TOKENS:

View File

@@ -1,15 +1,12 @@
import os
import pytest
from letta.schemas.providers import (
AnthropicBedrockProvider,
AnthropicProvider,
AzureProvider,
DeepSeekProvider,
GoogleAIProvider,
GoogleVertexProvider,
GroqProvider,
MistralProvider,
OllamaProvider,
OpenAIProvider,
TogetherProvider,
)
@@ -17,11 +14,9 @@ from letta.settings import model_settings
def test_openai():
api_key = os.getenv("OPENAI_API_KEY")
assert api_key is not None
provider = OpenAIProvider(
name="openai",
api_key=api_key,
api_key=model_settings.openai_api_key,
base_url=model_settings.openai_api_base,
)
models = provider.list_llm_models()
@@ -33,24 +28,33 @@ def test_openai():
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_deepseek():
api_key = os.getenv("DEEPSEEK_API_KEY")
assert api_key is not None
provider = DeepSeekProvider(
name="deepseek",
api_key=api_key,
@pytest.mark.asyncio
async def test_openai_async():
provider = OpenAIProvider(
name="openai",
api_key=model_settings.openai_api_key,
base_url=model_settings.openai_api_base,
)
models = await provider.list_llm_models_async()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = await provider.list_embedding_models_async()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_deepseek():
provider = DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_anthropic():
api_key = os.getenv("ANTHROPIC_API_KEY")
assert api_key is not None
provider = AnthropicProvider(
name="anthropic",
api_key=api_key,
api_key=model_settings.anthropic_api_key,
)
models = provider.list_llm_models()
assert len(models) > 0
@@ -60,7 +64,7 @@ def test_anthropic():
def test_groq():
provider = GroqProvider(
name="groq",
api_key=os.getenv("GROQ_API_KEY"),
api_key=model_settings.groq_api_key,
)
models = provider.list_llm_models()
assert len(models) > 0
@@ -70,8 +74,9 @@ def test_groq():
def test_azure():
provider = AzureProvider(
name="azure",
api_key=os.getenv("AZURE_API_KEY"),
base_url=os.getenv("AZURE_BASE_URL"),
api_key=model_settings.azure_api_key,
base_url=model_settings.azure_base_url,
api_version=model_settings.azure_api_version,
)
models = provider.list_llm_models()
assert len(models) > 0
@@ -82,26 +87,24 @@ def test_azure():
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_ollama():
base_url = os.getenv("OLLAMA_BASE_URL")
assert base_url is not None
provider = OllamaProvider(
name="ollama",
base_url=base_url,
default_prompt_formatter=model_settings.default_prompt_formatter,
api_key=None,
)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
# def test_ollama():
# provider = OllamaProvider(
# name="ollama",
# base_url=model_settings.ollama_base_url,
# api_key=None,
# default_prompt_formatter=model_settings.default_prompt_formatter,
# )
# models = provider.list_llm_models()
# assert len(models) > 0
# assert models[0].handle == f"{provider.name}/{models[0].model}"
#
# embedding_models = provider.list_embedding_models()
# assert len(embedding_models) > 0
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_googleai():
api_key = os.getenv("GEMINI_API_KEY")
api_key = model_settings.gemini_api_key
assert api_key is not None
provider = GoogleAIProvider(
name="google_ai",
@@ -119,8 +122,8 @@ def test_googleai():
def test_google_vertex():
provider = GoogleVertexProvider(
name="google_vertex",
google_cloud_project=os.getenv("GCP_PROJECT_ID"),
google_cloud_location=os.getenv("GCP_REGION"),
google_cloud_project=model_settings.google_cloud_project,
google_cloud_location=model_settings.google_cloud_location,
)
models = provider.list_llm_models()
assert len(models) > 0
@@ -131,50 +134,40 @@ def test_google_vertex():
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_mistral():
provider = MistralProvider(
name="mistral",
api_key=os.getenv("MISTRAL_API_KEY"),
)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_together():
provider = TogetherProvider(
name="together",
api_key=os.getenv("TOGETHER_API_KEY"),
default_prompt_formatter="chatml",
api_key=model_settings.together_api_key,
default_prompt_formatter=model_settings.default_prompt_formatter,
)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
# TODO: We don't have embedding models on together for CI
# embedding_models = provider.list_embedding_models()
# assert len(embedding_models) > 0
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_anthropic_bedrock():
from letta.settings import model_settings
provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
# TODO: Add back in, difficulty adding this to CI properly, need boto credentials
# def test_anthropic_bedrock():
# from letta.settings import model_settings
#
# provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region)
# models = provider.list_llm_models()
# assert len(models) > 0
# assert models[0].handle == f"{provider.name}/{models[0].model}"
#
# embedding_models = provider.list_embedding_models()
# assert len(embedding_models) > 0
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_custom_anthropic():
api_key = os.getenv("ANTHROPIC_API_KEY")
assert api_key is not None
provider = AnthropicProvider(
name="custom_anthropic",
api_key=api_key,
api_key=model_settings.anthropic_api_key,
)
models = provider.list_llm_models()
assert len(models) > 0