chore: migrate tests to new client (#1290)

This commit is contained in:
Sarah Wooders
2025-03-14 15:17:28 -07:00
committed by GitHub
parent 40c70b46cf
commit 56679d2cea
3 changed files with 395 additions and 293 deletions

View File

@@ -57,7 +57,8 @@ jobs:
run: |
pipx install poetry==1.8.2
poetry install -E dev -E postgres
poetry run pytest -s tests/test_client_legacy.py
poetry run pytest -s tests/test_client.py
# poetry run pytest -s tests/test_client_legacy.py
- name: Print docker logs if tests fail
if: failure()

View File

@@ -8,18 +8,11 @@ from typing import List, Union
import pytest
from dotenv import load_dotenv
from letta_client import AgentState, JobStatus, Letta, MessageCreate, MessageRole
from letta_client.core.api_error import ApiError
from sqlalchemy import delete
from letta import LocalClient, RESTClient, create_client
from letta.orm import SandboxConfig, SandboxEnvironmentVariable
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.job import JobStatus
from letta.schemas.letta_message import ToolReturnMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxType
from letta.utils import create_random_username
# Constants
SERVER_PORT = 8283
@@ -42,63 +35,68 @@ def run_server():
@pytest.fixture(
params=[
{"server": False},
{"server": True},
], # whether to use REST API server
# params=[{"server": False}], # whether to use REST API server
scope="module",
)
def client(request):
if request.param["server"]:
# Get URL from environment or start server
server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:{SERVER_PORT}")
if not os.getenv("LETTA_SERVER_URL"):
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
time.sleep(5)
print("Running client tests with server:", server_url)
client = create_client(base_url=server_url, token=None)
else:
client = create_client()
# Get URL from environment or start server
server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:{SERVER_PORT}")
if not os.getenv("LETTA_SERVER_URL"):
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
time.sleep(5)
print("Running client tests with server:", server_url)
client.set_default_llm_config(LLMConfig.default_config("gpt-4"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
yield client
# create the Letta client
yield Letta(base_url=server_url, token=None)
# Fixture for test agent
@pytest.fixture(scope="module")
def agent(client: Union[LocalClient, RESTClient]):
agent_state = client.create_agent(name=f"test_client_{str(uuid.uuid4())}")
def agent(client: Letta):
agent_state = client.agents.create(
name="test_client",
memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}],
model="letta/letta-free",
embedding="letta/letta-free",
)
yield agent_state
# delete agent
client.delete_agent(agent_state.id)
client.agents.delete(agent_state.id)
# Fixture for test agent
@pytest.fixture
def search_agent_one(client: Union[LocalClient, RESTClient]):
agent_state = client.create_agent(name="Search Agent One")
def search_agent_one(client: Letta):
agent_state = client.agents.create(
name="Search Agent One",
memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}],
model="letta/letta-free",
embedding="letta/letta-free",
)
yield agent_state
# delete agent
client.delete_agent(agent_state.id)
client.agents.delete(agent_state.id)
# Fixture for test agent
@pytest.fixture
def search_agent_two(client: Union[LocalClient, RESTClient]):
agent_state = client.create_agent(name="Search Agent Two")
def search_agent_two(client: Letta):
agent_state = client.agents.create(
name="Search Agent Two",
memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}],
model="letta/letta-free",
embedding="letta/letta-free",
)
yield agent_state
# delete agent
client.delete_agent(agent_state.id)
client.agents.delete(agent_state.id)
@pytest.fixture(autouse=True)
@@ -112,55 +110,56 @@ def clear_tables():
session.commit()
def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]):
"""
Test sandbox config and environment variable functions for both LocalClient and RESTClient.
"""
# 1. Create a sandbox config
local_config = LocalSandboxConfig(sandbox_dir=SANDBOX_DIR)
sandbox_config = client.create_sandbox_config(config=local_config)
# Assert the created sandbox config
assert sandbox_config.id is not None
assert sandbox_config.type == SandboxType.LOCAL
# 2. Update the sandbox config
updated_config = LocalSandboxConfig(sandbox_dir=UPDATED_SANDBOX_DIR)
sandbox_config = client.update_sandbox_config(sandbox_config_id=sandbox_config.id, config=updated_config)
assert sandbox_config.config["sandbox_dir"] == UPDATED_SANDBOX_DIR
# 3. List all sandbox configs
sandbox_configs = client.list_sandbox_configs(limit=10)
assert isinstance(sandbox_configs, List)
assert len(sandbox_configs) == 1
assert sandbox_configs[0].id == sandbox_config.id
# 4. Create an environment variable
env_var = client.create_sandbox_env_var(
sandbox_config_id=sandbox_config.id, key=ENV_VAR_KEY, value=ENV_VAR_VALUE, description=ENV_VAR_DESCRIPTION
)
assert env_var.id is not None
assert env_var.key == ENV_VAR_KEY
assert env_var.value == ENV_VAR_VALUE
assert env_var.description == ENV_VAR_DESCRIPTION
# 5. Update the environment variable
updated_env_var = client.update_sandbox_env_var(env_var_id=env_var.id, key=UPDATED_ENV_VAR_KEY, value=UPDATED_ENV_VAR_VALUE)
assert updated_env_var.key == UPDATED_ENV_VAR_KEY
assert updated_env_var.value == UPDATED_ENV_VAR_VALUE
# 6. List environment variables
env_vars = client.list_sandbox_env_vars(sandbox_config_id=sandbox_config.id)
assert isinstance(env_vars, List)
assert len(env_vars) == 1
assert env_vars[0].key == UPDATED_ENV_VAR_KEY
# 7. Delete the environment variable
client.delete_sandbox_env_var(env_var_id=env_var.id)
# 8. Delete the sandbox config
client.delete_sandbox_config(sandbox_config_id=sandbox_config.id)
# TODO: add back
# def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]):
# """
# Test sandbox config and environment variable functions for both LocalClient and RESTClient.
# """
#
# # 1. Create a sandbox config
# local_config = LocalSandboxConfig(sandbox_dir=SANDBOX_DIR)
# sandbox_config = client.create_sandbox_config(config=local_config)
#
# # Assert the created sandbox config
# assert sandbox_config.id is not None
# assert sandbox_config.type == SandboxType.LOCAL
#
# # 2. Update the sandbox config
# updated_config = LocalSandboxConfig(sandbox_dir=UPDATED_SANDBOX_DIR)
# sandbox_config = client.update_sandbox_config(sandbox_config_id=sandbox_config.id, config=updated_config)
# assert sandbox_config.config["sandbox_dir"] == UPDATED_SANDBOX_DIR
#
# # 3. List all sandbox configs
# sandbox_configs = client.list_sandbox_configs(limit=10)
# assert isinstance(sandbox_configs, List)
# assert len(sandbox_configs) == 1
# assert sandbox_configs[0].id == sandbox_config.id
#
# # 4. Create an environment variable
# env_var = client.create_sandbox_env_var(
# sandbox_config_id=sandbox_config.id, key=ENV_VAR_KEY, value=ENV_VAR_VALUE, description=ENV_VAR_DESCRIPTION
# )
# assert env_var.id is not None
# assert env_var.key == ENV_VAR_KEY
# assert env_var.value == ENV_VAR_VALUE
# assert env_var.description == ENV_VAR_DESCRIPTION
#
# # 5. Update the environment variable
# updated_env_var = client.update_sandbox_env_var(env_var_id=env_var.id, key=UPDATED_ENV_VAR_KEY, value=UPDATED_ENV_VAR_VALUE)
# assert updated_env_var.key == UPDATED_ENV_VAR_KEY
# assert updated_env_var.value == UPDATED_ENV_VAR_VALUE
#
# # 6. List environment variables
# env_vars = client.list_sandbox_env_vars(sandbox_config_id=sandbox_config.id)
# assert isinstance(env_vars, List)
# assert len(env_vars) == 1
# assert env_vars[0].key == UPDATED_ENV_VAR_KEY
#
# # 7. Delete the environment variable
# client.delete_sandbox_env_var(env_var_id=env_var.id)
#
# # 8. Delete the sandbox config
# client.delete_sandbox_config(sandbox_config_id=sandbox_config.id)
# --------------------------------------------------------------------------------------------------------------------
@@ -168,197 +167,186 @@ def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]
# --------------------------------------------------------------------------------------------------------------------
def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient]):
def test_add_and_manage_tags_for_agent(client: Letta):
"""
Comprehensive happy path test for adding, retrieving, and managing tags on an agent.
"""
tags_to_add = ["test_tag_1", "test_tag_2", "test_tag_3"]
# Step 0: create an agent with no tags
agent = client.create_agent()
agent = client.agents.create(memory_blocks=[], model="letta/letta-free", embedding="letta/letta-free")
assert len(agent.tags) == 0
# Step 1: Add multiple tags to the agent
client.update_agent(agent_id=agent.id, tags=tags_to_add)
client.agents.modify(agent_id=agent.id, tags=tags_to_add)
# Step 2: Retrieve tags for the agent and verify they match the added tags
retrieved_tags = client.get_agent(agent_id=agent.id).tags
retrieved_tags = client.agents.retrieve(agent_id=agent.id).tags
assert set(retrieved_tags) == set(tags_to_add), f"Expected tags {tags_to_add}, but got {retrieved_tags}"
# Step 3: Retrieve agents by each tag to ensure the agent is associated correctly
for tag in tags_to_add:
agents_with_tag = client.list_agents(tags=[tag])
agents_with_tag = client.agents.list(tags=[tag])
assert agent.id in [a.id for a in agents_with_tag], f"Expected agent {agent.id} to be associated with tag '{tag}'"
# Step 4: Delete a specific tag from the agent and verify its removal
tag_to_delete = tags_to_add.pop()
client.update_agent(agent_id=agent.id, tags=tags_to_add)
client.agents.modify(agent_id=agent.id, tags=tags_to_add)
# Verify the tag is removed from the agent's tags
remaining_tags = client.get_agent(agent_id=agent.id).tags
remaining_tags = client.agents.retrieve(agent_id=agent.id).tags
assert tag_to_delete not in remaining_tags, f"Tag '{tag_to_delete}' was not removed as expected"
assert set(remaining_tags) == set(tags_to_add), f"Expected remaining tags to be {tags_to_add[1:]}, but got {remaining_tags}"
# Step 5: Delete all remaining tags from the agent
client.update_agent(agent_id=agent.id, tags=[])
client.agents.modify(agent_id=agent.id, tags=[])
# Verify all tags are removed
final_tags = client.get_agent(agent_id=agent.id).tags
final_tags = client.agents.retrieve(agent_id=agent.id).tags
assert len(final_tags) == 0, f"Expected no tags, but found {final_tags}"
# Remove agent
client.delete_agent(agent.id)
client.agents.delete(agent.id)
def test_agent_tags(client: Union[LocalClient, RESTClient]):
def test_agent_tags(client: Letta):
"""Test creating agents with tags and retrieving tags via the API."""
if not isinstance(client, RESTClient):
pytest.skip("This test only runs when the server is enabled")
# Create multiple agents with different tags
agent1 = client.create_agent(
agent1 = client.agents.create(
name=f"test_agent_{str(uuid.uuid4())}",
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
tags=["test", "agent1", "production"],
model="letta/letta-free",
embedding="letta/letta-free",
)
agent2 = client.create_agent(
agent2 = client.agents.create(
name=f"test_agent_{str(uuid.uuid4())}",
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
tags=["test", "agent2", "development"],
model="letta/letta-free",
embedding="letta/letta-free",
)
agent3 = client.create_agent(
agent3 = client.agents.create(
name=f"test_agent_{str(uuid.uuid4())}",
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
tags=["test", "agent3", "production"],
model="letta/letta-free",
embedding="letta/letta-free",
)
# Test getting all tags
all_tags = client.get_tags()
all_tags = client.tag.list_tags()
expected_tags = ["agent1", "agent2", "agent3", "development", "production", "test"]
assert sorted(all_tags) == expected_tags
# Test pagination
paginated_tags = client.get_tags(limit=2)
paginated_tags = client.tag.list_tags(limit=2)
assert len(paginated_tags) == 2
assert paginated_tags[0] == "agent1"
assert paginated_tags[1] == "agent2"
# Test pagination with cursor
next_page_tags = client.get_tags(after="agent2", limit=2)
next_page_tags = client.tag.list_tags(after="agent2", limit=2)
assert len(next_page_tags) == 2
assert next_page_tags[0] == "agent3"
assert next_page_tags[1] == "development"
# Test text search
prod_tags = client.get_tags(query_text="prod")
prod_tags = client.tag.list_tags(query_text="prod")
assert sorted(prod_tags) == ["production"]
dev_tags = client.get_tags(query_text="dev")
dev_tags = client.tag.list_tags(query_text="dev")
assert sorted(dev_tags) == ["development"]
agent_tags = client.get_tags(query_text="agent")
agent_tags = client.tag.list_tags(query_text="agent")
assert sorted(agent_tags) == ["agent1", "agent2", "agent3"]
# Remove agents
client.delete_agent(agent1.id)
client.delete_agent(agent2.id)
client.delete_agent(agent3.id)
client.agents.delete(agent1.id)
client.agents.delete(agent2.id)
client.agents.delete(agent3.id)
# --------------------------------------------------------------------------------------------------------------------
# Agent memory blocks
# --------------------------------------------------------------------------------------------------------------------
def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient]):
# _reset_config()
def test_shared_blocks(mock_e2b_api_key_none, client: Letta):
# create a block
block = client.create_block(label="human", value="username: sarah")
block = client.blocks.create(label="human", value="username: sarah")
# create agents with shared block
from letta.schemas.block import Block
from letta.schemas.memory import BasicBlockMemory
# persona1_block = client.create_block(label="persona", value="you are agent 1")
# persona2_block = client.create_block(label="persona", value="you are agent 2")
# create agents
agent_state1 = client.create_agent(
name="agent1", memory=BasicBlockMemory([Block(label="persona", value="you are agent 1")]), block_ids=[block.id]
agent_state1 = client.agents.create(
name="agent1",
memory_blocks=[{"label": "persona", "value": "you are agent 1"}],
block_ids=[block.id],
model="letta/letta-free",
embedding="letta/letta-free",
)
agent_state2 = client.create_agent(
name="agent2", memory=BasicBlockMemory([Block(label="persona", value="you are agent 2")]), block_ids=[block.id]
agent_state2 = client.agents.create(
name="agent2",
memory_blocks=[{"label": "persona", "value": "you are agent 2"}],
block_ids=[block.id],
model="letta/letta-free",
embedding="letta/letta-free",
)
## attach shared block to both agents
# client.link_agent_memory_block(agent_state1.id, block.id)
# client.link_agent_memory_block(agent_state2.id, block.id)
# update memory
client.user_message(agent_id=agent_state1.id, message="my name is actually charles")
client.agents.messages.create(agent_id=agent_state1.id, messages=[{"role": "user", "content": "my name is actually charles"}])
# check agent 2 memory
assert "charles" in client.get_block(block.id).value.lower(), f"Shared block update failed {client.get_block(block.id).value}"
client.user_message(agent_id=agent_state2.id, message="whats my name?")
assert (
"charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower()
), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}"
assert "charles" in client.agents.blocks.retrieve(agent_id=agent_state2.id, block_label="human").value.lower()
# cleanup
client.delete_agent(agent_state1.id)
client.delete_agent(agent_state2.id)
client.agents.delete(agent_state1.id)
client.agents.delete(agent_state2.id)
def test_update_agent_memory_label(client: Union[LocalClient, RESTClient], agent: AgentState):
def test_update_agent_memory_label(client: Letta):
"""Test that we can update the label of a block in an agent's memory"""
agent = client.create_agent(name=create_random_username())
agent = client.agents.create(model="letta/letta-free", embedding="letta/letta-free", memory_blocks=[{"label": "human", "value": ""}])
try:
current_labels = agent.memory.list_block_labels()
current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)]
example_label = current_labels[0]
example_new_label = "example_new_label"
assert example_new_label not in current_labels
assert example_new_label not in [b.label for b in client.agents.blocks.list(agent_id=agent.id)]
client.update_agent_memory_block_label(agent_id=agent.id, current_label=example_label, new_label=example_new_label)
client.agents.blocks.modify(agent_id=agent.id, block_label=example_label, label=example_new_label)
updated_agent = client.get_agent(agent_id=agent.id)
assert example_new_label in updated_agent.memory.list_block_labels()
updated_blocks = client.agents.blocks.list(agent_id=agent.id)
assert example_new_label in [b.label for b in updated_blocks]
finally:
client.delete_agent(agent.id)
client.agents.delete(agent.id)
def test_attach_detach_agent_memory_block(client: Union[LocalClient, RESTClient], agent: AgentState):
def test_attach_detach_agent_memory_block(client: Letta, agent: AgentState):
"""Test that we can add and remove a block from an agent's memory"""
current_labels = agent.memory.list_block_labels()
current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)]
example_new_label = current_labels[0] + "_v2"
example_new_value = "example value"
assert example_new_label not in current_labels
# Link a new memory block
block = client.create_block(
block = client.blocks.create(
label=example_new_label,
value=example_new_value,
limit=1000,
)
updated_agent = client.attach_block(
updated_agent = client.agents.blocks.attach(
agent_id=agent.id,
block_id=block.id,
)
assert example_new_label in updated_agent.memory.list_block_labels()
assert example_new_label in [block.label for block in client.agents.blocks.list(agent_id=updated_agent.id)]
# Now unlink the block
updated_agent = client.detach_block(
updated_agent = client.agents.blocks.detach(
agent_id=agent.id,
block_id=block.id,
)
assert example_new_label not in updated_agent.memory.list_block_labels()
assert example_new_label not in [block.label for block in client.agents.blocks.list(agent_id=updated_agent.id)]
# def test_core_memory_token_limits(client: Union[LocalClient, RESTClient], agent: AgentState):
@@ -385,39 +373,57 @@ def test_attach_detach_agent_memory_block(client: Union[LocalClient, RESTClient]
# client.delete_agent(new_agent.id)
def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient]):
def test_update_agent_memory_limit(client: Letta):
"""Test that we can update the limit of a block in an agent's memory"""
agent = client.create_agent()
agent = client.agents.create(
model="letta/letta-free",
embedding="letta/letta-free",
memory_blocks=[
{"label": "human", "value": "username: sarah", "limit": 1000},
{"label": "persona", "value": "you are sarah", "limit": 1000},
],
)
current_labels = agent.memory.list_block_labels()
current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)]
example_label = current_labels[0]
example_new_limit = 1
current_block = agent.memory.get_block(label=example_label)
current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)]
example_label = current_labels[0]
example_new_limit = 1
current_block = client.agents.blocks.retrieve(agent_id=agent.id, block_label=example_label)
current_block_length = len(current_block.value)
assert example_new_limit != agent.memory.get_block(label=example_label).limit
assert example_new_limit != current_block.limit
assert example_new_limit < current_block_length
# We expect this to throw a value error
with pytest.raises(ValueError):
client.update_agent_memory_block(agent_id=agent.id, label=example_label, limit=example_new_limit)
with pytest.raises(ApiError):
client.agents.blocks.modify(
agent_id=agent.id,
block_label=example_label,
limit=example_new_limit,
)
# Now try the same thing with a higher limit
example_new_limit = current_block_length + 10000
assert example_new_limit > current_block_length
client.update_agent_memory_block(agent_id=agent.id, label=example_label, limit=example_new_limit)
client.agents.blocks.modify(
agent_id=agent.id,
block_label=example_label,
limit=example_new_limit,
)
updated_agent = client.get_agent(agent_id=agent.id)
assert example_new_limit == updated_agent.memory.get_block(label=example_label).limit
assert example_new_limit == client.agents.blocks.retrieve(agent_id=agent.id, block_label=example_label).limit
client.delete_agent(agent.id)
client.agents.delete(agent.id)
# --------------------------------------------------------------------------------------------------------------------
# Agent Tools
# --------------------------------------------------------------------------------------------------------------------
def test_function_return_limit(client: Union[LocalClient, RESTClient]):
def test_function_return_limit(client: Letta):
"""Test to see if the function return limit works"""
def big_return():
@@ -430,15 +436,21 @@ def test_function_return_limit(client: Union[LocalClient, RESTClient]):
return "x" * 100000
padding = len("[NOTE: function output was truncated since it exceeded the character limit (100000 > 1000)]") + 50
tool = client.create_or_update_tool(func=big_return, return_char_limit=1000)
agent = client.create_agent(tool_ids=[tool.id])
tool = client.tools.upsert_from_function(func=big_return, return_char_limit=1000)
agent = client.agents.create(
model="letta/letta-free",
embedding="letta/letta-free",
tool_ids=[tool.id],
)
# get function response
response = client.send_message(agent_id=agent.id, message="call the big_return function", role="user")
response = client.agents.messages.create(
agent_id=agent.id, messages=[MessageCreate(role="user", content="call the big_return function")]
)
print(response.messages)
response_message = None
for message in response.messages:
if isinstance(message, ToolReturnMessage):
if message.message_type == "tool_return_message":
response_message = message
break
@@ -452,44 +464,58 @@ def test_function_return_limit(client: Union[LocalClient, RESTClient]):
# len(res_json["message"]) <= 1000 + padding
# ), f"Expected length to be less than or equal to 1000 + {padding}, but got {len(res_json['message'])}"
client.delete_agent(agent_id=agent.id)
client.agents.delete(agent_id=agent.id)
def test_function_always_error(client: Union[LocalClient, RESTClient]):
def test_function_always_error(client: Letta):
"""Test to see if function that errors works correctly"""
def testing_method():
"""
Always throw an error.
Call this tool when the user asks
"""
return 5 / 0
tool = client.create_or_update_tool(func=testing_method)
agent = client.create_agent(tool_ids=[tool.id])
tool = client.tools.upsert_from_function(func=testing_method)
agent = client.agents.create(
model="letta/letta-free",
embedding="letta/letta-free",
memory_blocks=[
{
"label": "human",
"value": "username: sarah",
},
{
"label": "persona",
"value": "you are sarah",
},
],
tool_ids=[tool.id],
)
print("AGENT TOOLS", [tool.name for tool in agent.tools])
# get function response
response = client.send_message(agent_id=agent.id, message="call the testing_method function and tell me the result", role="user")
response = client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="call the testing_method function and tell me the result")],
)
print(response.messages)
response_message = None
for message in response.messages:
if isinstance(message, ToolReturnMessage):
if message.message_type == "tool_return_message":
response_message = message
break
assert response_message, "ToolReturnMessage message not found in response"
assert response_message.status == "error"
assert (
response_message.tool_return == "Error executing function testing_method: ZeroDivisionError: division by zero"
), response_message.tool_return
if isinstance(client, RESTClient):
assert response_message.tool_return == "Error executing function testing_method: ZeroDivisionError: division by zero"
else:
response_json = json.loads(response_message.tool_return)
assert response_json["status"] == "Failed"
assert response_json["message"] == "Error executing function testing_method: ZeroDivisionError: division by zero"
client.delete_agent(agent_id=agent.id)
client.agents.delete(agent_id=agent.id)
def test_attach_detach_agent_tool(client: Union[LocalClient, RESTClient], agent: AgentState):
def test_attach_detach_agent_tool(client: Letta, agent: AgentState):
"""Test that we can attach and detach a tool from an agent"""
try:
@@ -506,64 +532,64 @@ def test_attach_detach_agent_tool(client: Union[LocalClient, RESTClient], agent:
"""
return x * 2
tool = client.create_or_update_tool(func=example_tool)
tool = client.tools.upsert_from_function(func=example_tool)
# Initially tool should not be attached
initial_tools = client.list_attached_tools(agent_id=agent.id)
initial_tools = client.agents.tools.list(agent_id=agent.id)
assert tool.id not in [t.id for t in initial_tools]
# Attach tool
new_agent_state = client.attach_tool(agent_id=agent.id, tool_id=tool.id)
new_agent_state = client.agents.tools.attach(agent_id=agent.id, tool_id=tool.id)
assert tool.id in [t.id for t in new_agent_state.tools]
# Verify tool is attached
updated_tools = client.list_attached_tools(agent_id=agent.id)
updated_tools = client.agents.tools.list(agent_id=agent.id)
assert tool.id in [t.id for t in updated_tools]
# Detach tool
new_agent_state = client.detach_tool(agent_id=agent.id, tool_id=tool.id)
new_agent_state = client.agents.tools.detach(agent_id=agent.id, tool_id=tool.id)
assert tool.id not in [t.id for t in new_agent_state.tools]
# Verify tool is detached
final_tools = client.list_attached_tools(agent_id=agent.id)
final_tools = client.agents.tools.list(agent_id=agent.id)
assert tool.id not in [t.id for t in final_tools]
finally:
client.delete_tool(tool.id)
client.tools.delete(tool.id)
# --------------------------------------------------------------------------------------------------------------------
# AgentMessages
# --------------------------------------------------------------------------------------------------------------------
def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState):
def test_messages(client: Letta, agent: AgentState):
# _reset_config()
send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
send_message_response = client.agents.messages.create(agent_id=agent.id, messages=[MessageCreate(role="user", content="Test message")])
assert send_message_response, "Sending message failed"
messages_response = client.get_messages(agent_id=agent.id, limit=1)
messages_response = client.agents.messages.list(agent_id=agent.id, limit=1)
assert len(messages_response) > 0, "Retrieving messages failed"
def test_send_system_message(client: Union[LocalClient, RESTClient], agent: AgentState):
def test_send_system_message(client: Letta, agent: AgentState):
"""Important unit test since the Letta API exposes sending system messages, but some backends don't natively support it (eg Anthropic)"""
send_system_message_response = client.send_message(
agent_id=agent.id, message="Event occurred: The user just logged off.", role="system"
send_system_message_response = client.agents.messages.create(
agent_id=agent.id, messages=[MessageCreate(role="system", content="Event occurred: The user just logged off.")]
)
assert send_system_message_response, "Sending message failed"
@pytest.mark.asyncio
async def test_send_message_parallel(client: Union[LocalClient, RESTClient], agent: AgentState, request):
async def test_send_message_parallel(client: Letta, agent: AgentState, request):
"""
Test that sending two messages in parallel does not error.
"""
if not isinstance(client, RESTClient):
pytest.skip("This test only runs when the server is enabled")
# Define a coroutine for sending a message using asyncio.to_thread for synchronous calls
async def send_message_task(message: str):
response = await asyncio.to_thread(client.send_message, agent_id=agent.id, message=message, role="user")
response = await asyncio.to_thread(
client.agents.messages.create, agent_id=agent.id, messages=[MessageCreate(role="user", content=message)]
)
assert response, f"Sending message '{message}' failed"
return response
@@ -585,76 +611,31 @@ async def test_send_message_parallel(client: Union[LocalClient, RESTClient], age
assert len(responses) == len(messages), "Not all messages were processed"
def test_send_message_async(client: Union[LocalClient, RESTClient], agent: AgentState):
"""
Test that we can send a message asynchronously and retrieve the messages, along with usage statistics
"""
if not isinstance(client, RESTClient):
pytest.skip("send_message_async is only supported by the RESTClient")
print("Sending message asynchronously")
test_message = "This is a test message, respond to the user with a sentence."
run = client.send_message_async(agent_id=agent.id, role="user", message=test_message)
assert run.id is not None
assert run.status == JobStatus.created
print(f"Run created, run={run}, status={run.status}")
# Wait for the job to complete, cancel it if takes over 10 seconds
start_time = time.time()
while run.status == JobStatus.created:
time.sleep(1)
run = client.get_run(run_id=run.id)
print(f"Run status: {run.status}")
if time.time() - start_time > 10:
pytest.fail("Run took too long to complete")
print(f"Run completed in {time.time() - start_time} seconds, run={run}")
assert run.status == JobStatus.completed
# Get messages for the job
messages = client.get_run_messages(run_id=run.id)
assert len(messages) >= 2 # At least assistant response
# Check filters
assistant_messages = client.get_run_messages(run_id=run.id, role=MessageRole.assistant)
assert len(assistant_messages) > 0
tool_messages = client.get_run_messages(run_id=run.id, role=MessageRole.tool)
assert len(tool_messages) > 0
# Get and verify usage statistics
usage = client.get_run_usage(run_id=run.id)[0]
assert usage.completion_tokens >= 0
assert usage.prompt_tokens >= 0
assert usage.total_tokens >= 0
assert usage.total_tokens == usage.completion_tokens + usage.prompt_tokens
# ----------------------------------------------------------------------------------------------------
# Agent listing
# ----------------------------------------------------------------------------------------------------
def test_agent_listing(client: Union[LocalClient, RESTClient], agent, search_agent_one, search_agent_two):
def test_agent_listing(client: Letta, agent, search_agent_one, search_agent_two):
"""Test listing agents with pagination and query text filtering."""
# Test query text filtering
search_results = client.list_agents(query_text="search agent")
search_results = client.agents.list(query_text="search agent")
assert len(search_results) == 2
search_agent_ids = {agent.id for agent in search_results}
assert search_agent_one.id in search_agent_ids
assert search_agent_two.id in search_agent_ids
assert agent.id not in search_agent_ids
different_results = client.list_agents(query_text="client")
different_results = client.agents.list(query_text="client")
assert len(different_results) == 1
assert different_results[0].id == agent.id
# Test pagination
first_page = client.list_agents(query_text="search agent", limit=1)
first_page = client.agents.list(query_text="search agent", limit=1)
assert len(first_page) == 1
first_agent = first_page[0]
second_page = client.list_agents(query_text="search agent", after=first_agent.id, limit=1) # Use agent ID as cursor
second_page = client.agents.list(query_text="search agent", after=first_agent.id, limit=1) # Use agent ID as cursor
assert len(second_page) == 1
assert second_page[0].id != first_agent.id
@@ -664,20 +645,16 @@ def test_agent_listing(client: Union[LocalClient, RESTClient], agent, search_age
assert all_ids == {search_agent_one.id, search_agent_two.id}
# Test listing without any filters
all_agents = client.list_agents()
all_agents = client.agents.list()
assert len(all_agents) == 3
assert all(agent.id in {a.id for a in all_agents} for agent in [search_agent_one, search_agent_two, agent])
def test_agent_creation(client: Union[LocalClient, RESTClient]):
def test_agent_creation(client: Letta):
"""Test that block IDs are properly attached when creating an agent."""
if not isinstance(client, RESTClient):
pytest.skip("This test only runs when the server is enabled")
from letta import BasicBlockMemory
# Create a test block that will represent user preferences
user_preferences_block = client.create_block(label="user_preferences", value="", limit=10000)
user_preferences_block = client.blocks.create(label="user_preferences", value="", limit=10000)
# Create test tools
def test_tool():
@@ -688,73 +665,82 @@ def test_agent_creation(client: Union[LocalClient, RESTClient]):
"""Another test tool."""
return "Hello from another test tool!"
tool1 = client.create_or_update_tool(func=test_tool, tags=["test"])
tool2 = client.create_or_update_tool(func=another_test_tool, tags=["test"])
# Create test blocks
offline_persona_block = client.create_block(label="persona", value="persona description", limit=5000)
mindy_block = client.create_block(label="mindy", value="Mindy is a helpful assistant", limit=5000)
memory_blocks = BasicBlockMemory(blocks=[offline_persona_block, mindy_block])
tool1 = client.tools.upsert_from_function(func=test_tool, tags=["test"])
tool2 = client.tools.upsert_from_function(func=another_test_tool, tags=["test"])
# Create agent with the blocks and tools
agent = client.create_agent(
name=f"test_agent_{str(uuid.uuid4())}",
memory=memory_blocks,
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
agent = client.agents.create(
memory_blocks=[
{
"label": "human",
"value": "you are a human",
},
{"label": "persona", "value": "you are an assistant"},
],
model="letta/letta-free",
embedding="letta/letta-free",
tool_ids=[tool1.id, tool2.id],
include_base_tools=False,
tags=["test"],
block_ids=[user_preferences_block.id],
)
memory_blocks = agent.memory.blocks
# Verify the agent was created successfully
assert agent is not None
assert agent.id is not None
# Verify the blocks are properly attached
agent_blocks = client.list_agent_memory_blocks(agent.id)
agent_blocks = client.agents.blocks.list(agent_id=agent.id)
agent_block_ids = {block.id for block in agent_blocks}
# Check that all memory blocks are present
memory_block_ids = {block.id for block in memory_blocks.blocks}
for block_id in memory_block_ids | {user_preferences_block.id}:
assert block_id in agent_block_ids
memory_block_ids = {block.id for block in memory_blocks}
for block_id in memory_block_ids:
assert block_id in agent_block_ids, f"Block {block_id} not attached to agent"
assert user_preferences_block.id in agent_block_ids, f"User preferences block {user_preferences_block.id} not attached to agent"
# Verify the tools are properly attached
agent_tools = client.get_tools_from_agent(agent.id)
agent_tools = client.agents.tools.list(agent_id=agent.id)
assert len(agent_tools) == 2
tool_ids = {tool1.id, tool2.id}
assert all(tool.id in tool_ids for tool in agent_tools)
client.delete_agent(agent_id=agent.id)
client.agents.delete(agent_id=agent.id)
# --------------------------------------------------------------------------------------------------------------------
# Agent sources
# --------------------------------------------------------------------------------------------------------------------
def test_attach_detach_agent_source(client: Union[LocalClient, RESTClient], agent: AgentState):
def test_attach_detach_agent_source(client: Letta, agent: AgentState):
"""Test that we can attach and detach a source from an agent"""
# Create a source
source = client.create_source(
source = client.sources.create(
name="test_source",
embedding_config={ # TODO: change this
"embedding_endpoint": "https://embeddings.memgpt.ai",
"embedding_model": "BAAI/bge-large-en-v1.5",
"embedding_dim": 1024,
"embedding_chunk_size": 300,
"embedding_endpoint_type": "hugging-face",
},
)
initial_sources = client.list_attached_sources(agent_id=agent.id)
initial_sources = client.agents.sources.list(agent_id=agent.id)
assert source.id not in [s.id for s in initial_sources]
# Attach source
client.attach_source(agent_id=agent.id, source_id=source.id)
client.agents.sources.attach(agent_id=agent.id, source_id=source.id)
# Verify source is attached
final_sources = client.list_attached_sources(agent_id=agent.id)
final_sources = client.agents.sources.list(agent_id=agent.id)
assert source.id in [s.id for s in final_sources]
# Detach source
client.detach_source(agent_id=agent.id, source_id=source.id)
client.agents.sources.detach(agent_id=agent.id, source_id=source.id)
# Verify source is detached
final_sources = client.list_attached_sources(agent_id=agent.id)
final_sources = client.agents.sources.list(agent_id=agent.id)
assert source.id not in [s.id for s in final_sources]
client.delete_source(source.id)
client.sources.delete(source.id)

115
tests/test_streaming.py Normal file
View File

@@ -0,0 +1,115 @@
import os
import threading
import time
import pytest
from dotenv import load_dotenv
from letta_client import AgentState, Letta, LlmConfig, MessageCreate
from letta_client.core.api_error import ApiError
from pytest import fixture
def run_server():
load_dotenv()
from letta.server.rest_api.app import start_server
print("Starting server...")
start_server(debug=True)
@pytest.fixture(
scope="module",
)
def client(request):
# Get URL from environment or start server
server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
time.sleep(5)
print("Running client tests with server:", server_url)
# create the Letta client
yield Letta(base_url=server_url, token=None)
# Fixture for test agent
@pytest.fixture(scope="module")
def agent(client: Letta):
agent_state = client.agents.create(
name="test_client",
memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}],
model="letta/letta-free",
embedding="letta/letta-free",
)
yield agent_state
# delete agent
client.agents.delete(agent_state.id)
@pytest.mark.parametrize(
"stream_tokens,model",
[
(True, "openai/gpt-4o-mini"),
(True, "anthropic/claude-3-sonnet-20240229"),
(False, "openai/gpt-4o-mini"),
(False, "anthropic/claude-3-sonnet-20240229"),
],
)
def test_streaming_send_message(
mock_e2b_api_key_none,
client: Letta,
agent: AgentState,
stream_tokens: bool,
model: str,
):
# Update agent's model
config = client.agents.retrieve(agent_id=agent.id).llm_config
config_dump = config.model_dump()
config_dump["model"] = model
config = LlmConfig(**config_dump)
client.agents.modify(agent_id=agent.id, llm_config=config)
# Send streaming message
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="This is a test. Repeat after me: 'banana'")],
stream_tokens=stream_tokens,
)
# Tracking variables for test validation
inner_thoughts_exist = False
inner_thoughts_count = 0
send_message_ran = False
done = False
assert response, "Sending message failed"
for chunk in response:
# Check chunk type and content based on the current client API
if hasattr(chunk, "message_type") and chunk.message_type == "reasoning_message":
inner_thoughts_exist = True
inner_thoughts_count += 1
if chunk.message_type == "tool_call_message" and hasattr(chunk, "tool_call") and chunk.tool_call.name == "send_message":
send_message_ran = True
if chunk.message_type == "assistant_message":
send_message_ran = True
if chunk.message_type == "usage_statistics":
# Validate usage statistics
assert chunk.step_count == 1
assert chunk.completion_tokens > 10
assert chunk.prompt_tokens > 1000
assert chunk.total_tokens > 1000
done = True
print(chunk)
# If stream tokens, we expect at least one inner thought
assert inner_thoughts_count >= 1, "Expected more than one inner thought"
assert inner_thoughts_exist, "No inner thoughts found"
assert send_message_ran, "send_message function call not found"
assert done, "Message stream not done"