Files
letta-server/letta/schemas/providers/google_gemini.py
2025-07-22 16:09:50 -07:00

103 lines
4.5 KiB
Python

import asyncio
from typing import Literal
from pydantic import Field
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LLM_MAX_TOKENS
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers.base import Provider
class GoogleAIProvider(Provider):
provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
api_key: str = Field(..., description="API key for the Google AI API.")
base_url: str = "https://generativelanguage.googleapis.com"
async def check_api_key(self):
from letta.llm_api.google_ai_client import google_ai_check_valid_api_key
google_ai_check_valid_api_key(self.api_key)
async def list_llm_models_async(self):
from letta.llm_api.google_ai_client import google_ai_get_model_list_async
# Get and filter the model list
model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key)
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
model_options = [str(m["name"]) for m in model_options]
# filter by model names
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
# Add support for all gemini models
model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
# Prepare tasks for context window lookups in parallel
async def create_config(model):
context_window = await self.get_model_context_window_async(model)
return LLMConfig(
model=model,
model_endpoint_type="google_ai",
model_endpoint=self.base_url,
context_window=context_window,
handle=self.get_handle(model),
max_tokens=8192,
provider_name=self.name,
provider_category=self.provider_category,
)
# Execute all config creation tasks concurrently
configs = await asyncio.gather(*[create_config(model) for model in model_options])
return configs
async def list_embedding_models_async(self):
from letta.llm_api.google_ai_client import google_ai_get_model_list_async
# TODO: use base_url instead
model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key)
return self._list_embedding_models(model_options)
def _list_embedding_models(self, model_options):
# filter by 'generateContent' models
model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]]
model_options = [str(m["name"]) for m in model_options]
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
configs = []
for model in model_options:
configs.append(
EmbeddingConfig(
embedding_model=model,
embedding_endpoint_type="google_ai",
embedding_endpoint=self.base_url,
embedding_dim=768,
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, # NOTE: max is 2048
handle=self.get_handle(model, is_embedding=True),
batch_size=1024,
)
)
return configs
def get_model_context_window(self, model_name: str) -> int | None:
import warnings
warnings.warn("This is deprecated, use get_model_context_window_async when possible.", DeprecationWarning)
from letta.llm_api.google_ai_client import google_ai_get_model_context_window
if model_name in LLM_MAX_TOKENS:
return LLM_MAX_TOKENS[model_name]
else:
return google_ai_get_model_context_window(self.base_url, self.api_key, model_name)
async def get_model_context_window_async(self, model_name: str) -> int | None:
from letta.llm_api.google_ai_client import google_ai_get_model_context_window_async
if model_name in LLM_MAX_TOKENS:
return LLM_MAX_TOKENS[model_name]
else:
return await google_ai_get_model_context_window_async(self.base_url, self.api_key, model_name)