feat: list out embedding models for Google AI provider (#1839)
This commit is contained in:
1
.github/workflows/tests.yml
vendored
1
.github/workflows/tests.yml
vendored
@@ -4,6 +4,7 @@ env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
COMPOSIO_API_KEY: ${{ secrets.COMPOSIO_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
||||
|
||||
on:
|
||||
push:
|
||||
|
||||
@@ -182,7 +182,7 @@ class GroqProvider(OpenAIProvider):
|
||||
class GoogleAIProvider(Provider):
|
||||
# gemini
|
||||
api_key: str = Field(..., description="API key for the Google AI API.")
|
||||
service_endpoint: str = "generativelanguage"
|
||||
service_endpoint: str = "generativelanguage" # TODO: remove once old functions are refactored to just use base_url
|
||||
base_url: str = "https://generativelanguage.googleapis.com"
|
||||
|
||||
def list_llm_models(self):
|
||||
@@ -190,12 +190,15 @@ class GoogleAIProvider(Provider):
|
||||
|
||||
# TODO: use base_url instead
|
||||
model_options = google_ai_get_model_list(service_endpoint=self.service_endpoint, api_key=self.api_key)
|
||||
# filter by 'generateContent' models
|
||||
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
|
||||
model_options = [str(m["name"]) for m in model_options]
|
||||
|
||||
# filter by model names
|
||||
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
|
||||
|
||||
# TODO remove manual filtering for gemini-pro
|
||||
model_options = [mo for mo in model_options if str(mo).startswith("gemini") and "-pro" in str(mo)]
|
||||
# TODO: add context windows
|
||||
# model_options = ["gemini-pro"]
|
||||
|
||||
configs = []
|
||||
for model in model_options:
|
||||
@@ -210,7 +213,27 @@ class GoogleAIProvider(Provider):
|
||||
return configs
|
||||
|
||||
def list_embedding_models(self):
|
||||
return []
|
||||
from letta.llm_api.google_ai import google_ai_get_model_list
|
||||
|
||||
# TODO: use base_url instead
|
||||
model_options = google_ai_get_model_list(service_endpoint=self.service_endpoint, api_key=self.api_key)
|
||||
# filter by 'generateContent' models
|
||||
model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]]
|
||||
model_options = [str(m["name"]) for m in model_options]
|
||||
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
|
||||
|
||||
configs = []
|
||||
for model in model_options:
|
||||
configs.append(
|
||||
EmbeddingConfig(
|
||||
embedding_model=model,
|
||||
embedding_endpoint_type="google_ai",
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=768,
|
||||
embedding_chunk_size=300, # NOTE: max is 2048
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
||||
def get_model_context_window(self, model_name: str):
|
||||
from letta.llm_api.google_ai import google_ai_get_model_context_window
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
|
||||
from letta.providers import AnthropicProvider, OpenAIProvider
|
||||
from letta.providers import AnthropicProvider, GoogleAIProvider, OpenAIProvider
|
||||
|
||||
|
||||
def test_openai():
|
||||
@@ -30,13 +30,17 @@ def test_anthropic():
|
||||
# print(models)
|
||||
#
|
||||
#
|
||||
# def test_googleai():
|
||||
# provider = GoogleAIProvider(api_key=os.getenv("GEMINI_API_KEY"))
|
||||
# models = provider.list_llm_models()
|
||||
# print(models)
|
||||
def test_googleai():
|
||||
provider = GoogleAIProvider(api_key=os.getenv("GEMINI_API_KEY"))
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
|
||||
provider.list_embedding_models()
|
||||
|
||||
|
||||
#
|
||||
#
|
||||
# test_googleai()
|
||||
test_googleai()
|
||||
# test_ollama()
|
||||
# test_groq()
|
||||
# test_openai()
|
||||
|
||||
Reference in New Issue
Block a user