fix: Clean up some legacy code and fix Groq provider (#1950)

This commit is contained in:
Matthew Zhou
2024-10-28 14:13:11 -07:00
committed by GitHub
parent 969300fb56
commit 9439e5e255
7 changed files with 29 additions and 76 deletions

View File

@@ -2,7 +2,7 @@ import os
import threading
import time
import uuid
from typing import Union
from typing import List, Union
import pytest
from dotenv import load_dotenv
@@ -18,6 +18,7 @@ from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.usage import LettaUsageStatistics
from letta.settings import model_settings
from tests.helpers.client_helper import upload_file_using_client
# from tests.utils import create_config
@@ -262,21 +263,6 @@ def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentSta
assert human.value == "Human text", "Creating human failed"
def test_config(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
models_response = client.list_models()
print("MODELS", models_response)
embeddings_response = client.list_embedding_models()
print("EMBEDDINGS", embeddings_response)
# TODO: add back
# config_response = client.get_config()
# TODO: ensure config is the same as the one in the server
# print("CONFIG", config_response)
def test_list_tools_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
tools = client.list_tools()
visited_ids = {t.id: False for t in tools}
@@ -503,11 +489,20 @@ def test_organization(client: RESTClient):
pytest.skip("Skipping test_organization because LocalClient does not support organizations")
def test_model_configs(client: Union[LocalClient, RESTClient]):
# _reset_config()
def test_list_llm_models(client: RESTClient):
"""Test that if the user's env has the right api keys set, at least one model appears in the model list"""
model_configs = client.list_models()
print("MODEL CONFIGS", model_configs)
def has_model_endpoint_type(models: List["LLMConfig"], target_type: str) -> bool:
return any(model.model_endpoint_type == target_type for model in models)
embedding_configs = client.list_embedding_models()
print("EMBEDDING CONFIGS", embedding_configs)
models = client.list_llm_configs()
if model_settings.groq_api_key:
assert has_model_endpoint_type(models, "groq")
if model_settings.azure_api_key:
assert has_model_endpoint_type(models, "azure")
if model_settings.openai_api_key:
assert has_model_endpoint_type(models, "openai")
if model_settings.gemini_api_key:
assert has_model_endpoint_type(models, "google_ai")
if model_settings.anthropic_api_key:
assert has_model_endpoint_type(models, "anthropic")