From fa3a4ab88b9f594f57c5b255c74e30b850715d1f Mon Sep 17 00:00:00 2001 From: Andy Li <55300002+cliandy@users.noreply.github.com> Date: Tue, 10 Jun 2025 10:44:26 -0700 Subject: [PATCH] fix: list ollama embeddings models (#2711) --- letta/functions/async_composio_toolset.py | 5 ++- letta/local_llm/utils.py | 23 ++-------- letta/schemas/providers.py | 55 +++++++++++++++++------ letta/server/server.py | 11 +++-- letta/utils.py | 20 +++------ 5 files changed, 62 insertions(+), 52 deletions(-) diff --git a/letta/functions/async_composio_toolset.py b/letta/functions/async_composio_toolset.py index bcea60d6..3094bf59 100644 --- a/letta/functions/async_composio_toolset.py +++ b/letta/functions/async_composio_toolset.py @@ -84,7 +84,10 @@ class AsyncComposioToolSet(BaseComposioToolSet, runtime="letta", description_cha # Handle specific error codes from Composio API if error_code == 10401 or "API_KEY_NOT_FOUND" in error_message: raise ApiKeyNotProvidedError() - if "connected account not found" in error_message.lower(): + if ( + "connected account not found" in error_message.lower() + or "no connected account found" in error_message.lower() + ): raise ConnectedAccountNotFoundError(f"Connected account not found: {error_message}") if "enum metadata not found" in error_message.lower(): raise EnumMetadataNotFound(f"Enum metadata not found: {error_message}") diff --git a/letta/local_llm/utils.py b/letta/local_llm/utils.py index c5c084ed..8d962a84 100644 --- a/letta/local_llm/utils.py +++ b/letta/local_llm/utils.py @@ -44,24 +44,6 @@ def post_json_auth_request(uri, json_payload, auth_type, auth_key): return response -# deprecated for Box -class DotDict(dict): - """Allow dot access on properties similar to OpenAI response object""" - - def __getattr__(self, attr): - return self.get(attr) - - def __setattr__(self, key, value): - self[key] = value - - # following methods necessary for pickling - def __getstate__(self): - return vars(self) - - def __setstate__(self, state): - vars(self).update(state) - - def load_grammar_file(grammar): # Set grammar grammar_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "grammars", f"{grammar}.gbnf") @@ -79,8 +61,9 @@ def load_grammar_file(grammar): # TODO: support tokenizers/tokenizer apis available in local models def count_tokens(s: str, model: str = "gpt-4") -> int: - encoding = tiktoken.encoding_for_model(model) - return len(encoding.encode(s)) + from letta.utils import count_tokens + + return count_tokens(s, model) def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"): diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index b87fdf48..c9455b47 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -2,6 +2,8 @@ import warnings from datetime import datetime from typing import List, Literal, Optional +import aiohttp +import requests from pydantic import BaseModel, Field, model_validator from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LETTA_MODEL_ENDPOINT, LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW @@ -872,9 +874,6 @@ class OllamaProvider(OpenAIProvider): async def list_llm_models_async(self) -> List[LLMConfig]: """Async version of list_llm_models below""" endpoint = f"{self.base_url}/api/tags" - - import aiohttp - async with aiohttp.ClientSession() as session: async with session.get(endpoint) as response: if response.status != 200: @@ -903,8 +902,6 @@ class OllamaProvider(OpenAIProvider): def list_llm_models(self) -> List[LLMConfig]: # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models - import requests - response = requests.get(f"{self.base_url}/api/tags") if response.status_code != 200: raise Exception(f"Failed to list Ollama models: {response.text}") @@ -931,9 +928,6 @@ class OllamaProvider(OpenAIProvider): return configs def get_model_context_window(self, model_name: str) -> Optional[int]: - - import requests - response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) response_json = response.json() @@ -965,11 +959,19 @@ class OllamaProvider(OpenAIProvider): return value return None - def get_model_embedding_dim(self, model_name: str): - import requests - + def _get_model_embedding_dim(self, model_name: str): response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) response_json = response.json() + return self._get_model_embedding_dim_impl(response_json, model_name) + + async def _get_model_embedding_dim_async(self, model_name: str): + async with aiohttp.ClientSession() as session: + async with session.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) as response: + response_json = await response.json() + return self._get_model_embedding_dim_impl(response_json, model_name) + + @staticmethod + def _get_model_embedding_dim_impl(response_json: dict, model_name: str): if "model_info" not in response_json: if "error" in response_json: print(f"Ollama fetch model info error for {model_name}: {response_json['error']}") @@ -979,10 +981,35 @@ class OllamaProvider(OpenAIProvider): return value return None + async def list_embedding_models_async(self) -> List[EmbeddingConfig]: + """Async version of list_embedding_models below""" + endpoint = f"{self.base_url}/api/tags" + async with aiohttp.ClientSession() as session: + async with session.get(endpoint) as response: + if response.status != 200: + raise Exception(f"Failed to list Ollama models: {response.text}") + response_json = await response.json() + + configs = [] + for model in response_json["models"]: + embedding_dim = await self._get_model_embedding_dim_async(model["name"]) + if not embedding_dim: + print(f"Ollama model {model['name']} has no embedding dimension") + continue + configs.append( + EmbeddingConfig( + embedding_model=model["name"], + embedding_endpoint_type="ollama", + embedding_endpoint=self.base_url, + embedding_dim=embedding_dim, + embedding_chunk_size=300, + handle=self.get_handle(model["name"], is_embedding=True), + ) + ) + return configs + def list_embedding_models(self) -> List[EmbeddingConfig]: # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models - import requests - response = requests.get(f"{self.base_url}/api/tags") if response.status_code != 200: raise Exception(f"Failed to list Ollama models: {response.text}") @@ -990,7 +1017,7 @@ class OllamaProvider(OpenAIProvider): configs = [] for model in response_json["models"]: - embedding_dim = self.get_model_embedding_dim(model["name"]) + embedding_dim = self._get_model_embedding_dim(model["name"]) if not embedding_dim: print(f"Ollama model {model['name']} has no embedding dimension") continue diff --git a/letta/server/server.py b/letta/server/server.py index 400505f1..2cdb0d33 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -2028,7 +2028,8 @@ class SyncServer(Server): ) # Composio wrappers - def get_composio_client(self, api_key: Optional[str] = None): + @staticmethod + def get_composio_client(api_key: Optional[str] = None): if api_key: return Composio(api_key=api_key) elif tool_settings.composio_api_key: @@ -2036,9 +2037,10 @@ class SyncServer(Server): else: return Composio() - def get_composio_apps(self, api_key: Optional[str] = None) -> List["AppModel"]: + @staticmethod + def get_composio_apps(api_key: Optional[str] = None) -> List["AppModel"]: """Get a list of all Composio apps with actions""" - apps = self.get_composio_client(api_key=api_key).apps.get() + apps = SyncServer.get_composio_client(api_key=api_key).apps.get() apps_with_actions = [] for app in apps: # A bit of hacky logic until composio patches this @@ -2049,7 +2051,8 @@ class SyncServer(Server): def get_composio_actions_from_app_name(self, composio_app_name: str, api_key: Optional[str] = None) -> List["ActionModel"]: actions = self.get_composio_client(api_key=api_key).actions.get(apps=[composio_app_name]) - return actions + # Filter out deprecated composio actions + return [action for action in actions if "deprecated" not in action.description.lower()] # MCP wrappers # TODO support both command + SSE servers (via config) diff --git a/letta/utils.py b/letta/utils.py index c3183f92..5c834d7e 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -468,17 +468,6 @@ NOUN_BANK = [ ] -def deduplicate(target_list: list) -> list: - seen = set() - dedup_list = [] - for i in target_list: - if i not in seen: - seen.add(i) - dedup_list.append(i) - - return dedup_list - - def smart_urljoin(base_url: str, relative_url: str) -> str: """urljoin is stupid and wants a trailing / at the end of the endpoint address, or it will chop the suffix off""" if not base_url.endswith("/"): @@ -516,8 +505,9 @@ def is_optional_type(hint): def enforce_types(func): """Enforces that values passed in match the expected types. + Technically will handle coroutines as well. - Technically will handle coroutines as well. + TODO (cliandy): use stricter pydantic fields """ @wraps(func) @@ -808,7 +798,11 @@ class OpenAIBackcompatUnpickler(pickle.Unpickler): def count_tokens(s: str, model: str = "gpt-4") -> int: - encoding = tiktoken.encoding_for_model(model) + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + print("Falling back to cl100k base for token counting.") + encoding = tiktoken.get_encoding("cl100k_base") return len(encoding.encode(s))