feat: Asyncify openai model listing (#2281)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import warnings
|
||||
from typing import Generator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -111,6 +112,62 @@ def openai_get_model_list(url: str, api_key: Optional[str] = None, fix_url: bool
|
||||
raise e
|
||||
|
||||
|
||||
async def openai_get_model_list_async(
|
||||
url: str,
|
||||
api_key: Optional[str] = None,
|
||||
fix_url: bool = False,
|
||||
extra_params: Optional[dict] = None,
|
||||
client: Optional["httpx.AsyncClient"] = None,
|
||||
) -> dict:
|
||||
"""https://platform.openai.com/docs/api-reference/models/list"""
|
||||
from letta.utils import printd
|
||||
|
||||
# In some cases we may want to double-check the URL and do basic correction
|
||||
if fix_url and not url.endswith("/v1"):
|
||||
url = smart_urljoin(url, "v1")
|
||||
|
||||
url = smart_urljoin(url, "models")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
printd(f"Sending request to {url}")
|
||||
|
||||
# Use provided client or create a new one
|
||||
close_client = False
|
||||
if client is None:
|
||||
client = httpx.AsyncClient()
|
||||
close_client = True
|
||||
|
||||
try:
|
||||
response = await client.get(url, headers=headers, params=extra_params)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
printd(f"response = {result}")
|
||||
return result
|
||||
except httpx.HTTPStatusError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
error_response = None
|
||||
try:
|
||||
error_response = http_err.response.json()
|
||||
except:
|
||||
error_response = {"status_code": http_err.response.status_code, "text": http_err.response.text}
|
||||
printd(f"Got HTTPError, exception={http_err}, response={error_response}")
|
||||
raise http_err
|
||||
except httpx.RequestError as req_err:
|
||||
# Handle other httpx-related errors (e.g., connection error)
|
||||
printd(f"Got RequestException, exception={req_err}")
|
||||
raise req_err
|
||||
except Exception as e:
|
||||
# Handle other potential errors
|
||||
printd(f"Got unknown Exception, exception={e}")
|
||||
raise e
|
||||
finally:
|
||||
if close_client:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
def build_openai_chat_completions_request(
|
||||
llm_config: LLMConfig,
|
||||
messages: List[_Message],
|
||||
|
||||
@@ -47,9 +47,15 @@ class Provider(ProviderBase):
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
return []
|
||||
|
||||
async def list_llm_models_async(self) -> List[LLMConfig]:
|
||||
return []
|
||||
|
||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||
return []
|
||||
|
||||
async def list_embedding_models_async(self) -> List[EmbeddingConfig]:
|
||||
return []
|
||||
|
||||
def get_model_context_window(self, model_name: str) -> Optional[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -140,6 +146,19 @@ class LettaProvider(Provider):
|
||||
)
|
||||
]
|
||||
|
||||
async def list_llm_models_async(self) -> List[LLMConfig]:
|
||||
return [
|
||||
LLMConfig(
|
||||
model="letta-free", # NOTE: renamed
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint=LETTA_MODEL_ENDPOINT,
|
||||
context_window=8192,
|
||||
handle=self.get_handle("letta-free"),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
]
|
||||
|
||||
def list_embedding_models(self):
|
||||
return [
|
||||
EmbeddingConfig(
|
||||
@@ -189,9 +208,40 @@ class OpenAIProvider(Provider):
|
||||
|
||||
return data
|
||||
|
||||
async def _get_models_async(self) -> List[dict]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
|
||||
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
|
||||
# See: https://openrouter.ai/docs/requests
|
||||
extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None
|
||||
|
||||
# Similar to Nebius
|
||||
extra_params = {"verbose": True} if "nebius.com" in self.base_url else None
|
||||
|
||||
response = await openai_get_model_list_async(
|
||||
self.base_url,
|
||||
api_key=self.api_key,
|
||||
extra_params=extra_params,
|
||||
# fix_url=True, # NOTE: make sure together ends with /v1
|
||||
)
|
||||
|
||||
if "data" in response:
|
||||
data = response["data"]
|
||||
else:
|
||||
# TogetherAI's response is missing the 'data' field
|
||||
data = response
|
||||
|
||||
return data
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
data = self._get_models()
|
||||
return self._list_llm_models(data)
|
||||
|
||||
async def list_llm_models_async(self) -> List[LLMConfig]:
|
||||
data = await self._get_models_async()
|
||||
return self._list_llm_models(data)
|
||||
|
||||
def _list_llm_models(self, data) -> List[LLMConfig]:
|
||||
configs = []
|
||||
for model in data:
|
||||
assert "id" in model, f"OpenAI model missing 'id' field: {model}"
|
||||
@@ -279,7 +329,6 @@ class OpenAIProvider(Provider):
|
||||
return configs
|
||||
|
||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||
|
||||
if self.base_url == "https://api.openai.com/v1":
|
||||
# TODO: actually automatically list models for OpenAI
|
||||
return [
|
||||
@@ -312,55 +361,92 @@ class OpenAIProvider(Provider):
|
||||
else:
|
||||
# Actually attempt to list
|
||||
data = self._get_models()
|
||||
return self._list_embedding_models(data)
|
||||
|
||||
configs = []
|
||||
for model in data:
|
||||
assert "id" in model, f"Model missing 'id' field: {model}"
|
||||
model_name = model["id"]
|
||||
async def list_embedding_models_async(self) -> List[EmbeddingConfig]:
|
||||
if self.base_url == "https://api.openai.com/v1":
|
||||
# TODO: actually automatically list models for OpenAI
|
||||
return [
|
||||
EmbeddingConfig(
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=1536,
|
||||
embedding_chunk_size=300,
|
||||
handle=self.get_handle("text-embedding-ada-002", is_embedding=True),
|
||||
),
|
||||
EmbeddingConfig(
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=2000,
|
||||
embedding_chunk_size=300,
|
||||
handle=self.get_handle("text-embedding-3-small", is_embedding=True),
|
||||
),
|
||||
EmbeddingConfig(
|
||||
embedding_model="text-embedding-3-large",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=2000,
|
||||
embedding_chunk_size=300,
|
||||
handle=self.get_handle("text-embedding-3-large", is_embedding=True),
|
||||
),
|
||||
]
|
||||
|
||||
if "context_length" in model:
|
||||
# Context length is returned in Nebius as "context_length"
|
||||
context_window_size = model["context_length"]
|
||||
else:
|
||||
context_window_size = self.get_model_context_window_size(model_name)
|
||||
else:
|
||||
# Actually attempt to list
|
||||
data = await self._get_models_async()
|
||||
return self._list_embedding_models(data)
|
||||
|
||||
# We need the context length for embeddings too
|
||||
if not context_window_size:
|
||||
continue
|
||||
def _list_embedding_models(self, data) -> List[EmbeddingConfig]:
|
||||
configs = []
|
||||
for model in data:
|
||||
assert "id" in model, f"Model missing 'id' field: {model}"
|
||||
model_name = model["id"]
|
||||
|
||||
if "nebius.com" in self.base_url:
|
||||
# Nebius includes the type, which we can use to filter for embedidng models
|
||||
try:
|
||||
model_type = model["architecture"]["modality"]
|
||||
if model_type not in ["text->embedding"]:
|
||||
# print(f"Skipping model w/ modality {model_type}:\n{model}")
|
||||
continue
|
||||
except KeyError:
|
||||
print(f"Couldn't access architecture type field, skipping model:\n{model}")
|
||||
continue
|
||||
if "context_length" in model:
|
||||
# Context length is returned in Nebius as "context_length"
|
||||
context_window_size = model["context_length"]
|
||||
else:
|
||||
context_window_size = self.get_model_context_window_size(model_name)
|
||||
|
||||
elif "together.ai" in self.base_url or "together.xyz" in self.base_url:
|
||||
# TogetherAI includes the type, which we can use to filter for embedding models
|
||||
if "type" in model and model["type"] not in ["embedding"]:
|
||||
# We need the context length for embeddings too
|
||||
if not context_window_size:
|
||||
continue
|
||||
|
||||
if "nebius.com" in self.base_url:
|
||||
# Nebius includes the type, which we can use to filter for embedidng models
|
||||
try:
|
||||
model_type = model["architecture"]["modality"]
|
||||
if model_type not in ["text->embedding"]:
|
||||
# print(f"Skipping model w/ modality {model_type}:\n{model}")
|
||||
continue
|
||||
|
||||
else:
|
||||
# For other providers we should skip by default, since we don't want to assume embeddings are supported
|
||||
except KeyError:
|
||||
print(f"Couldn't access architecture type field, skipping model:\n{model}")
|
||||
continue
|
||||
|
||||
configs.append(
|
||||
EmbeddingConfig(
|
||||
embedding_model=model_name,
|
||||
embedding_endpoint_type=self.provider_type,
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=context_window_size,
|
||||
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
handle=self.get_handle(model, is_embedding=True),
|
||||
)
|
||||
)
|
||||
elif "together.ai" in self.base_url or "together.xyz" in self.base_url:
|
||||
# TogetherAI includes the type, which we can use to filter for embedding models
|
||||
if "type" in model and model["type"] not in ["embedding"]:
|
||||
# print(f"Skipping model w/ modality {model_type}:\n{model}")
|
||||
continue
|
||||
|
||||
return configs
|
||||
else:
|
||||
# For other providers we should skip by default, since we don't want to assume embeddings are supported
|
||||
continue
|
||||
|
||||
configs.append(
|
||||
EmbeddingConfig(
|
||||
embedding_model=model_name,
|
||||
embedding_endpoint_type=self.provider_type,
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=context_window_size,
|
||||
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
handle=self.get_handle(model, is_embedding=True),
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
def get_model_context_window_size(self, model_name: str):
|
||||
if model_name in LLM_MAX_TOKENS:
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from letta.schemas.providers import (
|
||||
AnthropicBedrockProvider,
|
||||
AnthropicProvider,
|
||||
AzureProvider,
|
||||
DeepSeekProvider,
|
||||
GoogleAIProvider,
|
||||
GoogleVertexProvider,
|
||||
GroqProvider,
|
||||
MistralProvider,
|
||||
OllamaProvider,
|
||||
OpenAIProvider,
|
||||
TogetherProvider,
|
||||
)
|
||||
@@ -17,11 +14,9 @@ from letta.settings import model_settings
|
||||
|
||||
|
||||
def test_openai():
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = OpenAIProvider(
|
||||
name="openai",
|
||||
api_key=api_key,
|
||||
api_key=model_settings.openai_api_key,
|
||||
base_url=model_settings.openai_api_base,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
@@ -33,24 +28,33 @@ def test_openai():
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_deepseek():
|
||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = DeepSeekProvider(
|
||||
name="deepseek",
|
||||
api_key=api_key,
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_async():
|
||||
provider = OpenAIProvider(
|
||||
name="openai",
|
||||
api_key=model_settings.openai_api_key,
|
||||
base_url=model_settings.openai_api_base,
|
||||
)
|
||||
models = await provider.list_llm_models_async()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = await provider.list_embedding_models_async()
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_deepseek():
|
||||
provider = DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
def test_anthropic():
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = AnthropicProvider(
|
||||
name="anthropic",
|
||||
api_key=api_key,
|
||||
api_key=model_settings.anthropic_api_key,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
@@ -60,7 +64,7 @@ def test_anthropic():
|
||||
def test_groq():
|
||||
provider = GroqProvider(
|
||||
name="groq",
|
||||
api_key=os.getenv("GROQ_API_KEY"),
|
||||
api_key=model_settings.groq_api_key,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
@@ -70,8 +74,9 @@ def test_groq():
|
||||
def test_azure():
|
||||
provider = AzureProvider(
|
||||
name="azure",
|
||||
api_key=os.getenv("AZURE_API_KEY"),
|
||||
base_url=os.getenv("AZURE_BASE_URL"),
|
||||
api_key=model_settings.azure_api_key,
|
||||
base_url=model_settings.azure_base_url,
|
||||
api_version=model_settings.azure_api_version,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
@@ -82,26 +87,24 @@ def test_azure():
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_ollama():
|
||||
base_url = os.getenv("OLLAMA_BASE_URL")
|
||||
assert base_url is not None
|
||||
provider = OllamaProvider(
|
||||
name="ollama",
|
||||
base_url=base_url,
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
api_key=None,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = provider.list_embedding_models()
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
# def test_ollama():
|
||||
# provider = OllamaProvider(
|
||||
# name="ollama",
|
||||
# base_url=model_settings.ollama_base_url,
|
||||
# api_key=None,
|
||||
# default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
# )
|
||||
# models = provider.list_llm_models()
|
||||
# assert len(models) > 0
|
||||
# assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
#
|
||||
# embedding_models = provider.list_embedding_models()
|
||||
# assert len(embedding_models) > 0
|
||||
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_googleai():
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
api_key = model_settings.gemini_api_key
|
||||
assert api_key is not None
|
||||
provider = GoogleAIProvider(
|
||||
name="google_ai",
|
||||
@@ -119,8 +122,8 @@ def test_googleai():
|
||||
def test_google_vertex():
|
||||
provider = GoogleVertexProvider(
|
||||
name="google_vertex",
|
||||
google_cloud_project=os.getenv("GCP_PROJECT_ID"),
|
||||
google_cloud_location=os.getenv("GCP_REGION"),
|
||||
google_cloud_project=model_settings.google_cloud_project,
|
||||
google_cloud_location=model_settings.google_cloud_location,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
@@ -131,50 +134,40 @@ def test_google_vertex():
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_mistral():
|
||||
provider = MistralProvider(
|
||||
name="mistral",
|
||||
api_key=os.getenv("MISTRAL_API_KEY"),
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
def test_together():
|
||||
provider = TogetherProvider(
|
||||
name="together",
|
||||
api_key=os.getenv("TOGETHER_API_KEY"),
|
||||
default_prompt_formatter="chatml",
|
||||
api_key=model_settings.together_api_key,
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = provider.list_embedding_models()
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
# TODO: We don't have embedding models on together for CI
|
||||
# embedding_models = provider.list_embedding_models()
|
||||
# assert len(embedding_models) > 0
|
||||
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_anthropic_bedrock():
|
||||
from letta.settings import model_settings
|
||||
|
||||
provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = provider.list_embedding_models()
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
# TODO: Add back in, difficulty adding this to CI properly, need boto credentials
|
||||
# def test_anthropic_bedrock():
|
||||
# from letta.settings import model_settings
|
||||
#
|
||||
# provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region)
|
||||
# models = provider.list_llm_models()
|
||||
# assert len(models) > 0
|
||||
# assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
#
|
||||
# embedding_models = provider.list_embedding_models()
|
||||
# assert len(embedding_models) > 0
|
||||
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_custom_anthropic():
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = AnthropicProvider(
|
||||
name="custom_anthropic",
|
||||
api_key=api_key,
|
||||
api_key=model_settings.anthropic_api_key,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
|
||||
Reference in New Issue
Block a user