From 87dae5d6e7b8204d4a71edcf51f9672bc23d80c5 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 20 May 2025 15:24:00 -0700 Subject: [PATCH] feat: Asyncify openai model listing (#2281) --- letta/llm_api/openai.py | 57 +++++++++++++ letta/schemas/providers.py | 166 ++++++++++++++++++++++++++++--------- tests/test_providers.py | 127 ++++++++++++++-------------- 3 files changed, 243 insertions(+), 107 deletions(-) diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index f9bfa21b..045ab1f4 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -1,6 +1,7 @@ import warnings from typing import Generator, List, Optional, Union +import httpx import requests from openai import OpenAI @@ -111,6 +112,62 @@ def openai_get_model_list(url: str, api_key: Optional[str] = None, fix_url: bool raise e +async def openai_get_model_list_async( + url: str, + api_key: Optional[str] = None, + fix_url: bool = False, + extra_params: Optional[dict] = None, + client: Optional["httpx.AsyncClient"] = None, +) -> dict: + """https://platform.openai.com/docs/api-reference/models/list""" + from letta.utils import printd + + # In some cases we may want to double-check the URL and do basic correction + if fix_url and not url.endswith("/v1"): + url = smart_urljoin(url, "v1") + + url = smart_urljoin(url, "models") + + headers = {"Content-Type": "application/json"} + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" + + printd(f"Sending request to {url}") + + # Use provided client or create a new one + close_client = False + if client is None: + client = httpx.AsyncClient() + close_client = True + + try: + response = await client.get(url, headers=headers, params=extra_params) + response.raise_for_status() + result = response.json() + printd(f"response = {result}") + return result + except httpx.HTTPStatusError as http_err: + # Handle HTTP errors (e.g., response 4XX, 5XX) + error_response = None + try: + error_response = http_err.response.json() + except: + error_response = {"status_code": http_err.response.status_code, "text": http_err.response.text} + printd(f"Got HTTPError, exception={http_err}, response={error_response}") + raise http_err + except httpx.RequestError as req_err: + # Handle other httpx-related errors (e.g., connection error) + printd(f"Got RequestException, exception={req_err}") + raise req_err + except Exception as e: + # Handle other potential errors + printd(f"Got unknown Exception, exception={e}") + raise e + finally: + if close_client: + await client.aclose() + + def build_openai_chat_completions_request( llm_config: LLMConfig, messages: List[_Message], diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index b55d9267..1c0072b4 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -47,9 +47,15 @@ class Provider(ProviderBase): def list_llm_models(self) -> List[LLMConfig]: return [] + async def list_llm_models_async(self) -> List[LLMConfig]: + return [] + def list_embedding_models(self) -> List[EmbeddingConfig]: return [] + async def list_embedding_models_async(self) -> List[EmbeddingConfig]: + return [] + def get_model_context_window(self, model_name: str) -> Optional[int]: raise NotImplementedError @@ -140,6 +146,19 @@ class LettaProvider(Provider): ) ] + async def list_llm_models_async(self) -> List[LLMConfig]: + return [ + LLMConfig( + model="letta-free", # NOTE: renamed + model_endpoint_type="openai", + model_endpoint=LETTA_MODEL_ENDPOINT, + context_window=8192, + handle=self.get_handle("letta-free"), + provider_name=self.name, + provider_category=self.provider_category, + ) + ] + def list_embedding_models(self): return [ EmbeddingConfig( @@ -189,9 +208,40 @@ class OpenAIProvider(Provider): return data + async def _get_models_async(self) -> List[dict]: + from letta.llm_api.openai import openai_get_model_list_async + + # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)... + # See: https://openrouter.ai/docs/requests + extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None + + # Similar to Nebius + extra_params = {"verbose": True} if "nebius.com" in self.base_url else None + + response = await openai_get_model_list_async( + self.base_url, + api_key=self.api_key, + extra_params=extra_params, + # fix_url=True, # NOTE: make sure together ends with /v1 + ) + + if "data" in response: + data = response["data"] + else: + # TogetherAI's response is missing the 'data' field + data = response + + return data + def list_llm_models(self) -> List[LLMConfig]: data = self._get_models() + return self._list_llm_models(data) + async def list_llm_models_async(self) -> List[LLMConfig]: + data = await self._get_models_async() + return self._list_llm_models(data) + + def _list_llm_models(self, data) -> List[LLMConfig]: configs = [] for model in data: assert "id" in model, f"OpenAI model missing 'id' field: {model}" @@ -279,7 +329,6 @@ class OpenAIProvider(Provider): return configs def list_embedding_models(self) -> List[EmbeddingConfig]: - if self.base_url == "https://api.openai.com/v1": # TODO: actually automatically list models for OpenAI return [ @@ -312,55 +361,92 @@ class OpenAIProvider(Provider): else: # Actually attempt to list data = self._get_models() + return self._list_embedding_models(data) - configs = [] - for model in data: - assert "id" in model, f"Model missing 'id' field: {model}" - model_name = model["id"] + async def list_embedding_models_async(self) -> List[EmbeddingConfig]: + if self.base_url == "https://api.openai.com/v1": + # TODO: actually automatically list models for OpenAI + return [ + EmbeddingConfig( + embedding_model="text-embedding-ada-002", + embedding_endpoint_type="openai", + embedding_endpoint=self.base_url, + embedding_dim=1536, + embedding_chunk_size=300, + handle=self.get_handle("text-embedding-ada-002", is_embedding=True), + ), + EmbeddingConfig( + embedding_model="text-embedding-3-small", + embedding_endpoint_type="openai", + embedding_endpoint=self.base_url, + embedding_dim=2000, + embedding_chunk_size=300, + handle=self.get_handle("text-embedding-3-small", is_embedding=True), + ), + EmbeddingConfig( + embedding_model="text-embedding-3-large", + embedding_endpoint_type="openai", + embedding_endpoint=self.base_url, + embedding_dim=2000, + embedding_chunk_size=300, + handle=self.get_handle("text-embedding-3-large", is_embedding=True), + ), + ] - if "context_length" in model: - # Context length is returned in Nebius as "context_length" - context_window_size = model["context_length"] - else: - context_window_size = self.get_model_context_window_size(model_name) + else: + # Actually attempt to list + data = await self._get_models_async() + return self._list_embedding_models(data) - # We need the context length for embeddings too - if not context_window_size: - continue + def _list_embedding_models(self, data) -> List[EmbeddingConfig]: + configs = [] + for model in data: + assert "id" in model, f"Model missing 'id' field: {model}" + model_name = model["id"] - if "nebius.com" in self.base_url: - # Nebius includes the type, which we can use to filter for embedidng models - try: - model_type = model["architecture"]["modality"] - if model_type not in ["text->embedding"]: - # print(f"Skipping model w/ modality {model_type}:\n{model}") - continue - except KeyError: - print(f"Couldn't access architecture type field, skipping model:\n{model}") - continue + if "context_length" in model: + # Context length is returned in Nebius as "context_length" + context_window_size = model["context_length"] + else: + context_window_size = self.get_model_context_window_size(model_name) - elif "together.ai" in self.base_url or "together.xyz" in self.base_url: - # TogetherAI includes the type, which we can use to filter for embedding models - if "type" in model and model["type"] not in ["embedding"]: + # We need the context length for embeddings too + if not context_window_size: + continue + + if "nebius.com" in self.base_url: + # Nebius includes the type, which we can use to filter for embedidng models + try: + model_type = model["architecture"]["modality"] + if model_type not in ["text->embedding"]: # print(f"Skipping model w/ modality {model_type}:\n{model}") continue - - else: - # For other providers we should skip by default, since we don't want to assume embeddings are supported + except KeyError: + print(f"Couldn't access architecture type field, skipping model:\n{model}") continue - configs.append( - EmbeddingConfig( - embedding_model=model_name, - embedding_endpoint_type=self.provider_type, - embedding_endpoint=self.base_url, - embedding_dim=context_window_size, - embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, - handle=self.get_handle(model, is_embedding=True), - ) - ) + elif "together.ai" in self.base_url or "together.xyz" in self.base_url: + # TogetherAI includes the type, which we can use to filter for embedding models + if "type" in model and model["type"] not in ["embedding"]: + # print(f"Skipping model w/ modality {model_type}:\n{model}") + continue - return configs + else: + # For other providers we should skip by default, since we don't want to assume embeddings are supported + continue + + configs.append( + EmbeddingConfig( + embedding_model=model_name, + embedding_endpoint_type=self.provider_type, + embedding_endpoint=self.base_url, + embedding_dim=context_window_size, + embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, + handle=self.get_handle(model, is_embedding=True), + ) + ) + + return configs def get_model_context_window_size(self, model_name: str): if model_name in LLM_MAX_TOKENS: diff --git a/tests/test_providers.py b/tests/test_providers.py index 2ab6606d..67ae5a83 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -1,15 +1,12 @@ -import os +import pytest from letta.schemas.providers import ( - AnthropicBedrockProvider, AnthropicProvider, AzureProvider, DeepSeekProvider, GoogleAIProvider, GoogleVertexProvider, GroqProvider, - MistralProvider, - OllamaProvider, OpenAIProvider, TogetherProvider, ) @@ -17,11 +14,9 @@ from letta.settings import model_settings def test_openai(): - api_key = os.getenv("OPENAI_API_KEY") - assert api_key is not None provider = OpenAIProvider( name="openai", - api_key=api_key, + api_key=model_settings.openai_api_key, base_url=model_settings.openai_api_base, ) models = provider.list_llm_models() @@ -33,24 +28,33 @@ def test_openai(): assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" -def test_deepseek(): - api_key = os.getenv("DEEPSEEK_API_KEY") - assert api_key is not None - provider = DeepSeekProvider( - name="deepseek", - api_key=api_key, +@pytest.mark.asyncio +async def test_openai_async(): + provider = OpenAIProvider( + name="openai", + api_key=model_settings.openai_api_key, + base_url=model_settings.openai_api_base, ) + models = await provider.list_llm_models_async() + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" + + embedding_models = await provider.list_embedding_models_async() + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" + + +def test_deepseek(): + provider = DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key) models = provider.list_llm_models() assert len(models) > 0 assert models[0].handle == f"{provider.name}/{models[0].model}" def test_anthropic(): - api_key = os.getenv("ANTHROPIC_API_KEY") - assert api_key is not None provider = AnthropicProvider( name="anthropic", - api_key=api_key, + api_key=model_settings.anthropic_api_key, ) models = provider.list_llm_models() assert len(models) > 0 @@ -60,7 +64,7 @@ def test_anthropic(): def test_groq(): provider = GroqProvider( name="groq", - api_key=os.getenv("GROQ_API_KEY"), + api_key=model_settings.groq_api_key, ) models = provider.list_llm_models() assert len(models) > 0 @@ -70,8 +74,9 @@ def test_groq(): def test_azure(): provider = AzureProvider( name="azure", - api_key=os.getenv("AZURE_API_KEY"), - base_url=os.getenv("AZURE_BASE_URL"), + api_key=model_settings.azure_api_key, + base_url=model_settings.azure_base_url, + api_version=model_settings.azure_api_version, ) models = provider.list_llm_models() assert len(models) > 0 @@ -82,26 +87,24 @@ def test_azure(): assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" -def test_ollama(): - base_url = os.getenv("OLLAMA_BASE_URL") - assert base_url is not None - provider = OllamaProvider( - name="ollama", - base_url=base_url, - default_prompt_formatter=model_settings.default_prompt_formatter, - api_key=None, - ) - models = provider.list_llm_models() - assert len(models) > 0 - assert models[0].handle == f"{provider.name}/{models[0].model}" - - embedding_models = provider.list_embedding_models() - assert len(embedding_models) > 0 - assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" +# def test_ollama(): +# provider = OllamaProvider( +# name="ollama", +# base_url=model_settings.ollama_base_url, +# api_key=None, +# default_prompt_formatter=model_settings.default_prompt_formatter, +# ) +# models = provider.list_llm_models() +# assert len(models) > 0 +# assert models[0].handle == f"{provider.name}/{models[0].model}" +# +# embedding_models = provider.list_embedding_models() +# assert len(embedding_models) > 0 +# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" def test_googleai(): - api_key = os.getenv("GEMINI_API_KEY") + api_key = model_settings.gemini_api_key assert api_key is not None provider = GoogleAIProvider( name="google_ai", @@ -119,8 +122,8 @@ def test_googleai(): def test_google_vertex(): provider = GoogleVertexProvider( name="google_vertex", - google_cloud_project=os.getenv("GCP_PROJECT_ID"), - google_cloud_location=os.getenv("GCP_REGION"), + google_cloud_project=model_settings.google_cloud_project, + google_cloud_location=model_settings.google_cloud_location, ) models = provider.list_llm_models() assert len(models) > 0 @@ -131,50 +134,40 @@ def test_google_vertex(): assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" -def test_mistral(): - provider = MistralProvider( - name="mistral", - api_key=os.getenv("MISTRAL_API_KEY"), - ) - models = provider.list_llm_models() - assert len(models) > 0 - assert models[0].handle == f"{provider.name}/{models[0].model}" - - def test_together(): provider = TogetherProvider( name="together", - api_key=os.getenv("TOGETHER_API_KEY"), - default_prompt_formatter="chatml", + api_key=model_settings.together_api_key, + default_prompt_formatter=model_settings.default_prompt_formatter, ) models = provider.list_llm_models() assert len(models) > 0 assert models[0].handle == f"{provider.name}/{models[0].model}" - embedding_models = provider.list_embedding_models() - assert len(embedding_models) > 0 - assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" + # TODO: We don't have embedding models on together for CI + # embedding_models = provider.list_embedding_models() + # assert len(embedding_models) > 0 + # assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" -def test_anthropic_bedrock(): - from letta.settings import model_settings - - provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region) - models = provider.list_llm_models() - assert len(models) > 0 - assert models[0].handle == f"{provider.name}/{models[0].model}" - - embedding_models = provider.list_embedding_models() - assert len(embedding_models) > 0 - assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" +# TODO: Add back in, difficulty adding this to CI properly, need boto credentials +# def test_anthropic_bedrock(): +# from letta.settings import model_settings +# +# provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region) +# models = provider.list_llm_models() +# assert len(models) > 0 +# assert models[0].handle == f"{provider.name}/{models[0].model}" +# +# embedding_models = provider.list_embedding_models() +# assert len(embedding_models) > 0 +# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" def test_custom_anthropic(): - api_key = os.getenv("ANTHROPIC_API_KEY") - assert api_key is not None provider = AnthropicProvider( name="custom_anthropic", - api_key=api_key, + api_key=model_settings.anthropic_api_key, ) models = provider.list_llm_models() assert len(models) > 0