fix: Clean up some legacy code and fix Groq provider (#1950)

This commit is contained in:
Matthew Zhou
2024-10-28 14:13:11 -07:00
committed by GitHub
parent 969300fb56
commit 9439e5e255
7 changed files with 29 additions and 76 deletions

View File

@@ -5,6 +5,7 @@ env:
COMPOSIO_API_KEY: ${{ secrets.COMPOSIO_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
on:
push:

View File

@@ -276,10 +276,10 @@ class AbstractClient(object):
) -> List[Message]:
raise NotImplementedError
def list_models(self) -> List[LLMConfig]:
def list_model_configs(self) -> List[LLMConfig]:
raise NotImplementedError
def list_embedding_models(self) -> List[EmbeddingConfig]:
def list_embedding_configs(self) -> List[EmbeddingConfig]:
raise NotImplementedError
@@ -1234,32 +1234,6 @@ class RESTClient(AbstractClient):
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
return Source(**response.json())
# server configuration commands
def list_models(self):
"""
List available LLM models
Returns:
models (List[LLMConfig]): List of LLM models
"""
response = requests.get(f"{self.base_url}/{self.api_prefix}/models", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to list models: {response.text}")
return [LLMConfig(**model) for model in response.json()]
def list_embedding_models(self):
"""
List available embedding models
Returns:
models (List[EmbeddingConfig]): List of embedding models
"""
response = requests.get(f"{self.base_url}/{self.api_prefix}/models/embedding", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to list embedding models: {response.text}")
return [EmbeddingConfig(**model) for model in response.json()]
# tools
def get_tool_id(self, tool_name: str):
@@ -2572,24 +2546,6 @@ class LocalClient(AbstractClient):
return_message_object=True,
)
def list_models(self) -> List[LLMConfig]:
"""
List available LLM models
Returns:
models (List[LLMConfig]): List of LLM models
"""
return self.server.list_models()
def list_embedding_models(self) -> List[EmbeddingConfig]:
"""
List available embedding models
Returns:
models (List[EmbeddingConfig]): List of embedding models
"""
return [self.server.server_embedding_config]
def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool] = True) -> List[Block]:
"""
List available blocks

View File

@@ -313,7 +313,6 @@ def create(
stream_interface.stream_start()
try:
# groq uses the openai chat completions API, so this component should be reusable
assert model_settings.groq_api_key is not None, "Groq key is missing"
response = openai_chat_completions_request(
url=llm_config.model_endpoint,
api_key=model_settings.groq_api_key,

View File

@@ -313,7 +313,7 @@ class GroqProvider(OpenAIProvider):
continue
configs.append(
LLMConfig(
model=model["id"], model_endpoint_type="openai", model_endpoint=self.base_url, context_window=model["context_window"]
model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"]
)
)
return configs

View File

@@ -43,6 +43,7 @@ from letta.providers import (
AnthropicProvider,
AzureProvider,
GoogleAIProvider,
GroqProvider,
LettaProvider,
OllamaProvider,
OpenAIProvider,
@@ -298,6 +299,8 @@ class SyncServer(Server):
api_version=model_settings.azure_api_version,
)
)
if model_settings.groq_api_key:
self._enabled_providers.append(GroqProvider(api_key=model_settings.groq_api_key))
if model_settings.vllm_api_base:
# vLLM exposes both a /chat/completions and a /completions endpoint
self._enabled_providers.append(

View File

@@ -2,7 +2,7 @@ import os
import threading
import time
import uuid
from typing import Union
from typing import List, Union
import pytest
from dotenv import load_dotenv
@@ -18,6 +18,7 @@ from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.usage import LettaUsageStatistics
from letta.settings import model_settings
from tests.helpers.client_helper import upload_file_using_client
# from tests.utils import create_config
@@ -262,21 +263,6 @@ def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentSta
assert human.value == "Human text", "Creating human failed"
def test_config(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
models_response = client.list_models()
print("MODELS", models_response)
embeddings_response = client.list_embedding_models()
print("EMBEDDINGS", embeddings_response)
# TODO: add back
# config_response = client.get_config()
# TODO: ensure config is the same as the one in the server
# print("CONFIG", config_response)
def test_list_tools_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
tools = client.list_tools()
visited_ids = {t.id: False for t in tools}
@@ -503,11 +489,20 @@ def test_organization(client: RESTClient):
pytest.skip("Skipping test_organization because LocalClient does not support organizations")
def test_model_configs(client: Union[LocalClient, RESTClient]):
# _reset_config()
def test_list_llm_models(client: RESTClient):
"""Test that if the user's env has the right api keys set, at least one model appears in the model list"""
model_configs = client.list_models()
print("MODEL CONFIGS", model_configs)
def has_model_endpoint_type(models: List["LLMConfig"], target_type: str) -> bool:
return any(model.model_endpoint_type == target_type for model in models)
embedding_configs = client.list_embedding_models()
print("EMBEDDING CONFIGS", embedding_configs)
models = client.list_llm_configs()
if model_settings.groq_api_key:
assert has_model_endpoint_type(models, "groq")
if model_settings.azure_api_key:
assert has_model_endpoint_type(models, "azure")
if model_settings.openai_api_key:
assert has_model_endpoint_type(models, "openai")
if model_settings.gemini_api_key:
assert has_model_endpoint_type(models, "google_ai")
if model_settings.anthropic_api_key:
assert has_model_endpoint_type(models, "anthropic")

View File

@@ -4,6 +4,7 @@ from letta.providers import (
AnthropicProvider,
AzureProvider,
GoogleAIProvider,
GroqProvider,
MistralProvider,
OllamaProvider,
OpenAIProvider,
@@ -27,12 +28,10 @@ def test_anthropic():
print(models)
# def test_groq():
# provider = GroqProvider(api_key=os.getenv("GROQ_API_KEY"))
# models = provider.list_llm_models()
# print(models)
#
#
def test_groq():
provider = GroqProvider(api_key=os.getenv("GROQ_API_KEY"))
models = provider.list_llm_models()
print(models)
def test_azure():