From 1099809e49434ef7e1188dd71bc040bae0c03435 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Sat, 12 Oct 2024 20:19:20 -0700 Subject: [PATCH] fix: patch errors with `OllamaProvider` (#1875) --- letta/local_llm/chat_completion_proxy.py | 4 +-- letta/providers.py | 23 +++++++++++++++- letta/server/server.py | 35 ++++++++++++++++++++---- tests/test_providers.py | 20 +++++++++----- 4 files changed, 65 insertions(+), 17 deletions(-) diff --git a/letta/local_llm/chat_completion_proxy.py b/letta/local_llm/chat_completion_proxy.py index 25b91420..c6dbd4a1 100644 --- a/letta/local_llm/chat_completion_proxy.py +++ b/letta/local_llm/chat_completion_proxy.py @@ -85,9 +85,7 @@ def get_chat_completion( elif wrapper is None: # Warn the user that we're using the fallback if not has_shown_warning: - print( - f"{CLI_WARNING_PREFIX}no wrapper specified for local LLM, using the default wrapper (you can remove this warning by specifying the wrapper with --model-wrapper)" - ) + print(f"{CLI_WARNING_PREFIX}no prompt formatter specified for local LLM, using the default formatter") has_shown_warning = True llm_wrapper = DEFAULT_WRAPPER() diff --git a/letta/providers.py b/letta/providers.py index fa545708..9ea298c2 100644 --- a/letta/providers.py +++ b/letta/providers.py @@ -140,9 +140,17 @@ class AnthropicProvider(Provider): class OllamaProvider(OpenAIProvider): + """Ollama provider that uses the native /api/generate endpoint + + See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion + """ + name: str = "ollama" base_url: str = Field(..., description="Base URL for the Ollama API.") api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).") + default_prompt_formatter: str = Field( + ..., description="Default prompt formatter (aka model wrapper) to use on a /completions style API." + ) def list_llm_models(self) -> List[LLMConfig]: # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models @@ -156,11 +164,15 @@ class OllamaProvider(OpenAIProvider): configs = [] for model in response_json["models"]: context_window = self.get_model_context_window(model["name"]) + if context_window is None: + print(f"Ollama model {model['name']} has no context window") + continue configs.append( LLMConfig( model=model["name"], model_endpoint_type="ollama", model_endpoint=self.base_url, + model_wrapper=self.default_prompt_formatter, context_window=context_window, ) ) @@ -192,6 +204,10 @@ class OllamaProvider(OpenAIProvider): # ] # max_position_embeddings # parse model cards: nous, dolphon, llama + if "model_info" not in response_json: + if "error" in response_json: + print(f"Ollama fetch model info error for {model_name}: {response_json['error']}") + return None for key, value in response_json["model_info"].items(): if "context_length" in key: return value @@ -202,6 +218,10 @@ class OllamaProvider(OpenAIProvider): response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) response_json = response.json() + if "model_info" not in response_json: + if "error" in response_json: + print(f"Ollama fetch model info error for {model_name}: {response_json['error']}") + return None for key, value in response_json["model_info"].items(): if "embedding_length" in key: return value @@ -220,6 +240,7 @@ class OllamaProvider(OpenAIProvider): for model in response_json["models"]: embedding_dim = self.get_model_embedding_dim(model["name"]) if not embedding_dim: + print(f"Ollama model {model['name']} has no embedding dimension") continue configs.append( EmbeddingConfig( @@ -420,7 +441,7 @@ class VLLMCompletionsProvider(Provider): # NOTE: vLLM only serves one model at a time (so could configure that through env variables) name: str = "vllm" base_url: str = Field(..., description="Base URL for the vLLM API.") - default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper)to use on vLLM /completions API.") + default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.") def list_llm_models(self) -> List[LLMConfig]: # not supported with vLLM diff --git a/letta/server/server.py b/letta/server/server.py index 08050ac0..fcb00962 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -200,7 +200,7 @@ class SyncServer(Server): def __init__( self, chaining: bool = True, - max_chaining_steps: bool = None, + max_chaining_steps: Optional[bool] = None, default_interface_factory: Callable[[], AgentInterface] = lambda: CLIInterface(), # default_interface: AgentInterface = CLIInterface(), # default_persistence_manager_cls: PersistenceManager = LocalStateManager, @@ -241,13 +241,32 @@ class SyncServer(Server): # collect providers (always has Letta as a default) self._enabled_providers: List[Provider] = [LettaProvider()] if model_settings.openai_api_key: - self._enabled_providers.append(OpenAIProvider(api_key=model_settings.openai_api_key, base_url=model_settings.openai_api_base)) + self._enabled_providers.append( + OpenAIProvider( + api_key=model_settings.openai_api_key, + base_url=model_settings.openai_api_base, + ) + ) if model_settings.anthropic_api_key: - self._enabled_providers.append(AnthropicProvider(api_key=model_settings.anthropic_api_key)) + self._enabled_providers.append( + AnthropicProvider( + api_key=model_settings.anthropic_api_key, + ) + ) if model_settings.ollama_base_url: - self._enabled_providers.append(OllamaProvider(base_url=model_settings.ollama_base_url, api_key=None)) + self._enabled_providers.append( + OllamaProvider( + base_url=model_settings.ollama_base_url, + api_key=None, + default_prompt_formatter=model_settings.default_prompt_formatter, + ) + ) if model_settings.gemini_api_key: - self._enabled_providers.append(GoogleAIProvider(api_key=model_settings.gemini_api_key)) + self._enabled_providers.append( + GoogleAIProvider( + api_key=model_settings.gemini_api_key, + ) + ) if model_settings.azure_api_key and model_settings.azure_base_url: assert model_settings.azure_api_version, "AZURE_API_VERSION is required" self._enabled_providers.append( @@ -268,7 +287,11 @@ class SyncServer(Server): # NOTE: to use the /chat/completions endpoint, you need to specify extra flags on vLLM startup # see: https://docs.vllm.ai/en/latest/getting_started/examples/openai_chat_completion_client_with_tools.html # e.g. "... --enable-auto-tool-choice --tool-call-parser hermes" - self._enabled_providers.append(VLLMChatCompletionsProvider(base_url=model_settings.vllm_api_base)) + self._enabled_providers.append( + VLLMChatCompletionsProvider( + base_url=model_settings.vllm_api_base, + ) + ) def save_agents(self): """Saves all the agents that are in the in-memory object store""" diff --git a/tests/test_providers.py b/tests/test_providers.py index 01bb8d41..f2d4f95b 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -6,19 +6,21 @@ from letta.providers import ( OllamaProvider, OpenAIProvider, ) +from letta.settings import model_settings def test_openai(): - - provider = OpenAIProvider(api_key=os.getenv("OPENAI_API_KEY")) + api_key = os.getenv("OPENAI_API_KEY") + assert api_key is not None + provider = OpenAIProvider(api_key=api_key, base_url=model_settings.openai_api_base) models = provider.list_llm_models() print(models) def test_anthropic(): - if os.getenv("ANTHROPIC_API_KEY") is None: - return - provider = AnthropicProvider(api_key=os.getenv("ANTHROPIC_API_KEY")) + api_key = os.getenv("ANTHROPIC_API_KEY") + assert api_key is not None + provider = AnthropicProvider(api_key=api_key) models = provider.list_llm_models() print(models) @@ -38,7 +40,9 @@ def test_azure(): def test_ollama(): - provider = OllamaProvider(base_url=os.getenv("OLLAMA_BASE_URL")) + base_url = os.getenv("OLLAMA_BASE_URL") + assert base_url is not None + provider = OllamaProvider(base_url=base_url, default_prompt_formatter=model_settings.default_prompt_formatter, api_key=None) models = provider.list_llm_models() print(models) @@ -47,7 +51,9 @@ def test_ollama(): def test_googleai(): - provider = GoogleAIProvider(api_key=os.getenv("GEMINI_API_KEY")) + api_key = os.getenv("GEMINI_API_KEY") + assert api_key is not None + provider = GoogleAIProvider(api_key=api_key) models = provider.list_llm_models() print(models)