feat: update tests to new client (#2488)
This commit is contained in:
@@ -32,7 +32,8 @@ jobs:
|
||||
LETTA_PG_PORT: 8888
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
run: docker compose -f dev-compose.yaml up --build -d
|
||||
run: |
|
||||
docker compose -f dev-compose.yaml up --build -d
|
||||
#- name: "Setup Python, Poetry and Dependencies"
|
||||
# uses: packetcoders/action-setup-cache-python-poetry@v1.2.0
|
||||
# with:
|
||||
@@ -56,7 +57,8 @@ jobs:
|
||||
run: |
|
||||
pipx install poetry==1.8.2
|
||||
poetry install -E dev -E postgres
|
||||
poetry run pytest -s tests/test_client_legacy.py
|
||||
poetry run pytest -s tests/test_client.py
|
||||
# poetry run pytest -s tests/test_client_legacy.py
|
||||
|
||||
- name: Print docker logs if tests fail
|
||||
if: failure()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.6.39"
|
||||
__version__ = "0.6.40"
|
||||
|
||||
# import clients
|
||||
from letta.client.client import LocalClient, RESTClient, create_client
|
||||
|
||||
@@ -22,11 +22,11 @@ from letta.errors import ContextWindowExceededError
|
||||
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
|
||||
from letta.functions.functions import get_function_from_module
|
||||
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name
|
||||
from letta.functions.mcp_client.base_client import BaseMCPClient
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.helpers.mcp_helpers import BaseMCPClient
|
||||
from letta.interface import AgentInterface
|
||||
from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
|
||||
0
letta/functions/mcp_client/__init__.py
Normal file
0
letta/functions/mcp_client/__init__.py
Normal file
61
letta/functions/mcp_client/base_client.py
Normal file
61
letta/functions/mcp_client/base_client.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import asyncio
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from mcp import ClientSession, Tool
|
||||
|
||||
from letta.functions.mcp_client.types import BaseServerConfig
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BaseMCPClient:
|
||||
def __init__(self):
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.stdio = None
|
||||
self.write = None
|
||||
self.initialized = False
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.cleanup_funcs = []
|
||||
|
||||
def connect_to_server(self, server_config: BaseServerConfig):
|
||||
asyncio.set_event_loop(self.loop)
|
||||
success = self._initialize_connection(server_config)
|
||||
|
||||
if success:
|
||||
self.loop.run_until_complete(self.session.initialize())
|
||||
self.initialized = True
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Connecting to MCP server failed. Please review your server config: {server_config.model_dump_json(indent=4)}"
|
||||
)
|
||||
|
||||
def _initialize_connection(self, server_config: BaseServerConfig) -> bool:
|
||||
raise NotImplementedError("Subclasses must implement _initialize_connection")
|
||||
|
||||
def list_tools(self) -> List[Tool]:
|
||||
self._check_initialized()
|
||||
response = self.loop.run_until_complete(self.session.list_tools())
|
||||
return response.tools
|
||||
|
||||
def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]:
|
||||
self._check_initialized()
|
||||
result = self.loop.run_until_complete(self.session.call_tool(tool_name, tool_args))
|
||||
return str(result.content), result.isError
|
||||
|
||||
def _check_initialized(self):
|
||||
if not self.initialized:
|
||||
logger.error("MCPClient has not been initialized")
|
||||
raise RuntimeError("MCPClient has not been initialized")
|
||||
|
||||
def cleanup(self):
|
||||
try:
|
||||
for cleanup_func in self.cleanup_funcs:
|
||||
cleanup_func()
|
||||
self.initialized = False
|
||||
if not self.loop.is_closed():
|
||||
self.loop.close()
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
finally:
|
||||
logger.info("Cleaned up MCP clients on shutdown.")
|
||||
21
letta/functions/mcp_client/sse_client.py
Normal file
21
letta/functions/mcp_client/sse_client.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from letta.functions.mcp_client.base_client import BaseMCPClient
|
||||
from letta.functions.mcp_client.types import SSEServerConfig
|
||||
|
||||
# see: https://modelcontextprotocol.io/quickstart/user
|
||||
MCP_CONFIG_TOPLEVEL_KEY = "mcpServers"
|
||||
|
||||
|
||||
class SSEMCPClient(BaseMCPClient):
|
||||
def _initialize_connection(self, server_config: SSEServerConfig) -> bool:
|
||||
sse_cm = sse_client(url=server_config.server_url)
|
||||
sse_transport = self.loop.run_until_complete(sse_cm.__aenter__())
|
||||
self.stdio, self.write = sse_transport
|
||||
self.cleanup_funcs.append(lambda: self.loop.run_until_complete(sse_cm.__aexit__(None, None, None)))
|
||||
|
||||
session_cm = ClientSession(self.stdio, self.write)
|
||||
self.session = self.loop.run_until_complete(session_cm.__aenter__())
|
||||
self.cleanup_funcs.append(lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None)))
|
||||
return True
|
||||
103
letta/functions/mcp_client/stdio_client.py
Normal file
103
letta/functions/mcp_client/stdio_client.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
import mcp.types as types
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import get_default_environment
|
||||
|
||||
from letta.functions.mcp_client.base_client import BaseMCPClient
|
||||
from letta.functions.mcp_client.types import StdioServerConfig
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class StdioMCPClient(BaseMCPClient):
|
||||
def _initialize_connection(self, server_config: StdioServerConfig) -> bool:
|
||||
try:
|
||||
server_params = StdioServerParameters(command=server_config.command, args=server_config.args)
|
||||
stdio_cm = forked_stdio_client(server_params)
|
||||
stdio_transport = self.loop.run_until_complete(stdio_cm.__aenter__())
|
||||
self.stdio, self.write = stdio_transport
|
||||
self.cleanup_funcs.append(lambda: self.loop.run_until_complete(stdio_cm.__aexit__(None, None, None)))
|
||||
|
||||
session_cm = ClientSession(self.stdio, self.write)
|
||||
self.session = self.loop.run_until_complete(session_cm.__aenter__())
|
||||
self.cleanup_funcs.append(lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None)))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def forked_stdio_client(server: StdioServerParameters):
|
||||
"""
|
||||
Client transport for stdio: this will connect to a server by spawning a
|
||||
process and communicating with it over stdin/stdout.
|
||||
"""
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
try:
|
||||
process = await anyio.open_process(
|
||||
[server.command, *server.args],
|
||||
env=server.env or get_default_environment(),
|
||||
stderr=sys.stderr, # Consider logging stderr somewhere instead of silencing it
|
||||
)
|
||||
except OSError as exc:
|
||||
raise RuntimeError(f"Failed to spawn process: {server.command} {server.args}") from exc
|
||||
|
||||
async def stdout_reader():
|
||||
assert process.stdout, "Opened process is missing stdout"
|
||||
buffer = ""
|
||||
try:
|
||||
async with read_stream_writer:
|
||||
async for chunk in TextReceiveStream(
|
||||
process.stdout,
|
||||
encoding=server.encoding,
|
||||
errors=server.encoding_error_handler,
|
||||
):
|
||||
lines = (buffer + chunk).split("\n")
|
||||
buffer = lines.pop()
|
||||
for line in lines:
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(line)
|
||||
except Exception as exc:
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
await read_stream_writer.send(message)
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async def stdin_writer():
|
||||
assert process.stdin, "Opened process is missing stdin"
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
json = message.model_dump_json(by_alias=True, exclude_none=True)
|
||||
await process.stdin.send(
|
||||
(json + "\n").encode(
|
||||
encoding=server.encoding,
|
||||
errors=server.encoding_error_handler,
|
||||
)
|
||||
)
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async def watch_process_exit():
|
||||
returncode = await process.wait()
|
||||
if returncode != 0:
|
||||
raise RuntimeError(f"Subprocess exited with code {returncode}. Command: {server.command} {server.args}")
|
||||
|
||||
async with anyio.create_task_group() as tg, process:
|
||||
tg.start_soon(stdout_reader)
|
||||
tg.start_soon(stdin_writer)
|
||||
tg.start_soon(watch_process_exit)
|
||||
|
||||
with anyio.move_on_after(0.2):
|
||||
await anyio.sleep_forever()
|
||||
|
||||
yield read_stream, write_stream
|
||||
48
letta/functions/mcp_client/types.py
Normal file
48
letta/functions/mcp_client/types.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from mcp import Tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
"""A simple wrapper around MCP's tool definition (to avoid conflict with our own)"""
|
||||
|
||||
|
||||
class MCPServerType(str, Enum):
|
||||
SSE = "sse"
|
||||
STDIO = "stdio"
|
||||
|
||||
|
||||
class BaseServerConfig(BaseModel):
|
||||
server_name: str = Field(..., description="The name of the server")
|
||||
type: MCPServerType
|
||||
|
||||
|
||||
class SSEServerConfig(BaseServerConfig):
|
||||
type: MCPServerType = MCPServerType.SSE
|
||||
server_url: str = Field(..., description="The URL of the server (MCP SSE client will connect to this URL)")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
values = {
|
||||
"transport": "sse",
|
||||
"url": self.server_url,
|
||||
}
|
||||
return values
|
||||
|
||||
|
||||
class StdioServerConfig(BaseServerConfig):
|
||||
type: MCPServerType = MCPServerType.STDIO
|
||||
command: str = Field(..., description="The command to run (MCP 'local' client will run this command)")
|
||||
args: List[str] = Field(..., description="The arguments to pass to the command")
|
||||
env: Optional[dict[str, str]] = Field(None, description="Environment variables to set")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
values = {
|
||||
"transport": "stdio",
|
||||
"command": self.command,
|
||||
"args": self.args,
|
||||
}
|
||||
if self.env is not None:
|
||||
values["env"] = self.env
|
||||
return values
|
||||
@@ -6,7 +6,7 @@ from composio.client.collections import ActionParametersModel
|
||||
from docstring_parser import parse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from letta.helpers.mcp_helpers import MCPTool
|
||||
from letta.functions.mcp_client.types import MCPTool
|
||||
|
||||
|
||||
def is_optional(annotation):
|
||||
|
||||
@@ -1,129 +0,0 @@
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters, Tool
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# see: https://modelcontextprotocol.io/quickstart/user
|
||||
MCP_CONFIG_TOPLEVEL_KEY = "mcpServers"
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
"""A simple wrapper around MCP's tool definition (to avoid conflict with our own)"""
|
||||
|
||||
|
||||
class MCPServerType(str, Enum):
|
||||
SSE = "sse"
|
||||
STDIO = "stdio"
|
||||
|
||||
|
||||
class BaseServerConfig(BaseModel):
|
||||
server_name: str = Field(..., description="The name of the server")
|
||||
type: MCPServerType
|
||||
|
||||
|
||||
class SSEServerConfig(BaseServerConfig):
|
||||
type: MCPServerType = MCPServerType.SSE
|
||||
server_url: str = Field(..., description="The URL of the server (MCP SSE client will connect to this URL)")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
values = {
|
||||
"transport": "sse",
|
||||
"url": self.server_url,
|
||||
}
|
||||
return values
|
||||
|
||||
|
||||
class StdioServerConfig(BaseServerConfig):
|
||||
type: MCPServerType = MCPServerType.STDIO
|
||||
command: str = Field(..., description="The command to run (MCP 'local' client will run this command)")
|
||||
args: List[str] = Field(..., description="The arguments to pass to the command")
|
||||
env: Optional[dict[str, str]] = Field(None, description="Environment variables to set")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
values = {
|
||||
"transport": "stdio",
|
||||
"command": self.command,
|
||||
"args": self.args,
|
||||
}
|
||||
if self.env is not None:
|
||||
values["env"] = self.env
|
||||
return values
|
||||
|
||||
|
||||
class BaseMCPClient:
|
||||
def __init__(self):
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.stdio = None
|
||||
self.write = None
|
||||
self.initialized = False
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.cleanup_funcs = []
|
||||
|
||||
def connect_to_server(self, server_config: BaseServerConfig):
|
||||
asyncio.set_event_loop(self.loop)
|
||||
self._initialize_connection(server_config)
|
||||
self.loop.run_until_complete(self.session.initialize())
|
||||
self.initialized = True
|
||||
|
||||
def _initialize_connection(self, server_config: BaseServerConfig):
|
||||
raise NotImplementedError("Subclasses must implement _initialize_connection")
|
||||
|
||||
def list_tools(self) -> List[Tool]:
|
||||
self._check_initialized()
|
||||
response = self.loop.run_until_complete(self.session.list_tools())
|
||||
return response.tools
|
||||
|
||||
def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]:
|
||||
self._check_initialized()
|
||||
result = self.loop.run_until_complete(self.session.call_tool(tool_name, tool_args))
|
||||
return str(result.content), result.isError
|
||||
|
||||
def _check_initialized(self):
|
||||
if not self.initialized:
|
||||
logger.error("MCPClient has not been initialized")
|
||||
raise RuntimeError("MCPClient has not been initialized")
|
||||
|
||||
def cleanup(self):
|
||||
try:
|
||||
for cleanup_func in self.cleanup_funcs:
|
||||
cleanup_func()
|
||||
self.initialized = False
|
||||
if not self.loop.is_closed():
|
||||
self.loop.close()
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
finally:
|
||||
logger.info("Cleaned up MCP clients on shutdown.")
|
||||
|
||||
|
||||
class StdioMCPClient(BaseMCPClient):
|
||||
def _initialize_connection(self, server_config: StdioServerConfig):
|
||||
server_params = StdioServerParameters(command=server_config.command, args=server_config.args)
|
||||
stdio_cm = stdio_client(server_params)
|
||||
stdio_transport = self.loop.run_until_complete(stdio_cm.__aenter__())
|
||||
self.stdio, self.write = stdio_transport
|
||||
self.cleanup_funcs.append(lambda: self.loop.run_until_complete(stdio_cm.__aexit__(None, None, None)))
|
||||
|
||||
session_cm = ClientSession(self.stdio, self.write)
|
||||
self.session = self.loop.run_until_complete(session_cm.__aenter__())
|
||||
self.cleanup_funcs.append(lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None)))
|
||||
|
||||
|
||||
class SSEMCPClient(BaseMCPClient):
|
||||
def _initialize_connection(self, server_config: SSEServerConfig):
|
||||
sse_cm = sse_client(url=server_config.server_url)
|
||||
sse_transport = self.loop.run_until_complete(sse_cm.__aenter__())
|
||||
self.stdio, self.write = sse_transport
|
||||
self.cleanup_funcs.append(lambda: self.loop.run_until_complete(sse_cm.__aexit__(None, None, None)))
|
||||
|
||||
session_cm = ClientSession(self.stdio, self.write)
|
||||
self.session = self.loop.run_until_complete(session_cm.__aenter__())
|
||||
self.cleanup_funcs.append(lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None)))
|
||||
@@ -17,12 +17,12 @@ from letta.functions.helpers import (
|
||||
generate_mcp_tool_wrapper,
|
||||
generate_model_from_args_json_schema,
|
||||
)
|
||||
from letta.functions.mcp_client.types import MCPTool
|
||||
from letta.functions.schema_generator import (
|
||||
generate_schema_from_args_schema_v2,
|
||||
generate_tool_schema_for_composio,
|
||||
generate_tool_schema_for_mcp,
|
||||
)
|
||||
from letta.helpers.mcp_helpers import MCPTool
|
||||
from letta.log import get_logger
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
@@ -12,8 +12,8 @@ from composio.exceptions import (
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException
|
||||
|
||||
from letta.errors import LettaToolCreateError
|
||||
from letta.functions.mcp_client.types import MCPTool, SSEServerConfig, StdioServerConfig
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.helpers.mcp_helpers import MCPTool, SSEServerConfig, StdioServerConfig
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import UniqueConstraintViolationError
|
||||
from letta.schemas.letta_message import ToolReturnMessage
|
||||
|
||||
@@ -20,18 +20,12 @@ from letta.agent import Agent, save_agent
|
||||
from letta.config import LettaConfig
|
||||
from letta.data_sources.connectors import DataConnector, load_data
|
||||
from letta.dynamic_multi_agent import DynamicMultiAgent
|
||||
from letta.functions.mcp_client.base_client import BaseMCPClient
|
||||
from letta.functions.mcp_client.sse_client import MCP_CONFIG_TOPLEVEL_KEY, SSEMCPClient
|
||||
from letta.functions.mcp_client.stdio_client import StdioMCPClient
|
||||
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.helpers.mcp_helpers import (
|
||||
MCP_CONFIG_TOPLEVEL_KEY,
|
||||
BaseMCPClient,
|
||||
MCPServerType,
|
||||
MCPTool,
|
||||
SSEMCPClient,
|
||||
SSEServerConfig,
|
||||
StdioMCPClient,
|
||||
StdioServerConfig,
|
||||
)
|
||||
|
||||
# TODO use custom interface
|
||||
from letta.interface import AgentInterface # abstract
|
||||
@@ -343,11 +337,12 @@ class SyncServer(Server):
|
||||
self.mcp_clients[server_name] = StdioMCPClient()
|
||||
else:
|
||||
raise ValueError(f"Invalid MCP server config: {server_config}")
|
||||
|
||||
try:
|
||||
self.mcp_clients[server_name].connect_to_server(server_config)
|
||||
except:
|
||||
logger.exception(f"Failed to connect to MCP server: {server_name}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
self.mcp_clients.pop(server_name)
|
||||
|
||||
# Print out the tools that are connected
|
||||
for server_name, client in self.mcp_clients.items():
|
||||
|
||||
@@ -98,216 +98,3 @@ def test_recall(client, agent_obj):
|
||||
# Conversation search
|
||||
result = base_functions.conversation_search(agent_obj, "banana")
|
||||
assert keyword in result
|
||||
|
||||
|
||||
# This test is nondeterministic, so we retry until we get the perfect behavior from the LLM
|
||||
@retry_until_success(max_attempts=2, sleep_time_seconds=2)
|
||||
def test_send_message_to_agent(client, agent_obj, other_agent_obj):
|
||||
secret_word = "banana"
|
||||
|
||||
# Encourage the agent to send a message to the other agent_obj with the secret string
|
||||
client.send_message(
|
||||
agent_id=agent_obj.agent_state.id,
|
||||
role="user",
|
||||
message=f"Use your tool to send a message to another agent with id {other_agent_obj.agent_state.id} to share the secret word: {secret_word}!",
|
||||
)
|
||||
|
||||
# Conversation search the other agent
|
||||
messages = client.get_messages(other_agent_obj.agent_state.id)
|
||||
# Check for the presence of system message
|
||||
for m in reversed(messages):
|
||||
print(f"\n\n {other_agent_obj.agent_state.id} -> {m.model_dump_json(indent=4)}")
|
||||
if isinstance(m, SystemMessage):
|
||||
assert secret_word in m.content
|
||||
break
|
||||
|
||||
# Search the sender agent for the response from another agent
|
||||
in_context_messages = agent_obj.agent_manager.get_in_context_messages(agent_id=agent_obj.agent_state.id, actor=agent_obj.user)
|
||||
found = False
|
||||
target_snippet = f"{other_agent_obj.agent_state.id} said:"
|
||||
|
||||
for m in in_context_messages:
|
||||
if target_snippet in m.text:
|
||||
found = True
|
||||
break
|
||||
|
||||
# Compute the joined string first
|
||||
joined_messages = "\n".join([m.text for m in in_context_messages[1:]])
|
||||
print(f"In context messages of the sender agent (without system):\n\n{joined_messages}")
|
||||
if not found:
|
||||
raise Exception(f"Was not able to find an instance of the target snippet: {target_snippet}")
|
||||
|
||||
# Test that the agent can still receive messages fine
|
||||
response = client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="So what did the other agent say?")
|
||||
print(response.messages)
|
||||
|
||||
|
||||
@retry_until_success(max_attempts=2, sleep_time_seconds=2)
|
||||
def test_send_message_to_agents_with_tags_simple(client):
|
||||
worker_tags = ["worker", "user-456"]
|
||||
|
||||
# Clean up first from possibly failed tests
|
||||
prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags, match_all_tags=True)
|
||||
for agent in prev_worker_agents:
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
secret_word = "banana"
|
||||
|
||||
# Create "manager" agent
|
||||
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
|
||||
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id])
|
||||
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
|
||||
|
||||
# Create 3 non-matching worker agents (These should NOT get the message)
|
||||
worker_agents = []
|
||||
worker_tags = ["worker", "user-123"]
|
||||
for _ in range(3):
|
||||
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags)
|
||||
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
|
||||
worker_agents.append(worker_agent)
|
||||
|
||||
# Create 3 worker agents that should get the message
|
||||
worker_agents = []
|
||||
worker_tags = ["worker", "user-456"]
|
||||
for _ in range(3):
|
||||
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags)
|
||||
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
|
||||
worker_agents.append(worker_agent)
|
||||
|
||||
# Encourage the manager to send a message to the other agent_obj with the secret string
|
||||
response = client.send_message(
|
||||
agent_id=manager_agent.agent_state.id,
|
||||
role="user",
|
||||
message=f"Send a message to all agents with tags {worker_tags} informing them of the secret word: {secret_word}!",
|
||||
)
|
||||
|
||||
for m in response.messages:
|
||||
if isinstance(m, ToolReturnMessage):
|
||||
tool_response = eval(json.loads(m.tool_return)["message"])
|
||||
print(f"\n\nManager agent tool response: \n{tool_response}\n\n")
|
||||
assert len(tool_response) == len(worker_agents)
|
||||
|
||||
# We can break after this, the ToolReturnMessage after is not related
|
||||
break
|
||||
|
||||
# Conversation search the worker agents
|
||||
for agent in worker_agents:
|
||||
messages = client.get_messages(agent.agent_state.id)
|
||||
# Check for the presence of system message
|
||||
for m in reversed(messages):
|
||||
print(f"\n\n {agent.agent_state.id} -> {m.model_dump_json(indent=4)}")
|
||||
if isinstance(m, SystemMessage):
|
||||
assert secret_word in m.content
|
||||
break
|
||||
|
||||
# Test that the agent can still receive messages fine
|
||||
response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?")
|
||||
print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages]))
|
||||
|
||||
# Clean up agents
|
||||
client.delete_agent(manager_agent_state.id)
|
||||
for agent in worker_agents:
|
||||
client.delete_agent(agent.agent_state.id)
|
||||
|
||||
|
||||
# This test is nondeterministic, so we retry until we get the perfect behavior from the LLM
|
||||
@retry_until_success(max_attempts=2, sleep_time_seconds=2)
|
||||
def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_tool):
|
||||
worker_tags = ["dice-rollers"]
|
||||
|
||||
# Clean up first from possibly failed tests
|
||||
prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags, match_all_tags=True)
|
||||
for agent in prev_worker_agents:
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
# Create "manager" agent
|
||||
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
|
||||
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id])
|
||||
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
|
||||
|
||||
# Create 3 worker agents
|
||||
worker_agents = []
|
||||
worker_tags = ["dice-rollers"]
|
||||
for _ in range(2):
|
||||
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags, tool_ids=[roll_dice_tool.id])
|
||||
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
|
||||
worker_agents.append(worker_agent)
|
||||
|
||||
# Encourage the manager to send a message to the other agent_obj with the secret string
|
||||
broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!"
|
||||
response = client.send_message(
|
||||
agent_id=manager_agent.agent_state.id,
|
||||
role="user",
|
||||
message=broadcast_message,
|
||||
)
|
||||
|
||||
for m in response.messages:
|
||||
if isinstance(m, ToolReturnMessage):
|
||||
tool_response = eval(json.loads(m.tool_return)["message"])
|
||||
print(f"\n\nManager agent tool response: \n{tool_response}\n\n")
|
||||
assert len(tool_response) == len(worker_agents)
|
||||
|
||||
# We can break after this, the ToolReturnMessage after is not related
|
||||
break
|
||||
|
||||
# Test that the agent can still receive messages fine
|
||||
response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?")
|
||||
print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages]))
|
||||
|
||||
# Clean up agents
|
||||
client.delete_agent(manager_agent_state.id)
|
||||
for agent in worker_agents:
|
||||
client.delete_agent(agent.agent_state.id)
|
||||
|
||||
|
||||
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
|
||||
def test_agents_async_simple(client):
|
||||
"""
|
||||
Test two agents with multi-agent tools sending messages back and forth to count to 5.
|
||||
The chain is started by prompting one of the agents.
|
||||
"""
|
||||
# Cleanup from potentially failed previous runs
|
||||
existing_agents = client.server.agent_manager.list_agents(client.user)
|
||||
for agent in existing_agents:
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
# Create two agents with multi-agent tools
|
||||
send_message_to_agent_async_tool_id = client.get_tool_id(name="send_message_to_agent_async")
|
||||
memory_a = ChatMemory(
|
||||
human="Chad - I'm interested in hearing poem.",
|
||||
persona="You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who has some great poem ideas (so I've heard).",
|
||||
)
|
||||
charles_state = client.create_agent(name="charles", memory=memory_a, tool_ids=[send_message_to_agent_async_tool_id])
|
||||
charles = client.server.load_agent(agent_id=charles_state.id, actor=client.user)
|
||||
|
||||
memory_b = ChatMemory(
|
||||
human="No human - you are to only communicate with the other AI agent.",
|
||||
persona="You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who is interested in great poem ideas.",
|
||||
)
|
||||
sarah_state = client.create_agent(name="sarah", memory=memory_b, tool_ids=[send_message_to_agent_async_tool_id])
|
||||
|
||||
# Start the count chain with Agent1
|
||||
initial_prompt = f"I want you to talk to the other agent with ID {sarah_state.id} using `send_message_to_agent_async`. Specifically, I want you to ask him for a poem idea, and then craft a poem for me."
|
||||
client.send_message(
|
||||
agent_id=charles.agent_state.id,
|
||||
role="user",
|
||||
message=initial_prompt,
|
||||
)
|
||||
|
||||
found_in_charles = wait_for_incoming_message(
|
||||
client=client,
|
||||
agent_id=charles_state.id,
|
||||
substring="[Incoming message from agent with ID",
|
||||
max_wait_seconds=10,
|
||||
sleep_interval=0.5,
|
||||
)
|
||||
assert found_in_charles, "Charles never received the system message from Sarah (timed out)."
|
||||
|
||||
found_in_sarah = wait_for_incoming_message(
|
||||
client=client,
|
||||
agent_id=sarah_state.id,
|
||||
substring="[Incoming message from agent with ID",
|
||||
max_wait_seconds=10,
|
||||
sleep_interval=0.5,
|
||||
)
|
||||
assert found_in_sarah, "Sarah never received the system message from Charles (timed out)."
|
||||
|
||||
@@ -8,18 +8,11 @@ from typing import List, Union
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import AgentState, JobStatus, Letta, MessageCreate, MessageRole
|
||||
from letta_client.core.api_error import ApiError
|
||||
from sqlalchemy import delete
|
||||
|
||||
from letta import LocalClient, RESTClient, create_client
|
||||
from letta.orm import SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.job import JobStatus
|
||||
from letta.schemas.letta_message import ToolReturnMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxType
|
||||
from letta.utils import create_random_username
|
||||
|
||||
# Constants
|
||||
SERVER_PORT = 8283
|
||||
@@ -42,63 +35,68 @@ def run_server():
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
{"server": False},
|
||||
{"server": True},
|
||||
], # whether to use REST API server
|
||||
# params=[{"server": False}], # whether to use REST API server
|
||||
scope="module",
|
||||
)
|
||||
def client(request):
|
||||
if request.param["server"]:
|
||||
# Get URL from environment or start server
|
||||
server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:{SERVER_PORT}")
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
print("Starting server thread")
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(5)
|
||||
print("Running client tests with server:", server_url)
|
||||
client = create_client(base_url=server_url, token=None)
|
||||
else:
|
||||
client = create_client()
|
||||
# Get URL from environment or start server
|
||||
server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:{SERVER_PORT}")
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
print("Starting server thread")
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(5)
|
||||
print("Running client tests with server:", server_url)
|
||||
|
||||
client.set_default_llm_config(LLMConfig.default_config("gpt-4"))
|
||||
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
|
||||
yield client
|
||||
# create the Letta client
|
||||
yield Letta(base_url=server_url, token=None)
|
||||
|
||||
|
||||
# Fixture for test agent
|
||||
@pytest.fixture(scope="module")
|
||||
def agent(client: Union[LocalClient, RESTClient]):
|
||||
agent_state = client.create_agent(name=f"test_client_{str(uuid.uuid4())}")
|
||||
def agent(client: Letta):
|
||||
agent_state = client.agents.create(
|
||||
name="test_client",
|
||||
memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}],
|
||||
model="letta/letta-free",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
yield agent_state
|
||||
|
||||
# delete agent
|
||||
client.delete_agent(agent_state.id)
|
||||
client.agents.delete(agent_state.id)
|
||||
|
||||
|
||||
# Fixture for test agent
|
||||
@pytest.fixture
|
||||
def search_agent_one(client: Union[LocalClient, RESTClient]):
|
||||
agent_state = client.create_agent(name="Search Agent One")
|
||||
def search_agent_one(client: Letta):
|
||||
agent_state = client.agents.create(
|
||||
name="Search Agent One",
|
||||
memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}],
|
||||
model="letta/letta-free",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
yield agent_state
|
||||
|
||||
# delete agent
|
||||
client.delete_agent(agent_state.id)
|
||||
client.agents.delete(agent_state.id)
|
||||
|
||||
|
||||
# Fixture for test agent
|
||||
@pytest.fixture
|
||||
def search_agent_two(client: Union[LocalClient, RESTClient]):
|
||||
agent_state = client.create_agent(name="Search Agent Two")
|
||||
def search_agent_two(client: Letta):
|
||||
agent_state = client.agents.create(
|
||||
name="Search Agent Two",
|
||||
memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}],
|
||||
model="letta/letta-free",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
yield agent_state
|
||||
|
||||
# delete agent
|
||||
client.delete_agent(agent_state.id)
|
||||
client.agents.delete(agent_state.id)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -112,55 +110,56 @@ def clear_tables():
|
||||
session.commit()
|
||||
|
||||
|
||||
def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]):
|
||||
"""
|
||||
Test sandbox config and environment variable functions for both LocalClient and RESTClient.
|
||||
"""
|
||||
|
||||
# 1. Create a sandbox config
|
||||
local_config = LocalSandboxConfig(sandbox_dir=SANDBOX_DIR)
|
||||
sandbox_config = client.create_sandbox_config(config=local_config)
|
||||
|
||||
# Assert the created sandbox config
|
||||
assert sandbox_config.id is not None
|
||||
assert sandbox_config.type == SandboxType.LOCAL
|
||||
|
||||
# 2. Update the sandbox config
|
||||
updated_config = LocalSandboxConfig(sandbox_dir=UPDATED_SANDBOX_DIR)
|
||||
sandbox_config = client.update_sandbox_config(sandbox_config_id=sandbox_config.id, config=updated_config)
|
||||
assert sandbox_config.config["sandbox_dir"] == UPDATED_SANDBOX_DIR
|
||||
|
||||
# 3. List all sandbox configs
|
||||
sandbox_configs = client.list_sandbox_configs(limit=10)
|
||||
assert isinstance(sandbox_configs, List)
|
||||
assert len(sandbox_configs) == 1
|
||||
assert sandbox_configs[0].id == sandbox_config.id
|
||||
|
||||
# 4. Create an environment variable
|
||||
env_var = client.create_sandbox_env_var(
|
||||
sandbox_config_id=sandbox_config.id, key=ENV_VAR_KEY, value=ENV_VAR_VALUE, description=ENV_VAR_DESCRIPTION
|
||||
)
|
||||
assert env_var.id is not None
|
||||
assert env_var.key == ENV_VAR_KEY
|
||||
assert env_var.value == ENV_VAR_VALUE
|
||||
assert env_var.description == ENV_VAR_DESCRIPTION
|
||||
|
||||
# 5. Update the environment variable
|
||||
updated_env_var = client.update_sandbox_env_var(env_var_id=env_var.id, key=UPDATED_ENV_VAR_KEY, value=UPDATED_ENV_VAR_VALUE)
|
||||
assert updated_env_var.key == UPDATED_ENV_VAR_KEY
|
||||
assert updated_env_var.value == UPDATED_ENV_VAR_VALUE
|
||||
|
||||
# 6. List environment variables
|
||||
env_vars = client.list_sandbox_env_vars(sandbox_config_id=sandbox_config.id)
|
||||
assert isinstance(env_vars, List)
|
||||
assert len(env_vars) == 1
|
||||
assert env_vars[0].key == UPDATED_ENV_VAR_KEY
|
||||
|
||||
# 7. Delete the environment variable
|
||||
client.delete_sandbox_env_var(env_var_id=env_var.id)
|
||||
|
||||
# 8. Delete the sandbox config
|
||||
client.delete_sandbox_config(sandbox_config_id=sandbox_config.id)
|
||||
# TODO: add back
|
||||
# def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]):
|
||||
# """
|
||||
# Test sandbox config and environment variable functions for both LocalClient and RESTClient.
|
||||
# """
|
||||
#
|
||||
# # 1. Create a sandbox config
|
||||
# local_config = LocalSandboxConfig(sandbox_dir=SANDBOX_DIR)
|
||||
# sandbox_config = client.create_sandbox_config(config=local_config)
|
||||
#
|
||||
# # Assert the created sandbox config
|
||||
# assert sandbox_config.id is not None
|
||||
# assert sandbox_config.type == SandboxType.LOCAL
|
||||
#
|
||||
# # 2. Update the sandbox config
|
||||
# updated_config = LocalSandboxConfig(sandbox_dir=UPDATED_SANDBOX_DIR)
|
||||
# sandbox_config = client.update_sandbox_config(sandbox_config_id=sandbox_config.id, config=updated_config)
|
||||
# assert sandbox_config.config["sandbox_dir"] == UPDATED_SANDBOX_DIR
|
||||
#
|
||||
# # 3. List all sandbox configs
|
||||
# sandbox_configs = client.list_sandbox_configs(limit=10)
|
||||
# assert isinstance(sandbox_configs, List)
|
||||
# assert len(sandbox_configs) == 1
|
||||
# assert sandbox_configs[0].id == sandbox_config.id
|
||||
#
|
||||
# # 4. Create an environment variable
|
||||
# env_var = client.create_sandbox_env_var(
|
||||
# sandbox_config_id=sandbox_config.id, key=ENV_VAR_KEY, value=ENV_VAR_VALUE, description=ENV_VAR_DESCRIPTION
|
||||
# )
|
||||
# assert env_var.id is not None
|
||||
# assert env_var.key == ENV_VAR_KEY
|
||||
# assert env_var.value == ENV_VAR_VALUE
|
||||
# assert env_var.description == ENV_VAR_DESCRIPTION
|
||||
#
|
||||
# # 5. Update the environment variable
|
||||
# updated_env_var = client.update_sandbox_env_var(env_var_id=env_var.id, key=UPDATED_ENV_VAR_KEY, value=UPDATED_ENV_VAR_VALUE)
|
||||
# assert updated_env_var.key == UPDATED_ENV_VAR_KEY
|
||||
# assert updated_env_var.value == UPDATED_ENV_VAR_VALUE
|
||||
#
|
||||
# # 6. List environment variables
|
||||
# env_vars = client.list_sandbox_env_vars(sandbox_config_id=sandbox_config.id)
|
||||
# assert isinstance(env_vars, List)
|
||||
# assert len(env_vars) == 1
|
||||
# assert env_vars[0].key == UPDATED_ENV_VAR_KEY
|
||||
#
|
||||
# # 7. Delete the environment variable
|
||||
# client.delete_sandbox_env_var(env_var_id=env_var.id)
|
||||
#
|
||||
# # 8. Delete the sandbox config
|
||||
# client.delete_sandbox_config(sandbox_config_id=sandbox_config.id)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
@@ -168,197 +167,186 @@ def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient]):
|
||||
def test_add_and_manage_tags_for_agent(client: Letta):
|
||||
"""
|
||||
Comprehensive happy path test for adding, retrieving, and managing tags on an agent.
|
||||
"""
|
||||
tags_to_add = ["test_tag_1", "test_tag_2", "test_tag_3"]
|
||||
|
||||
# Step 0: create an agent with no tags
|
||||
agent = client.create_agent()
|
||||
agent = client.agents.create(memory_blocks=[], model="letta/letta-free", embedding="letta/letta-free")
|
||||
assert len(agent.tags) == 0
|
||||
|
||||
# Step 1: Add multiple tags to the agent
|
||||
client.update_agent(agent_id=agent.id, tags=tags_to_add)
|
||||
client.agents.modify(agent_id=agent.id, tags=tags_to_add)
|
||||
|
||||
# Step 2: Retrieve tags for the agent and verify they match the added tags
|
||||
retrieved_tags = client.get_agent(agent_id=agent.id).tags
|
||||
retrieved_tags = client.agents.retrieve(agent_id=agent.id).tags
|
||||
assert set(retrieved_tags) == set(tags_to_add), f"Expected tags {tags_to_add}, but got {retrieved_tags}"
|
||||
|
||||
# Step 3: Retrieve agents by each tag to ensure the agent is associated correctly
|
||||
for tag in tags_to_add:
|
||||
agents_with_tag = client.list_agents(tags=[tag])
|
||||
agents_with_tag = client.agents.list(tags=[tag])
|
||||
assert agent.id in [a.id for a in agents_with_tag], f"Expected agent {agent.id} to be associated with tag '{tag}'"
|
||||
|
||||
# Step 4: Delete a specific tag from the agent and verify its removal
|
||||
tag_to_delete = tags_to_add.pop()
|
||||
client.update_agent(agent_id=agent.id, tags=tags_to_add)
|
||||
client.agents.modify(agent_id=agent.id, tags=tags_to_add)
|
||||
|
||||
# Verify the tag is removed from the agent's tags
|
||||
remaining_tags = client.get_agent(agent_id=agent.id).tags
|
||||
remaining_tags = client.agents.retrieve(agent_id=agent.id).tags
|
||||
assert tag_to_delete not in remaining_tags, f"Tag '{tag_to_delete}' was not removed as expected"
|
||||
assert set(remaining_tags) == set(tags_to_add), f"Expected remaining tags to be {tags_to_add[1:]}, but got {remaining_tags}"
|
||||
|
||||
# Step 5: Delete all remaining tags from the agent
|
||||
client.update_agent(agent_id=agent.id, tags=[])
|
||||
client.agents.modify(agent_id=agent.id, tags=[])
|
||||
|
||||
# Verify all tags are removed
|
||||
final_tags = client.get_agent(agent_id=agent.id).tags
|
||||
final_tags = client.agents.retrieve(agent_id=agent.id).tags
|
||||
assert len(final_tags) == 0, f"Expected no tags, but found {final_tags}"
|
||||
|
||||
# Remove agent
|
||||
client.delete_agent(agent.id)
|
||||
client.agents.delete(agent.id)
|
||||
|
||||
|
||||
def test_agent_tags(client: Union[LocalClient, RESTClient]):
|
||||
def test_agent_tags(client: Letta):
|
||||
"""Test creating agents with tags and retrieving tags via the API."""
|
||||
if not isinstance(client, RESTClient):
|
||||
pytest.skip("This test only runs when the server is enabled")
|
||||
|
||||
# Create multiple agents with different tags
|
||||
agent1 = client.create_agent(
|
||||
agent1 = client.agents.create(
|
||||
name=f"test_agent_{str(uuid.uuid4())}",
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
tags=["test", "agent1", "production"],
|
||||
model="letta/letta-free",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
agent2 = client.create_agent(
|
||||
agent2 = client.agents.create(
|
||||
name=f"test_agent_{str(uuid.uuid4())}",
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
tags=["test", "agent2", "development"],
|
||||
model="letta/letta-free",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
agent3 = client.create_agent(
|
||||
agent3 = client.agents.create(
|
||||
name=f"test_agent_{str(uuid.uuid4())}",
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
tags=["test", "agent3", "production"],
|
||||
model="letta/letta-free",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
# Test getting all tags
|
||||
all_tags = client.get_tags()
|
||||
all_tags = client.tag.list_tags()
|
||||
expected_tags = ["agent1", "agent2", "agent3", "development", "production", "test"]
|
||||
assert sorted(all_tags) == expected_tags
|
||||
|
||||
# Test pagination
|
||||
paginated_tags = client.get_tags(limit=2)
|
||||
paginated_tags = client.tag.list_tags(limit=2)
|
||||
assert len(paginated_tags) == 2
|
||||
assert paginated_tags[0] == "agent1"
|
||||
assert paginated_tags[1] == "agent2"
|
||||
|
||||
# Test pagination with cursor
|
||||
next_page_tags = client.get_tags(after="agent2", limit=2)
|
||||
next_page_tags = client.tag.list_tags(after="agent2", limit=2)
|
||||
assert len(next_page_tags) == 2
|
||||
assert next_page_tags[0] == "agent3"
|
||||
assert next_page_tags[1] == "development"
|
||||
|
||||
# Test text search
|
||||
prod_tags = client.get_tags(query_text="prod")
|
||||
prod_tags = client.tag.list_tags(query_text="prod")
|
||||
assert sorted(prod_tags) == ["production"]
|
||||
|
||||
dev_tags = client.get_tags(query_text="dev")
|
||||
dev_tags = client.tag.list_tags(query_text="dev")
|
||||
assert sorted(dev_tags) == ["development"]
|
||||
|
||||
agent_tags = client.get_tags(query_text="agent")
|
||||
agent_tags = client.tag.list_tags(query_text="agent")
|
||||
assert sorted(agent_tags) == ["agent1", "agent2", "agent3"]
|
||||
|
||||
# Remove agents
|
||||
client.delete_agent(agent1.id)
|
||||
client.delete_agent(agent2.id)
|
||||
client.delete_agent(agent3.id)
|
||||
client.agents.delete(agent1.id)
|
||||
client.agents.delete(agent2.id)
|
||||
client.agents.delete(agent3.id)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
# Agent memory blocks
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient]):
|
||||
# _reset_config()
|
||||
|
||||
def test_shared_blocks(mock_e2b_api_key_none, client: Letta):
|
||||
# create a block
|
||||
block = client.create_block(label="human", value="username: sarah")
|
||||
block = client.blocks.create(label="human", value="username: sarah")
|
||||
|
||||
# create agents with shared block
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.memory import BasicBlockMemory
|
||||
|
||||
# persona1_block = client.create_block(label="persona", value="you are agent 1")
|
||||
# persona2_block = client.create_block(label="persona", value="you are agent 2")
|
||||
# create agents
|
||||
agent_state1 = client.create_agent(
|
||||
name="agent1", memory=BasicBlockMemory([Block(label="persona", value="you are agent 1")]), block_ids=[block.id]
|
||||
agent_state1 = client.agents.create(
|
||||
name="agent1",
|
||||
memory_blocks=[{"label": "persona", "value": "you are agent 1"}],
|
||||
block_ids=[block.id],
|
||||
model="letta/letta-free",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
agent_state2 = client.create_agent(
|
||||
name="agent2", memory=BasicBlockMemory([Block(label="persona", value="you are agent 2")]), block_ids=[block.id]
|
||||
agent_state2 = client.agents.create(
|
||||
name="agent2",
|
||||
memory_blocks=[{"label": "persona", "value": "you are agent 2"}],
|
||||
block_ids=[block.id],
|
||||
model="letta/letta-free",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
## attach shared block to both agents
|
||||
# client.link_agent_memory_block(agent_state1.id, block.id)
|
||||
# client.link_agent_memory_block(agent_state2.id, block.id)
|
||||
|
||||
# update memory
|
||||
client.user_message(agent_id=agent_state1.id, message="my name is actually charles")
|
||||
client.agents.messages.create(agent_id=agent_state1.id, messages=[{"role": "user", "content": "my name is actually charles"}])
|
||||
|
||||
# check agent 2 memory
|
||||
assert "charles" in client.get_block(block.id).value.lower(), f"Shared block update failed {client.get_block(block.id).value}"
|
||||
|
||||
client.user_message(agent_id=agent_state2.id, message="whats my name?")
|
||||
assert (
|
||||
"charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower()
|
||||
), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}"
|
||||
assert "charles" in client.agents.blocks.retrieve(agent_id=agent_state2.id, block_label="human").value.lower()
|
||||
|
||||
# cleanup
|
||||
client.delete_agent(agent_state1.id)
|
||||
client.delete_agent(agent_state2.id)
|
||||
client.agents.delete(agent_state1.id)
|
||||
client.agents.delete(agent_state2.id)
|
||||
|
||||
|
||||
def test_update_agent_memory_label(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
def test_update_agent_memory_label(client: Letta):
|
||||
"""Test that we can update the label of a block in an agent's memory"""
|
||||
|
||||
agent = client.create_agent(name=create_random_username())
|
||||
agent = client.agents.create(model="letta/letta-free", embedding="letta/letta-free", memory_blocks=[{"label": "human", "value": ""}])
|
||||
|
||||
try:
|
||||
current_labels = agent.memory.list_block_labels()
|
||||
current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)]
|
||||
example_label = current_labels[0]
|
||||
example_new_label = "example_new_label"
|
||||
assert example_new_label not in current_labels
|
||||
assert example_new_label not in [b.label for b in client.agents.blocks.list(agent_id=agent.id)]
|
||||
|
||||
client.update_agent_memory_block_label(agent_id=agent.id, current_label=example_label, new_label=example_new_label)
|
||||
client.agents.blocks.modify(agent_id=agent.id, block_label=example_label, label=example_new_label)
|
||||
|
||||
updated_agent = client.get_agent(agent_id=agent.id)
|
||||
assert example_new_label in updated_agent.memory.list_block_labels()
|
||||
updated_blocks = client.agents.blocks.list(agent_id=agent.id)
|
||||
assert example_new_label in [b.label for b in updated_blocks]
|
||||
|
||||
finally:
|
||||
client.delete_agent(agent.id)
|
||||
client.agents.delete(agent.id)
|
||||
|
||||
|
||||
def test_attach_detach_agent_memory_block(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
def test_attach_detach_agent_memory_block(client: Letta, agent: AgentState):
|
||||
"""Test that we can add and remove a block from an agent's memory"""
|
||||
|
||||
current_labels = agent.memory.list_block_labels()
|
||||
current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)]
|
||||
example_new_label = current_labels[0] + "_v2"
|
||||
example_new_value = "example value"
|
||||
assert example_new_label not in current_labels
|
||||
|
||||
# Link a new memory block
|
||||
block = client.create_block(
|
||||
block = client.blocks.create(
|
||||
label=example_new_label,
|
||||
value=example_new_value,
|
||||
limit=1000,
|
||||
)
|
||||
updated_agent = client.attach_block(
|
||||
updated_agent = client.agents.blocks.attach(
|
||||
agent_id=agent.id,
|
||||
block_id=block.id,
|
||||
)
|
||||
assert example_new_label in updated_agent.memory.list_block_labels()
|
||||
assert example_new_label in [block.label for block in client.agents.blocks.list(agent_id=updated_agent.id)]
|
||||
|
||||
# Now unlink the block
|
||||
updated_agent = client.detach_block(
|
||||
updated_agent = client.agents.blocks.detach(
|
||||
agent_id=agent.id,
|
||||
block_id=block.id,
|
||||
)
|
||||
assert example_new_label not in updated_agent.memory.list_block_labels()
|
||||
assert example_new_label not in [block.label for block in client.agents.blocks.list(agent_id=updated_agent.id)]
|
||||
|
||||
|
||||
# def test_core_memory_token_limits(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
@@ -385,39 +373,57 @@ def test_attach_detach_agent_memory_block(client: Union[LocalClient, RESTClient]
|
||||
# client.delete_agent(new_agent.id)
|
||||
|
||||
|
||||
def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient]):
|
||||
def test_update_agent_memory_limit(client: Letta):
|
||||
"""Test that we can update the limit of a block in an agent's memory"""
|
||||
|
||||
agent = client.create_agent()
|
||||
agent = client.agents.create(
|
||||
model="letta/letta-free",
|
||||
embedding="letta/letta-free",
|
||||
memory_blocks=[
|
||||
{"label": "human", "value": "username: sarah", "limit": 1000},
|
||||
{"label": "persona", "value": "you are sarah", "limit": 1000},
|
||||
],
|
||||
)
|
||||
|
||||
current_labels = agent.memory.list_block_labels()
|
||||
current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)]
|
||||
example_label = current_labels[0]
|
||||
example_new_limit = 1
|
||||
current_block = agent.memory.get_block(label=example_label)
|
||||
|
||||
current_labels = [block.label for block in client.agents.blocks.list(agent_id=agent.id)]
|
||||
example_label = current_labels[0]
|
||||
example_new_limit = 1
|
||||
current_block = client.agents.blocks.retrieve(agent_id=agent.id, block_label=example_label)
|
||||
current_block_length = len(current_block.value)
|
||||
|
||||
assert example_new_limit != agent.memory.get_block(label=example_label).limit
|
||||
assert example_new_limit != current_block.limit
|
||||
assert example_new_limit < current_block_length
|
||||
|
||||
# We expect this to throw a value error
|
||||
with pytest.raises(ValueError):
|
||||
client.update_agent_memory_block(agent_id=agent.id, label=example_label, limit=example_new_limit)
|
||||
with pytest.raises(ApiError):
|
||||
client.agents.blocks.modify(
|
||||
agent_id=agent.id,
|
||||
block_label=example_label,
|
||||
limit=example_new_limit,
|
||||
)
|
||||
|
||||
# Now try the same thing with a higher limit
|
||||
example_new_limit = current_block_length + 10000
|
||||
assert example_new_limit > current_block_length
|
||||
client.update_agent_memory_block(agent_id=agent.id, label=example_label, limit=example_new_limit)
|
||||
client.agents.blocks.modify(
|
||||
agent_id=agent.id,
|
||||
block_label=example_label,
|
||||
limit=example_new_limit,
|
||||
)
|
||||
|
||||
updated_agent = client.get_agent(agent_id=agent.id)
|
||||
assert example_new_limit == updated_agent.memory.get_block(label=example_label).limit
|
||||
assert example_new_limit == client.agents.blocks.retrieve(agent_id=agent.id, block_label=example_label).limit
|
||||
|
||||
client.delete_agent(agent.id)
|
||||
client.agents.delete(agent.id)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
# Agent Tools
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
def test_function_return_limit(client: Union[LocalClient, RESTClient]):
|
||||
def test_function_return_limit(client: Letta):
|
||||
"""Test to see if the function return limit works"""
|
||||
|
||||
def big_return():
|
||||
@@ -430,15 +436,21 @@ def test_function_return_limit(client: Union[LocalClient, RESTClient]):
|
||||
return "x" * 100000
|
||||
|
||||
padding = len("[NOTE: function output was truncated since it exceeded the character limit (100000 > 1000)]") + 50
|
||||
tool = client.create_or_update_tool(func=big_return, return_char_limit=1000)
|
||||
agent = client.create_agent(tool_ids=[tool.id])
|
||||
tool = client.tools.upsert_from_function(func=big_return, return_char_limit=1000)
|
||||
agent = client.agents.create(
|
||||
model="letta/letta-free",
|
||||
embedding="letta/letta-free",
|
||||
tool_ids=[tool.id],
|
||||
)
|
||||
# get function response
|
||||
response = client.send_message(agent_id=agent.id, message="call the big_return function", role="user")
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id, messages=[MessageCreate(role="user", content="call the big_return function")]
|
||||
)
|
||||
print(response.messages)
|
||||
|
||||
response_message = None
|
||||
for message in response.messages:
|
||||
if isinstance(message, ToolReturnMessage):
|
||||
if message.message_type == "tool_return_message":
|
||||
response_message = message
|
||||
break
|
||||
|
||||
@@ -452,44 +464,58 @@ def test_function_return_limit(client: Union[LocalClient, RESTClient]):
|
||||
# len(res_json["message"]) <= 1000 + padding
|
||||
# ), f"Expected length to be less than or equal to 1000 + {padding}, but got {len(res_json['message'])}"
|
||||
|
||||
client.delete_agent(agent_id=agent.id)
|
||||
client.agents.delete(agent_id=agent.id)
|
||||
|
||||
|
||||
def test_function_always_error(client: Union[LocalClient, RESTClient]):
|
||||
def test_function_always_error(client: Letta):
|
||||
"""Test to see if function that errors works correctly"""
|
||||
|
||||
def testing_method():
|
||||
"""
|
||||
Always throw an error.
|
||||
Call this tool when the user asks
|
||||
"""
|
||||
return 5 / 0
|
||||
|
||||
tool = client.create_or_update_tool(func=testing_method)
|
||||
agent = client.create_agent(tool_ids=[tool.id])
|
||||
tool = client.tools.upsert_from_function(func=testing_method)
|
||||
agent = client.agents.create(
|
||||
model="letta/letta-free",
|
||||
embedding="letta/letta-free",
|
||||
memory_blocks=[
|
||||
{
|
||||
"label": "human",
|
||||
"value": "username: sarah",
|
||||
},
|
||||
{
|
||||
"label": "persona",
|
||||
"value": "you are sarah",
|
||||
},
|
||||
],
|
||||
tool_ids=[tool.id],
|
||||
)
|
||||
print("AGENT TOOLS", [tool.name for tool in agent.tools])
|
||||
# get function response
|
||||
response = client.send_message(agent_id=agent.id, message="call the testing_method function and tell me the result", role="user")
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="call the testing_method function and tell me the result")],
|
||||
)
|
||||
print(response.messages)
|
||||
|
||||
response_message = None
|
||||
for message in response.messages:
|
||||
if isinstance(message, ToolReturnMessage):
|
||||
if message.message_type == "tool_return_message":
|
||||
response_message = message
|
||||
break
|
||||
|
||||
assert response_message, "ToolReturnMessage message not found in response"
|
||||
assert response_message.status == "error"
|
||||
assert (
|
||||
response_message.tool_return == "Error executing function testing_method: ZeroDivisionError: division by zero"
|
||||
), response_message.tool_return
|
||||
|
||||
if isinstance(client, RESTClient):
|
||||
assert response_message.tool_return == "Error executing function testing_method: ZeroDivisionError: division by zero"
|
||||
else:
|
||||
response_json = json.loads(response_message.tool_return)
|
||||
assert response_json["status"] == "Failed"
|
||||
assert response_json["message"] == "Error executing function testing_method: ZeroDivisionError: division by zero"
|
||||
|
||||
client.delete_agent(agent_id=agent.id)
|
||||
client.agents.delete(agent_id=agent.id)
|
||||
|
||||
|
||||
def test_attach_detach_agent_tool(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
def test_attach_detach_agent_tool(client: Letta, agent: AgentState):
|
||||
"""Test that we can attach and detach a tool from an agent"""
|
||||
|
||||
try:
|
||||
@@ -506,64 +532,64 @@ def test_attach_detach_agent_tool(client: Union[LocalClient, RESTClient], agent:
|
||||
"""
|
||||
return x * 2
|
||||
|
||||
tool = client.create_or_update_tool(func=example_tool)
|
||||
tool = client.tools.upsert_from_function(func=example_tool)
|
||||
|
||||
# Initially tool should not be attached
|
||||
initial_tools = client.list_attached_tools(agent_id=agent.id)
|
||||
initial_tools = client.agents.tools.list(agent_id=agent.id)
|
||||
assert tool.id not in [t.id for t in initial_tools]
|
||||
|
||||
# Attach tool
|
||||
new_agent_state = client.attach_tool(agent_id=agent.id, tool_id=tool.id)
|
||||
new_agent_state = client.agents.tools.attach(agent_id=agent.id, tool_id=tool.id)
|
||||
assert tool.id in [t.id for t in new_agent_state.tools]
|
||||
|
||||
# Verify tool is attached
|
||||
updated_tools = client.list_attached_tools(agent_id=agent.id)
|
||||
updated_tools = client.agents.tools.list(agent_id=agent.id)
|
||||
assert tool.id in [t.id for t in updated_tools]
|
||||
|
||||
# Detach tool
|
||||
new_agent_state = client.detach_tool(agent_id=agent.id, tool_id=tool.id)
|
||||
new_agent_state = client.agents.tools.detach(agent_id=agent.id, tool_id=tool.id)
|
||||
assert tool.id not in [t.id for t in new_agent_state.tools]
|
||||
|
||||
# Verify tool is detached
|
||||
final_tools = client.list_attached_tools(agent_id=agent.id)
|
||||
final_tools = client.agents.tools.list(agent_id=agent.id)
|
||||
assert tool.id not in [t.id for t in final_tools]
|
||||
|
||||
finally:
|
||||
client.delete_tool(tool.id)
|
||||
client.tools.delete(tool.id)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
# AgentMessages
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
def test_messages(client: Letta, agent: AgentState):
|
||||
# _reset_config()
|
||||
|
||||
send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
|
||||
send_message_response = client.agents.messages.create(agent_id=agent.id, messages=[MessageCreate(role="user", content="Test message")])
|
||||
assert send_message_response, "Sending message failed"
|
||||
|
||||
messages_response = client.get_messages(agent_id=agent.id, limit=1)
|
||||
messages_response = client.agents.messages.list(agent_id=agent.id, limit=1)
|
||||
assert len(messages_response) > 0, "Retrieving messages failed"
|
||||
|
||||
|
||||
def test_send_system_message(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
def test_send_system_message(client: Letta, agent: AgentState):
|
||||
"""Important unit test since the Letta API exposes sending system messages, but some backends don't natively support it (eg Anthropic)"""
|
||||
send_system_message_response = client.send_message(
|
||||
agent_id=agent.id, message="Event occurred: The user just logged off.", role="system"
|
||||
send_system_message_response = client.agents.messages.create(
|
||||
agent_id=agent.id, messages=[MessageCreate(role="system", content="Event occurred: The user just logged off.")]
|
||||
)
|
||||
assert send_system_message_response, "Sending message failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_parallel(client: Union[LocalClient, RESTClient], agent: AgentState, request):
|
||||
async def test_send_message_parallel(client: Letta, agent: AgentState, request):
|
||||
"""
|
||||
Test that sending two messages in parallel does not error.
|
||||
"""
|
||||
if not isinstance(client, RESTClient):
|
||||
pytest.skip("This test only runs when the server is enabled")
|
||||
|
||||
# Define a coroutine for sending a message using asyncio.to_thread for synchronous calls
|
||||
async def send_message_task(message: str):
|
||||
response = await asyncio.to_thread(client.send_message, agent_id=agent.id, message=message, role="user")
|
||||
response = await asyncio.to_thread(
|
||||
client.agents.messages.create, agent_id=agent.id, messages=[MessageCreate(role="user", content=message)]
|
||||
)
|
||||
assert response, f"Sending message '{message}' failed"
|
||||
return response
|
||||
|
||||
@@ -585,76 +611,31 @@ async def test_send_message_parallel(client: Union[LocalClient, RESTClient], age
|
||||
assert len(responses) == len(messages), "Not all messages were processed"
|
||||
|
||||
|
||||
def test_send_message_async(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
"""
|
||||
Test that we can send a message asynchronously and retrieve the messages, along with usage statistics
|
||||
"""
|
||||
|
||||
if not isinstance(client, RESTClient):
|
||||
pytest.skip("send_message_async is only supported by the RESTClient")
|
||||
|
||||
print("Sending message asynchronously")
|
||||
test_message = "This is a test message, respond to the user with a sentence."
|
||||
run = client.send_message_async(agent_id=agent.id, role="user", message=test_message)
|
||||
assert run.id is not None
|
||||
assert run.status == JobStatus.created
|
||||
print(f"Run created, run={run}, status={run.status}")
|
||||
|
||||
# Wait for the job to complete, cancel it if takes over 10 seconds
|
||||
start_time = time.time()
|
||||
while run.status == JobStatus.created:
|
||||
time.sleep(1)
|
||||
run = client.get_run(run_id=run.id)
|
||||
print(f"Run status: {run.status}")
|
||||
if time.time() - start_time > 10:
|
||||
pytest.fail("Run took too long to complete")
|
||||
|
||||
print(f"Run completed in {time.time() - start_time} seconds, run={run}")
|
||||
assert run.status == JobStatus.completed
|
||||
|
||||
# Get messages for the job
|
||||
messages = client.get_run_messages(run_id=run.id)
|
||||
assert len(messages) >= 2 # At least assistant response
|
||||
|
||||
# Check filters
|
||||
assistant_messages = client.get_run_messages(run_id=run.id, role=MessageRole.assistant)
|
||||
assert len(assistant_messages) > 0
|
||||
tool_messages = client.get_run_messages(run_id=run.id, role=MessageRole.tool)
|
||||
assert len(tool_messages) > 0
|
||||
|
||||
# Get and verify usage statistics
|
||||
usage = client.get_run_usage(run_id=run.id)[0]
|
||||
assert usage.completion_tokens >= 0
|
||||
assert usage.prompt_tokens >= 0
|
||||
assert usage.total_tokens >= 0
|
||||
assert usage.total_tokens == usage.completion_tokens + usage.prompt_tokens
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
# Agent listing
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_agent_listing(client: Union[LocalClient, RESTClient], agent, search_agent_one, search_agent_two):
|
||||
def test_agent_listing(client: Letta, agent, search_agent_one, search_agent_two):
|
||||
"""Test listing agents with pagination and query text filtering."""
|
||||
# Test query text filtering
|
||||
search_results = client.list_agents(query_text="search agent")
|
||||
search_results = client.agents.list(query_text="search agent")
|
||||
assert len(search_results) == 2
|
||||
search_agent_ids = {agent.id for agent in search_results}
|
||||
assert search_agent_one.id in search_agent_ids
|
||||
assert search_agent_two.id in search_agent_ids
|
||||
assert agent.id not in search_agent_ids
|
||||
|
||||
different_results = client.list_agents(query_text="client")
|
||||
different_results = client.agents.list(query_text="client")
|
||||
assert len(different_results) == 1
|
||||
assert different_results[0].id == agent.id
|
||||
|
||||
# Test pagination
|
||||
first_page = client.list_agents(query_text="search agent", limit=1)
|
||||
first_page = client.agents.list(query_text="search agent", limit=1)
|
||||
assert len(first_page) == 1
|
||||
first_agent = first_page[0]
|
||||
|
||||
second_page = client.list_agents(query_text="search agent", after=first_agent.id, limit=1) # Use agent ID as cursor
|
||||
second_page = client.agents.list(query_text="search agent", after=first_agent.id, limit=1) # Use agent ID as cursor
|
||||
assert len(second_page) == 1
|
||||
assert second_page[0].id != first_agent.id
|
||||
|
||||
@@ -664,20 +645,16 @@ def test_agent_listing(client: Union[LocalClient, RESTClient], agent, search_age
|
||||
assert all_ids == {search_agent_one.id, search_agent_two.id}
|
||||
|
||||
# Test listing without any filters
|
||||
all_agents = client.list_agents()
|
||||
all_agents = client.agents.list()
|
||||
assert len(all_agents) == 3
|
||||
assert all(agent.id in {a.id for a in all_agents} for agent in [search_agent_one, search_agent_two, agent])
|
||||
|
||||
|
||||
def test_agent_creation(client: Union[LocalClient, RESTClient]):
|
||||
def test_agent_creation(client: Letta):
|
||||
"""Test that block IDs are properly attached when creating an agent."""
|
||||
if not isinstance(client, RESTClient):
|
||||
pytest.skip("This test only runs when the server is enabled")
|
||||
|
||||
from letta import BasicBlockMemory
|
||||
|
||||
# Create a test block that will represent user preferences
|
||||
user_preferences_block = client.create_block(label="user_preferences", value="", limit=10000)
|
||||
user_preferences_block = client.blocks.create(label="user_preferences", value="", limit=10000)
|
||||
|
||||
# Create test tools
|
||||
def test_tool():
|
||||
@@ -688,73 +665,82 @@ def test_agent_creation(client: Union[LocalClient, RESTClient]):
|
||||
"""Another test tool."""
|
||||
return "Hello from another test tool!"
|
||||
|
||||
tool1 = client.create_or_update_tool(func=test_tool, tags=["test"])
|
||||
tool2 = client.create_or_update_tool(func=another_test_tool, tags=["test"])
|
||||
|
||||
# Create test blocks
|
||||
offline_persona_block = client.create_block(label="persona", value="persona description", limit=5000)
|
||||
mindy_block = client.create_block(label="mindy", value="Mindy is a helpful assistant", limit=5000)
|
||||
memory_blocks = BasicBlockMemory(blocks=[offline_persona_block, mindy_block])
|
||||
tool1 = client.tools.upsert_from_function(func=test_tool, tags=["test"])
|
||||
tool2 = client.tools.upsert_from_function(func=another_test_tool, tags=["test"])
|
||||
|
||||
# Create agent with the blocks and tools
|
||||
agent = client.create_agent(
|
||||
name=f"test_agent_{str(uuid.uuid4())}",
|
||||
memory=memory_blocks,
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
agent = client.agents.create(
|
||||
memory_blocks=[
|
||||
{
|
||||
"label": "human",
|
||||
"value": "you are a human",
|
||||
},
|
||||
{"label": "persona", "value": "you are an assistant"},
|
||||
],
|
||||
model="letta/letta-free",
|
||||
embedding="letta/letta-free",
|
||||
tool_ids=[tool1.id, tool2.id],
|
||||
include_base_tools=False,
|
||||
tags=["test"],
|
||||
block_ids=[user_preferences_block.id],
|
||||
)
|
||||
memory_blocks = agent.memory.blocks
|
||||
|
||||
# Verify the agent was created successfully
|
||||
assert agent is not None
|
||||
assert agent.id is not None
|
||||
|
||||
# Verify the blocks are properly attached
|
||||
agent_blocks = client.list_agent_memory_blocks(agent.id)
|
||||
agent_blocks = client.agents.blocks.list(agent_id=agent.id)
|
||||
agent_block_ids = {block.id for block in agent_blocks}
|
||||
|
||||
# Check that all memory blocks are present
|
||||
memory_block_ids = {block.id for block in memory_blocks.blocks}
|
||||
for block_id in memory_block_ids | {user_preferences_block.id}:
|
||||
assert block_id in agent_block_ids
|
||||
memory_block_ids = {block.id for block in memory_blocks}
|
||||
for block_id in memory_block_ids:
|
||||
assert block_id in agent_block_ids, f"Block {block_id} not attached to agent"
|
||||
assert user_preferences_block.id in agent_block_ids, f"User preferences block {user_preferences_block.id} not attached to agent"
|
||||
|
||||
# Verify the tools are properly attached
|
||||
agent_tools = client.get_tools_from_agent(agent.id)
|
||||
agent_tools = client.agents.tools.list(agent_id=agent.id)
|
||||
assert len(agent_tools) == 2
|
||||
tool_ids = {tool1.id, tool2.id}
|
||||
assert all(tool.id in tool_ids for tool in agent_tools)
|
||||
|
||||
client.delete_agent(agent_id=agent.id)
|
||||
client.agents.delete(agent_id=agent.id)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
# Agent sources
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
def test_attach_detach_agent_source(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
def test_attach_detach_agent_source(client: Letta, agent: AgentState):
|
||||
"""Test that we can attach and detach a source from an agent"""
|
||||
|
||||
# Create a source
|
||||
source = client.create_source(
|
||||
source = client.sources.create(
|
||||
name="test_source",
|
||||
embedding_config={ # TODO: change this
|
||||
"embedding_endpoint": "https://embeddings.memgpt.ai",
|
||||
"embedding_model": "BAAI/bge-large-en-v1.5",
|
||||
"embedding_dim": 1024,
|
||||
"embedding_chunk_size": 300,
|
||||
"embedding_endpoint_type": "hugging-face",
|
||||
},
|
||||
)
|
||||
initial_sources = client.list_attached_sources(agent_id=agent.id)
|
||||
initial_sources = client.agents.sources.list(agent_id=agent.id)
|
||||
assert source.id not in [s.id for s in initial_sources]
|
||||
|
||||
# Attach source
|
||||
client.attach_source(agent_id=agent.id, source_id=source.id)
|
||||
client.agents.sources.attach(agent_id=agent.id, source_id=source.id)
|
||||
|
||||
# Verify source is attached
|
||||
final_sources = client.list_attached_sources(agent_id=agent.id)
|
||||
final_sources = client.agents.sources.list(agent_id=agent.id)
|
||||
assert source.id in [s.id for s in final_sources]
|
||||
|
||||
# Detach source
|
||||
client.detach_source(agent_id=agent.id, source_id=source.id)
|
||||
client.agents.sources.detach(agent_id=agent.id, source_id=source.id)
|
||||
|
||||
# Verify source is detached
|
||||
final_sources = client.list_attached_sources(agent_id=agent.id)
|
||||
final_sources = client.agents.sources.list(agent_id=agent.id)
|
||||
assert source.id not in [s.id for s in final_sources]
|
||||
|
||||
client.delete_source(source.id)
|
||||
client.sources.delete(source.id)
|
||||
|
||||
115
tests/test_streaming.py
Normal file
115
tests/test_streaming.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import AgentState, Letta, LlmConfig, MessageCreate
|
||||
from letta_client.core.api_error import ApiError
|
||||
from pytest import fixture
|
||||
|
||||
|
||||
def run_server():
|
||||
load_dotenv()
|
||||
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
print("Starting server...")
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
)
|
||||
def client(request):
|
||||
# Get URL from environment or start server
|
||||
server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:8283")
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
print("Starting server thread")
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(5)
|
||||
print("Running client tests with server:", server_url)
|
||||
|
||||
# create the Letta client
|
||||
yield Letta(base_url=server_url, token=None)
|
||||
|
||||
|
||||
# Fixture for test agent
|
||||
@pytest.fixture(scope="module")
|
||||
def agent(client: Letta):
|
||||
agent_state = client.agents.create(
|
||||
name="test_client",
|
||||
memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}],
|
||||
model="letta/letta-free",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
yield agent_state
|
||||
|
||||
# delete agent
|
||||
client.agents.delete(agent_state.id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"stream_tokens,model",
|
||||
[
|
||||
(True, "openai/gpt-4o-mini"),
|
||||
(True, "anthropic/claude-3-sonnet-20240229"),
|
||||
(False, "openai/gpt-4o-mini"),
|
||||
(False, "anthropic/claude-3-sonnet-20240229"),
|
||||
],
|
||||
)
|
||||
def test_streaming_send_message(
|
||||
mock_e2b_api_key_none,
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
stream_tokens: bool,
|
||||
model: str,
|
||||
):
|
||||
# Update agent's model
|
||||
config = client.agents.retrieve(agent_id=agent.id).llm_config
|
||||
config_dump = config.model_dump()
|
||||
config_dump["model"] = model
|
||||
config = LlmConfig(**config_dump)
|
||||
client.agents.modify(agent_id=agent.id, llm_config=config)
|
||||
|
||||
# Send streaming message
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="This is a test. Repeat after me: 'banana'")],
|
||||
stream_tokens=stream_tokens,
|
||||
)
|
||||
|
||||
# Tracking variables for test validation
|
||||
inner_thoughts_exist = False
|
||||
inner_thoughts_count = 0
|
||||
send_message_ran = False
|
||||
done = False
|
||||
|
||||
assert response, "Sending message failed"
|
||||
for chunk in response:
|
||||
# Check chunk type and content based on the current client API
|
||||
if hasattr(chunk, "message_type") and chunk.message_type == "reasoning_message":
|
||||
inner_thoughts_exist = True
|
||||
inner_thoughts_count += 1
|
||||
|
||||
if chunk.message_type == "tool_call_message" and hasattr(chunk, "tool_call") and chunk.tool_call.name == "send_message":
|
||||
send_message_ran = True
|
||||
if chunk.message_type == "assistant_message":
|
||||
send_message_ran = True
|
||||
|
||||
if chunk.message_type == "usage_statistics":
|
||||
# Validate usage statistics
|
||||
assert chunk.step_count == 1
|
||||
assert chunk.completion_tokens > 10
|
||||
assert chunk.prompt_tokens > 1000
|
||||
assert chunk.total_tokens > 1000
|
||||
done = True
|
||||
print(chunk)
|
||||
|
||||
# If stream tokens, we expect at least one inner thought
|
||||
assert inner_thoughts_count >= 1, "Expected more than one inner thought"
|
||||
assert inner_thoughts_exist, "No inner thoughts found"
|
||||
assert send_message_ran, "send_message function call not found"
|
||||
assert done, "Message stream not done"
|
||||
Reference in New Issue
Block a user