From e15dea623d191203be0a1157dd2c432d05d74a7f Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Thu, 10 Oct 2024 14:07:45 -0700 Subject: [PATCH] chore: fix branch (#1865) --- letta/cli/cli.py | 27 +++++++++++++++++++++++++-- letta/llm_api/openai.py | 7 +++++-- letta/providers.py | 21 ++++++++++++++++----- letta/server/server.py | 5 +++-- letta/settings.py | 2 +- 5 files changed, 50 insertions(+), 12 deletions(-) diff --git a/letta/cli/cli.py b/letta/cli/cli.py index 31a567e1..160615b7 100644 --- a/letta/cli/cli.py +++ b/letta/cli/cli.py @@ -14,7 +14,9 @@ from letta.constants import CLI_WARNING_PREFIX, LETTA_DIR from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL from letta.log import get_logger from letta.metadata import MetadataStore +from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import OptionState +from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ChatMemory, Memory from letta.server.server import logger as server_logger @@ -233,25 +235,46 @@ def run( # choose from list of llm_configs llm_configs = client.list_llm_configs() llm_options = [llm_config.model for llm_config in llm_configs] + + # TODO move into LLMConfig as a class method? + def prettify_llm_config(llm_config: LLMConfig) -> str: + return f"{llm_config.model}" + f" ({llm_config.model_endpoint})" if llm_config.model_endpoint else "" + + llm_choices = [questionary.Choice(title=prettify_llm_config(llm_config), value=llm_config) for llm_config in llm_configs] + # select model if len(llm_options) == 0: raise ValueError("No LLM models found. Please enable a provider.") elif len(llm_options) == 1: llm_model_name = llm_options[0] else: - llm_model_name = questionary.select("Select LLM model:", choices=llm_options).ask() + llm_model_name = questionary.select("Select LLM model:", choices=llm_choices).ask().model llm_config = [llm_config for llm_config in llm_configs if llm_config.model == llm_model_name][0] # choose form list of embedding configs embedding_configs = client.list_embedding_configs() embedding_options = [embedding_config.embedding_model for embedding_config in embedding_configs] + + # TODO move into EmbeddingConfig as a class method? + def prettify_embed_config(embedding_config: EmbeddingConfig) -> str: + return ( + f"{embedding_config.embedding_model}" + f" ({embedding_config.embedding_endpoint})" + if embedding_config.embedding_endpoint + else "" + ) + + embedding_choices = [ + questionary.Choice(title=prettify_embed_config(embedding_config), value=embedding_config) + for embedding_config in embedding_configs + ] + # select model if len(embedding_options) == 0: raise ValueError("No embedding models found. Please enable a provider.") elif len(embedding_options) == 1: embedding_model_name = embedding_options[0] else: - embedding_model_name = questionary.select("Select embedding model:", choices=embedding_options).ask() + embedding_model_name = questionary.select("Select embedding model:", choices=embedding_choices).ask().embedding_model embedding_config = [ embedding_config for embedding_config in embedding_configs if embedding_config.embedding_model == embedding_model_name ][0] diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index 69a50fc2..f60150ee 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -41,7 +41,9 @@ from letta.utils import smart_urljoin OPENAI_SSE_DONE = "[DONE]" -def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional[bool] = False) -> dict: +def openai_get_model_list( + url: str, api_key: Union[str, None], fix_url: Optional[bool] = False, extra_params: Optional[dict] = None +) -> dict: """https://platform.openai.com/docs/api-reference/models/list""" from letta.utils import printd @@ -60,7 +62,8 @@ def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional printd(f"Sending request to {url}") try: - response = requests.get(url, headers=headers) + # TODO add query param "tool" to be true + response = requests.get(url, headers=headers, params=extra_params) response.raise_for_status() # Raises HTTPError for 4XX/5XX status response = response.json() # convert to dict from string printd(f"response = {response}") diff --git a/letta/providers.py b/letta/providers.py index 361d1728..761fcd7e 100644 --- a/letta/providers.py +++ b/letta/providers.py @@ -53,17 +53,28 @@ class LettaProvider(Provider): class OpenAIProvider(Provider): name: str = "openai" api_key: str = Field(..., description="API key for the OpenAI API.") - base_url: str = "https://api.openai.com/v1" + base_url: str = Field(..., description="Base URL for the OpenAI API.") def list_llm_models(self) -> List[LLMConfig]: from letta.llm_api.openai import openai_get_model_list - response = openai_get_model_list(self.base_url, api_key=self.api_key) - model_options = [obj["id"] for obj in response["data"]] + # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)... + # See: https://openrouter.ai/docs/requests + extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None + response = openai_get_model_list(self.base_url, api_key=self.api_key, extra_params=extra_params) + + assert "data" in response, f"OpenAI model query response missing 'data' field: {response}" configs = [] - for model_name in model_options: - context_window_size = self.get_model_context_window_size(model_name) + for model in response["data"]: + assert "id" in model, f"OpenAI model missing 'id' field: {model}" + model_name = model["id"] + + if "context_length" in model: + # Context length is returned in OpenRouter as "context_length" + context_window_size = model["context_length"] + else: + context_window_size = self.get_model_context_window_size(model_name) if not context_window_size: continue diff --git a/letta/server/server.py b/letta/server/server.py index 5088e9ef..b37ec867 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -50,6 +50,7 @@ from letta.providers import ( LettaProvider, OllamaProvider, OpenAIProvider, + Provider, VLLMProvider, ) from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState @@ -261,9 +262,9 @@ class SyncServer(Server): self.add_default_tools(module_name="base") # collect providers (always has Letta as a default) - self._enabled_providers = [LettaProvider()] + self._enabled_providers: List[Provider] = [LettaProvider()] if model_settings.openai_api_key: - self._enabled_providers.append(OpenAIProvider(api_key=model_settings.openai_api_key)) + self._enabled_providers.append(OpenAIProvider(api_key=model_settings.openai_api_key, base_url=model_settings.openai_api_base)) if model_settings.anthropic_api_key: self._enabled_providers.append(AnthropicProvider(api_key=model_settings.anthropic_api_key)) if model_settings.ollama_base_url: diff --git a/letta/settings.py b/letta/settings.py index 8b7fee27..12d42567 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -11,7 +11,7 @@ class ModelSettings(BaseSettings): # openai openai_api_key: Optional[str] = None - # TODO: provide overriding BASE_URL? + openai_api_base: Optional[str] = "https://api.openai.com/v1" # groq groq_api_key: Optional[str] = None