fix: incorrect context_window or embedding_dim using ollama (#2743)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
from typing import Literal
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from pydantic import Field
|
||||
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
@@ -43,7 +42,7 @@ class OllamaProvider(OpenAIProvider):
|
||||
|
||||
configs = []
|
||||
for model in response_json["models"]:
|
||||
context_window = self.get_model_context_window(model["name"])
|
||||
context_window = await self._get_model_context_window(model["name"])
|
||||
if context_window is None:
|
||||
print(f"Ollama model {model['name']} has no context window, using default 32000")
|
||||
context_window = 32000
|
||||
@@ -75,7 +74,7 @@ class OllamaProvider(OpenAIProvider):
|
||||
|
||||
configs = []
|
||||
for model in response_json["models"]:
|
||||
embedding_dim = await self._get_model_embedding_dim_async(model["name"])
|
||||
embedding_dim = await self._get_model_embedding_dim(model["name"])
|
||||
if not embedding_dim:
|
||||
print(f"Ollama model {model['name']} has no embedding dimension, using default 1024")
|
||||
# continue
|
||||
@@ -92,63 +91,50 @@ class OllamaProvider(OpenAIProvider):
|
||||
)
|
||||
return configs
|
||||
|
||||
def get_model_context_window(self, model_name: str) -> int | None:
|
||||
"""Gets model context window for Ollama. As this can look different based on models,
|
||||
we use the following for guidance:
|
||||
|
||||
"llama.context_length": 8192,
|
||||
"llama.embedding_length": 4096,
|
||||
source: https://github.com/ollama/ollama/blob/main/docs/api.md#show-model-information
|
||||
|
||||
FROM 2024-10-08
|
||||
Notes from vLLM around keys
|
||||
source: https://github.com/vllm-project/vllm/blob/72ad2735823e23b4e1cc79b7c73c3a5f3c093ab0/vllm/config.py#L3488
|
||||
|
||||
possible_keys = [
|
||||
# OPT
|
||||
"max_position_embeddings",
|
||||
# GPT-2
|
||||
"n_positions",
|
||||
# MPT
|
||||
"max_seq_len",
|
||||
# ChatGLM2
|
||||
"seq_length",
|
||||
# Command-R
|
||||
"model_max_length",
|
||||
# Whisper
|
||||
"max_target_positions",
|
||||
# Others
|
||||
"max_sequence_length",
|
||||
"max_seq_length",
|
||||
"seq_len",
|
||||
]
|
||||
max_position_embeddings
|
||||
parse model cards: nous, dolphon, llama
|
||||
"""
|
||||
async def _get_model_context_window(self, model_name: str) -> int | None:
|
||||
endpoint = f"{self.base_url}/api/show"
|
||||
payload = {"name": model_name, "verbose": True}
|
||||
response = requests.post(endpoint, json=payload)
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
payload = {"name": model_name}
|
||||
|
||||
try:
|
||||
model_info = response.json()
|
||||
# Try to extract context window from model parameters
|
||||
if "model_info" in model_info and "llama.context_length" in model_info["model_info"]:
|
||||
return int(model_info["model_info"]["llama.context_length"])
|
||||
except Exception:
|
||||
pass
|
||||
logger.warning(f"Failed to get model context window for {model_name}")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(endpoint, json=payload) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.warning(f"Failed to get model info for {model_name}: {response.status} - {error_text}")
|
||||
return None
|
||||
|
||||
response_json = await response.json()
|
||||
model_info = response_json.get("model_info", {})
|
||||
|
||||
if architecture := model_info.get("general.architecture"):
|
||||
if context_length := model_info.get(f"{architecture}.context_length"):
|
||||
return int(context_length)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get model context window for {model_name} with error: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_model_embedding_dim_async(self, model_name: str):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) as response:
|
||||
response_json = await response.json()
|
||||
async def _get_model_embedding_dim(self, model_name: str) -> int | None:
|
||||
endpoint = f"{self.base_url}/api/show"
|
||||
payload = {"name": model_name}
|
||||
|
||||
if "model_info" not in response_json:
|
||||
if "error" in response_json:
|
||||
logger.warning("Ollama fetch model info error for %s: %s", model_name, response_json["error"])
|
||||
return None
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(endpoint, json=payload) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.warning(f"Failed to get model info for {model_name}: {response.status} - {error_text}")
|
||||
return None
|
||||
|
||||
return response_json["model_info"].get("embedding_length")
|
||||
response_json = await response.json()
|
||||
model_info = response_json.get("model_info", {})
|
||||
|
||||
if architecture := model_info.get("general.architecture"):
|
||||
if embedding_length := model_info.get(f"{architecture}.embedding_length"):
|
||||
return int(embedding_length)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get model embedding dimension for {model_name} with error: {e}")
|
||||
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user