fix: incorrect context_window or embedding_dim using ollama (#2743)

This commit is contained in:
Sarah Wooders
2025-08-01 12:51:06 -07:00
committed by GitHub

View File

@@ -1,7 +1,6 @@
from typing import Literal
import aiohttp
import requests
from pydantic import Field
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
@@ -43,7 +42,7 @@ class OllamaProvider(OpenAIProvider):
configs = []
for model in response_json["models"]:
context_window = self.get_model_context_window(model["name"])
context_window = await self._get_model_context_window(model["name"])
if context_window is None:
print(f"Ollama model {model['name']} has no context window, using default 32000")
context_window = 32000
@@ -75,7 +74,7 @@ class OllamaProvider(OpenAIProvider):
configs = []
for model in response_json["models"]:
embedding_dim = await self._get_model_embedding_dim_async(model["name"])
embedding_dim = await self._get_model_embedding_dim(model["name"])
if not embedding_dim:
print(f"Ollama model {model['name']} has no embedding dimension, using default 1024")
# continue
@@ -92,63 +91,50 @@ class OllamaProvider(OpenAIProvider):
)
return configs
def get_model_context_window(self, model_name: str) -> int | None:
"""Gets model context window for Ollama. As this can look different based on models,
we use the following for guidance:
"llama.context_length": 8192,
"llama.embedding_length": 4096,
source: https://github.com/ollama/ollama/blob/main/docs/api.md#show-model-information
FROM 2024-10-08
Notes from vLLM around keys
source: https://github.com/vllm-project/vllm/blob/72ad2735823e23b4e1cc79b7c73c3a5f3c093ab0/vllm/config.py#L3488
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# ChatGLM2
"seq_length",
# Command-R
"model_max_length",
# Whisper
"max_target_positions",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
max_position_embeddings
parse model cards: nous, dolphon, llama
"""
async def _get_model_context_window(self, model_name: str) -> int | None:
endpoint = f"{self.base_url}/api/show"
payload = {"name": model_name, "verbose": True}
response = requests.post(endpoint, json=payload)
if response.status_code != 200:
return None
payload = {"name": model_name}
try:
model_info = response.json()
# Try to extract context window from model parameters
if "model_info" in model_info and "llama.context_length" in model_info["model_info"]:
return int(model_info["model_info"]["llama.context_length"])
except Exception:
pass
logger.warning(f"Failed to get model context window for {model_name}")
async with aiohttp.ClientSession() as session:
async with session.post(endpoint, json=payload) as response:
if response.status != 200:
error_text = await response.text()
logger.warning(f"Failed to get model info for {model_name}: {response.status} - {error_text}")
return None
response_json = await response.json()
model_info = response_json.get("model_info", {})
if architecture := model_info.get("general.architecture"):
if context_length := model_info.get(f"{architecture}.context_length"):
return int(context_length)
except Exception as e:
logger.warning(f"Failed to get model context window for {model_name} with error: {e}")
return None
async def _get_model_embedding_dim_async(self, model_name: str):
async with aiohttp.ClientSession() as session:
async with session.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) as response:
response_json = await response.json()
async def _get_model_embedding_dim(self, model_name: str) -> int | None:
endpoint = f"{self.base_url}/api/show"
payload = {"name": model_name}
if "model_info" not in response_json:
if "error" in response_json:
logger.warning("Ollama fetch model info error for %s: %s", model_name, response_json["error"])
return None
try:
async with aiohttp.ClientSession() as session:
async with session.post(endpoint, json=payload) as response:
if response.status != 200:
error_text = await response.text()
logger.warning(f"Failed to get model info for {model_name}: {response.status} - {error_text}")
return None
return response_json["model_info"].get("embedding_length")
response_json = await response.json()
model_info = response_json.get("model_info", {})
if architecture := model_info.get("general.architecture"):
if embedding_length := model_info.get(f"{architecture}.embedding_length"):
return int(embedding_length)
except Exception as e:
logger.warning(f"Failed to get model embedding dimension for {model_name} with error: {e}")
return None