diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 779233e9..24a9b7db 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -4,6 +4,7 @@ env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} COMPOSIO_API_KEY: ${{ secrets.COMPOSIO_API_KEY }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} on: push: diff --git a/letta/providers.py b/letta/providers.py index ccb6c97c..ac617062 100644 --- a/letta/providers.py +++ b/letta/providers.py @@ -182,7 +182,7 @@ class GroqProvider(OpenAIProvider): class GoogleAIProvider(Provider): # gemini api_key: str = Field(..., description="API key for the Google AI API.") - service_endpoint: str = "generativelanguage" + service_endpoint: str = "generativelanguage" # TODO: remove once old functions are refactored to just use base_url base_url: str = "https://generativelanguage.googleapis.com" def list_llm_models(self): @@ -190,12 +190,15 @@ class GoogleAIProvider(Provider): # TODO: use base_url instead model_options = google_ai_get_model_list(service_endpoint=self.service_endpoint, api_key=self.api_key) + # filter by 'generateContent' models + model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]] model_options = [str(m["name"]) for m in model_options] + + # filter by model names model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options] + # TODO remove manual filtering for gemini-pro model_options = [mo for mo in model_options if str(mo).startswith("gemini") and "-pro" in str(mo)] - # TODO: add context windows - # model_options = ["gemini-pro"] configs = [] for model in model_options: @@ -210,7 +213,27 @@ class GoogleAIProvider(Provider): return configs def list_embedding_models(self): - return [] + from letta.llm_api.google_ai import google_ai_get_model_list + + # TODO: use base_url instead + model_options = google_ai_get_model_list(service_endpoint=self.service_endpoint, api_key=self.api_key) + # filter by 'generateContent' models + model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]] + model_options = [str(m["name"]) for m in model_options] + model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options] + + configs = [] + for model in model_options: + configs.append( + EmbeddingConfig( + embedding_model=model, + embedding_endpoint_type="google_ai", + embedding_endpoint=self.base_url, + embedding_dim=768, + embedding_chunk_size=300, # NOTE: max is 2048 + ) + ) + return configs def get_model_context_window(self, model_name: str): from letta.llm_api.google_ai import google_ai_get_model_context_window diff --git a/tests/test_providers.py b/tests/test_providers.py index 94058b37..fecacd79 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -1,6 +1,6 @@ import os -from letta.providers import AnthropicProvider, OpenAIProvider +from letta.providers import AnthropicProvider, GoogleAIProvider, OpenAIProvider def test_openai(): @@ -30,13 +30,17 @@ def test_anthropic(): # print(models) # # -# def test_googleai(): -# provider = GoogleAIProvider(api_key=os.getenv("GEMINI_API_KEY")) -# models = provider.list_llm_models() -# print(models) +def test_googleai(): + provider = GoogleAIProvider(api_key=os.getenv("GEMINI_API_KEY")) + models = provider.list_llm_models() + print(models) + + provider.list_embedding_models() + + # # -# test_googleai() +test_googleai() # test_ollama() # test_groq() # test_openai()