diff --git a/letta/schemas/providers/ollama.py b/letta/schemas/providers/ollama.py index 8cc8f720..d34d86d7 100644 --- a/letta/schemas/providers/ollama.py +++ b/letta/schemas/providers/ollama.py @@ -1,7 +1,6 @@ from typing import Literal import aiohttp -import requests from pydantic import Field from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE @@ -43,7 +42,7 @@ class OllamaProvider(OpenAIProvider): configs = [] for model in response_json["models"]: - context_window = self.get_model_context_window(model["name"]) + context_window = await self._get_model_context_window(model["name"]) if context_window is None: print(f"Ollama model {model['name']} has no context window, using default 32000") context_window = 32000 @@ -75,7 +74,7 @@ class OllamaProvider(OpenAIProvider): configs = [] for model in response_json["models"]: - embedding_dim = await self._get_model_embedding_dim_async(model["name"]) + embedding_dim = await self._get_model_embedding_dim(model["name"]) if not embedding_dim: print(f"Ollama model {model['name']} has no embedding dimension, using default 1024") # continue @@ -92,63 +91,50 @@ class OllamaProvider(OpenAIProvider): ) return configs - def get_model_context_window(self, model_name: str) -> int | None: - """Gets model context window for Ollama. As this can look different based on models, - we use the following for guidance: - - "llama.context_length": 8192, - "llama.embedding_length": 4096, - source: https://github.com/ollama/ollama/blob/main/docs/api.md#show-model-information - - FROM 2024-10-08 - Notes from vLLM around keys - source: https://github.com/vllm-project/vllm/blob/72ad2735823e23b4e1cc79b7c73c3a5f3c093ab0/vllm/config.py#L3488 - - possible_keys = [ - # OPT - "max_position_embeddings", - # GPT-2 - "n_positions", - # MPT - "max_seq_len", - # ChatGLM2 - "seq_length", - # Command-R - "model_max_length", - # Whisper - "max_target_positions", - # Others - "max_sequence_length", - "max_seq_length", - "seq_len", - ] - max_position_embeddings - parse model cards: nous, dolphon, llama - """ + async def _get_model_context_window(self, model_name: str) -> int | None: endpoint = f"{self.base_url}/api/show" - payload = {"name": model_name, "verbose": True} - response = requests.post(endpoint, json=payload) - if response.status_code != 200: - return None + payload = {"name": model_name} try: - model_info = response.json() - # Try to extract context window from model parameters - if "model_info" in model_info and "llama.context_length" in model_info["model_info"]: - return int(model_info["model_info"]["llama.context_length"]) - except Exception: - pass - logger.warning(f"Failed to get model context window for {model_name}") + async with aiohttp.ClientSession() as session: + async with session.post(endpoint, json=payload) as response: + if response.status != 200: + error_text = await response.text() + logger.warning(f"Failed to get model info for {model_name}: {response.status} - {error_text}") + return None + + response_json = await response.json() + model_info = response_json.get("model_info", {}) + + if architecture := model_info.get("general.architecture"): + if context_length := model_info.get(f"{architecture}.context_length"): + return int(context_length) + + except Exception as e: + logger.warning(f"Failed to get model context window for {model_name} with error: {e}") + return None - async def _get_model_embedding_dim_async(self, model_name: str): - async with aiohttp.ClientSession() as session: - async with session.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) as response: - response_json = await response.json() + async def _get_model_embedding_dim(self, model_name: str) -> int | None: + endpoint = f"{self.base_url}/api/show" + payload = {"name": model_name} - if "model_info" not in response_json: - if "error" in response_json: - logger.warning("Ollama fetch model info error for %s: %s", model_name, response_json["error"]) - return None + try: + async with aiohttp.ClientSession() as session: + async with session.post(endpoint, json=payload) as response: + if response.status != 200: + error_text = await response.text() + logger.warning(f"Failed to get model info for {model_name}: {response.status} - {error_text}") + return None - return response_json["model_info"].get("embedding_length") + response_json = await response.json() + model_info = response_json.get("model_info", {}) + + if architecture := model_info.get("general.architecture"): + if embedding_length := model_info.get(f"{architecture}.embedding_length"): + return int(embedding_length) + + except Exception as e: + logger.warning(f"Failed to get model embedding dimension for {model_name} with error: {e}") + + return None