diff --git a/.github/workflows/docker-integration-tests.yaml b/.github/workflows/docker-integration-tests.yaml index 77ddb3a0..63886ffe 100644 --- a/.github/workflows/docker-integration-tests.yaml +++ b/.github/workflows/docker-integration-tests.yaml @@ -57,7 +57,8 @@ jobs: run: | pipx install poetry==1.8.2 poetry install -E dev -E postgres - poetry run pytest -s tests/test_client_legacy.py + poetry run pytest -s tests/test_client.py +# poetry run pytest -s tests/test_client_legacy.py - name: Print docker logs if tests fail if: failure() diff --git a/tests/test_client.py b/tests/test_client.py index c53ac781..856c4227 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,18 +8,11 @@ from typing import List, Union import pytest from dotenv import load_dotenv +from letta_client import AgentState, JobStatus, Letta, MessageCreate, MessageRole +from letta_client.core.api_error import ApiError from sqlalchemy import delete -from letta import LocalClient, RESTClient, create_client from letta.orm import SandboxConfig, SandboxEnvironmentVariable -from letta.schemas.agent import AgentState -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import MessageRole -from letta.schemas.job import JobStatus -from letta.schemas.letta_message import ToolReturnMessage -from letta.schemas.llm_config import LLMConfig -from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxType -from letta.utils import create_random_username # Constants SERVER_PORT = 8283 @@ -42,63 +35,68 @@ def run_server(): @pytest.fixture( - params=[ - {"server": False}, - {"server": True}, - ], # whether to use REST API server - # params=[{"server": False}], # whether to use REST API server scope="module", ) def client(request): - if request.param["server"]: - # Get URL from environment or start server - server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:{SERVER_PORT}") - if not os.getenv("LETTA_SERVER_URL"): - print("Starting server thread") - thread = threading.Thread(target=run_server, daemon=True) - thread.start() - time.sleep(5) - print("Running client tests with server:", server_url) - client = create_client(base_url=server_url, token=None) - else: - client = create_client() + # Get URL from environment or start server + server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:{SERVER_PORT}") + if not os.getenv("LETTA_SERVER_URL"): + print("Starting server thread") + thread = threading.Thread(target=run_server, daemon=True) + thread.start() + time.sleep(5) + print("Running client tests with server:", server_url) - client.set_default_llm_config(LLMConfig.default_config("gpt-4")) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - yield client + # create the Letta client + yield Letta(base_url=server_url, token=None) # Fixture for test agent @pytest.fixture(scope="module") -def agent(client: Union[LocalClient, RESTClient]): - agent_state = client.create_agent(name=f"test_client_{str(uuid.uuid4())}") +def agent(client: Letta): + agent_state = client.agents.create( + name="test_client", + memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}], + model="letta/letta-free", + embedding="letta/letta-free", + ) yield agent_state # delete agent - client.delete_agent(agent_state.id) + client.agents.delete(agent_state.id) # Fixture for test agent @pytest.fixture -def search_agent_one(client: Union[LocalClient, RESTClient]): - agent_state = client.create_agent(name="Search Agent One") +def search_agent_one(client: Letta): + agent_state = client.agents.create( + name="Search Agent One", + memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}], + model="letta/letta-free", + embedding="letta/letta-free", + ) yield agent_state # delete agent - client.delete_agent(agent_state.id) + client.agents.delete(agent_state.id) # Fixture for test agent @pytest.fixture -def search_agent_two(client: Union[LocalClient, RESTClient]): - agent_state = client.create_agent(name="Search Agent Two") +def search_agent_two(client: Letta): + agent_state = client.agents.create( + name="Search Agent Two", + memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}], + model="letta/letta-free", + embedding="letta/letta-free", + ) yield agent_state # delete agent - client.delete_agent(agent_state.id) + client.agents.delete(agent_state.id) @pytest.fixture(autouse=True) @@ -112,55 +110,56 @@ def clear_tables(): session.commit() -def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]): - """ - Test sandbox config and environment variable functions for both LocalClient and RESTClient. - """ - - # 1. Create a sandbox config - local_config = LocalSandboxConfig(sandbox_dir=SANDBOX_DIR) - sandbox_config = client.create_sandbox_config(config=local_config) - - # Assert the created sandbox config - assert sandbox_config.id is not None - assert sandbox_config.type == SandboxType.LOCAL - - # 2. Update the sandbox config - updated_config = LocalSandboxConfig(sandbox_dir=UPDATED_SANDBOX_DIR) - sandbox_config = client.update_sandbox_config(sandbox_config_id=sandbox_config.id, config=updated_config) - assert sandbox_config.config["sandbox_dir"] == UPDATED_SANDBOX_DIR - - # 3. List all sandbox configs - sandbox_configs = client.list_sandbox_configs(limit=10) - assert isinstance(sandbox_configs, List) - assert len(sandbox_configs) == 1 - assert sandbox_configs[0].id == sandbox_config.id - - # 4. Create an environment variable - env_var = client.create_sandbox_env_var( - sandbox_config_id=sandbox_config.id, key=ENV_VAR_KEY, value=ENV_VAR_VALUE, description=ENV_VAR_DESCRIPTION - ) - assert env_var.id is not None - assert env_var.key == ENV_VAR_KEY - assert env_var.value == ENV_VAR_VALUE - assert env_var.description == ENV_VAR_DESCRIPTION - - # 5. Update the environment variable - updated_env_var = client.update_sandbox_env_var(env_var_id=env_var.id, key=UPDATED_ENV_VAR_KEY, value=UPDATED_ENV_VAR_VALUE) - assert updated_env_var.key == UPDATED_ENV_VAR_KEY - assert updated_env_var.value == UPDATED_ENV_VAR_VALUE - - # 6. List environment variables - env_vars = client.list_sandbox_env_vars(sandbox_config_id=sandbox_config.id) - assert isinstance(env_vars, List) - assert len(env_vars) == 1 - assert env_vars[0].key == UPDATED_ENV_VAR_KEY - - # 7. Delete the environment variable - client.delete_sandbox_env_var(env_var_id=env_var.id) - - # 8. Delete the sandbox config - client.delete_sandbox_config(sandbox_config_id=sandbox_config.id) +# TODO: add back +# def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]): +# """ +# Test sandbox config and environment variable functions for both LocalClient and RESTClient. +# """ +# +# # 1. Create a sandbox config +# local_config = LocalSandboxConfig(sandbox_dir=SANDBOX_DIR) +# sandbox_config = client.create_sandbox_config(config=local_config) +# +# # Assert the created sandbox config +# assert sandbox_config.id is not None +# assert sandbox_config.type == SandboxType.LOCAL +# +# # 2. Update the sandbox config +# updated_config = LocalSandboxConfig(sandbox_dir=UPDATED_SANDBOX_DIR) +# sandbox_config = client.update_sandbox_config(sandbox_config_id=sandbox_config.id, config=updated_config) +# assert sandbox_config.config["sandbox_dir"] == UPDATED_SANDBOX_DIR +# +# # 3. List all sandbox configs +# sandbox_configs = client.list_sandbox_configs(limit=10) +# assert isinstance(sandbox_configs, List) +# assert len(sandbox_configs) == 1 +# assert sandbox_configs[0].id == sandbox_config.id +# +# # 4. Create an environment variable +# env_var = client.create_sandbox_env_var( +# sandbox_config_id=sandbox_config.id, key=ENV_VAR_KEY, value=ENV_VAR_VALUE, description=ENV_VAR_DESCRIPTION +# ) +# assert env_var.id is not None +# assert env_var.key == ENV_VAR_KEY +# assert env_var.value == ENV_VAR_VALUE +# assert env_var.description == ENV_VAR_DESCRIPTION +# +# # 5. Update the environment variable +# updated_env_var = client.update_sandbox_env_var(env_var_id=env_var.id, key=UPDATED_ENV_VAR_KEY, value=UPDATED_ENV_VAR_VALUE) +# assert updated_env_var.key == UPDATED_ENV_VAR_KEY +# assert updated_env_var.value == UPDATED_ENV_VAR_VALUE +# +# # 6. List environment variables +# env_vars = client.list_sandbox_env_vars(sandbox_config_id=sandbox_config.id) +# assert isinstance(env_vars, List) +# assert len(env_vars) == 1 +# assert env_vars[0].key == UPDATED_ENV_VAR_KEY +# +# # 7. Delete the environment variable +# client.delete_sandbox_env_var(env_var_id=env_var.id) +# +# # 8. Delete the sandbox config +# client.delete_sandbox_config(sandbox_config_id=sandbox_config.id) # -------------------------------------------------------------------------------------------------------------------- @@ -168,197 +167,186 @@ def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient] # -------------------------------------------------------------------------------------------------------------------- -def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient]): +def test_add_and_manage_tags_for_agent(client: Letta): """ Comprehensive happy path test for adding, retrieving, and managing tags on an agent. """ tags_to_add = ["test_tag_1", "test_tag_2", "test_tag_3"] # Step 0: create an agent with no tags - agent = client.create_agent() + agent = client.agents.create(memory_blocks=[], model="letta/letta-free", embedding="letta/letta-free") assert len(agent.tags) == 0 # Step 1: Add multiple tags to the agent - client.update_agent(agent_id=agent.id, tags=tags_to_add) + client.agents.modify(agent_id=agent.id, tags=tags_to_add) # Step 2: Retrieve tags for the agent and verify they match the added tags - retrieved_tags = client.get_agent(agent_id=agent.id).tags + retrieved_tags = client.agents.retrieve(agent_id=agent.id).tags assert set(retrieved_tags) == set(tags_to_add), f"Expected tags {tags_to_add}, but got {retrieved_tags}" # Step 3: Retrieve agents by each tag to ensure the agent is associated correctly for tag in tags_to_add: - agents_with_tag = client.list_agents(tags=[tag]) + agents_with_tag = client.agents.list(tags=[tag]) assert agent.id in [a.id for a in agents_with_tag], f"Expected agent {agent.id} to be associated with tag '{tag}'" # Step 4: Delete a specific tag from the agent and verify its removal tag_to_delete = tags_to_add.pop() - client.update_agent(agent_id=agent.id, tags=tags_to_add) + client.agents.modify(agent_id=agent.id, tags=tags_to_add) # Verify the tag is removed from the agent's tags - remaining_tags = client.get_agent(agent_id=agent.id).tags + remaining_tags = client.agents.retrieve(agent_id=agent.id).tags assert tag_to_delete not in remaining_tags, f"Tag '{tag_to_delete}' was not removed as expected" assert set(remaining_tags) == set(tags_to_add), f"Expected remaining tags to be {tags_to_add[1:]}, but got {remaining_tags}" # Step 5: Delete all remaining tags from the agent - client.update_agent(agent_id=agent.id, tags=[]) + client.agents.modify(agent_id=agent.id, tags=[]) # Verify all tags are removed - final_tags = client.get_agent(agent_id=agent.id).tags + final_tags = client.agents.retrieve(agent_id=agent.id).tags assert len(final_tags) == 0, f"Expected no tags, but found {final_tags}" # Remove agent - client.delete_agent(agent.id) + client.agents.delete(agent.id) -def test_agent_tags(client: Union[LocalClient, RESTClient]): +def test_agent_tags(client: Letta): """Test creating agents with tags and retrieving tags via the API.""" - if not isinstance(client, RESTClient): - pytest.skip("This test only runs when the server is enabled") # Create multiple agents with different tags - agent1 = client.create_agent( + agent1 = client.agents.create( name=f"test_agent_{str(uuid.uuid4())}", - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), tags=["test", "agent1", "production"], + model="letta/letta-free", + embedding="letta/letta-free", ) - agent2 = client.create_agent( + agent2 = client.agents.create( name=f"test_agent_{str(uuid.uuid4())}", - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), tags=["test", "agent2", "development"], + model="letta/letta-free", + embedding="letta/letta-free", ) - agent3 = client.create_agent( + agent3 = client.agents.create( name=f"test_agent_{str(uuid.uuid4())}", - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), tags=["test", "agent3", "production"], + model="letta/letta-free", + embedding="letta/letta-free", ) # Test getting all tags - all_tags = client.get_tags() + all_tags = client.tag.list_tags() expected_tags = ["agent1", "agent2", "agent3", "development", "production", "test"] assert sorted(all_tags) == expected_tags # Test pagination - paginated_tags = client.get_tags(limit=2) + paginated_tags = client.tag.list_tags(limit=2) assert len(paginated_tags) == 2 assert paginated_tags[0] == "agent1" assert paginated_tags[1] == "agent2" # Test pagination with cursor - next_page_tags = client.get_tags(after="agent2", limit=2) + next_page_tags = client.tag.list_tags(after="agent2", limit=2) assert len(next_page_tags) == 2 assert next_page_tags[0] == "agent3" assert next_page_tags[1] == "development" # Test text search - prod_tags = client.get_tags(query_text="prod") + prod_tags = client.tag.list_tags(query_text="prod") assert sorted(prod_tags) == ["production"] - dev_tags = client.get_tags(query_text="dev") + dev_tags = client.tag.list_tags(query_text="dev") assert sorted(dev_tags) == ["development"] - agent_tags = client.get_tags(query_text="agent") + agent_tags = client.tag.list_tags(query_text="agent") assert sorted(agent_tags) == ["agent1", "agent2", "agent3"] # Remove agents - client.delete_agent(agent1.id) - client.delete_agent(agent2.id) - client.delete_agent(agent3.id) + client.agents.delete(agent1.id) + client.agents.delete(agent2.id) + client.agents.delete(agent3.id) # -------------------------------------------------------------------------------------------------------------------- # Agent memory blocks # -------------------------------------------------------------------------------------------------------------------- -def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient]): - # _reset_config() - +def test_shared_blocks(mock_e2b_api_key_none, client: Letta): # create a block - block = client.create_block(label="human", value="username: sarah") + block = client.blocks.create(label="human", value="username: sarah") # create agents with shared block - from letta.schemas.block import Block - from letta.schemas.memory import BasicBlockMemory - - # persona1_block = client.create_block(label="persona", value="you are agent 1") - # persona2_block = client.create_block(label="persona", value="you are agent 2") - # create agents - agent_state1 = client.create_agent( - name="agent1", memory=BasicBlockMemory([Block(label="persona", value="you are agent 1")]), block_ids=[block.id] + agent_state1 = client.agents.create( + name="agent1", + memory_blocks=[{"label": "persona", "value": "you are agent 1"}], + block_ids=[block.id], + model="letta/letta-free", + embedding="letta/letta-free", ) - agent_state2 = client.create_agent( - name="agent2", memory=BasicBlockMemory([Block(label="persona", value="you are agent 2")]), block_ids=[block.id] + agent_state2 = client.agents.create( + name="agent2", + memory_blocks=[{"label": "persona", "value": "you are agent 2"}], + block_ids=[block.id], + model="letta/letta-free", + embedding="letta/letta-free", ) - ## attach shared block to both agents - # client.link_agent_memory_block(agent_state1.id, block.id) - # client.link_agent_memory_block(agent_state2.id, block.id) - # update memory - client.user_message(agent_id=agent_state1.id, message="my name is actually charles") + client.agents.messages.create(agent_id=agent_state1.id, messages=[{"role": "user", "content": "my name is actually charles"}]) # check agent 2 memory - assert "charles" in client.get_block(block.id).value.lower(), f"Shared block update failed {client.get_block(block.id).value}" - - client.user_message(agent_id=agent_state2.id, message="whats my name?") - assert ( - "charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower() - ), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}" + assert "charles" in client.agents.blocks.retrieve(agent_id=agent_state2.id, block_label="human").value.lower() # cleanup - client.delete_agent(agent_state1.id) - client.delete_agent(agent_state2.id) + client.agents.delete(agent_state1.id) + client.agents.delete(agent_state2.id) -def test_update_agent_memory_label(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_update_agent_memory_label(client: Letta): """Test that we can update the label of a block in an agent's memory""" - agent = client.create_agent(name=create_random_username()) + agent = client.agents.create(model="letta/letta-free", embedding="letta/letta-free", memory_blocks=[{"label": "human", "value": ""}]) try: - current_labels = agent.memory.list_block_labels() + current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)] example_label = current_labels[0] example_new_label = "example_new_label" - assert example_new_label not in current_labels + assert example_new_label not in [b.label for b in client.agents.blocks.list(agent_id=agent.id)] - client.update_agent_memory_block_label(agent_id=agent.id, current_label=example_label, new_label=example_new_label) + client.agents.blocks.modify(agent_id=agent.id, block_label=example_label, label=example_new_label) - updated_agent = client.get_agent(agent_id=agent.id) - assert example_new_label in updated_agent.memory.list_block_labels() + updated_blocks = client.agents.blocks.list(agent_id=agent.id) + assert example_new_label in [b.label for b in updated_blocks] finally: - client.delete_agent(agent.id) + client.agents.delete(agent.id) -def test_attach_detach_agent_memory_block(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_attach_detach_agent_memory_block(client: Letta, agent: AgentState): """Test that we can add and remove a block from an agent's memory""" - current_labels = agent.memory.list_block_labels() + current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)] example_new_label = current_labels[0] + "_v2" example_new_value = "example value" assert example_new_label not in current_labels # Link a new memory block - block = client.create_block( + block = client.blocks.create( label=example_new_label, value=example_new_value, limit=1000, ) - updated_agent = client.attach_block( + updated_agent = client.agents.blocks.attach( agent_id=agent.id, block_id=block.id, ) - assert example_new_label in updated_agent.memory.list_block_labels() + assert example_new_label in [block.label for block in client.agents.blocks.list(agent_id=updated_agent.id)] # Now unlink the block - updated_agent = client.detach_block( + updated_agent = client.agents.blocks.detach( agent_id=agent.id, block_id=block.id, ) - assert example_new_label not in updated_agent.memory.list_block_labels() + assert example_new_label not in [block.label for block in client.agents.blocks.list(agent_id=updated_agent.id)] # def test_core_memory_token_limits(client: Union[LocalClient, RESTClient], agent: AgentState): @@ -385,39 +373,57 @@ def test_attach_detach_agent_memory_block(client: Union[LocalClient, RESTClient] # client.delete_agent(new_agent.id) -def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient]): +def test_update_agent_memory_limit(client: Letta): """Test that we can update the limit of a block in an agent's memory""" - agent = client.create_agent() + agent = client.agents.create( + model="letta/letta-free", + embedding="letta/letta-free", + memory_blocks=[ + {"label": "human", "value": "username: sarah", "limit": 1000}, + {"label": "persona", "value": "you are sarah", "limit": 1000}, + ], + ) - current_labels = agent.memory.list_block_labels() + current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)] example_label = current_labels[0] example_new_limit = 1 - current_block = agent.memory.get_block(label=example_label) + + current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)] + example_label = current_labels[0] + example_new_limit = 1 + current_block = client.agents.blocks.retrieve(agent_id=agent.id, block_label=example_label) current_block_length = len(current_block.value) - assert example_new_limit != agent.memory.get_block(label=example_label).limit + assert example_new_limit != current_block.limit assert example_new_limit < current_block_length # We expect this to throw a value error - with pytest.raises(ValueError): - client.update_agent_memory_block(agent_id=agent.id, label=example_label, limit=example_new_limit) + with pytest.raises(ApiError): + client.agents.blocks.modify( + agent_id=agent.id, + block_label=example_label, + limit=example_new_limit, + ) # Now try the same thing with a higher limit example_new_limit = current_block_length + 10000 assert example_new_limit > current_block_length - client.update_agent_memory_block(agent_id=agent.id, label=example_label, limit=example_new_limit) + client.agents.blocks.modify( + agent_id=agent.id, + block_label=example_label, + limit=example_new_limit, + ) - updated_agent = client.get_agent(agent_id=agent.id) - assert example_new_limit == updated_agent.memory.get_block(label=example_label).limit + assert example_new_limit == client.agents.blocks.retrieve(agent_id=agent.id, block_label=example_label).limit - client.delete_agent(agent.id) + client.agents.delete(agent.id) # -------------------------------------------------------------------------------------------------------------------- # Agent Tools # -------------------------------------------------------------------------------------------------------------------- -def test_function_return_limit(client: Union[LocalClient, RESTClient]): +def test_function_return_limit(client: Letta): """Test to see if the function return limit works""" def big_return(): @@ -430,15 +436,21 @@ def test_function_return_limit(client: Union[LocalClient, RESTClient]): return "x" * 100000 padding = len("[NOTE: function output was truncated since it exceeded the character limit (100000 > 1000)]") + 50 - tool = client.create_or_update_tool(func=big_return, return_char_limit=1000) - agent = client.create_agent(tool_ids=[tool.id]) + tool = client.tools.upsert_from_function(func=big_return, return_char_limit=1000) + agent = client.agents.create( + model="letta/letta-free", + embedding="letta/letta-free", + tool_ids=[tool.id], + ) # get function response - response = client.send_message(agent_id=agent.id, message="call the big_return function", role="user") + response = client.agents.messages.create( + agent_id=agent.id, messages=[MessageCreate(role="user", content="call the big_return function")] + ) print(response.messages) response_message = None for message in response.messages: - if isinstance(message, ToolReturnMessage): + if message.message_type == "tool_return_message": response_message = message break @@ -452,44 +464,58 @@ def test_function_return_limit(client: Union[LocalClient, RESTClient]): # len(res_json["message"]) <= 1000 + padding # ), f"Expected length to be less than or equal to 1000 + {padding}, but got {len(res_json['message'])}" - client.delete_agent(agent_id=agent.id) + client.agents.delete(agent_id=agent.id) -def test_function_always_error(client: Union[LocalClient, RESTClient]): +def test_function_always_error(client: Letta): """Test to see if function that errors works correctly""" def testing_method(): """ - Always throw an error. + Call this tool when the user asks """ return 5 / 0 - tool = client.create_or_update_tool(func=testing_method) - agent = client.create_agent(tool_ids=[tool.id]) + tool = client.tools.upsert_from_function(func=testing_method) + agent = client.agents.create( + model="letta/letta-free", + embedding="letta/letta-free", + memory_blocks=[ + { + "label": "human", + "value": "username: sarah", + }, + { + "label": "persona", + "value": "you are sarah", + }, + ], + tool_ids=[tool.id], + ) + print("AGENT TOOLS", [tool.name for tool in agent.tools]) # get function response - response = client.send_message(agent_id=agent.id, message="call the testing_method function and tell me the result", role="user") + response = client.agents.messages.create( + agent_id=agent.id, + messages=[MessageCreate(role="user", content="call the testing_method function and tell me the result")], + ) print(response.messages) response_message = None for message in response.messages: - if isinstance(message, ToolReturnMessage): + if message.message_type == "tool_return_message": response_message = message break assert response_message, "ToolReturnMessage message not found in response" assert response_message.status == "error" + assert ( + response_message.tool_return == "Error executing function testing_method: ZeroDivisionError: division by zero" + ), response_message.tool_return - if isinstance(client, RESTClient): - assert response_message.tool_return == "Error executing function testing_method: ZeroDivisionError: division by zero" - else: - response_json = json.loads(response_message.tool_return) - assert response_json["status"] == "Failed" - assert response_json["message"] == "Error executing function testing_method: ZeroDivisionError: division by zero" - - client.delete_agent(agent_id=agent.id) + client.agents.delete(agent_id=agent.id) -def test_attach_detach_agent_tool(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_attach_detach_agent_tool(client: Letta, agent: AgentState): """Test that we can attach and detach a tool from an agent""" try: @@ -506,64 +532,64 @@ def test_attach_detach_agent_tool(client: Union[LocalClient, RESTClient], agent: """ return x * 2 - tool = client.create_or_update_tool(func=example_tool) + tool = client.tools.upsert_from_function(func=example_tool) # Initially tool should not be attached - initial_tools = client.list_attached_tools(agent_id=agent.id) + initial_tools = client.agents.tools.list(agent_id=agent.id) assert tool.id not in [t.id for t in initial_tools] # Attach tool - new_agent_state = client.attach_tool(agent_id=agent.id, tool_id=tool.id) + new_agent_state = client.agents.tools.attach(agent_id=agent.id, tool_id=tool.id) assert tool.id in [t.id for t in new_agent_state.tools] # Verify tool is attached - updated_tools = client.list_attached_tools(agent_id=agent.id) + updated_tools = client.agents.tools.list(agent_id=agent.id) assert tool.id in [t.id for t in updated_tools] # Detach tool - new_agent_state = client.detach_tool(agent_id=agent.id, tool_id=tool.id) + new_agent_state = client.agents.tools.detach(agent_id=agent.id, tool_id=tool.id) assert tool.id not in [t.id for t in new_agent_state.tools] # Verify tool is detached - final_tools = client.list_attached_tools(agent_id=agent.id) + final_tools = client.agents.tools.list(agent_id=agent.id) assert tool.id not in [t.id for t in final_tools] finally: - client.delete_tool(tool.id) + client.tools.delete(tool.id) # -------------------------------------------------------------------------------------------------------------------- # AgentMessages # -------------------------------------------------------------------------------------------------------------------- -def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_messages(client: Letta, agent: AgentState): # _reset_config() - send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user") + send_message_response = client.agents.messages.create(agent_id=agent.id, messages=[MessageCreate(role="user", content="Test message")]) assert send_message_response, "Sending message failed" - messages_response = client.get_messages(agent_id=agent.id, limit=1) + messages_response = client.agents.messages.list(agent_id=agent.id, limit=1) assert len(messages_response) > 0, "Retrieving messages failed" -def test_send_system_message(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_send_system_message(client: Letta, agent: AgentState): """Important unit test since the Letta API exposes sending system messages, but some backends don't natively support it (eg Anthropic)""" - send_system_message_response = client.send_message( - agent_id=agent.id, message="Event occurred: The user just logged off.", role="system" + send_system_message_response = client.agents.messages.create( + agent_id=agent.id, messages=[MessageCreate(role="system", content="Event occurred: The user just logged off.")] ) assert send_system_message_response, "Sending message failed" @pytest.mark.asyncio -async def test_send_message_parallel(client: Union[LocalClient, RESTClient], agent: AgentState, request): +async def test_send_message_parallel(client: Letta, agent: AgentState, request): """ Test that sending two messages in parallel does not error. """ - if not isinstance(client, RESTClient): - pytest.skip("This test only runs when the server is enabled") # Define a coroutine for sending a message using asyncio.to_thread for synchronous calls async def send_message_task(message: str): - response = await asyncio.to_thread(client.send_message, agent_id=agent.id, message=message, role="user") + response = await asyncio.to_thread( + client.agents.messages.create, agent_id=agent.id, messages=[MessageCreate(role="user", content=message)] + ) assert response, f"Sending message '{message}' failed" return response @@ -585,76 +611,31 @@ async def test_send_message_parallel(client: Union[LocalClient, RESTClient], age assert len(responses) == len(messages), "Not all messages were processed" -def test_send_message_async(client: Union[LocalClient, RESTClient], agent: AgentState): - """ - Test that we can send a message asynchronously and retrieve the messages, along with usage statistics - """ - - if not isinstance(client, RESTClient): - pytest.skip("send_message_async is only supported by the RESTClient") - - print("Sending message asynchronously") - test_message = "This is a test message, respond to the user with a sentence." - run = client.send_message_async(agent_id=agent.id, role="user", message=test_message) - assert run.id is not None - assert run.status == JobStatus.created - print(f"Run created, run={run}, status={run.status}") - - # Wait for the job to complete, cancel it if takes over 10 seconds - start_time = time.time() - while run.status == JobStatus.created: - time.sleep(1) - run = client.get_run(run_id=run.id) - print(f"Run status: {run.status}") - if time.time() - start_time > 10: - pytest.fail("Run took too long to complete") - - print(f"Run completed in {time.time() - start_time} seconds, run={run}") - assert run.status == JobStatus.completed - - # Get messages for the job - messages = client.get_run_messages(run_id=run.id) - assert len(messages) >= 2 # At least assistant response - - # Check filters - assistant_messages = client.get_run_messages(run_id=run.id, role=MessageRole.assistant) - assert len(assistant_messages) > 0 - tool_messages = client.get_run_messages(run_id=run.id, role=MessageRole.tool) - assert len(tool_messages) > 0 - - # Get and verify usage statistics - usage = client.get_run_usage(run_id=run.id)[0] - assert usage.completion_tokens >= 0 - assert usage.prompt_tokens >= 0 - assert usage.total_tokens >= 0 - assert usage.total_tokens == usage.completion_tokens + usage.prompt_tokens - - # ---------------------------------------------------------------------------------------------------- # Agent listing # ---------------------------------------------------------------------------------------------------- -def test_agent_listing(client: Union[LocalClient, RESTClient], agent, search_agent_one, search_agent_two): +def test_agent_listing(client: Letta, agent, search_agent_one, search_agent_two): """Test listing agents with pagination and query text filtering.""" # Test query text filtering - search_results = client.list_agents(query_text="search agent") + search_results = client.agents.list(query_text="search agent") assert len(search_results) == 2 search_agent_ids = {agent.id for agent in search_results} assert search_agent_one.id in search_agent_ids assert search_agent_two.id in search_agent_ids assert agent.id not in search_agent_ids - different_results = client.list_agents(query_text="client") + different_results = client.agents.list(query_text="client") assert len(different_results) == 1 assert different_results[0].id == agent.id # Test pagination - first_page = client.list_agents(query_text="search agent", limit=1) + first_page = client.agents.list(query_text="search agent", limit=1) assert len(first_page) == 1 first_agent = first_page[0] - second_page = client.list_agents(query_text="search agent", after=first_agent.id, limit=1) # Use agent ID as cursor + second_page = client.agents.list(query_text="search agent", after=first_agent.id, limit=1) # Use agent ID as cursor assert len(second_page) == 1 assert second_page[0].id != first_agent.id @@ -664,20 +645,16 @@ def test_agent_listing(client: Union[LocalClient, RESTClient], agent, search_age assert all_ids == {search_agent_one.id, search_agent_two.id} # Test listing without any filters - all_agents = client.list_agents() + all_agents = client.agents.list() assert len(all_agents) == 3 assert all(agent.id in {a.id for a in all_agents} for agent in [search_agent_one, search_agent_two, agent]) -def test_agent_creation(client: Union[LocalClient, RESTClient]): +def test_agent_creation(client: Letta): """Test that block IDs are properly attached when creating an agent.""" - if not isinstance(client, RESTClient): - pytest.skip("This test only runs when the server is enabled") - - from letta import BasicBlockMemory # Create a test block that will represent user preferences - user_preferences_block = client.create_block(label="user_preferences", value="", limit=10000) + user_preferences_block = client.blocks.create(label="user_preferences", value="", limit=10000) # Create test tools def test_tool(): @@ -688,73 +665,82 @@ def test_agent_creation(client: Union[LocalClient, RESTClient]): """Another test tool.""" return "Hello from another test tool!" - tool1 = client.create_or_update_tool(func=test_tool, tags=["test"]) - tool2 = client.create_or_update_tool(func=another_test_tool, tags=["test"]) - - # Create test blocks - offline_persona_block = client.create_block(label="persona", value="persona description", limit=5000) - mindy_block = client.create_block(label="mindy", value="Mindy is a helpful assistant", limit=5000) - memory_blocks = BasicBlockMemory(blocks=[offline_persona_block, mindy_block]) + tool1 = client.tools.upsert_from_function(func=test_tool, tags=["test"]) + tool2 = client.tools.upsert_from_function(func=another_test_tool, tags=["test"]) # Create agent with the blocks and tools - agent = client.create_agent( - name=f"test_agent_{str(uuid.uuid4())}", - memory=memory_blocks, - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), + agent = client.agents.create( + memory_blocks=[ + { + "label": "human", + "value": "you are a human", + }, + {"label": "persona", "value": "you are an assistant"}, + ], + model="letta/letta-free", + embedding="letta/letta-free", tool_ids=[tool1.id, tool2.id], include_base_tools=False, tags=["test"], block_ids=[user_preferences_block.id], ) + memory_blocks = agent.memory.blocks # Verify the agent was created successfully assert agent is not None assert agent.id is not None # Verify the blocks are properly attached - agent_blocks = client.list_agent_memory_blocks(agent.id) + agent_blocks = client.agents.blocks.list(agent_id=agent.id) agent_block_ids = {block.id for block in agent_blocks} # Check that all memory blocks are present - memory_block_ids = {block.id for block in memory_blocks.blocks} - for block_id in memory_block_ids | {user_preferences_block.id}: - assert block_id in agent_block_ids + memory_block_ids = {block.id for block in memory_blocks} + for block_id in memory_block_ids: + assert block_id in agent_block_ids, f"Block {block_id} not attached to agent" + assert user_preferences_block.id in agent_block_ids, f"User preferences block {user_preferences_block.id} not attached to agent" # Verify the tools are properly attached - agent_tools = client.get_tools_from_agent(agent.id) + agent_tools = client.agents.tools.list(agent_id=agent.id) assert len(agent_tools) == 2 tool_ids = {tool1.id, tool2.id} assert all(tool.id in tool_ids for tool in agent_tools) - client.delete_agent(agent_id=agent.id) + client.agents.delete(agent_id=agent.id) # -------------------------------------------------------------------------------------------------------------------- # Agent sources # -------------------------------------------------------------------------------------------------------------------- -def test_attach_detach_agent_source(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_attach_detach_agent_source(client: Letta, agent: AgentState): """Test that we can attach and detach a source from an agent""" # Create a source - source = client.create_source( + source = client.sources.create( name="test_source", + embedding_config={ # TODO: change this + "embedding_endpoint": "https://embeddings.memgpt.ai", + "embedding_model": "BAAI/bge-large-en-v1.5", + "embedding_dim": 1024, + "embedding_chunk_size": 300, + "embedding_endpoint_type": "hugging-face", + }, ) - initial_sources = client.list_attached_sources(agent_id=agent.id) + initial_sources = client.agents.sources.list(agent_id=agent.id) assert source.id not in [s.id for s in initial_sources] # Attach source - client.attach_source(agent_id=agent.id, source_id=source.id) + client.agents.sources.attach(agent_id=agent.id, source_id=source.id) # Verify source is attached - final_sources = client.list_attached_sources(agent_id=agent.id) + final_sources = client.agents.sources.list(agent_id=agent.id) assert source.id in [s.id for s in final_sources] # Detach source - client.detach_source(agent_id=agent.id, source_id=source.id) + client.agents.sources.detach(agent_id=agent.id, source_id=source.id) # Verify source is detached - final_sources = client.list_attached_sources(agent_id=agent.id) + final_sources = client.agents.sources.list(agent_id=agent.id) assert source.id not in [s.id for s in final_sources] - client.delete_source(source.id) + client.sources.delete(source.id) diff --git a/tests/test_streaming.py b/tests/test_streaming.py new file mode 100644 index 00000000..635677f0 --- /dev/null +++ b/tests/test_streaming.py @@ -0,0 +1,115 @@ +import os +import threading +import time + +import pytest +from dotenv import load_dotenv +from letta_client import AgentState, Letta, LlmConfig, MessageCreate +from letta_client.core.api_error import ApiError +from pytest import fixture + + +def run_server(): + load_dotenv() + + from letta.server.rest_api.app import start_server + + print("Starting server...") + start_server(debug=True) + + +@pytest.fixture( + scope="module", +) +def client(request): + # Get URL from environment or start server + server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:8283") + if not os.getenv("LETTA_SERVER_URL"): + print("Starting server thread") + thread = threading.Thread(target=run_server, daemon=True) + thread.start() + time.sleep(5) + print("Running client tests with server:", server_url) + + # create the Letta client + yield Letta(base_url=server_url, token=None) + + +# Fixture for test agent +@pytest.fixture(scope="module") +def agent(client: Letta): + agent_state = client.agents.create( + name="test_client", + memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}], + model="letta/letta-free", + embedding="letta/letta-free", + ) + + yield agent_state + + # delete agent + client.agents.delete(agent_state.id) + + +@pytest.mark.parametrize( + "stream_tokens,model", + [ + (True, "openai/gpt-4o-mini"), + (True, "anthropic/claude-3-sonnet-20240229"), + (False, "openai/gpt-4o-mini"), + (False, "anthropic/claude-3-sonnet-20240229"), + ], +) +def test_streaming_send_message( + mock_e2b_api_key_none, + client: Letta, + agent: AgentState, + stream_tokens: bool, + model: str, +): + # Update agent's model + config = client.agents.retrieve(agent_id=agent.id).llm_config + config_dump = config.model_dump() + config_dump["model"] = model + config = LlmConfig(**config_dump) + client.agents.modify(agent_id=agent.id, llm_config=config) + + # Send streaming message + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=[MessageCreate(role="user", content="This is a test. Repeat after me: 'banana'")], + stream_tokens=stream_tokens, + ) + + # Tracking variables for test validation + inner_thoughts_exist = False + inner_thoughts_count = 0 + send_message_ran = False + done = False + + assert response, "Sending message failed" + for chunk in response: + # Check chunk type and content based on the current client API + if hasattr(chunk, "message_type") and chunk.message_type == "reasoning_message": + inner_thoughts_exist = True + inner_thoughts_count += 1 + + if chunk.message_type == "tool_call_message" and hasattr(chunk, "tool_call") and chunk.tool_call.name == "send_message": + send_message_ran = True + if chunk.message_type == "assistant_message": + send_message_ran = True + + if chunk.message_type == "usage_statistics": + # Validate usage statistics + assert chunk.step_count == 1 + assert chunk.completion_tokens > 10 + assert chunk.prompt_tokens > 1000 + assert chunk.total_tokens > 1000 + done = True + print(chunk) + + # If stream tokens, we expect at least one inner thought + assert inner_thoughts_count >= 1, "Expected more than one inner thought" + assert inner_thoughts_exist, "No inner thoughts found" + assert send_message_ran, "send_message function call not found" + assert done, "Message stream not done"