feat: list out embedding models for Google AI provider (#1839)

This commit is contained in:
Sarah Wooders
2024-10-08 11:28:24 -07:00
committed by GitHub
parent 5e294158af
commit 91287a76c9
3 changed files with 38 additions and 10 deletions

View File

@@ -4,6 +4,7 @@ env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
COMPOSIO_API_KEY: ${{ secrets.COMPOSIO_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
on:
push:

View File

@@ -182,7 +182,7 @@ class GroqProvider(OpenAIProvider):
class GoogleAIProvider(Provider):
# gemini
api_key: str = Field(..., description="API key for the Google AI API.")
service_endpoint: str = "generativelanguage"
service_endpoint: str = "generativelanguage" # TODO: remove once old functions are refactored to just use base_url
base_url: str = "https://generativelanguage.googleapis.com"
def list_llm_models(self):
@@ -190,12 +190,15 @@ class GoogleAIProvider(Provider):
# TODO: use base_url instead
model_options = google_ai_get_model_list(service_endpoint=self.service_endpoint, api_key=self.api_key)
# filter by 'generateContent' models
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
model_options = [str(m["name"]) for m in model_options]
# filter by model names
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
# TODO remove manual filtering for gemini-pro
model_options = [mo for mo in model_options if str(mo).startswith("gemini") and "-pro" in str(mo)]
# TODO: add context windows
# model_options = ["gemini-pro"]
configs = []
for model in model_options:
@@ -210,7 +213,27 @@ class GoogleAIProvider(Provider):
return configs
def list_embedding_models(self):
return []
from letta.llm_api.google_ai import google_ai_get_model_list
# TODO: use base_url instead
model_options = google_ai_get_model_list(service_endpoint=self.service_endpoint, api_key=self.api_key)
# filter by 'generateContent' models
model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]]
model_options = [str(m["name"]) for m in model_options]
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
configs = []
for model in model_options:
configs.append(
EmbeddingConfig(
embedding_model=model,
embedding_endpoint_type="google_ai",
embedding_endpoint=self.base_url,
embedding_dim=768,
embedding_chunk_size=300, # NOTE: max is 2048
)
)
return configs
def get_model_context_window(self, model_name: str):
from letta.llm_api.google_ai import google_ai_get_model_context_window

View File

@@ -1,6 +1,6 @@
import os
from letta.providers import AnthropicProvider, OpenAIProvider
from letta.providers import AnthropicProvider, GoogleAIProvider, OpenAIProvider
def test_openai():
@@ -30,13 +30,17 @@ def test_anthropic():
# print(models)
#
#
# def test_googleai():
# provider = GoogleAIProvider(api_key=os.getenv("GEMINI_API_KEY"))
# models = provider.list_llm_models()
# print(models)
def test_googleai():
provider = GoogleAIProvider(api_key=os.getenv("GEMINI_API_KEY"))
models = provider.list_llm_models()
print(models)
provider.list_embedding_models()
#
#
# test_googleai()
test_googleai()
# test_ollama()
# test_groq()
# test_openai()