From 725eaa7b2e9c99f045e0449973b638b707afa808 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Fri, 14 Mar 2025 09:43:03 -0700 Subject: [PATCH 1/4] feat: fix MCP-related logs format and add Docker tests (#1280) --- .github/workflows/docker-integration-tests.yaml | 3 ++- letta/server/server.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/workflows/docker-integration-tests.yaml b/.github/workflows/docker-integration-tests.yaml index a6683446..77ddb3a0 100644 --- a/.github/workflows/docker-integration-tests.yaml +++ b/.github/workflows/docker-integration-tests.yaml @@ -32,7 +32,8 @@ jobs: LETTA_PG_PORT: 8888 OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - run: docker compose -f dev-compose.yaml up --build -d + run: | + docker compose -f dev-compose.yaml up --build -d #- name: "Setup Python, Poetry and Dependencies" # uses: packetcoders/action-setup-cache-python-poetry@v1.2.0 # with: diff --git a/letta/server/server.py b/letta/server/server.py index af6adbfb..65e01832 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -353,8 +353,8 @@ class SyncServer(Server): for server_name, client in self.mcp_clients.items(): logger.info(f"Attempting to fetch tools from MCP server: {server_name}") mcp_tools = client.list_tools() - logger.info(f"MCP tools connected: {", ".join([t.name for t in mcp_tools])}") - logger.debug(f"MCP tools: {"\n".join([str(t) for t in mcp_tools])}") + logger.info(f"MCP tools connected: {', '.join([t.name for t in mcp_tools])}") + logger.debug(f"MCP tools: {', '.join([str(t) for t in mcp_tools])}") def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent: """Updated method to load agents from persisted storage""" @@ -1366,8 +1366,8 @@ class SyncServer(Server): # Print out the tools that are connected logger.info(f"Attempting to fetch tools from MCP server: {server_config.server_name}") new_mcp_tools = new_mcp_client.list_tools() - logger.info(f"MCP tools connected: {", ".join([t.name for t in new_mcp_tools])}") - logger.debug(f"MCP tools: {"\n".join([str(t) for t in new_mcp_tools])}") + logger.info(f"MCP tools connected: {', '.join([t.name for t in new_mcp_tools])}") + logger.debug(f"MCP tools: {', '.join([str(t) for t in new_mcp_tools])}") # Now that we've confirmed the config is working, let's add it to the client list self.mcp_clients[server_config.server_name] = new_mcp_client From 6e95c1490ec19cdc403e19f0f5866caf82d3eedd Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Fri, 14 Mar 2025 11:39:03 -0700 Subject: [PATCH 2/4] chore: bump version to 0.6.40 (#1289) --- letta/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/letta/__init__.py b/letta/__init__.py index 05d3a097..6cb151ab 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.6.39" +__version__ = "0.6.40" # import clients from letta.client.client import LocalClient, RESTClient, create_client From 578aeee50db7624399d0567ab21fc5403999ad8a Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Fri, 14 Mar 2025 14:58:30 -0700 Subject: [PATCH 3/4] feat: Refactor mcp client and make stdio errors more manageable (#1291) --- letta/agent.py | 2 +- letta/functions/mcp_client/__init__.py | 0 letta/functions/mcp_client/base_client.py | 61 ++++++ letta/functions/mcp_client/sse_client.py | 21 ++ letta/functions/mcp_client/stdio_client.py | 103 ++++++++++ letta/functions/mcp_client/types.py | 48 +++++ letta/functions/schema_generator.py | 2 +- letta/helpers/mcp_helpers.py | 129 ------------- letta/schemas/tool.py | 2 +- letta/server/rest_api/routers/v1/tools.py | 2 +- letta/server/server.py | 21 +- tests/test_base_functions.py | 213 --------------------- 12 files changed, 245 insertions(+), 359 deletions(-) create mode 100644 letta/functions/mcp_client/__init__.py create mode 100644 letta/functions/mcp_client/base_client.py create mode 100644 letta/functions/mcp_client/sse_client.py create mode 100644 letta/functions/mcp_client/stdio_client.py create mode 100644 letta/functions/mcp_client/types.py delete mode 100644 letta/helpers/mcp_helpers.py diff --git a/letta/agent.py b/letta/agent.py index 11414bce..5286f9cb 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -22,11 +22,11 @@ from letta.errors import ContextWindowExceededError from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source from letta.functions.functions import get_function_from_module from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name +from letta.functions.mcp_client.base_client import BaseMCPClient from letta.helpers import ToolRulesSolver from letta.helpers.composio_helpers import get_composio_api_key from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.json_helpers import json_dumps, json_loads -from letta.helpers.mcp_helpers import BaseMCPClient from letta.interface import AgentInterface from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error from letta.llm_api.llm_api_tools import create diff --git a/letta/functions/mcp_client/__init__.py b/letta/functions/mcp_client/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/letta/functions/mcp_client/base_client.py b/letta/functions/mcp_client/base_client.py new file mode 100644 index 00000000..91d46d91 --- /dev/null +++ b/letta/functions/mcp_client/base_client.py @@ -0,0 +1,61 @@ +import asyncio +from typing import List, Optional, Tuple + +from mcp import ClientSession, Tool + +from letta.functions.mcp_client.types import BaseServerConfig +from letta.log import get_logger + +logger = get_logger(__name__) + + +class BaseMCPClient: + def __init__(self): + self.session: Optional[ClientSession] = None + self.stdio = None + self.write = None + self.initialized = False + self.loop = asyncio.new_event_loop() + self.cleanup_funcs = [] + + def connect_to_server(self, server_config: BaseServerConfig): + asyncio.set_event_loop(self.loop) + success = self._initialize_connection(server_config) + + if success: + self.loop.run_until_complete(self.session.initialize()) + self.initialized = True + else: + raise RuntimeError( + f"Connecting to MCP server failed. Please review your server config: {server_config.model_dump_json(indent=4)}" + ) + + def _initialize_connection(self, server_config: BaseServerConfig) -> bool: + raise NotImplementedError("Subclasses must implement _initialize_connection") + + def list_tools(self) -> List[Tool]: + self._check_initialized() + response = self.loop.run_until_complete(self.session.list_tools()) + return response.tools + + def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]: + self._check_initialized() + result = self.loop.run_until_complete(self.session.call_tool(tool_name, tool_args)) + return str(result.content), result.isError + + def _check_initialized(self): + if not self.initialized: + logger.error("MCPClient has not been initialized") + raise RuntimeError("MCPClient has not been initialized") + + def cleanup(self): + try: + for cleanup_func in self.cleanup_funcs: + cleanup_func() + self.initialized = False + if not self.loop.is_closed(): + self.loop.close() + except Exception as e: + logger.warning(e) + finally: + logger.info("Cleaned up MCP clients on shutdown.") diff --git a/letta/functions/mcp_client/sse_client.py b/letta/functions/mcp_client/sse_client.py new file mode 100644 index 00000000..daf31367 --- /dev/null +++ b/letta/functions/mcp_client/sse_client.py @@ -0,0 +1,21 @@ +from mcp import ClientSession +from mcp.client.sse import sse_client + +from letta.functions.mcp_client.base_client import BaseMCPClient +from letta.functions.mcp_client.types import SSEServerConfig + +# see: https://modelcontextprotocol.io/quickstart/user +MCP_CONFIG_TOPLEVEL_KEY = "mcpServers" + + +class SSEMCPClient(BaseMCPClient): + def _initialize_connection(self, server_config: SSEServerConfig) -> bool: + sse_cm = sse_client(url=server_config.server_url) + sse_transport = self.loop.run_until_complete(sse_cm.__aenter__()) + self.stdio, self.write = sse_transport + self.cleanup_funcs.append(lambda: self.loop.run_until_complete(sse_cm.__aexit__(None, None, None))) + + session_cm = ClientSession(self.stdio, self.write) + self.session = self.loop.run_until_complete(session_cm.__aenter__()) + self.cleanup_funcs.append(lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None))) + return True diff --git a/letta/functions/mcp_client/stdio_client.py b/letta/functions/mcp_client/stdio_client.py new file mode 100644 index 00000000..b7960889 --- /dev/null +++ b/letta/functions/mcp_client/stdio_client.py @@ -0,0 +1,103 @@ +import sys +from contextlib import asynccontextmanager + +import anyio +import anyio.lowlevel +import mcp.types as types +from anyio.streams.text import TextReceiveStream +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import get_default_environment + +from letta.functions.mcp_client.base_client import BaseMCPClient +from letta.functions.mcp_client.types import StdioServerConfig +from letta.log import get_logger + +logger = get_logger(__name__) + + +class StdioMCPClient(BaseMCPClient): + def _initialize_connection(self, server_config: StdioServerConfig) -> bool: + try: + server_params = StdioServerParameters(command=server_config.command, args=server_config.args) + stdio_cm = forked_stdio_client(server_params) + stdio_transport = self.loop.run_until_complete(stdio_cm.__aenter__()) + self.stdio, self.write = stdio_transport + self.cleanup_funcs.append(lambda: self.loop.run_until_complete(stdio_cm.__aexit__(None, None, None))) + + session_cm = ClientSession(self.stdio, self.write) + self.session = self.loop.run_until_complete(session_cm.__aenter__()) + self.cleanup_funcs.append(lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None))) + return True + except Exception: + return False + + +@asynccontextmanager +async def forked_stdio_client(server: StdioServerParameters): + """ + Client transport for stdio: this will connect to a server by spawning a + process and communicating with it over stdin/stdout. + """ + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + try: + process = await anyio.open_process( + [server.command, *server.args], + env=server.env or get_default_environment(), + stderr=sys.stderr, # Consider logging stderr somewhere instead of silencing it + ) + except OSError as exc: + raise RuntimeError(f"Failed to spawn process: {server.command} {server.args}") from exc + + async def stdout_reader(): + assert process.stdout, "Opened process is missing stdout" + buffer = "" + try: + async with read_stream_writer: + async for chunk in TextReceiveStream( + process.stdout, + encoding=server.encoding, + errors=server.encoding_error_handler, + ): + lines = (buffer + chunk).split("\n") + buffer = lines.pop() + for line in lines: + try: + message = types.JSONRPCMessage.model_validate_json(line) + except Exception as exc: + await read_stream_writer.send(exc) + continue + await read_stream_writer.send(message) + except anyio.ClosedResourceError: + await anyio.lowlevel.checkpoint() + + async def stdin_writer(): + assert process.stdin, "Opened process is missing stdin" + try: + async with write_stream_reader: + async for message in write_stream_reader: + json = message.model_dump_json(by_alias=True, exclude_none=True) + await process.stdin.send( + (json + "\n").encode( + encoding=server.encoding, + errors=server.encoding_error_handler, + ) + ) + except anyio.ClosedResourceError: + await anyio.lowlevel.checkpoint() + + async def watch_process_exit(): + returncode = await process.wait() + if returncode != 0: + raise RuntimeError(f"Subprocess exited with code {returncode}. Command: {server.command} {server.args}") + + async with anyio.create_task_group() as tg, process: + tg.start_soon(stdout_reader) + tg.start_soon(stdin_writer) + tg.start_soon(watch_process_exit) + + with anyio.move_on_after(0.2): + await anyio.sleep_forever() + + yield read_stream, write_stream diff --git a/letta/functions/mcp_client/types.py b/letta/functions/mcp_client/types.py new file mode 100644 index 00000000..2d8b7af6 --- /dev/null +++ b/letta/functions/mcp_client/types.py @@ -0,0 +1,48 @@ +from enum import Enum +from typing import List, Optional + +from mcp import Tool +from pydantic import BaseModel, Field + + +class MCPTool(Tool): + """A simple wrapper around MCP's tool definition (to avoid conflict with our own)""" + + +class MCPServerType(str, Enum): + SSE = "sse" + STDIO = "stdio" + + +class BaseServerConfig(BaseModel): + server_name: str = Field(..., description="The name of the server") + type: MCPServerType + + +class SSEServerConfig(BaseServerConfig): + type: MCPServerType = MCPServerType.SSE + server_url: str = Field(..., description="The URL of the server (MCP SSE client will connect to this URL)") + + def to_dict(self) -> dict: + values = { + "transport": "sse", + "url": self.server_url, + } + return values + + +class StdioServerConfig(BaseServerConfig): + type: MCPServerType = MCPServerType.STDIO + command: str = Field(..., description="The command to run (MCP 'local' client will run this command)") + args: List[str] = Field(..., description="The arguments to pass to the command") + env: Optional[dict[str, str]] = Field(None, description="Environment variables to set") + + def to_dict(self) -> dict: + values = { + "transport": "stdio", + "command": self.command, + "args": self.args, + } + if self.env is not None: + values["env"] = self.env + return values diff --git a/letta/functions/schema_generator.py b/letta/functions/schema_generator.py index 30fd9cb7..62b134ae 100644 --- a/letta/functions/schema_generator.py +++ b/letta/functions/schema_generator.py @@ -6,7 +6,7 @@ from composio.client.collections import ActionParametersModel from docstring_parser import parse from pydantic import BaseModel -from letta.helpers.mcp_helpers import MCPTool +from letta.functions.mcp_client.types import MCPTool def is_optional(annotation): diff --git a/letta/helpers/mcp_helpers.py b/letta/helpers/mcp_helpers.py deleted file mode 100644 index 450622a3..00000000 --- a/letta/helpers/mcp_helpers.py +++ /dev/null @@ -1,129 +0,0 @@ -import asyncio -from enum import Enum -from typing import List, Optional, Tuple - -from mcp import ClientSession, StdioServerParameters, Tool -from mcp.client.sse import sse_client -from mcp.client.stdio import stdio_client -from pydantic import BaseModel, Field - -from letta.log import get_logger - -logger = get_logger(__name__) - -# see: https://modelcontextprotocol.io/quickstart/user -MCP_CONFIG_TOPLEVEL_KEY = "mcpServers" - - -class MCPTool(Tool): - """A simple wrapper around MCP's tool definition (to avoid conflict with our own)""" - - -class MCPServerType(str, Enum): - SSE = "sse" - STDIO = "stdio" - - -class BaseServerConfig(BaseModel): - server_name: str = Field(..., description="The name of the server") - type: MCPServerType - - -class SSEServerConfig(BaseServerConfig): - type: MCPServerType = MCPServerType.SSE - server_url: str = Field(..., description="The URL of the server (MCP SSE client will connect to this URL)") - - def to_dict(self) -> dict: - values = { - "transport": "sse", - "url": self.server_url, - } - return values - - -class StdioServerConfig(BaseServerConfig): - type: MCPServerType = MCPServerType.STDIO - command: str = Field(..., description="The command to run (MCP 'local' client will run this command)") - args: List[str] = Field(..., description="The arguments to pass to the command") - env: Optional[dict[str, str]] = Field(None, description="Environment variables to set") - - def to_dict(self) -> dict: - values = { - "transport": "stdio", - "command": self.command, - "args": self.args, - } - if self.env is not None: - values["env"] = self.env - return values - - -class BaseMCPClient: - def __init__(self): - self.session: Optional[ClientSession] = None - self.stdio = None - self.write = None - self.initialized = False - self.loop = asyncio.new_event_loop() - self.cleanup_funcs = [] - - def connect_to_server(self, server_config: BaseServerConfig): - asyncio.set_event_loop(self.loop) - self._initialize_connection(server_config) - self.loop.run_until_complete(self.session.initialize()) - self.initialized = True - - def _initialize_connection(self, server_config: BaseServerConfig): - raise NotImplementedError("Subclasses must implement _initialize_connection") - - def list_tools(self) -> List[Tool]: - self._check_initialized() - response = self.loop.run_until_complete(self.session.list_tools()) - return response.tools - - def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]: - self._check_initialized() - result = self.loop.run_until_complete(self.session.call_tool(tool_name, tool_args)) - return str(result.content), result.isError - - def _check_initialized(self): - if not self.initialized: - logger.error("MCPClient has not been initialized") - raise RuntimeError("MCPClient has not been initialized") - - def cleanup(self): - try: - for cleanup_func in self.cleanup_funcs: - cleanup_func() - self.initialized = False - if not self.loop.is_closed(): - self.loop.close() - except Exception as e: - logger.warning(e) - finally: - logger.info("Cleaned up MCP clients on shutdown.") - - -class StdioMCPClient(BaseMCPClient): - def _initialize_connection(self, server_config: StdioServerConfig): - server_params = StdioServerParameters(command=server_config.command, args=server_config.args) - stdio_cm = stdio_client(server_params) - stdio_transport = self.loop.run_until_complete(stdio_cm.__aenter__()) - self.stdio, self.write = stdio_transport - self.cleanup_funcs.append(lambda: self.loop.run_until_complete(stdio_cm.__aexit__(None, None, None))) - - session_cm = ClientSession(self.stdio, self.write) - self.session = self.loop.run_until_complete(session_cm.__aenter__()) - self.cleanup_funcs.append(lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None))) - - -class SSEMCPClient(BaseMCPClient): - def _initialize_connection(self, server_config: SSEServerConfig): - sse_cm = sse_client(url=server_config.server_url) - sse_transport = self.loop.run_until_complete(sse_cm.__aenter__()) - self.stdio, self.write = sse_transport - self.cleanup_funcs.append(lambda: self.loop.run_until_complete(sse_cm.__aexit__(None, None, None))) - - session_cm = ClientSession(self.stdio, self.write) - self.session = self.loop.run_until_complete(session_cm.__aenter__()) - self.cleanup_funcs.append(lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None))) diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index b2499a6f..55fac00c 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -17,12 +17,12 @@ from letta.functions.helpers import ( generate_mcp_tool_wrapper, generate_model_from_args_json_schema, ) +from letta.functions.mcp_client.types import MCPTool from letta.functions.schema_generator import ( generate_schema_from_args_schema_v2, generate_tool_schema_for_composio, generate_tool_schema_for_mcp, ) -from letta.helpers.mcp_helpers import MCPTool from letta.log import get_logger from letta.orm.enums import ToolType from letta.schemas.letta_base import LettaBase diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 2290d281..b4423027 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -12,8 +12,8 @@ from composio.exceptions import ( from fastapi import APIRouter, Body, Depends, Header, HTTPException from letta.errors import LettaToolCreateError +from letta.functions.mcp_client.types import MCPTool, SSEServerConfig, StdioServerConfig from letta.helpers.composio_helpers import get_composio_api_key -from letta.helpers.mcp_helpers import MCPTool, SSEServerConfig, StdioServerConfig from letta.log import get_logger from letta.orm.errors import UniqueConstraintViolationError from letta.schemas.letta_message import ToolReturnMessage diff --git a/letta/server/server.py b/letta/server/server.py index 65e01832..8e9ff8dc 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -20,18 +20,12 @@ from letta.agent import Agent, save_agent from letta.config import LettaConfig from letta.data_sources.connectors import DataConnector, load_data from letta.dynamic_multi_agent import DynamicMultiAgent +from letta.functions.mcp_client.base_client import BaseMCPClient +from letta.functions.mcp_client.sse_client import MCP_CONFIG_TOPLEVEL_KEY, SSEMCPClient +from letta.functions.mcp_client.stdio_client import StdioMCPClient +from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.json_helpers import json_dumps, json_loads -from letta.helpers.mcp_helpers import ( - MCP_CONFIG_TOPLEVEL_KEY, - BaseMCPClient, - MCPServerType, - MCPTool, - SSEMCPClient, - SSEServerConfig, - StdioMCPClient, - StdioServerConfig, -) # TODO use custom interface from letta.interface import AgentInterface # abstract @@ -343,11 +337,12 @@ class SyncServer(Server): self.mcp_clients[server_name] = StdioMCPClient() else: raise ValueError(f"Invalid MCP server config: {server_config}") + try: self.mcp_clients[server_name].connect_to_server(server_config) - except: - logger.exception(f"Failed to connect to MCP server: {server_name}") - raise + except Exception as e: + logger.error(e) + self.mcp_clients.pop(server_name) # Print out the tools that are connected for server_name, client in self.mcp_clients.items(): diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index 8b133638..037eda2f 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -98,216 +98,3 @@ def test_recall(client, agent_obj): # Conversation search result = base_functions.conversation_search(agent_obj, "banana") assert keyword in result - - -# This test is nondeterministic, so we retry until we get the perfect behavior from the LLM -@retry_until_success(max_attempts=2, sleep_time_seconds=2) -def test_send_message_to_agent(client, agent_obj, other_agent_obj): - secret_word = "banana" - - # Encourage the agent to send a message to the other agent_obj with the secret string - client.send_message( - agent_id=agent_obj.agent_state.id, - role="user", - message=f"Use your tool to send a message to another agent with id {other_agent_obj.agent_state.id} to share the secret word: {secret_word}!", - ) - - # Conversation search the other agent - messages = client.get_messages(other_agent_obj.agent_state.id) - # Check for the presence of system message - for m in reversed(messages): - print(f"\n\n {other_agent_obj.agent_state.id} -> {m.model_dump_json(indent=4)}") - if isinstance(m, SystemMessage): - assert secret_word in m.content - break - - # Search the sender agent for the response from another agent - in_context_messages = agent_obj.agent_manager.get_in_context_messages(agent_id=agent_obj.agent_state.id, actor=agent_obj.user) - found = False - target_snippet = f"{other_agent_obj.agent_state.id} said:" - - for m in in_context_messages: - if target_snippet in m.text: - found = True - break - - # Compute the joined string first - joined_messages = "\n".join([m.text for m in in_context_messages[1:]]) - print(f"In context messages of the sender agent (without system):\n\n{joined_messages}") - if not found: - raise Exception(f"Was not able to find an instance of the target snippet: {target_snippet}") - - # Test that the agent can still receive messages fine - response = client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="So what did the other agent say?") - print(response.messages) - - -@retry_until_success(max_attempts=2, sleep_time_seconds=2) -def test_send_message_to_agents_with_tags_simple(client): - worker_tags = ["worker", "user-456"] - - # Clean up first from possibly failed tests - prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags, match_all_tags=True) - for agent in prev_worker_agents: - client.delete_agent(agent.id) - - secret_word = "banana" - - # Create "manager" agent - send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags") - manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id]) - manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user) - - # Create 3 non-matching worker agents (These should NOT get the message) - worker_agents = [] - worker_tags = ["worker", "user-123"] - for _ in range(3): - worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags) - worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user) - worker_agents.append(worker_agent) - - # Create 3 worker agents that should get the message - worker_agents = [] - worker_tags = ["worker", "user-456"] - for _ in range(3): - worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags) - worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user) - worker_agents.append(worker_agent) - - # Encourage the manager to send a message to the other agent_obj with the secret string - response = client.send_message( - agent_id=manager_agent.agent_state.id, - role="user", - message=f"Send a message to all agents with tags {worker_tags} informing them of the secret word: {secret_word}!", - ) - - for m in response.messages: - if isinstance(m, ToolReturnMessage): - tool_response = eval(json.loads(m.tool_return)["message"]) - print(f"\n\nManager agent tool response: \n{tool_response}\n\n") - assert len(tool_response) == len(worker_agents) - - # We can break after this, the ToolReturnMessage after is not related - break - - # Conversation search the worker agents - for agent in worker_agents: - messages = client.get_messages(agent.agent_state.id) - # Check for the presence of system message - for m in reversed(messages): - print(f"\n\n {agent.agent_state.id} -> {m.model_dump_json(indent=4)}") - if isinstance(m, SystemMessage): - assert secret_word in m.content - break - - # Test that the agent can still receive messages fine - response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?") - print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages])) - - # Clean up agents - client.delete_agent(manager_agent_state.id) - for agent in worker_agents: - client.delete_agent(agent.agent_state.id) - - -# This test is nondeterministic, so we retry until we get the perfect behavior from the LLM -@retry_until_success(max_attempts=2, sleep_time_seconds=2) -def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_tool): - worker_tags = ["dice-rollers"] - - # Clean up first from possibly failed tests - prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags, match_all_tags=True) - for agent in prev_worker_agents: - client.delete_agent(agent.id) - - # Create "manager" agent - send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags") - manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id]) - manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user) - - # Create 3 worker agents - worker_agents = [] - worker_tags = ["dice-rollers"] - for _ in range(2): - worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags, tool_ids=[roll_dice_tool.id]) - worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user) - worker_agents.append(worker_agent) - - # Encourage the manager to send a message to the other agent_obj with the secret string - broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!" - response = client.send_message( - agent_id=manager_agent.agent_state.id, - role="user", - message=broadcast_message, - ) - - for m in response.messages: - if isinstance(m, ToolReturnMessage): - tool_response = eval(json.loads(m.tool_return)["message"]) - print(f"\n\nManager agent tool response: \n{tool_response}\n\n") - assert len(tool_response) == len(worker_agents) - - # We can break after this, the ToolReturnMessage after is not related - break - - # Test that the agent can still receive messages fine - response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?") - print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages])) - - # Clean up agents - client.delete_agent(manager_agent_state.id) - for agent in worker_agents: - client.delete_agent(agent.agent_state.id) - - -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_agents_async_simple(client): - """ - Test two agents with multi-agent tools sending messages back and forth to count to 5. - The chain is started by prompting one of the agents. - """ - # Cleanup from potentially failed previous runs - existing_agents = client.server.agent_manager.list_agents(client.user) - for agent in existing_agents: - client.delete_agent(agent.id) - - # Create two agents with multi-agent tools - send_message_to_agent_async_tool_id = client.get_tool_id(name="send_message_to_agent_async") - memory_a = ChatMemory( - human="Chad - I'm interested in hearing poem.", - persona="You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who has some great poem ideas (so I've heard).", - ) - charles_state = client.create_agent(name="charles", memory=memory_a, tool_ids=[send_message_to_agent_async_tool_id]) - charles = client.server.load_agent(agent_id=charles_state.id, actor=client.user) - - memory_b = ChatMemory( - human="No human - you are to only communicate with the other AI agent.", - persona="You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who is interested in great poem ideas.", - ) - sarah_state = client.create_agent(name="sarah", memory=memory_b, tool_ids=[send_message_to_agent_async_tool_id]) - - # Start the count chain with Agent1 - initial_prompt = f"I want you to talk to the other agent with ID {sarah_state.id} using `send_message_to_agent_async`. Specifically, I want you to ask him for a poem idea, and then craft a poem for me." - client.send_message( - agent_id=charles.agent_state.id, - role="user", - message=initial_prompt, - ) - - found_in_charles = wait_for_incoming_message( - client=client, - agent_id=charles_state.id, - substring="[Incoming message from agent with ID", - max_wait_seconds=10, - sleep_interval=0.5, - ) - assert found_in_charles, "Charles never received the system message from Sarah (timed out)." - - found_in_sarah = wait_for_incoming_message( - client=client, - agent_id=sarah_state.id, - substring="[Incoming message from agent with ID", - max_wait_seconds=10, - sleep_interval=0.5, - ) - assert found_in_sarah, "Sarah never received the system message from Charles (timed out)." From 9928bf60197b2c78ea9530d4f61469e7013ccb41 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Fri, 14 Mar 2025 15:17:28 -0700 Subject: [PATCH 4/4] chore: migrate tests to new client (#1290) --- .../workflows/docker-integration-tests.yaml | 3 +- tests/test_client.py | 570 +++++++++--------- tests/test_streaming.py | 115 ++++ 3 files changed, 395 insertions(+), 293 deletions(-) create mode 100644 tests/test_streaming.py diff --git a/.github/workflows/docker-integration-tests.yaml b/.github/workflows/docker-integration-tests.yaml index 77ddb3a0..63886ffe 100644 --- a/.github/workflows/docker-integration-tests.yaml +++ b/.github/workflows/docker-integration-tests.yaml @@ -57,7 +57,8 @@ jobs: run: | pipx install poetry==1.8.2 poetry install -E dev -E postgres - poetry run pytest -s tests/test_client_legacy.py + poetry run pytest -s tests/test_client.py +# poetry run pytest -s tests/test_client_legacy.py - name: Print docker logs if tests fail if: failure() diff --git a/tests/test_client.py b/tests/test_client.py index c53ac781..856c4227 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,18 +8,11 @@ from typing import List, Union import pytest from dotenv import load_dotenv +from letta_client import AgentState, JobStatus, Letta, MessageCreate, MessageRole +from letta_client.core.api_error import ApiError from sqlalchemy import delete -from letta import LocalClient, RESTClient, create_client from letta.orm import SandboxConfig, SandboxEnvironmentVariable -from letta.schemas.agent import AgentState -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import MessageRole -from letta.schemas.job import JobStatus -from letta.schemas.letta_message import ToolReturnMessage -from letta.schemas.llm_config import LLMConfig -from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxType -from letta.utils import create_random_username # Constants SERVER_PORT = 8283 @@ -42,63 +35,68 @@ def run_server(): @pytest.fixture( - params=[ - {"server": False}, - {"server": True}, - ], # whether to use REST API server - # params=[{"server": False}], # whether to use REST API server scope="module", ) def client(request): - if request.param["server"]: - # Get URL from environment or start server - server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:{SERVER_PORT}") - if not os.getenv("LETTA_SERVER_URL"): - print("Starting server thread") - thread = threading.Thread(target=run_server, daemon=True) - thread.start() - time.sleep(5) - print("Running client tests with server:", server_url) - client = create_client(base_url=server_url, token=None) - else: - client = create_client() + # Get URL from environment or start server + server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:{SERVER_PORT}") + if not os.getenv("LETTA_SERVER_URL"): + print("Starting server thread") + thread = threading.Thread(target=run_server, daemon=True) + thread.start() + time.sleep(5) + print("Running client tests with server:", server_url) - client.set_default_llm_config(LLMConfig.default_config("gpt-4")) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - yield client + # create the Letta client + yield Letta(base_url=server_url, token=None) # Fixture for test agent @pytest.fixture(scope="module") -def agent(client: Union[LocalClient, RESTClient]): - agent_state = client.create_agent(name=f"test_client_{str(uuid.uuid4())}") +def agent(client: Letta): + agent_state = client.agents.create( + name="test_client", + memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}], + model="letta/letta-free", + embedding="letta/letta-free", + ) yield agent_state # delete agent - client.delete_agent(agent_state.id) + client.agents.delete(agent_state.id) # Fixture for test agent @pytest.fixture -def search_agent_one(client: Union[LocalClient, RESTClient]): - agent_state = client.create_agent(name="Search Agent One") +def search_agent_one(client: Letta): + agent_state = client.agents.create( + name="Search Agent One", + memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}], + model="letta/letta-free", + embedding="letta/letta-free", + ) yield agent_state # delete agent - client.delete_agent(agent_state.id) + client.agents.delete(agent_state.id) # Fixture for test agent @pytest.fixture -def search_agent_two(client: Union[LocalClient, RESTClient]): - agent_state = client.create_agent(name="Search Agent Two") +def search_agent_two(client: Letta): + agent_state = client.agents.create( + name="Search Agent Two", + memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}], + model="letta/letta-free", + embedding="letta/letta-free", + ) yield agent_state # delete agent - client.delete_agent(agent_state.id) + client.agents.delete(agent_state.id) @pytest.fixture(autouse=True) @@ -112,55 +110,56 @@ def clear_tables(): session.commit() -def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]): - """ - Test sandbox config and environment variable functions for both LocalClient and RESTClient. - """ - - # 1. Create a sandbox config - local_config = LocalSandboxConfig(sandbox_dir=SANDBOX_DIR) - sandbox_config = client.create_sandbox_config(config=local_config) - - # Assert the created sandbox config - assert sandbox_config.id is not None - assert sandbox_config.type == SandboxType.LOCAL - - # 2. Update the sandbox config - updated_config = LocalSandboxConfig(sandbox_dir=UPDATED_SANDBOX_DIR) - sandbox_config = client.update_sandbox_config(sandbox_config_id=sandbox_config.id, config=updated_config) - assert sandbox_config.config["sandbox_dir"] == UPDATED_SANDBOX_DIR - - # 3. List all sandbox configs - sandbox_configs = client.list_sandbox_configs(limit=10) - assert isinstance(sandbox_configs, List) - assert len(sandbox_configs) == 1 - assert sandbox_configs[0].id == sandbox_config.id - - # 4. Create an environment variable - env_var = client.create_sandbox_env_var( - sandbox_config_id=sandbox_config.id, key=ENV_VAR_KEY, value=ENV_VAR_VALUE, description=ENV_VAR_DESCRIPTION - ) - assert env_var.id is not None - assert env_var.key == ENV_VAR_KEY - assert env_var.value == ENV_VAR_VALUE - assert env_var.description == ENV_VAR_DESCRIPTION - - # 5. Update the environment variable - updated_env_var = client.update_sandbox_env_var(env_var_id=env_var.id, key=UPDATED_ENV_VAR_KEY, value=UPDATED_ENV_VAR_VALUE) - assert updated_env_var.key == UPDATED_ENV_VAR_KEY - assert updated_env_var.value == UPDATED_ENV_VAR_VALUE - - # 6. List environment variables - env_vars = client.list_sandbox_env_vars(sandbox_config_id=sandbox_config.id) - assert isinstance(env_vars, List) - assert len(env_vars) == 1 - assert env_vars[0].key == UPDATED_ENV_VAR_KEY - - # 7. Delete the environment variable - client.delete_sandbox_env_var(env_var_id=env_var.id) - - # 8. Delete the sandbox config - client.delete_sandbox_config(sandbox_config_id=sandbox_config.id) +# TODO: add back +# def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]): +# """ +# Test sandbox config and environment variable functions for both LocalClient and RESTClient. +# """ +# +# # 1. Create a sandbox config +# local_config = LocalSandboxConfig(sandbox_dir=SANDBOX_DIR) +# sandbox_config = client.create_sandbox_config(config=local_config) +# +# # Assert the created sandbox config +# assert sandbox_config.id is not None +# assert sandbox_config.type == SandboxType.LOCAL +# +# # 2. Update the sandbox config +# updated_config = LocalSandboxConfig(sandbox_dir=UPDATED_SANDBOX_DIR) +# sandbox_config = client.update_sandbox_config(sandbox_config_id=sandbox_config.id, config=updated_config) +# assert sandbox_config.config["sandbox_dir"] == UPDATED_SANDBOX_DIR +# +# # 3. List all sandbox configs +# sandbox_configs = client.list_sandbox_configs(limit=10) +# assert isinstance(sandbox_configs, List) +# assert len(sandbox_configs) == 1 +# assert sandbox_configs[0].id == sandbox_config.id +# +# # 4. Create an environment variable +# env_var = client.create_sandbox_env_var( +# sandbox_config_id=sandbox_config.id, key=ENV_VAR_KEY, value=ENV_VAR_VALUE, description=ENV_VAR_DESCRIPTION +# ) +# assert env_var.id is not None +# assert env_var.key == ENV_VAR_KEY +# assert env_var.value == ENV_VAR_VALUE +# assert env_var.description == ENV_VAR_DESCRIPTION +# +# # 5. Update the environment variable +# updated_env_var = client.update_sandbox_env_var(env_var_id=env_var.id, key=UPDATED_ENV_VAR_KEY, value=UPDATED_ENV_VAR_VALUE) +# assert updated_env_var.key == UPDATED_ENV_VAR_KEY +# assert updated_env_var.value == UPDATED_ENV_VAR_VALUE +# +# # 6. List environment variables +# env_vars = client.list_sandbox_env_vars(sandbox_config_id=sandbox_config.id) +# assert isinstance(env_vars, List) +# assert len(env_vars) == 1 +# assert env_vars[0].key == UPDATED_ENV_VAR_KEY +# +# # 7. Delete the environment variable +# client.delete_sandbox_env_var(env_var_id=env_var.id) +# +# # 8. Delete the sandbox config +# client.delete_sandbox_config(sandbox_config_id=sandbox_config.id) # -------------------------------------------------------------------------------------------------------------------- @@ -168,197 +167,186 @@ def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient] # -------------------------------------------------------------------------------------------------------------------- -def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient]): +def test_add_and_manage_tags_for_agent(client: Letta): """ Comprehensive happy path test for adding, retrieving, and managing tags on an agent. """ tags_to_add = ["test_tag_1", "test_tag_2", "test_tag_3"] # Step 0: create an agent with no tags - agent = client.create_agent() + agent = client.agents.create(memory_blocks=[], model="letta/letta-free", embedding="letta/letta-free") assert len(agent.tags) == 0 # Step 1: Add multiple tags to the agent - client.update_agent(agent_id=agent.id, tags=tags_to_add) + client.agents.modify(agent_id=agent.id, tags=tags_to_add) # Step 2: Retrieve tags for the agent and verify they match the added tags - retrieved_tags = client.get_agent(agent_id=agent.id).tags + retrieved_tags = client.agents.retrieve(agent_id=agent.id).tags assert set(retrieved_tags) == set(tags_to_add), f"Expected tags {tags_to_add}, but got {retrieved_tags}" # Step 3: Retrieve agents by each tag to ensure the agent is associated correctly for tag in tags_to_add: - agents_with_tag = client.list_agents(tags=[tag]) + agents_with_tag = client.agents.list(tags=[tag]) assert agent.id in [a.id for a in agents_with_tag], f"Expected agent {agent.id} to be associated with tag '{tag}'" # Step 4: Delete a specific tag from the agent and verify its removal tag_to_delete = tags_to_add.pop() - client.update_agent(agent_id=agent.id, tags=tags_to_add) + client.agents.modify(agent_id=agent.id, tags=tags_to_add) # Verify the tag is removed from the agent's tags - remaining_tags = client.get_agent(agent_id=agent.id).tags + remaining_tags = client.agents.retrieve(agent_id=agent.id).tags assert tag_to_delete not in remaining_tags, f"Tag '{tag_to_delete}' was not removed as expected" assert set(remaining_tags) == set(tags_to_add), f"Expected remaining tags to be {tags_to_add[1:]}, but got {remaining_tags}" # Step 5: Delete all remaining tags from the agent - client.update_agent(agent_id=agent.id, tags=[]) + client.agents.modify(agent_id=agent.id, tags=[]) # Verify all tags are removed - final_tags = client.get_agent(agent_id=agent.id).tags + final_tags = client.agents.retrieve(agent_id=agent.id).tags assert len(final_tags) == 0, f"Expected no tags, but found {final_tags}" # Remove agent - client.delete_agent(agent.id) + client.agents.delete(agent.id) -def test_agent_tags(client: Union[LocalClient, RESTClient]): +def test_agent_tags(client: Letta): """Test creating agents with tags and retrieving tags via the API.""" - if not isinstance(client, RESTClient): - pytest.skip("This test only runs when the server is enabled") # Create multiple agents with different tags - agent1 = client.create_agent( + agent1 = client.agents.create( name=f"test_agent_{str(uuid.uuid4())}", - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), tags=["test", "agent1", "production"], + model="letta/letta-free", + embedding="letta/letta-free", ) - agent2 = client.create_agent( + agent2 = client.agents.create( name=f"test_agent_{str(uuid.uuid4())}", - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), tags=["test", "agent2", "development"], + model="letta/letta-free", + embedding="letta/letta-free", ) - agent3 = client.create_agent( + agent3 = client.agents.create( name=f"test_agent_{str(uuid.uuid4())}", - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), tags=["test", "agent3", "production"], + model="letta/letta-free", + embedding="letta/letta-free", ) # Test getting all tags - all_tags = client.get_tags() + all_tags = client.tag.list_tags() expected_tags = ["agent1", "agent2", "agent3", "development", "production", "test"] assert sorted(all_tags) == expected_tags # Test pagination - paginated_tags = client.get_tags(limit=2) + paginated_tags = client.tag.list_tags(limit=2) assert len(paginated_tags) == 2 assert paginated_tags[0] == "agent1" assert paginated_tags[1] == "agent2" # Test pagination with cursor - next_page_tags = client.get_tags(after="agent2", limit=2) + next_page_tags = client.tag.list_tags(after="agent2", limit=2) assert len(next_page_tags) == 2 assert next_page_tags[0] == "agent3" assert next_page_tags[1] == "development" # Test text search - prod_tags = client.get_tags(query_text="prod") + prod_tags = client.tag.list_tags(query_text="prod") assert sorted(prod_tags) == ["production"] - dev_tags = client.get_tags(query_text="dev") + dev_tags = client.tag.list_tags(query_text="dev") assert sorted(dev_tags) == ["development"] - agent_tags = client.get_tags(query_text="agent") + agent_tags = client.tag.list_tags(query_text="agent") assert sorted(agent_tags) == ["agent1", "agent2", "agent3"] # Remove agents - client.delete_agent(agent1.id) - client.delete_agent(agent2.id) - client.delete_agent(agent3.id) + client.agents.delete(agent1.id) + client.agents.delete(agent2.id) + client.agents.delete(agent3.id) # -------------------------------------------------------------------------------------------------------------------- # Agent memory blocks # -------------------------------------------------------------------------------------------------------------------- -def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient]): - # _reset_config() - +def test_shared_blocks(mock_e2b_api_key_none, client: Letta): # create a block - block = client.create_block(label="human", value="username: sarah") + block = client.blocks.create(label="human", value="username: sarah") # create agents with shared block - from letta.schemas.block import Block - from letta.schemas.memory import BasicBlockMemory - - # persona1_block = client.create_block(label="persona", value="you are agent 1") - # persona2_block = client.create_block(label="persona", value="you are agent 2") - # create agents - agent_state1 = client.create_agent( - name="agent1", memory=BasicBlockMemory([Block(label="persona", value="you are agent 1")]), block_ids=[block.id] + agent_state1 = client.agents.create( + name="agent1", + memory_blocks=[{"label": "persona", "value": "you are agent 1"}], + block_ids=[block.id], + model="letta/letta-free", + embedding="letta/letta-free", ) - agent_state2 = client.create_agent( - name="agent2", memory=BasicBlockMemory([Block(label="persona", value="you are agent 2")]), block_ids=[block.id] + agent_state2 = client.agents.create( + name="agent2", + memory_blocks=[{"label": "persona", "value": "you are agent 2"}], + block_ids=[block.id], + model="letta/letta-free", + embedding="letta/letta-free", ) - ## attach shared block to both agents - # client.link_agent_memory_block(agent_state1.id, block.id) - # client.link_agent_memory_block(agent_state2.id, block.id) - # update memory - client.user_message(agent_id=agent_state1.id, message="my name is actually charles") + client.agents.messages.create(agent_id=agent_state1.id, messages=[{"role": "user", "content": "my name is actually charles"}]) # check agent 2 memory - assert "charles" in client.get_block(block.id).value.lower(), f"Shared block update failed {client.get_block(block.id).value}" - - client.user_message(agent_id=agent_state2.id, message="whats my name?") - assert ( - "charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower() - ), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}" + assert "charles" in client.agents.blocks.retrieve(agent_id=agent_state2.id, block_label="human").value.lower() # cleanup - client.delete_agent(agent_state1.id) - client.delete_agent(agent_state2.id) + client.agents.delete(agent_state1.id) + client.agents.delete(agent_state2.id) -def test_update_agent_memory_label(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_update_agent_memory_label(client: Letta): """Test that we can update the label of a block in an agent's memory""" - agent = client.create_agent(name=create_random_username()) + agent = client.agents.create(model="letta/letta-free", embedding="letta/letta-free", memory_blocks=[{"label": "human", "value": ""}]) try: - current_labels = agent.memory.list_block_labels() + current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)] example_label = current_labels[0] example_new_label = "example_new_label" - assert example_new_label not in current_labels + assert example_new_label not in [b.label for b in client.agents.blocks.list(agent_id=agent.id)] - client.update_agent_memory_block_label(agent_id=agent.id, current_label=example_label, new_label=example_new_label) + client.agents.blocks.modify(agent_id=agent.id, block_label=example_label, label=example_new_label) - updated_agent = client.get_agent(agent_id=agent.id) - assert example_new_label in updated_agent.memory.list_block_labels() + updated_blocks = client.agents.blocks.list(agent_id=agent.id) + assert example_new_label in [b.label for b in updated_blocks] finally: - client.delete_agent(agent.id) + client.agents.delete(agent.id) -def test_attach_detach_agent_memory_block(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_attach_detach_agent_memory_block(client: Letta, agent: AgentState): """Test that we can add and remove a block from an agent's memory""" - current_labels = agent.memory.list_block_labels() + current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)] example_new_label = current_labels[0] + "_v2" example_new_value = "example value" assert example_new_label not in current_labels # Link a new memory block - block = client.create_block( + block = client.blocks.create( label=example_new_label, value=example_new_value, limit=1000, ) - updated_agent = client.attach_block( + updated_agent = client.agents.blocks.attach( agent_id=agent.id, block_id=block.id, ) - assert example_new_label in updated_agent.memory.list_block_labels() + assert example_new_label in [block.label for block in client.agents.blocks.list(agent_id=updated_agent.id)] # Now unlink the block - updated_agent = client.detach_block( + updated_agent = client.agents.blocks.detach( agent_id=agent.id, block_id=block.id, ) - assert example_new_label not in updated_agent.memory.list_block_labels() + assert example_new_label not in [block.label for block in client.agents.blocks.list(agent_id=updated_agent.id)] # def test_core_memory_token_limits(client: Union[LocalClient, RESTClient], agent: AgentState): @@ -385,39 +373,57 @@ def test_attach_detach_agent_memory_block(client: Union[LocalClient, RESTClient] # client.delete_agent(new_agent.id) -def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient]): +def test_update_agent_memory_limit(client: Letta): """Test that we can update the limit of a block in an agent's memory""" - agent = client.create_agent() + agent = client.agents.create( + model="letta/letta-free", + embedding="letta/letta-free", + memory_blocks=[ + {"label": "human", "value": "username: sarah", "limit": 1000}, + {"label": "persona", "value": "you are sarah", "limit": 1000}, + ], + ) - current_labels = agent.memory.list_block_labels() + current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)] example_label = current_labels[0] example_new_limit = 1 - current_block = agent.memory.get_block(label=example_label) + + current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)] + example_label = current_labels[0] + example_new_limit = 1 + current_block = client.agents.blocks.retrieve(agent_id=agent.id, block_label=example_label) current_block_length = len(current_block.value) - assert example_new_limit != agent.memory.get_block(label=example_label).limit + assert example_new_limit != current_block.limit assert example_new_limit < current_block_length # We expect this to throw a value error - with pytest.raises(ValueError): - client.update_agent_memory_block(agent_id=agent.id, label=example_label, limit=example_new_limit) + with pytest.raises(ApiError): + client.agents.blocks.modify( + agent_id=agent.id, + block_label=example_label, + limit=example_new_limit, + ) # Now try the same thing with a higher limit example_new_limit = current_block_length + 10000 assert example_new_limit > current_block_length - client.update_agent_memory_block(agent_id=agent.id, label=example_label, limit=example_new_limit) + client.agents.blocks.modify( + agent_id=agent.id, + block_label=example_label, + limit=example_new_limit, + ) - updated_agent = client.get_agent(agent_id=agent.id) - assert example_new_limit == updated_agent.memory.get_block(label=example_label).limit + assert example_new_limit == client.agents.blocks.retrieve(agent_id=agent.id, block_label=example_label).limit - client.delete_agent(agent.id) + client.agents.delete(agent.id) # -------------------------------------------------------------------------------------------------------------------- # Agent Tools # -------------------------------------------------------------------------------------------------------------------- -def test_function_return_limit(client: Union[LocalClient, RESTClient]): +def test_function_return_limit(client: Letta): """Test to see if the function return limit works""" def big_return(): @@ -430,15 +436,21 @@ def test_function_return_limit(client: Union[LocalClient, RESTClient]): return "x" * 100000 padding = len("[NOTE: function output was truncated since it exceeded the character limit (100000 > 1000)]") + 50 - tool = client.create_or_update_tool(func=big_return, return_char_limit=1000) - agent = client.create_agent(tool_ids=[tool.id]) + tool = client.tools.upsert_from_function(func=big_return, return_char_limit=1000) + agent = client.agents.create( + model="letta/letta-free", + embedding="letta/letta-free", + tool_ids=[tool.id], + ) # get function response - response = client.send_message(agent_id=agent.id, message="call the big_return function", role="user") + response = client.agents.messages.create( + agent_id=agent.id, messages=[MessageCreate(role="user", content="call the big_return function")] + ) print(response.messages) response_message = None for message in response.messages: - if isinstance(message, ToolReturnMessage): + if message.message_type == "tool_return_message": response_message = message break @@ -452,44 +464,58 @@ def test_function_return_limit(client: Union[LocalClient, RESTClient]): # len(res_json["message"]) <= 1000 + padding # ), f"Expected length to be less than or equal to 1000 + {padding}, but got {len(res_json['message'])}" - client.delete_agent(agent_id=agent.id) + client.agents.delete(agent_id=agent.id) -def test_function_always_error(client: Union[LocalClient, RESTClient]): +def test_function_always_error(client: Letta): """Test to see if function that errors works correctly""" def testing_method(): """ - Always throw an error. + Call this tool when the user asks """ return 5 / 0 - tool = client.create_or_update_tool(func=testing_method) - agent = client.create_agent(tool_ids=[tool.id]) + tool = client.tools.upsert_from_function(func=testing_method) + agent = client.agents.create( + model="letta/letta-free", + embedding="letta/letta-free", + memory_blocks=[ + { + "label": "human", + "value": "username: sarah", + }, + { + "label": "persona", + "value": "you are sarah", + }, + ], + tool_ids=[tool.id], + ) + print("AGENT TOOLS", [tool.name for tool in agent.tools]) # get function response - response = client.send_message(agent_id=agent.id, message="call the testing_method function and tell me the result", role="user") + response = client.agents.messages.create( + agent_id=agent.id, + messages=[MessageCreate(role="user", content="call the testing_method function and tell me the result")], + ) print(response.messages) response_message = None for message in response.messages: - if isinstance(message, ToolReturnMessage): + if message.message_type == "tool_return_message": response_message = message break assert response_message, "ToolReturnMessage message not found in response" assert response_message.status == "error" + assert ( + response_message.tool_return == "Error executing function testing_method: ZeroDivisionError: division by zero" + ), response_message.tool_return - if isinstance(client, RESTClient): - assert response_message.tool_return == "Error executing function testing_method: ZeroDivisionError: division by zero" - else: - response_json = json.loads(response_message.tool_return) - assert response_json["status"] == "Failed" - assert response_json["message"] == "Error executing function testing_method: ZeroDivisionError: division by zero" - - client.delete_agent(agent_id=agent.id) + client.agents.delete(agent_id=agent.id) -def test_attach_detach_agent_tool(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_attach_detach_agent_tool(client: Letta, agent: AgentState): """Test that we can attach and detach a tool from an agent""" try: @@ -506,64 +532,64 @@ def test_attach_detach_agent_tool(client: Union[LocalClient, RESTClient], agent: """ return x * 2 - tool = client.create_or_update_tool(func=example_tool) + tool = client.tools.upsert_from_function(func=example_tool) # Initially tool should not be attached - initial_tools = client.list_attached_tools(agent_id=agent.id) + initial_tools = client.agents.tools.list(agent_id=agent.id) assert tool.id not in [t.id for t in initial_tools] # Attach tool - new_agent_state = client.attach_tool(agent_id=agent.id, tool_id=tool.id) + new_agent_state = client.agents.tools.attach(agent_id=agent.id, tool_id=tool.id) assert tool.id in [t.id for t in new_agent_state.tools] # Verify tool is attached - updated_tools = client.list_attached_tools(agent_id=agent.id) + updated_tools = client.agents.tools.list(agent_id=agent.id) assert tool.id in [t.id for t in updated_tools] # Detach tool - new_agent_state = client.detach_tool(agent_id=agent.id, tool_id=tool.id) + new_agent_state = client.agents.tools.detach(agent_id=agent.id, tool_id=tool.id) assert tool.id not in [t.id for t in new_agent_state.tools] # Verify tool is detached - final_tools = client.list_attached_tools(agent_id=agent.id) + final_tools = client.agents.tools.list(agent_id=agent.id) assert tool.id not in [t.id for t in final_tools] finally: - client.delete_tool(tool.id) + client.tools.delete(tool.id) # -------------------------------------------------------------------------------------------------------------------- # AgentMessages # -------------------------------------------------------------------------------------------------------------------- -def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_messages(client: Letta, agent: AgentState): # _reset_config() - send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user") + send_message_response = client.agents.messages.create(agent_id=agent.id, messages=[MessageCreate(role="user", content="Test message")]) assert send_message_response, "Sending message failed" - messages_response = client.get_messages(agent_id=agent.id, limit=1) + messages_response = client.agents.messages.list(agent_id=agent.id, limit=1) assert len(messages_response) > 0, "Retrieving messages failed" -def test_send_system_message(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_send_system_message(client: Letta, agent: AgentState): """Important unit test since the Letta API exposes sending system messages, but some backends don't natively support it (eg Anthropic)""" - send_system_message_response = client.send_message( - agent_id=agent.id, message="Event occurred: The user just logged off.", role="system" + send_system_message_response = client.agents.messages.create( + agent_id=agent.id, messages=[MessageCreate(role="system", content="Event occurred: The user just logged off.")] ) assert send_system_message_response, "Sending message failed" @pytest.mark.asyncio -async def test_send_message_parallel(client: Union[LocalClient, RESTClient], agent: AgentState, request): +async def test_send_message_parallel(client: Letta, agent: AgentState, request): """ Test that sending two messages in parallel does not error. """ - if not isinstance(client, RESTClient): - pytest.skip("This test only runs when the server is enabled") # Define a coroutine for sending a message using asyncio.to_thread for synchronous calls async def send_message_task(message: str): - response = await asyncio.to_thread(client.send_message, agent_id=agent.id, message=message, role="user") + response = await asyncio.to_thread( + client.agents.messages.create, agent_id=agent.id, messages=[MessageCreate(role="user", content=message)] + ) assert response, f"Sending message '{message}' failed" return response @@ -585,76 +611,31 @@ async def test_send_message_parallel(client: Union[LocalClient, RESTClient], age assert len(responses) == len(messages), "Not all messages were processed" -def test_send_message_async(client: Union[LocalClient, RESTClient], agent: AgentState): - """ - Test that we can send a message asynchronously and retrieve the messages, along with usage statistics - """ - - if not isinstance(client, RESTClient): - pytest.skip("send_message_async is only supported by the RESTClient") - - print("Sending message asynchronously") - test_message = "This is a test message, respond to the user with a sentence." - run = client.send_message_async(agent_id=agent.id, role="user", message=test_message) - assert run.id is not None - assert run.status == JobStatus.created - print(f"Run created, run={run}, status={run.status}") - - # Wait for the job to complete, cancel it if takes over 10 seconds - start_time = time.time() - while run.status == JobStatus.created: - time.sleep(1) - run = client.get_run(run_id=run.id) - print(f"Run status: {run.status}") - if time.time() - start_time > 10: - pytest.fail("Run took too long to complete") - - print(f"Run completed in {time.time() - start_time} seconds, run={run}") - assert run.status == JobStatus.completed - - # Get messages for the job - messages = client.get_run_messages(run_id=run.id) - assert len(messages) >= 2 # At least assistant response - - # Check filters - assistant_messages = client.get_run_messages(run_id=run.id, role=MessageRole.assistant) - assert len(assistant_messages) > 0 - tool_messages = client.get_run_messages(run_id=run.id, role=MessageRole.tool) - assert len(tool_messages) > 0 - - # Get and verify usage statistics - usage = client.get_run_usage(run_id=run.id)[0] - assert usage.completion_tokens >= 0 - assert usage.prompt_tokens >= 0 - assert usage.total_tokens >= 0 - assert usage.total_tokens == usage.completion_tokens + usage.prompt_tokens - - # ---------------------------------------------------------------------------------------------------- # Agent listing # ---------------------------------------------------------------------------------------------------- -def test_agent_listing(client: Union[LocalClient, RESTClient], agent, search_agent_one, search_agent_two): +def test_agent_listing(client: Letta, agent, search_agent_one, search_agent_two): """Test listing agents with pagination and query text filtering.""" # Test query text filtering - search_results = client.list_agents(query_text="search agent") + search_results = client.agents.list(query_text="search agent") assert len(search_results) == 2 search_agent_ids = {agent.id for agent in search_results} assert search_agent_one.id in search_agent_ids assert search_agent_two.id in search_agent_ids assert agent.id not in search_agent_ids - different_results = client.list_agents(query_text="client") + different_results = client.agents.list(query_text="client") assert len(different_results) == 1 assert different_results[0].id == agent.id # Test pagination - first_page = client.list_agents(query_text="search agent", limit=1) + first_page = client.agents.list(query_text="search agent", limit=1) assert len(first_page) == 1 first_agent = first_page[0] - second_page = client.list_agents(query_text="search agent", after=first_agent.id, limit=1) # Use agent ID as cursor + second_page = client.agents.list(query_text="search agent", after=first_agent.id, limit=1) # Use agent ID as cursor assert len(second_page) == 1 assert second_page[0].id != first_agent.id @@ -664,20 +645,16 @@ def test_agent_listing(client: Union[LocalClient, RESTClient], agent, search_age assert all_ids == {search_agent_one.id, search_agent_two.id} # Test listing without any filters - all_agents = client.list_agents() + all_agents = client.agents.list() assert len(all_agents) == 3 assert all(agent.id in {a.id for a in all_agents} for agent in [search_agent_one, search_agent_two, agent]) -def test_agent_creation(client: Union[LocalClient, RESTClient]): +def test_agent_creation(client: Letta): """Test that block IDs are properly attached when creating an agent.""" - if not isinstance(client, RESTClient): - pytest.skip("This test only runs when the server is enabled") - - from letta import BasicBlockMemory # Create a test block that will represent user preferences - user_preferences_block = client.create_block(label="user_preferences", value="", limit=10000) + user_preferences_block = client.blocks.create(label="user_preferences", value="", limit=10000) # Create test tools def test_tool(): @@ -688,73 +665,82 @@ def test_agent_creation(client: Union[LocalClient, RESTClient]): """Another test tool.""" return "Hello from another test tool!" - tool1 = client.create_or_update_tool(func=test_tool, tags=["test"]) - tool2 = client.create_or_update_tool(func=another_test_tool, tags=["test"]) - - # Create test blocks - offline_persona_block = client.create_block(label="persona", value="persona description", limit=5000) - mindy_block = client.create_block(label="mindy", value="Mindy is a helpful assistant", limit=5000) - memory_blocks = BasicBlockMemory(blocks=[offline_persona_block, mindy_block]) + tool1 = client.tools.upsert_from_function(func=test_tool, tags=["test"]) + tool2 = client.tools.upsert_from_function(func=another_test_tool, tags=["test"]) # Create agent with the blocks and tools - agent = client.create_agent( - name=f"test_agent_{str(uuid.uuid4())}", - memory=memory_blocks, - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), + agent = client.agents.create( + memory_blocks=[ + { + "label": "human", + "value": "you are a human", + }, + {"label": "persona", "value": "you are an assistant"}, + ], + model="letta/letta-free", + embedding="letta/letta-free", tool_ids=[tool1.id, tool2.id], include_base_tools=False, tags=["test"], block_ids=[user_preferences_block.id], ) + memory_blocks = agent.memory.blocks # Verify the agent was created successfully assert agent is not None assert agent.id is not None # Verify the blocks are properly attached - agent_blocks = client.list_agent_memory_blocks(agent.id) + agent_blocks = client.agents.blocks.list(agent_id=agent.id) agent_block_ids = {block.id for block in agent_blocks} # Check that all memory blocks are present - memory_block_ids = {block.id for block in memory_blocks.blocks} - for block_id in memory_block_ids | {user_preferences_block.id}: - assert block_id in agent_block_ids + memory_block_ids = {block.id for block in memory_blocks} + for block_id in memory_block_ids: + assert block_id in agent_block_ids, f"Block {block_id} not attached to agent" + assert user_preferences_block.id in agent_block_ids, f"User preferences block {user_preferences_block.id} not attached to agent" # Verify the tools are properly attached - agent_tools = client.get_tools_from_agent(agent.id) + agent_tools = client.agents.tools.list(agent_id=agent.id) assert len(agent_tools) == 2 tool_ids = {tool1.id, tool2.id} assert all(tool.id in tool_ids for tool in agent_tools) - client.delete_agent(agent_id=agent.id) + client.agents.delete(agent_id=agent.id) # -------------------------------------------------------------------------------------------------------------------- # Agent sources # -------------------------------------------------------------------------------------------------------------------- -def test_attach_detach_agent_source(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_attach_detach_agent_source(client: Letta, agent: AgentState): """Test that we can attach and detach a source from an agent""" # Create a source - source = client.create_source( + source = client.sources.create( name="test_source", + embedding_config={ # TODO: change this + "embedding_endpoint": "https://embeddings.memgpt.ai", + "embedding_model": "BAAI/bge-large-en-v1.5", + "embedding_dim": 1024, + "embedding_chunk_size": 300, + "embedding_endpoint_type": "hugging-face", + }, ) - initial_sources = client.list_attached_sources(agent_id=agent.id) + initial_sources = client.agents.sources.list(agent_id=agent.id) assert source.id not in [s.id for s in initial_sources] # Attach source - client.attach_source(agent_id=agent.id, source_id=source.id) + client.agents.sources.attach(agent_id=agent.id, source_id=source.id) # Verify source is attached - final_sources = client.list_attached_sources(agent_id=agent.id) + final_sources = client.agents.sources.list(agent_id=agent.id) assert source.id in [s.id for s in final_sources] # Detach source - client.detach_source(agent_id=agent.id, source_id=source.id) + client.agents.sources.detach(agent_id=agent.id, source_id=source.id) # Verify source is detached - final_sources = client.list_attached_sources(agent_id=agent.id) + final_sources = client.agents.sources.list(agent_id=agent.id) assert source.id not in [s.id for s in final_sources] - client.delete_source(source.id) + client.sources.delete(source.id) diff --git a/tests/test_streaming.py b/tests/test_streaming.py new file mode 100644 index 00000000..635677f0 --- /dev/null +++ b/tests/test_streaming.py @@ -0,0 +1,115 @@ +import os +import threading +import time + +import pytest +from dotenv import load_dotenv +from letta_client import AgentState, Letta, LlmConfig, MessageCreate +from letta_client.core.api_error import ApiError +from pytest import fixture + + +def run_server(): + load_dotenv() + + from letta.server.rest_api.app import start_server + + print("Starting server...") + start_server(debug=True) + + +@pytest.fixture( + scope="module", +) +def client(request): + # Get URL from environment or start server + server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:8283") + if not os.getenv("LETTA_SERVER_URL"): + print("Starting server thread") + thread = threading.Thread(target=run_server, daemon=True) + thread.start() + time.sleep(5) + print("Running client tests with server:", server_url) + + # create the Letta client + yield Letta(base_url=server_url, token=None) + + +# Fixture for test agent +@pytest.fixture(scope="module") +def agent(client: Letta): + agent_state = client.agents.create( + name="test_client", + memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}], + model="letta/letta-free", + embedding="letta/letta-free", + ) + + yield agent_state + + # delete agent + client.agents.delete(agent_state.id) + + +@pytest.mark.parametrize( + "stream_tokens,model", + [ + (True, "openai/gpt-4o-mini"), + (True, "anthropic/claude-3-sonnet-20240229"), + (False, "openai/gpt-4o-mini"), + (False, "anthropic/claude-3-sonnet-20240229"), + ], +) +def test_streaming_send_message( + mock_e2b_api_key_none, + client: Letta, + agent: AgentState, + stream_tokens: bool, + model: str, +): + # Update agent's model + config = client.agents.retrieve(agent_id=agent.id).llm_config + config_dump = config.model_dump() + config_dump["model"] = model + config = LlmConfig(**config_dump) + client.agents.modify(agent_id=agent.id, llm_config=config) + + # Send streaming message + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=[MessageCreate(role="user", content="This is a test. Repeat after me: 'banana'")], + stream_tokens=stream_tokens, + ) + + # Tracking variables for test validation + inner_thoughts_exist = False + inner_thoughts_count = 0 + send_message_ran = False + done = False + + assert response, "Sending message failed" + for chunk in response: + # Check chunk type and content based on the current client API + if hasattr(chunk, "message_type") and chunk.message_type == "reasoning_message": + inner_thoughts_exist = True + inner_thoughts_count += 1 + + if chunk.message_type == "tool_call_message" and hasattr(chunk, "tool_call") and chunk.tool_call.name == "send_message": + send_message_ran = True + if chunk.message_type == "assistant_message": + send_message_ran = True + + if chunk.message_type == "usage_statistics": + # Validate usage statistics + assert chunk.step_count == 1 + assert chunk.completion_tokens > 10 + assert chunk.prompt_tokens > 1000 + assert chunk.total_tokens > 1000 + done = True + print(chunk) + + # If stream tokens, we expect at least one inner thought + assert inner_thoughts_count >= 1, "Expected more than one inner thought" + assert inner_thoughts_exist, "No inner thoughts found" + assert send_message_ran, "send_message function call not found" + assert done, "Message stream not done"