fix: max tokens and context window size [LET-6481] (#8298)

* fix: max tokens [LET-6481]

* remove print statements

* update

* simplofy fallback

* address comments async

* update other helpers

* update pyproject,.toml

* update pyproject w async lru

* oopen ai internal async methods

* update

* update uv lock
This commit is contained in:
Christina Tong
2026-01-20 13:35:07 -08:00
committed by Caren Thomas
parent 238894eebd
commit 0333ff0614
6 changed files with 32115 additions and 22 deletions

View File

@@ -0,0 +1 @@
"""Model specification utilities for Letta."""

View File

@@ -0,0 +1,120 @@
"""
Utility functions for working with litellm model specifications.
This module provides access to model specifications from the litellm model_prices_and_context_window.json file.
The data is synced from: https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json
"""
import json
import os
from typing import Optional
import aiofiles
from async_lru import alru_cache
from letta.log import get_logger
logger = get_logger(__name__)
# Path to the litellm model specs JSON file
MODEL_SPECS_PATH = os.path.join(os.path.dirname(__file__), "model_prices_and_context_window.json")
@alru_cache(maxsize=1)
async def load_model_specs() -> dict:
"""Load the litellm model specifications from the JSON file.
Returns:
dict: The model specifications data
Raises:
FileNotFoundError: If the model specs file is not found
json.JSONDecodeError: If the file is not valid JSON
"""
if not os.path.exists(MODEL_SPECS_PATH):
logger.warning(f"Model specs file not found at {MODEL_SPECS_PATH}")
return {}
try:
async with aiofiles.open(MODEL_SPECS_PATH, "r") as f:
content = await f.read()
return json.loads(content)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse model specs JSON: {e}")
return {}
async def get_model_spec(model_name: str) -> Optional[dict]:
"""Get the specification for a specific model.
Args:
model_name: The name of the model (e.g., "gpt-4o", "gpt-4o-mini")
Returns:
Optional[dict]: The model specification if found, None otherwise
"""
specs = await load_model_specs()
return specs.get(model_name)
async def get_max_input_tokens(model_name: str) -> Optional[int]:
"""Get the max input tokens for a model.
Args:
model_name: The name of the model
Returns:
Optional[int]: The max input tokens if found, None otherwise
"""
spec = await get_model_spec(model_name)
if not spec:
return None
return spec.get("max_input_tokens")
async def get_max_output_tokens(model_name: str) -> Optional[int]:
"""Get the max output tokens for a model.
Args:
model_name: The name of the model
Returns:
Optional[int]: The max output tokens if found, None otherwise
"""
spec = await get_model_spec(model_name)
if not spec:
return None
# Try max_output_tokens first, fall back to max_tokens
return spec.get("max_output_tokens") or spec.get("max_tokens")
async def get_context_window(model_name: str) -> Optional[int]:
"""Get the context window size for a model.
For most models, this is the max_input_tokens.
Args:
model_name: The name of the model
Returns:
Optional[int]: The context window size if found, None otherwise
"""
return await get_max_input_tokens(model_name)
async def get_litellm_provider(model_name: str) -> Optional[str]:
"""Get the litellm provider for a model.
Args:
model_name: The name of the model
Returns:
Optional[str]: The provider name if found, None otherwise
"""
spec = await get_model_spec(model_name)
if not spec:
return None
return spec.get("litellm_provider")

File diff suppressed because it is too large Load Diff

View File

@@ -42,12 +42,24 @@ class OpenAIProvider(Provider):
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
def get_default_max_output_tokens(self, model_name: str) -> int:
"""Get the default max output tokens for OpenAI models."""
if model_name.startswith("gpt-5"):
"""Get the default max output tokens for OpenAI models (sync fallback)."""
# Simple default for openai
return 16384
async def get_default_max_output_tokens_async(self, model_name: str) -> int:
"""Get the default max output tokens for OpenAI models.
Uses litellm model specifications with a simple fallback.
"""
from letta.model_specs.litellm_model_specs import get_max_output_tokens
# Try litellm specs
max_output = await get_max_output_tokens(model_name)
if max_output is not None:
return max_output
# Simple default for openai
return 16384
elif model_name.startswith("o1") or model_name.startswith("o3"):
return 100000
return 16384 # default for openai
async def _get_models_async(self) -> list[dict]:
from letta.llm_api.openai import openai_get_model_list_async
@@ -76,7 +88,7 @@ class OpenAIProvider(Provider):
async def list_llm_models_async(self) -> list[LLMConfig]:
data = await self._get_models_async()
return self._list_llm_models(data)
return await self._list_llm_models(data)
async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
"""Return known OpenAI embedding models.
@@ -116,13 +128,13 @@ class OpenAIProvider(Provider):
),
]
def _list_llm_models(self, data: list[dict]) -> list[LLMConfig]:
async def _list_llm_models(self, data: list[dict]) -> list[LLMConfig]:
"""
This handles filtering out LLM Models by provider that meet Letta's requirements.
"""
configs = []
for model in data:
check = self._do_model_checks_for_name_and_context_size(model)
check = await self._do_model_checks_for_name_and_context_size_async(model)
if check is None:
continue
model_name, context_window_size = check
@@ -174,7 +186,7 @@ class OpenAIProvider(Provider):
model_endpoint=self.base_url,
context_window=context_window_size,
handle=handle,
max_tokens=self.get_default_max_output_tokens(model_name),
max_tokens=await self.get_default_max_output_tokens_async(model_name),
provider_name=self.name,
provider_category=self.provider_category,
)
@@ -188,12 +200,30 @@ class OpenAIProvider(Provider):
return configs
def _do_model_checks_for_name_and_context_size(self, model: dict, length_key: str = "context_length") -> tuple[str, int] | None:
"""Sync version - uses sync get_model_context_window_size (for subclasses with hardcoded values)."""
if "id" not in model:
logger.warning("Model missing 'id' field for provider: %s and model: %s", self.provider_type, model)
return None
model_name = model["id"]
context_window_size = model.get(length_key) or self.get_model_context_window_size(model_name)
context_window_size = self.get_model_context_window_size(model_name)
if not context_window_size:
logger.info("No context window size found for model: %s", model_name)
return None
return model_name, context_window_size
async def _do_model_checks_for_name_and_context_size_async(
self, model: dict, length_key: str = "context_length"
) -> tuple[str, int] | None:
"""Async version - uses async get_model_context_window_size_async (for litellm lookup)."""
if "id" not in model:
logger.warning("Model missing 'id' field for provider: %s and model: %s", self.provider_type, model)
return None
model_name = model["id"]
context_window_size = await self.get_model_context_window_size_async(model_name)
if not context_window_size:
logger.info("No context window size found for model: %s", model_name)
@@ -211,14 +241,25 @@ class OpenAIProvider(Provider):
return llm_config
def get_model_context_window_size(self, model_name: str) -> int | None:
if model_name in LLM_MAX_CONTEXT_WINDOW:
return LLM_MAX_CONTEXT_WINDOW[model_name]
else:
"""Get the context window size for a model (sync fallback)."""
return LLM_MAX_CONTEXT_WINDOW["DEFAULT"]
async def get_model_context_window_size_async(self, model_name: str) -> int | None:
"""Get the context window size for a model.
Uses litellm model specifications which covers all OpenAI models.
"""
from letta.model_specs.litellm_model_specs import get_context_window
context_window = await get_context_window(model_name)
if context_window is not None:
return context_window
# Simple fallback
logger.debug(
"Model %s on %s for provider %s not found in LLM_MAX_CONTEXT_WINDOW. Using default of {LLM_MAX_CONTEXT_WINDOW['DEFAULT']}",
"Model %s not found in litellm specs. Using default of %s",
model_name,
self.base_url,
self.__class__.__name__,
LLM_MAX_CONTEXT_WINDOW["DEFAULT"],
)
return LLM_MAX_CONTEXT_WINDOW["DEFAULT"]
@@ -226,4 +267,4 @@ class OpenAIProvider(Provider):
return self.get_model_context_window_size(model_name)
async def get_model_context_window_async(self, model_name: str) -> int | None:
return self.get_model_context_window_size(model_name)
return await self.get_model_context_window_size_async(model_name)

View File

@@ -75,6 +75,8 @@ dependencies = [
"fastmcp>=2.12.5",
"ddtrace>=4.2.1",
"clickhouse-connect>=0.10.0",
"aiofiles>=24.1.0",
"async-lru>=2.0.5",
]
[project.scripts]

4
uv.lock generated
View File

@@ -2513,10 +2513,12 @@ name = "letta"
version = "0.16.2"
source = { editable = "." }
dependencies = [
{ name = "aiofiles" },
{ name = "aiomultiprocess" },
{ name = "alembic" },
{ name = "anthropic" },
{ name = "apscheduler" },
{ name = "async-lru" },
{ name = "black", extra = ["jupyter"] },
{ name = "brotli" },
{ name = "certifi" },
@@ -2661,12 +2663,14 @@ sqlite = [
[package.metadata]
requires-dist = [
{ name = "aioboto3", marker = "extra == 'bedrock'", specifier = ">=14.3.0" },
{ name = "aiofiles", specifier = ">=24.1.0" },
{ name = "aiomultiprocess", specifier = ">=0.9.1" },
{ name = "aiosqlite", marker = "extra == 'desktop'", specifier = ">=0.21.0" },
{ name = "aiosqlite", marker = "extra == 'sqlite'", specifier = ">=0.21.0" },
{ name = "alembic", specifier = ">=1.13.3" },
{ name = "anthropic", specifier = ">=0.75.0" },
{ name = "apscheduler", specifier = ">=3.11.0" },
{ name = "async-lru", specifier = ">=2.0.5" },
{ name = "async-lru", marker = "extra == 'desktop'", specifier = ">=2.0.5" },
{ name = "asyncpg", marker = "extra == 'postgres'", specifier = ">=0.30.0" },
{ name = "black", extras = ["jupyter"], specifier = ">=24.2.0" },