fix: max tokens and context window size [LET-6481] (#8298)
* fix: max tokens [LET-6481] * remove print statements * update * simplofy fallback * address comments async * update other helpers * update pyproject,.toml * update pyproject w async lru * oopen ai internal async methods * update * update uv lock
This commit is contained in:
committed by
Caren Thomas
parent
238894eebd
commit
0333ff0614
1
letta/model_specs/__init__.py
Normal file
1
letta/model_specs/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Model specification utilities for Letta."""
|
||||
120
letta/model_specs/litellm_model_specs.py
Normal file
120
letta/model_specs/litellm_model_specs.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
Utility functions for working with litellm model specifications.
|
||||
|
||||
This module provides access to model specifications from the litellm model_prices_and_context_window.json file.
|
||||
The data is synced from: https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import aiofiles
|
||||
from async_lru import alru_cache
|
||||
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Path to the litellm model specs JSON file
|
||||
MODEL_SPECS_PATH = os.path.join(os.path.dirname(__file__), "model_prices_and_context_window.json")
|
||||
|
||||
|
||||
@alru_cache(maxsize=1)
|
||||
async def load_model_specs() -> dict:
|
||||
"""Load the litellm model specifications from the JSON file.
|
||||
|
||||
Returns:
|
||||
dict: The model specifications data
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the model specs file is not found
|
||||
json.JSONDecodeError: If the file is not valid JSON
|
||||
"""
|
||||
if not os.path.exists(MODEL_SPECS_PATH):
|
||||
logger.warning(f"Model specs file not found at {MODEL_SPECS_PATH}")
|
||||
return {}
|
||||
|
||||
try:
|
||||
async with aiofiles.open(MODEL_SPECS_PATH, "r") as f:
|
||||
content = await f.read()
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse model specs JSON: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
async def get_model_spec(model_name: str) -> Optional[dict]:
|
||||
"""Get the specification for a specific model.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model (e.g., "gpt-4o", "gpt-4o-mini")
|
||||
|
||||
Returns:
|
||||
Optional[dict]: The model specification if found, None otherwise
|
||||
"""
|
||||
specs = await load_model_specs()
|
||||
return specs.get(model_name)
|
||||
|
||||
|
||||
async def get_max_input_tokens(model_name: str) -> Optional[int]:
|
||||
"""Get the max input tokens for a model.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model
|
||||
|
||||
Returns:
|
||||
Optional[int]: The max input tokens if found, None otherwise
|
||||
"""
|
||||
spec = await get_model_spec(model_name)
|
||||
if not spec:
|
||||
return None
|
||||
|
||||
return spec.get("max_input_tokens")
|
||||
|
||||
|
||||
async def get_max_output_tokens(model_name: str) -> Optional[int]:
|
||||
"""Get the max output tokens for a model.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model
|
||||
|
||||
Returns:
|
||||
Optional[int]: The max output tokens if found, None otherwise
|
||||
"""
|
||||
spec = await get_model_spec(model_name)
|
||||
if not spec:
|
||||
return None
|
||||
|
||||
# Try max_output_tokens first, fall back to max_tokens
|
||||
return spec.get("max_output_tokens") or spec.get("max_tokens")
|
||||
|
||||
|
||||
async def get_context_window(model_name: str) -> Optional[int]:
|
||||
"""Get the context window size for a model.
|
||||
|
||||
For most models, this is the max_input_tokens.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model
|
||||
|
||||
Returns:
|
||||
Optional[int]: The context window size if found, None otherwise
|
||||
"""
|
||||
return await get_max_input_tokens(model_name)
|
||||
|
||||
|
||||
async def get_litellm_provider(model_name: str) -> Optional[str]:
|
||||
"""Get the litellm provider for a model.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model
|
||||
|
||||
Returns:
|
||||
Optional[str]: The provider name if found, None otherwise
|
||||
"""
|
||||
spec = await get_model_spec(model_name)
|
||||
if not spec:
|
||||
return None
|
||||
|
||||
return spec.get("litellm_provider")
|
||||
31925
letta/model_specs/model_prices_and_context_window.json
Normal file
31925
letta/model_specs/model_prices_and_context_window.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -42,12 +42,24 @@ class OpenAIProvider(Provider):
|
||||
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
|
||||
|
||||
def get_default_max_output_tokens(self, model_name: str) -> int:
|
||||
"""Get the default max output tokens for OpenAI models."""
|
||||
if model_name.startswith("gpt-5"):
|
||||
"""Get the default max output tokens for OpenAI models (sync fallback)."""
|
||||
# Simple default for openai
|
||||
return 16384
|
||||
|
||||
async def get_default_max_output_tokens_async(self, model_name: str) -> int:
|
||||
"""Get the default max output tokens for OpenAI models.
|
||||
|
||||
Uses litellm model specifications with a simple fallback.
|
||||
"""
|
||||
from letta.model_specs.litellm_model_specs import get_max_output_tokens
|
||||
|
||||
# Try litellm specs
|
||||
max_output = await get_max_output_tokens(model_name)
|
||||
if max_output is not None:
|
||||
return max_output
|
||||
|
||||
# Simple default for openai
|
||||
return 16384
|
||||
elif model_name.startswith("o1") or model_name.startswith("o3"):
|
||||
return 100000
|
||||
return 16384 # default for openai
|
||||
|
||||
async def _get_models_async(self) -> list[dict]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
@@ -76,7 +88,7 @@ class OpenAIProvider(Provider):
|
||||
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
data = await self._get_models_async()
|
||||
return self._list_llm_models(data)
|
||||
return await self._list_llm_models(data)
|
||||
|
||||
async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
|
||||
"""Return known OpenAI embedding models.
|
||||
@@ -116,13 +128,13 @@ class OpenAIProvider(Provider):
|
||||
),
|
||||
]
|
||||
|
||||
def _list_llm_models(self, data: list[dict]) -> list[LLMConfig]:
|
||||
async def _list_llm_models(self, data: list[dict]) -> list[LLMConfig]:
|
||||
"""
|
||||
This handles filtering out LLM Models by provider that meet Letta's requirements.
|
||||
"""
|
||||
configs = []
|
||||
for model in data:
|
||||
check = self._do_model_checks_for_name_and_context_size(model)
|
||||
check = await self._do_model_checks_for_name_and_context_size_async(model)
|
||||
if check is None:
|
||||
continue
|
||||
model_name, context_window_size = check
|
||||
@@ -174,7 +186,7 @@ class OpenAIProvider(Provider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=handle,
|
||||
max_tokens=self.get_default_max_output_tokens(model_name),
|
||||
max_tokens=await self.get_default_max_output_tokens_async(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
@@ -188,12 +200,30 @@ class OpenAIProvider(Provider):
|
||||
return configs
|
||||
|
||||
def _do_model_checks_for_name_and_context_size(self, model: dict, length_key: str = "context_length") -> tuple[str, int] | None:
|
||||
"""Sync version - uses sync get_model_context_window_size (for subclasses with hardcoded values)."""
|
||||
if "id" not in model:
|
||||
logger.warning("Model missing 'id' field for provider: %s and model: %s", self.provider_type, model)
|
||||
return None
|
||||
|
||||
model_name = model["id"]
|
||||
context_window_size = model.get(length_key) or self.get_model_context_window_size(model_name)
|
||||
context_window_size = self.get_model_context_window_size(model_name)
|
||||
|
||||
if not context_window_size:
|
||||
logger.info("No context window size found for model: %s", model_name)
|
||||
return None
|
||||
|
||||
return model_name, context_window_size
|
||||
|
||||
async def _do_model_checks_for_name_and_context_size_async(
|
||||
self, model: dict, length_key: str = "context_length"
|
||||
) -> tuple[str, int] | None:
|
||||
"""Async version - uses async get_model_context_window_size_async (for litellm lookup)."""
|
||||
if "id" not in model:
|
||||
logger.warning("Model missing 'id' field for provider: %s and model: %s", self.provider_type, model)
|
||||
return None
|
||||
|
||||
model_name = model["id"]
|
||||
context_window_size = await self.get_model_context_window_size_async(model_name)
|
||||
|
||||
if not context_window_size:
|
||||
logger.info("No context window size found for model: %s", model_name)
|
||||
@@ -211,14 +241,25 @@ class OpenAIProvider(Provider):
|
||||
return llm_config
|
||||
|
||||
def get_model_context_window_size(self, model_name: str) -> int | None:
|
||||
if model_name in LLM_MAX_CONTEXT_WINDOW:
|
||||
return LLM_MAX_CONTEXT_WINDOW[model_name]
|
||||
else:
|
||||
"""Get the context window size for a model (sync fallback)."""
|
||||
return LLM_MAX_CONTEXT_WINDOW["DEFAULT"]
|
||||
|
||||
async def get_model_context_window_size_async(self, model_name: str) -> int | None:
|
||||
"""Get the context window size for a model.
|
||||
|
||||
Uses litellm model specifications which covers all OpenAI models.
|
||||
"""
|
||||
from letta.model_specs.litellm_model_specs import get_context_window
|
||||
|
||||
context_window = await get_context_window(model_name)
|
||||
if context_window is not None:
|
||||
return context_window
|
||||
|
||||
# Simple fallback
|
||||
logger.debug(
|
||||
"Model %s on %s for provider %s not found in LLM_MAX_CONTEXT_WINDOW. Using default of {LLM_MAX_CONTEXT_WINDOW['DEFAULT']}",
|
||||
"Model %s not found in litellm specs. Using default of %s",
|
||||
model_name,
|
||||
self.base_url,
|
||||
self.__class__.__name__,
|
||||
LLM_MAX_CONTEXT_WINDOW["DEFAULT"],
|
||||
)
|
||||
return LLM_MAX_CONTEXT_WINDOW["DEFAULT"]
|
||||
|
||||
@@ -226,4 +267,4 @@ class OpenAIProvider(Provider):
|
||||
return self.get_model_context_window_size(model_name)
|
||||
|
||||
async def get_model_context_window_async(self, model_name: str) -> int | None:
|
||||
return self.get_model_context_window_size(model_name)
|
||||
return await self.get_model_context_window_size_async(model_name)
|
||||
|
||||
@@ -75,6 +75,8 @@ dependencies = [
|
||||
"fastmcp>=2.12.5",
|
||||
"ddtrace>=4.2.1",
|
||||
"clickhouse-connect>=0.10.0",
|
||||
"aiofiles>=24.1.0",
|
||||
"async-lru>=2.0.5",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
4
uv.lock
generated
4
uv.lock
generated
@@ -2513,10 +2513,12 @@ name = "letta"
|
||||
version = "0.16.2"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "aiofiles" },
|
||||
{ name = "aiomultiprocess" },
|
||||
{ name = "alembic" },
|
||||
{ name = "anthropic" },
|
||||
{ name = "apscheduler" },
|
||||
{ name = "async-lru" },
|
||||
{ name = "black", extra = ["jupyter"] },
|
||||
{ name = "brotli" },
|
||||
{ name = "certifi" },
|
||||
@@ -2661,12 +2663,14 @@ sqlite = [
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "aioboto3", marker = "extra == 'bedrock'", specifier = ">=14.3.0" },
|
||||
{ name = "aiofiles", specifier = ">=24.1.0" },
|
||||
{ name = "aiomultiprocess", specifier = ">=0.9.1" },
|
||||
{ name = "aiosqlite", marker = "extra == 'desktop'", specifier = ">=0.21.0" },
|
||||
{ name = "aiosqlite", marker = "extra == 'sqlite'", specifier = ">=0.21.0" },
|
||||
{ name = "alembic", specifier = ">=1.13.3" },
|
||||
{ name = "anthropic", specifier = ">=0.75.0" },
|
||||
{ name = "apscheduler", specifier = ">=3.11.0" },
|
||||
{ name = "async-lru", specifier = ">=2.0.5" },
|
||||
{ name = "async-lru", marker = "extra == 'desktop'", specifier = ">=2.0.5" },
|
||||
{ name = "asyncpg", marker = "extra == 'postgres'", specifier = ">=0.30.0" },
|
||||
{ name = "black", extras = ["jupyter"], specifier = ">=24.2.0" },
|
||||
|
||||
Reference in New Issue
Block a user