diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c15d1a90..de9a4792 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -5,6 +5,7 @@ env: COMPOSIO_API_KEY: ${{ secrets.COMPOSIO_API_KEY }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} on: push: diff --git a/letta/client/client.py b/letta/client/client.py index 6d02d74d..30945530 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -276,10 +276,10 @@ class AbstractClient(object): ) -> List[Message]: raise NotImplementedError - def list_models(self) -> List[LLMConfig]: + def list_model_configs(self) -> List[LLMConfig]: raise NotImplementedError - def list_embedding_models(self) -> List[EmbeddingConfig]: + def list_embedding_configs(self) -> List[EmbeddingConfig]: raise NotImplementedError @@ -1234,32 +1234,6 @@ class RESTClient(AbstractClient): assert response.status_code == 200, f"Failed to detach source from agent: {response.text}" return Source(**response.json()) - # server configuration commands - - def list_models(self): - """ - List available LLM models - - Returns: - models (List[LLMConfig]): List of LLM models - """ - response = requests.get(f"{self.base_url}/{self.api_prefix}/models", headers=self.headers) - if response.status_code != 200: - raise ValueError(f"Failed to list models: {response.text}") - return [LLMConfig(**model) for model in response.json()] - - def list_embedding_models(self): - """ - List available embedding models - - Returns: - models (List[EmbeddingConfig]): List of embedding models - """ - response = requests.get(f"{self.base_url}/{self.api_prefix}/models/embedding", headers=self.headers) - if response.status_code != 200: - raise ValueError(f"Failed to list embedding models: {response.text}") - return [EmbeddingConfig(**model) for model in response.json()] - # tools def get_tool_id(self, tool_name: str): @@ -2572,24 +2546,6 @@ class LocalClient(AbstractClient): return_message_object=True, ) - def list_models(self) -> List[LLMConfig]: - """ - List available LLM models - - Returns: - models (List[LLMConfig]): List of LLM models - """ - return self.server.list_models() - - def list_embedding_models(self) -> List[EmbeddingConfig]: - """ - List available embedding models - - Returns: - models (List[EmbeddingConfig]): List of embedding models - """ - return [self.server.server_embedding_config] - def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool] = True) -> List[Block]: """ List available blocks diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 7408b25b..7af35721 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -313,7 +313,6 @@ def create( stream_interface.stream_start() try: # groq uses the openai chat completions API, so this component should be reusable - assert model_settings.groq_api_key is not None, "Groq key is missing" response = openai_chat_completions_request( url=llm_config.model_endpoint, api_key=model_settings.groq_api_key, diff --git a/letta/providers.py b/letta/providers.py index c4b878b6..6fa98327 100644 --- a/letta/providers.py +++ b/letta/providers.py @@ -313,7 +313,7 @@ class GroqProvider(OpenAIProvider): continue configs.append( LLMConfig( - model=model["id"], model_endpoint_type="openai", model_endpoint=self.base_url, context_window=model["context_window"] + model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"] ) ) return configs diff --git a/letta/server/server.py b/letta/server/server.py index 00e1dc37..5ebd77b3 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -43,6 +43,7 @@ from letta.providers import ( AnthropicProvider, AzureProvider, GoogleAIProvider, + GroqProvider, LettaProvider, OllamaProvider, OpenAIProvider, @@ -298,6 +299,8 @@ class SyncServer(Server): api_version=model_settings.azure_api_version, ) ) + if model_settings.groq_api_key: + self._enabled_providers.append(GroqProvider(api_key=model_settings.groq_api_key)) if model_settings.vllm_api_base: # vLLM exposes both a /chat/completions and a /completions endpoint self._enabled_providers.append( diff --git a/tests/test_client.py b/tests/test_client.py index 718f6045..3fd015c5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2,7 +2,7 @@ import os import threading import time import uuid -from typing import Union +from typing import List, Union import pytest from dotenv import load_dotenv @@ -18,6 +18,7 @@ from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message from letta.schemas.usage import LettaUsageStatistics +from letta.settings import model_settings from tests.helpers.client_helper import upload_file_using_client # from tests.utils import create_config @@ -262,21 +263,6 @@ def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentSta assert human.value == "Human text", "Creating human failed" -def test_config(client: Union[LocalClient, RESTClient], agent: AgentState): - # _reset_config() - - models_response = client.list_models() - print("MODELS", models_response) - - embeddings_response = client.list_embedding_models() - print("EMBEDDINGS", embeddings_response) - - # TODO: add back - # config_response = client.get_config() - # TODO: ensure config is the same as the one in the server - # print("CONFIG", config_response) - - def test_list_tools_pagination(client: Union[LocalClient, RESTClient], agent: AgentState): tools = client.list_tools() visited_ids = {t.id: False for t in tools} @@ -503,11 +489,20 @@ def test_organization(client: RESTClient): pytest.skip("Skipping test_organization because LocalClient does not support organizations") -def test_model_configs(client: Union[LocalClient, RESTClient]): - # _reset_config() +def test_list_llm_models(client: RESTClient): + """Test that if the user's env has the right api keys set, at least one model appears in the model list""" - model_configs = client.list_models() - print("MODEL CONFIGS", model_configs) + def has_model_endpoint_type(models: List["LLMConfig"], target_type: str) -> bool: + return any(model.model_endpoint_type == target_type for model in models) - embedding_configs = client.list_embedding_models() - print("EMBEDDING CONFIGS", embedding_configs) + models = client.list_llm_configs() + if model_settings.groq_api_key: + assert has_model_endpoint_type(models, "groq") + if model_settings.azure_api_key: + assert has_model_endpoint_type(models, "azure") + if model_settings.openai_api_key: + assert has_model_endpoint_type(models, "openai") + if model_settings.gemini_api_key: + assert has_model_endpoint_type(models, "google_ai") + if model_settings.anthropic_api_key: + assert has_model_endpoint_type(models, "anthropic") diff --git a/tests/test_providers.py b/tests/test_providers.py index e1e15be2..21f0c9ff 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -4,6 +4,7 @@ from letta.providers import ( AnthropicProvider, AzureProvider, GoogleAIProvider, + GroqProvider, MistralProvider, OllamaProvider, OpenAIProvider, @@ -27,12 +28,10 @@ def test_anthropic(): print(models) -# def test_groq(): -# provider = GroqProvider(api_key=os.getenv("GROQ_API_KEY")) -# models = provider.list_llm_models() -# print(models) -# -# +def test_groq(): + provider = GroqProvider(api_key=os.getenv("GROQ_API_KEY")) + models = provider.list_llm_models() + print(models) def test_azure():