fix: Clean up some legacy code and fix Groq provider (#1950)
This commit is contained in:
1
.github/workflows/tests.yml
vendored
1
.github/workflows/tests.yml
vendored
@@ -5,6 +5,7 @@ env:
|
||||
COMPOSIO_API_KEY: ${{ secrets.COMPOSIO_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
||||
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
|
||||
|
||||
on:
|
||||
push:
|
||||
|
||||
@@ -276,10 +276,10 @@ class AbstractClient(object):
|
||||
) -> List[Message]:
|
||||
raise NotImplementedError
|
||||
|
||||
def list_models(self) -> List[LLMConfig]:
|
||||
def list_model_configs(self) -> List[LLMConfig]:
|
||||
raise NotImplementedError
|
||||
|
||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||
def list_embedding_configs(self) -> List[EmbeddingConfig]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1234,32 +1234,6 @@ class RESTClient(AbstractClient):
|
||||
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
|
||||
return Source(**response.json())
|
||||
|
||||
# server configuration commands
|
||||
|
||||
def list_models(self):
|
||||
"""
|
||||
List available LLM models
|
||||
|
||||
Returns:
|
||||
models (List[LLMConfig]): List of LLM models
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/models", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to list models: {response.text}")
|
||||
return [LLMConfig(**model) for model in response.json()]
|
||||
|
||||
def list_embedding_models(self):
|
||||
"""
|
||||
List available embedding models
|
||||
|
||||
Returns:
|
||||
models (List[EmbeddingConfig]): List of embedding models
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/models/embedding", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to list embedding models: {response.text}")
|
||||
return [EmbeddingConfig(**model) for model in response.json()]
|
||||
|
||||
# tools
|
||||
|
||||
def get_tool_id(self, tool_name: str):
|
||||
@@ -2572,24 +2546,6 @@ class LocalClient(AbstractClient):
|
||||
return_message_object=True,
|
||||
)
|
||||
|
||||
def list_models(self) -> List[LLMConfig]:
|
||||
"""
|
||||
List available LLM models
|
||||
|
||||
Returns:
|
||||
models (List[LLMConfig]): List of LLM models
|
||||
"""
|
||||
return self.server.list_models()
|
||||
|
||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||
"""
|
||||
List available embedding models
|
||||
|
||||
Returns:
|
||||
models (List[EmbeddingConfig]): List of embedding models
|
||||
"""
|
||||
return [self.server.server_embedding_config]
|
||||
|
||||
def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool] = True) -> List[Block]:
|
||||
"""
|
||||
List available blocks
|
||||
|
||||
@@ -313,7 +313,6 @@ def create(
|
||||
stream_interface.stream_start()
|
||||
try:
|
||||
# groq uses the openai chat completions API, so this component should be reusable
|
||||
assert model_settings.groq_api_key is not None, "Groq key is missing"
|
||||
response = openai_chat_completions_request(
|
||||
url=llm_config.model_endpoint,
|
||||
api_key=model_settings.groq_api_key,
|
||||
|
||||
@@ -313,7 +313,7 @@ class GroqProvider(OpenAIProvider):
|
||||
continue
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model["id"], model_endpoint_type="openai", model_endpoint=self.base_url, context_window=model["context_window"]
|
||||
model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"]
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
||||
@@ -43,6 +43,7 @@ from letta.providers import (
|
||||
AnthropicProvider,
|
||||
AzureProvider,
|
||||
GoogleAIProvider,
|
||||
GroqProvider,
|
||||
LettaProvider,
|
||||
OllamaProvider,
|
||||
OpenAIProvider,
|
||||
@@ -298,6 +299,8 @@ class SyncServer(Server):
|
||||
api_version=model_settings.azure_api_version,
|
||||
)
|
||||
)
|
||||
if model_settings.groq_api_key:
|
||||
self._enabled_providers.append(GroqProvider(api_key=model_settings.groq_api_key))
|
||||
if model_settings.vllm_api_base:
|
||||
# vLLM exposes both a /chat/completions and a /completions endpoint
|
||||
self._enabled_providers.append(
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Union
|
||||
from typing import List, Union
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
@@ -18,6 +18,7 @@ from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.settings import model_settings
|
||||
from tests.helpers.client_helper import upload_file_using_client
|
||||
|
||||
# from tests.utils import create_config
|
||||
@@ -262,21 +263,6 @@ def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentSta
|
||||
assert human.value == "Human text", "Creating human failed"
|
||||
|
||||
|
||||
def test_config(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# _reset_config()
|
||||
|
||||
models_response = client.list_models()
|
||||
print("MODELS", models_response)
|
||||
|
||||
embeddings_response = client.list_embedding_models()
|
||||
print("EMBEDDINGS", embeddings_response)
|
||||
|
||||
# TODO: add back
|
||||
# config_response = client.get_config()
|
||||
# TODO: ensure config is the same as the one in the server
|
||||
# print("CONFIG", config_response)
|
||||
|
||||
|
||||
def test_list_tools_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
tools = client.list_tools()
|
||||
visited_ids = {t.id: False for t in tools}
|
||||
@@ -503,11 +489,20 @@ def test_organization(client: RESTClient):
|
||||
pytest.skip("Skipping test_organization because LocalClient does not support organizations")
|
||||
|
||||
|
||||
def test_model_configs(client: Union[LocalClient, RESTClient]):
|
||||
# _reset_config()
|
||||
def test_list_llm_models(client: RESTClient):
|
||||
"""Test that if the user's env has the right api keys set, at least one model appears in the model list"""
|
||||
|
||||
model_configs = client.list_models()
|
||||
print("MODEL CONFIGS", model_configs)
|
||||
def has_model_endpoint_type(models: List["LLMConfig"], target_type: str) -> bool:
|
||||
return any(model.model_endpoint_type == target_type for model in models)
|
||||
|
||||
embedding_configs = client.list_embedding_models()
|
||||
print("EMBEDDING CONFIGS", embedding_configs)
|
||||
models = client.list_llm_configs()
|
||||
if model_settings.groq_api_key:
|
||||
assert has_model_endpoint_type(models, "groq")
|
||||
if model_settings.azure_api_key:
|
||||
assert has_model_endpoint_type(models, "azure")
|
||||
if model_settings.openai_api_key:
|
||||
assert has_model_endpoint_type(models, "openai")
|
||||
if model_settings.gemini_api_key:
|
||||
assert has_model_endpoint_type(models, "google_ai")
|
||||
if model_settings.anthropic_api_key:
|
||||
assert has_model_endpoint_type(models, "anthropic")
|
||||
|
||||
@@ -4,6 +4,7 @@ from letta.providers import (
|
||||
AnthropicProvider,
|
||||
AzureProvider,
|
||||
GoogleAIProvider,
|
||||
GroqProvider,
|
||||
MistralProvider,
|
||||
OllamaProvider,
|
||||
OpenAIProvider,
|
||||
@@ -27,12 +28,10 @@ def test_anthropic():
|
||||
print(models)
|
||||
|
||||
|
||||
# def test_groq():
|
||||
# provider = GroqProvider(api_key=os.getenv("GROQ_API_KEY"))
|
||||
# models = provider.list_llm_models()
|
||||
# print(models)
|
||||
#
|
||||
#
|
||||
def test_groq():
|
||||
provider = GroqProvider(api_key=os.getenv("GROQ_API_KEY"))
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
|
||||
|
||||
def test_azure():
|
||||
|
||||
Reference in New Issue
Block a user