Files
letta-server/letta/model_specs/litellm_model_specs.py
Christina Tong 0333ff0614 fix: max tokens and context window size [LET-6481] (#8298)
* fix: max tokens [LET-6481]

* remove print statements

* update

* simplofy fallback

* address comments async

* update other helpers

* update pyproject,.toml

* update pyproject w async lru

* oopen ai internal async methods

* update

* update uv lock
2026-01-29 12:43:53 -08:00

121 lines
3.2 KiB
Python

"""
Utility functions for working with litellm model specifications.
This module provides access to model specifications from the litellm model_prices_and_context_window.json file.
The data is synced from: https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json
"""
import json
import os
from typing import Optional
import aiofiles
from async_lru import alru_cache
from letta.log import get_logger
logger = get_logger(__name__)
# Path to the litellm model specs JSON file
MODEL_SPECS_PATH = os.path.join(os.path.dirname(__file__), "model_prices_and_context_window.json")
@alru_cache(maxsize=1)
async def load_model_specs() -> dict:
"""Load the litellm model specifications from the JSON file.
Returns:
dict: The model specifications data
Raises:
FileNotFoundError: If the model specs file is not found
json.JSONDecodeError: If the file is not valid JSON
"""
if not os.path.exists(MODEL_SPECS_PATH):
logger.warning(f"Model specs file not found at {MODEL_SPECS_PATH}")
return {}
try:
async with aiofiles.open(MODEL_SPECS_PATH, "r") as f:
content = await f.read()
return json.loads(content)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse model specs JSON: {e}")
return {}
async def get_model_spec(model_name: str) -> Optional[dict]:
"""Get the specification for a specific model.
Args:
model_name: The name of the model (e.g., "gpt-4o", "gpt-4o-mini")
Returns:
Optional[dict]: The model specification if found, None otherwise
"""
specs = await load_model_specs()
return specs.get(model_name)
async def get_max_input_tokens(model_name: str) -> Optional[int]:
"""Get the max input tokens for a model.
Args:
model_name: The name of the model
Returns:
Optional[int]: The max input tokens if found, None otherwise
"""
spec = await get_model_spec(model_name)
if not spec:
return None
return spec.get("max_input_tokens")
async def get_max_output_tokens(model_name: str) -> Optional[int]:
"""Get the max output tokens for a model.
Args:
model_name: The name of the model
Returns:
Optional[int]: The max output tokens if found, None otherwise
"""
spec = await get_model_spec(model_name)
if not spec:
return None
# Try max_output_tokens first, fall back to max_tokens
return spec.get("max_output_tokens") or spec.get("max_tokens")
async def get_context_window(model_name: str) -> Optional[int]:
"""Get the context window size for a model.
For most models, this is the max_input_tokens.
Args:
model_name: The name of the model
Returns:
Optional[int]: The context window size if found, None otherwise
"""
return await get_max_input_tokens(model_name)
async def get_litellm_provider(model_name: str) -> Optional[str]:
"""Get the litellm provider for a model.
Args:
model_name: The name of the model
Returns:
Optional[str]: The provider name if found, None otherwise
"""
spec = await get_model_spec(model_name)
if not spec:
return None
return spec.get("litellm_provider")