61 lines
3.0 KiB
Python
61 lines
3.0 KiB
Python
from typing import Literal
|
|
|
|
from pydantic import Field
|
|
|
|
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.enums import ProviderCategory, ProviderType
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.providers.base import Provider
|
|
|
|
|
|
# TODO (cliandy): GoogleVertexProvider uses hardcoded models vs Gemini fetches from API
|
|
class GoogleVertexProvider(Provider):
|
|
provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.")
|
|
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
|
google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.")
|
|
google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.")
|
|
|
|
def get_default_max_output_tokens(self, model_name: str) -> int:
|
|
"""Get the default max output tokens for Google Vertex models."""
|
|
if "2.5" in model_name or "2-5" in model_name or model_name.startswith("gemini-3"):
|
|
return 65536
|
|
return 8192 # default for google vertex
|
|
|
|
async def list_llm_models_async(self) -> list[LLMConfig]:
|
|
from letta.llm_api.google_constants import GOOGLE_MODEL_TO_CONTEXT_LENGTH
|
|
|
|
configs = []
|
|
for model, context_length in GOOGLE_MODEL_TO_CONTEXT_LENGTH.items():
|
|
configs.append(
|
|
LLMConfig(
|
|
model=model,
|
|
model_endpoint_type="google_vertex",
|
|
model_endpoint=f"https://{self.google_cloud_location}-aiplatform.googleapis.com/v1/projects/{self.google_cloud_project}/locations/{self.google_cloud_location}",
|
|
context_window=context_length,
|
|
handle=self.get_handle(model),
|
|
max_tokens=self.get_default_max_output_tokens(model),
|
|
provider_name=self.name,
|
|
provider_category=self.provider_category,
|
|
)
|
|
)
|
|
return configs
|
|
|
|
async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
|
|
from letta.llm_api.google_constants import GOOGLE_EMBEDING_MODEL_TO_DIM
|
|
|
|
configs = []
|
|
for model, dim in GOOGLE_EMBEDING_MODEL_TO_DIM.items():
|
|
configs.append(
|
|
EmbeddingConfig(
|
|
embedding_model=model,
|
|
embedding_endpoint_type="google_vertex",
|
|
embedding_endpoint=f"https://{self.google_cloud_location}-aiplatform.googleapis.com/v1/projects/{self.google_cloud_project}/locations/{self.google_cloud_location}",
|
|
embedding_dim=dim,
|
|
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, # NOTE: max is 2048
|
|
handle=self.get_handle(model, is_embedding=True),
|
|
batch_size=1024,
|
|
)
|
|
)
|
|
return configs
|