Files
letta-server/tests/test_endpoints.py
Matt Zhou a3b46dc026 Fix typo
2024-10-01 16:42:47 -07:00

130 lines
4.1 KiB
Python

import json
import os
import uuid
from letta import LocalClient, RESTClient, create_client
from letta.agent import Agent
from letta.config import LettaConfig
from letta.embeddings import embedding_model
from letta.llm_api.llm_api_tools import create
from letta.prompts import gpt_system
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
messages = [Message(role="system", text=gpt_system.get_system_text("memgpt_chat")), Message(role="user", text="How are you?")]
# defaults (letta hosted)
embedding_config_path = "configs/embedding_model_configs/letta-hosted.json"
llm_config_path = "configs/llm_model_configs/letta-hosted.json"
# directories
embedding_config_dir = "configs/embedding_model_configs"
llm_config_dir = "configs/llm_model_configs"
# Generate uuid for agent name for this example
namespace = uuid.NAMESPACE_DNS
agent_uuid = str(uuid.uuid5(namespace, "letta-endpoint-tests"))
def clean_up_agent(client: LocalClient | RESTClient):
# Clear all agents
for agent_state in client.list_agents():
if agent_state.name == agent_uuid:
client.delete_agent(agent_id=agent_state.id)
print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}")
def run_llm_endpoint(filename):
config_data = json.load(open(filename, "r"))
print(config_data)
llm_config = LLMConfig(**config_data)
embedding_config = EmbeddingConfig(**json.load(open(embedding_config_path)))
# setup config
config = LettaConfig()
config.default_llm_config = llm_config
config.default_embedding_config = embedding_config
config.save()
client = create_client()
clean_up_agent(client)
agent_state = client.create_agent(name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config)
tools = [client.get_tool(client.get_tool_id(name=name)) for name in agent_state.tools]
agent = Agent(
interface=None,
tools=tools,
agent_state=agent_state,
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
first_message_verify_mono=True,
)
response = create(
llm_config=llm_config,
user_id=uuid.UUID(int=1), # dummy user_id
# messages=agent_state.messages,
messages=agent._messages,
functions=agent.functions,
functions_python=agent.functions_python,
)
client.delete_agent(agent_state.id)
assert response is not None
print(response)
def run_embedding_endpoint(filename):
# load JSON file
config_data = json.load(open(filename, "r"))
print(config_data)
embedding_config = EmbeddingConfig(**config_data)
model = embedding_model(embedding_config)
query_text = "hello"
query_vec = model.get_text_embedding(query_text)
print("vector dim", len(query_vec))
assert query_vec is not None
def test_llm_endpoint_openai():
filename = os.path.join(llm_config_dir, "gpt-4.json")
run_llm_endpoint(filename)
def test_embedding_endpoint_openai():
filename = os.path.join(embedding_config_dir, "text-embedding-ada-002.json")
run_embedding_endpoint(filename)
def test_llm_endpoint_letta_hosted():
filename = os.path.join(llm_config_dir, "letta-hosted.json")
run_llm_endpoint(filename)
def test_embedding_endpoint_letta_hosted():
filename = os.path.join(embedding_config_dir, "letta-hosted.json")
run_embedding_endpoint(filename)
def test_embedding_endpoint_local():
filename = os.path.join(embedding_config_dir, "local.json")
run_embedding_endpoint(filename)
def test_llm_endpoint_ollama():
filename = os.path.join(llm_config_dir, "ollama.json")
run_llm_endpoint(filename)
def test_embedding_endpoint_ollama():
filename = os.path.join(embedding_config_dir, "ollama.json")
run_embedding_endpoint(filename)
def test_llm_endpoint_anthropic():
filename = os.path.join(llm_config_dir, "anthropic.json")
run_llm_endpoint(filename)
def test_llm_endpoint_groq():
filename = os.path.join(llm_config_dir, "groq.json")
run_llm_endpoint(filename)