fix: Deprecate in memory function stores in agent.py (#2271)

This commit is contained in:
Matthew Zhou
2024-12-17 16:18:11 -08:00
committed by GitHub
parent 7fb8f16155
commit 9a0ffc84dd
7 changed files with 94 additions and 127 deletions

View File

@@ -18,7 +18,7 @@ from letta.constants import (
MESSAGE_SUMMARY_WARNING_FRAC,
O1_BASE_TOOLS,
REQ_HEARTBEAT_MESSAGE,
STRUCTURED_OUTPUT_MODELS
STRUCTURED_OUTPUT_MODELS,
)
from letta.errors import LLMError
from letta.helpers import ToolRulesSolver
@@ -260,9 +260,6 @@ class Agent(BaseAgent):
self.user = user
# link tools
self.link_tools(agent_state.tools)
# initialize a tool rules solver
if agent_state.tool_rules:
# if there are tool rules, print out a warning
@@ -385,7 +382,9 @@ class Agent(BaseAgent):
def check_tool_rules(self):
if self.model not in STRUCTURED_OUTPUT_MODELS:
if len(self.tool_rules_solver.init_tool_rules) > 1:
raise ValueError("Multiple initial tools are not supported for non-structured models. Please use only one initial tool rule.")
raise ValueError(
"Multiple initial tools are not supported for non-structured models. Please use only one initial tool rule."
)
self.supports_structured_output = False
else:
self.supports_structured_output = True
@@ -424,11 +423,21 @@ class Agent(BaseAgent):
return True
return False
def execute_tool_and_persist_state(self, function_name, function_to_call, function_args):
def execute_tool_and_persist_state(self, function_name: str, function_args: dict, target_letta_tool: Tool):
"""
Execute tool modifications and persist the state of the agent.
Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data
"""
# TODO: Get rid of this. This whole piece is pretty shady, that we exec the function to just get the type hints for args.
env = {}
env.update(globals())
exec(target_letta_tool.source_code, env)
callable_func = env[target_letta_tool.json_schema["name"]]
spec = inspect.getfullargspec(callable_func).annotations
for name, arg in function_args.items():
if isinstance(function_args[name], dict):
function_args[name] = spec[name](**function_args[name])
# TODO: add agent manager here
orig_memory_str = self.agent_state.memory.compile()
@@ -441,11 +450,11 @@ class Agent(BaseAgent):
if function_name in BASE_TOOLS or function_name in O1_BASE_TOOLS:
# base tools are allowed to access the `Agent` object and run on the database
function_args["self"] = self # need to attach self to arg since it's dynamically linked
function_response = function_to_call(**function_args)
function_response = callable_func(**function_args)
else:
# execute tool in a sandbox
# TODO: allow agent_state to specify which sandbox to execute tools in
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.agent_state.created_by_id).run(
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user).run(
agent_state=self.agent_state.__deepcopy__()
)
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
@@ -470,27 +479,6 @@ class Agent(BaseAgent):
def messages(self, value):
raise Exception("Modifying message list directly not allowed")
def link_tools(self, tools: List[Tool]):
"""Bind a tool object (schema + python function) to the agent object"""
# Store the functions schemas (this is passed as an argument to ChatCompletion)
self.functions = []
self.functions_python = {}
env = {}
env.update(globals())
for tool in tools:
try:
# WARNING: name may not be consistent?
# if tool.module: # execute the whole module
# exec(tool.module, env)
# else:
exec(tool.source_code, env)
self.functions_python[tool.json_schema["name"]] = env[tool.json_schema["name"]]
self.functions.append(tool.json_schema)
except Exception:
warnings.warn(f"WARNING: tool {tool.name} failed to link")
assert all([callable(f) for k, f in self.functions_python.items()]), self.functions_python
def _load_messages_from_recall(self, message_ids: List[str]) -> List[Message]:
"""Load a list of messages from recall storage"""
@@ -599,8 +587,12 @@ class Agent(BaseAgent):
"""Get response from LLM API with robust retry mechanism."""
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names()
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]
allowed_functions = (
self.functions if not allowed_tool_names else [func for func in self.functions if func["name"] in allowed_tool_names]
agent_state_tool_jsons
if not allowed_tool_names
else [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names]
)
# For the first message, force the initial tool if one is specified
@@ -620,7 +612,7 @@ class Agent(BaseAgent):
messages=message_sequence,
user_id=self.agent_state.created_by_id,
functions=allowed_functions,
functions_python=self.functions_python,
# functions_python=self.functions_python, do we need this?
function_call=function_call,
first_message=first_message,
force_tool_call=force_tool_call,
@@ -729,10 +721,13 @@ class Agent(BaseAgent):
function_name = function_call.name
printd(f"Request to call function {function_name} with tool_call_id: {tool_call_id}")
# Failure case 1: function name is wrong
try:
function_to_call = self.functions_python[function_name]
except KeyError:
# Failure case 1: function name is wrong (not in agent_state.tools)
target_letta_tool = None
for t in self.agent_state.tools:
if t.name == function_name:
target_letta_tool = t
if not target_letta_tool:
error_msg = f"No function named {function_name}"
function_response = package_function_response(False, error_msg)
messages.append(
@@ -800,14 +795,8 @@ class Agent(BaseAgent):
# this is because the function/tool role message is only created once the function/tool has executed/returned
self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1])
try:
spec = inspect.getfullargspec(function_to_call).annotations
for name, arg in function_args.items():
if isinstance(function_args[name], dict):
function_args[name] = spec[name](**function_args[name])
# handle tool execution (sandbox) and state updates
function_response = self.execute_tool_and_persist_state(function_name, function_to_call, function_args)
function_response = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool)
# handle trunction
if function_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]:
@@ -819,8 +808,7 @@ class Agent(BaseAgent):
truncate = True
# get the function response limit
tool_obj = [tool for tool in self.agent_state.tools if tool.name == function_name][0]
return_char_limit = tool_obj.return_char_limit
return_char_limit = target_letta_tool.return_char_limit
function_response_string = validate_function_response(
function_response, return_char_limit=return_char_limit, truncate=truncate
)
@@ -1564,9 +1552,10 @@ class Agent(BaseAgent):
num_tokens_external_memory_summary = count_tokens(external_memory_summary)
# tokens taken up by function definitions
if self.functions:
available_functions_definitions = [ChatCompletionRequestTool(type="function", function=f) for f in self.functions]
num_tokens_available_functions_definitions = num_tokens_from_functions(functions=self.functions, model=self.model)
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]
if agent_state_tool_jsons:
available_functions_definitions = [ChatCompletionRequestTool(type="function", function=f) for f in agent_state_tool_jsons]
num_tokens_available_functions_definitions = num_tokens_from_functions(functions=agent_state_tool_jsons, model=self.model)
else:
available_functions_definitions = []
num_tokens_available_functions_definitions = 0

View File

@@ -195,7 +195,7 @@ def run_tool_from_source(
tool_source_type=request.source_type,
tool_args=request.args,
tool_name=request.name,
user_id=actor.id,
actor=actor,
)
except LettaToolCreateError as e:
# HTTP 400 == Bad Request

View File

@@ -853,10 +853,6 @@ class SyncServer(Server):
# then (2) setting the attributes ._messages and .state.message_ids
letta_agent.set_message_buffer(message_ids=request.message_ids)
# tools
if request.tool_ids:
letta_agent.link_tools(letta_agent.agent_state.tools)
letta_agent.update_state()
return agent_state
@@ -882,11 +878,6 @@ class SyncServer(Server):
agent_state = self.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
# TODO: This is very redundant, and should probably be simplified
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
letta_agent.link_tools(agent_state.tools)
return agent_state
def remove_tool_from_agent(
@@ -900,10 +891,6 @@ class SyncServer(Server):
actor = self.user_manager.get_user_or_default(user_id=user_id)
agent_state = self.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
letta_agent.link_tools(agent_state.tools)
return agent_state
# convert name->id
@@ -1309,9 +1296,7 @@ class SyncServer(Server):
if context_window_limit:
if context_window_limit > llm_config.context_window:
raise ValueError(
f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})"
)
raise ValueError(f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})")
llm_config.context_window = context_window_limit
return llm_config
@@ -1366,7 +1351,7 @@ class SyncServer(Server):
def run_tool_from_source(
self,
user_id: str,
actor: User,
tool_args: str,
tool_source: str,
tool_source_type: Optional[str] = None,
@@ -1394,7 +1379,7 @@ class SyncServer(Server):
# Next, attempt to run the tool with the sandbox
try:
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, user_id, tool_object=tool).run(agent_state=agent_state)
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, actor, tool_object=tool).run(agent_state=agent_state)
return FunctionReturn(
id="null",
function_call_id="null",
@@ -1406,9 +1391,7 @@ class SyncServer(Server):
)
except Exception as e:
func_return = get_friendly_error_msg(
function_name=tool.name, exception_name=type(e).__name__, exception_message=str(e)
)
func_return = get_friendly_error_msg(function_name=tool.name, exception_name=type(e).__name__, exception_message=str(e))
return FunctionReturn(
id="null",
function_call_id="null",

View File

@@ -16,9 +16,9 @@ from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
from letta.schemas.tool import Tool
from letta.schemas.user import User
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.settings import tool_settings
from letta.utils import get_friendly_error_msg
@@ -38,14 +38,10 @@ class ToolExecutionSandbox:
# We make this a long random string to avoid collisions with any variables in the user's code
LOCAL_SANDBOX_RESULT_VAR_NAME = "result_ZQqiequkcFwRwwGQMqkt"
def __init__(self, tool_name: str, args: dict, user_id: str, force_recreate=False, tool_object: Optional[Tool] = None):
def __init__(self, tool_name: str, args: dict, user: User, force_recreate=False, tool_object: Optional[Tool] = None):
self.tool_name = tool_name
self.args = args
# Get the user
# This user corresponds to the agent_state's user_id field
# agent_state is the state of the agent that invoked this run
self.user = UserManager().get_user_by_id(user_id=user_id)
self.user = user
# If a tool object is provided, we use it directly, otherwise pull via name
if tool_object is not None:
@@ -184,7 +180,9 @@ class ToolExecutionSandbox:
except subprocess.CalledProcessError as e:
logger.error(f"Executing tool {self.tool_name} has process error: {e}")
func_return = get_friendly_error_msg(
function_name=self.tool_name, exception_name=type(e).__name__, exception_message=str(e),
function_name=self.tool_name,
exception_name=type(e).__name__,
exception_message=str(e),
)
return SandboxRunResult(
func_return=func_return,
@@ -202,9 +200,7 @@ class ToolExecutionSandbox:
logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}")
raise e
def run_local_dir_sandbox_runpy(
self, sbx_config: SandboxConfig, env_vars: Dict[str, str], temp_file_path: str
) -> SandboxRunResult:
def run_local_dir_sandbox_runpy(self, sbx_config: SandboxConfig, env_vars: Dict[str, str], temp_file_path: str) -> SandboxRunResult:
status = "success"
agent_state, stderr = None, None
@@ -225,9 +221,7 @@ class ToolExecutionSandbox:
func_return, agent_state = self.parse_best_effort(func_result)
except Exception as e:
func_return = get_friendly_error_msg(
function_name=self.tool_name, exception_name=type(e).__name__, exception_message=str(e)
)
func_return = get_friendly_error_msg(function_name=self.tool_name, exception_name=type(e).__name__, exception_message=str(e))
traceback.print_exc(file=sys.stderr)
status = "error"
@@ -248,7 +242,7 @@ class ToolExecutionSandbox:
def parse_out_function_results_markers(self, text: str):
if self.LOCAL_SANDBOX_RESULT_START_MARKER not in text:
return '', text
return "", text
marker_len = len(self.LOCAL_SANDBOX_RESULT_START_MARKER)
start_index = text.index(self.LOCAL_SANDBOX_RESULT_START_MARKER) + marker_len
end_index = text.index(self.LOCAL_SANDBOX_RESULT_END_MARKER)
@@ -293,6 +287,7 @@ class ToolExecutionSandbox:
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
code = self.generate_execution_script(agent_state=agent_state)
execution = sbx.run_code(code, envs=env_vars)
if execution.results:
func_return, agent_state = self.parse_best_effort(execution.results[0].text)
elif execution.error:
@@ -303,7 +298,7 @@ class ToolExecutionSandbox:
execution.logs.stderr.append(execution.error.traceback)
else:
raise ValueError(f"Tool {self.tool_name} returned execution with None")
return SandboxRunResult(
func_return=func_return,
agent_state=agent_state,

View File

@@ -110,8 +110,7 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatComplet
llm_config=agent_state.llm_config,
user_id=str(uuid.UUID(int=1)), # dummy user_id
messages=agent._messages,
functions=agent.functions,
functions_python=agent.functions_python,
functions=[t.json_schema for t in agent.agent_state.tools],
)
# Basic check

View File

@@ -283,12 +283,12 @@ def test_local_sandbox_default(mock_e2b_api_key_none, add_integers_tool, test_us
# Mock and assert correct pathway was invoked
with patch.object(ToolExecutionSandbox, "run_local_dir_sandbox") as mock_run_local_dir_sandbox:
sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user_id=test_user.id)
sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user=test_user)
sandbox.run()
mock_run_local_dir_sandbox.assert_called_once()
# Run again to get actual response
sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user_id=test_user.id)
sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user=test_user)
result = sandbox.run()
assert result.func_return == args["x"] + args["y"]
@@ -297,7 +297,7 @@ def test_local_sandbox_default(mock_e2b_api_key_none, add_integers_tool, test_us
def test_local_sandbox_stateful_tool(mock_e2b_api_key_none, clear_core_memory_tool, test_user, agent_state):
args = {}
# Run again to get actual response
sandbox = ToolExecutionSandbox(clear_core_memory_tool.name, args, user_id=test_user.id)
sandbox = ToolExecutionSandbox(clear_core_memory_tool.name, args, user=test_user)
result = sandbox.run(agent_state=agent_state)
assert result.agent_state.memory.get_block("human").value == ""
assert result.agent_state.memory.get_block("persona").value == ""
@@ -306,7 +306,7 @@ def test_local_sandbox_stateful_tool(mock_e2b_api_key_none, clear_core_memory_to
@pytest.mark.local_sandbox
def test_local_sandbox_with_list_rv(mock_e2b_api_key_none, list_tool, test_user):
sandbox = ToolExecutionSandbox(list_tool.name, {}, user_id=test_user.id)
sandbox = ToolExecutionSandbox(list_tool.name, {}, user=test_user)
result = sandbox.run()
assert len(result.func_return) == 5
@@ -331,7 +331,7 @@ def test_local_sandbox_env(mock_e2b_api_key_none, get_env_tool, test_user):
args = {}
# Run the custom sandbox
sandbox = ToolExecutionSandbox(get_env_tool.name, args, user_id=test_user.id)
sandbox = ToolExecutionSandbox(get_env_tool.name, args, user=test_user)
result = sandbox.run()
assert long_random_string in result.func_return
@@ -349,7 +349,7 @@ def test_local_sandbox_e2e_composio_star_github(mock_e2b_api_key_none, check_com
actor=test_user,
)
result = ToolExecutionSandbox(composio_github_star_tool.name, {"owner": "letta-ai", "repo": "letta"}, user_id=test_user.id).run()
result = ToolExecutionSandbox(composio_github_star_tool.name, {"owner": "letta-ai", "repo": "letta"}, user=test_user).run()
assert result.func_return["details"] == "Action executed successfully"
@@ -359,7 +359,7 @@ def test_local_sandbox_external_codebase(mock_e2b_api_key_none, custom_test_sand
args = {"percentage": 10}
# Run again to get actual response
sandbox = ToolExecutionSandbox(external_codebase_tool.name, args, user_id=test_user.id)
sandbox = ToolExecutionSandbox(external_codebase_tool.name, args, user=test_user)
result = sandbox.run()
# Assert that the function return is correct
@@ -371,14 +371,14 @@ def test_local_sandbox_external_codebase(mock_e2b_api_key_none, custom_test_sand
def test_local_sandbox_with_venv_and_warnings_does_not_error(
mock_e2b_api_key_none, custom_test_sandbox_config, get_warning_tool, test_user
):
sandbox = ToolExecutionSandbox(get_warning_tool.name, {}, user_id=test_user.id)
sandbox = ToolExecutionSandbox(get_warning_tool.name, {}, user=test_user)
result = sandbox.run()
assert result.func_return == "Hello World"
@pytest.mark.e2b_sandbox
def test_local_sandbox_with_venv_errors(mock_e2b_api_key_none, custom_test_sandbox_config, always_err_tool, test_user):
sandbox = ToolExecutionSandbox(always_err_tool.name, {}, user_id=test_user.id)
sandbox = ToolExecutionSandbox(always_err_tool.name, {}, user=test_user)
# run the sandbox
result = sandbox.run()
@@ -397,12 +397,12 @@ def test_e2b_sandbox_default(check_e2b_key_is_set, add_integers_tool, test_user)
# Mock and assert correct pathway was invoked
with patch.object(ToolExecutionSandbox, "run_e2b_sandbox") as mock_run_local_dir_sandbox:
sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user_id=test_user.id)
sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user=test_user)
sandbox.run()
mock_run_local_dir_sandbox.assert_called_once()
# Run again to get actual response
sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user_id=test_user.id)
sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user=test_user)
result = sandbox.run()
assert int(result.func_return) == args["x"] + args["y"]
@@ -420,14 +420,14 @@ def test_e2b_sandbox_pip_installs(check_e2b_key_is_set, cowsay_tool, test_user):
SandboxEnvironmentVariableCreate(key=key, value=long_random_string), sandbox_config_id=config.id, actor=test_user
)
sandbox = ToolExecutionSandbox(cowsay_tool.name, {}, user_id=test_user.id)
sandbox = ToolExecutionSandbox(cowsay_tool.name, {}, user=test_user)
result = sandbox.run()
assert long_random_string in result.stdout[0]
@pytest.mark.e2b_sandbox
def test_e2b_sandbox_reuses_same_sandbox(check_e2b_key_is_set, list_tool, test_user):
sandbox = ToolExecutionSandbox(list_tool.name, {}, user_id=test_user.id)
sandbox = ToolExecutionSandbox(list_tool.name, {}, user=test_user)
# Run the function once
result = sandbox.run()
@@ -442,7 +442,7 @@ def test_e2b_sandbox_reuses_same_sandbox(check_e2b_key_is_set, list_tool, test_u
@pytest.mark.e2b_sandbox
def test_e2b_sandbox_stateful_tool(check_e2b_key_is_set, clear_core_memory_tool, test_user, agent_state):
sandbox = ToolExecutionSandbox(clear_core_memory_tool.name, {}, user_id=test_user.id)
sandbox = ToolExecutionSandbox(clear_core_memory_tool.name, {}, user=test_user)
# run the sandbox
result = sandbox.run(agent_state=agent_state)
@@ -458,7 +458,7 @@ def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, get_e
config = manager.create_or_update_sandbox_config(config_create, test_user)
# Run the custom sandbox once, assert nothing returns because missing env variable
sandbox = ToolExecutionSandbox(get_env_tool.name, {}, user_id=test_user.id, force_recreate=True)
sandbox = ToolExecutionSandbox(get_env_tool.name, {}, user=test_user, force_recreate=True)
result = sandbox.run()
# response should be None
assert result.func_return is None
@@ -471,7 +471,7 @@ def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, get_e
)
# Assert that the environment variable gets injected correctly, even when the sandbox is NOT refreshed
sandbox = ToolExecutionSandbox(get_env_tool.name, {}, user_id=test_user.id)
sandbox = ToolExecutionSandbox(get_env_tool.name, {}, user=test_user)
result = sandbox.run()
assert long_random_string in result.func_return
@@ -487,7 +487,7 @@ def test_e2b_sandbox_config_change_force_recreates_sandbox(check_e2b_key_is_set,
config = manager.create_or_update_sandbox_config(config_create, test_user)
# Run the custom sandbox once, assert a failure gets returned because missing environment variable
sandbox = ToolExecutionSandbox(list_tool.name, {}, user_id=test_user.id)
sandbox = ToolExecutionSandbox(list_tool.name, {}, user=test_user)
result = sandbox.run()
assert len(result.func_return) == 5
old_config_fingerprint = result.sandbox_config_fingerprint
@@ -497,7 +497,7 @@ def test_e2b_sandbox_config_change_force_recreates_sandbox(check_e2b_key_is_set,
config = manager.update_sandbox_config(config.id, config_update, test_user)
# Run again
result = ToolExecutionSandbox(list_tool.name, {}, user_id=test_user.id).run()
result = ToolExecutionSandbox(list_tool.name, {}, user=test_user).run()
new_config_fingerprint = result.sandbox_config_fingerprint
assert config.fingerprint() == new_config_fingerprint
@@ -507,7 +507,7 @@ def test_e2b_sandbox_config_change_force_recreates_sandbox(check_e2b_key_is_set,
@pytest.mark.e2b_sandbox
def test_e2b_sandbox_with_list_rv(check_e2b_key_is_set, list_tool, test_user):
sandbox = ToolExecutionSandbox(list_tool.name, {}, user_id=test_user.id)
sandbox = ToolExecutionSandbox(list_tool.name, {}, user=test_user)
result = sandbox.run()
assert len(result.func_return) == 5
@@ -524,7 +524,7 @@ def test_e2b_e2e_composio_star_github(check_e2b_key_is_set, check_composio_key_s
actor=test_user,
)
result = ToolExecutionSandbox(composio_github_star_tool.name, {"owner": "letta-ai", "repo": "letta"}, user_id=test_user.id).run()
result = ToolExecutionSandbox(composio_github_star_tool.name, {"owner": "letta-ai", "repo": "letta"}, user=test_user).run()
assert result.func_return["details"] == "Action executed successfully"
@@ -541,7 +541,7 @@ class TestCoreMemoryTools:
"""Test successful replacement of content in core memory - local sandbox."""
new_name = "Charles"
args = {"label": "human", "old_content": "Chad", "new_content": new_name}
sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user_id=test_user.id)
sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user=test_user)
result = sandbox.run(agent_state=agent_state)
assert new_name in result.agent_state.memory.get_block("human").value
@@ -552,7 +552,7 @@ class TestCoreMemoryTools:
"""Test successful appending of content to core memory - local sandbox."""
append_text = "\nLikes coffee"
args = {"label": "human", "content": append_text}
sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_append"].name, args, user_id=test_user.id)
sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_append"].name, args, user=test_user)
result = sandbox.run(agent_state=agent_state)
assert append_text in result.agent_state.memory.get_block("human").value
@@ -563,7 +563,7 @@ class TestCoreMemoryTools:
"""Test error handling when trying to replace non-existent content - local sandbox."""
nonexistent_name = "Alexander Wang"
args = {"label": "human", "old_content": nonexistent_name, "new_content": "Charles"}
sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user_id=test_user.id)
sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user=test_user)
result = sandbox.run(agent_state=agent_state)
assert len(result.stderr) != 0
@@ -575,7 +575,7 @@ class TestCoreMemoryTools:
"""Test successful replacement of content in core memory - e2b sandbox."""
new_name = "Charles"
args = {"label": "human", "old_content": "Chad", "new_content": new_name}
sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user_id=test_user.id)
sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user=test_user)
result = sandbox.run(agent_state=agent_state)
assert new_name in result.agent_state.memory.get_block("human").value
@@ -586,7 +586,7 @@ class TestCoreMemoryTools:
"""Test successful appending of content to core memory - e2b sandbox."""
append_text = "\nLikes coffee"
args = {"label": "human", "content": append_text}
sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_append"].name, args, user_id=test_user.id)
sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_append"].name, args, user=test_user)
result = sandbox.run(agent_state=agent_state)
assert append_text in result.agent_state.memory.get_block("human").value
@@ -597,7 +597,7 @@ class TestCoreMemoryTools:
"""Test error handling when trying to replace non-existent content - e2b sandbox."""
nonexistent_name = "Alexander Wang"
args = {"label": "human", "old_content": nonexistent_name, "new_content": "Charles"}
sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user_id=test_user.id)
sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user=test_user)
result = sandbox.run(agent_state=agent_state)
assert len(result.stderr) != 0

View File

@@ -288,17 +288,18 @@ def org_id(server):
@pytest.fixture(scope="module")
def user_id(server, org_id):
# create user
def user(server, org_id):
user = server.user_manager.create_default_user()
print(f"Created user\n{user.id}")
yield user.id
# cleanup
yield user
server.user_manager.delete_user_by_id(user.id)
@pytest.fixture(scope="module")
def user_id(server, user):
# create user
yield user.id
@pytest.fixture(scope="module")
def base_tools(server, user_id):
actor = server.user_manager.get_user_or_default(user_id)
@@ -789,11 +790,11 @@ def ingest(message: str):
'''
def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id):
def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
"""Test that the server can run tools"""
result = server.run_tool_from_source(
user_id=user_id,
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE,
tool_source_type="python",
tool_args=json.dumps({"message": "Hello, world!"}),
@@ -806,7 +807,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id):
assert not result.stderr
result = server.run_tool_from_source(
user_id=user_id,
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE,
tool_source_type="python",
tool_args=json.dumps({"message": "Well well well"}),
@@ -819,7 +820,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id):
assert not result.stderr
result = server.run_tool_from_source(
user_id=user_id,
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE,
tool_source_type="python",
tool_args=json.dumps({"bad_arg": "oh no"}),
@@ -835,7 +836,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id):
# Test that we can still pull the tool out by default (pulls that last tool in the source)
result = server.run_tool_from_source(
user_id=user_id,
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
tool_source_type="python",
tool_args=json.dumps({"message": "Well well well"}),
@@ -850,7 +851,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id):
# Test that we can pull the tool out by name
result = server.run_tool_from_source(
user_id=user_id,
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
tool_source_type="python",
tool_args=json.dumps({"message": "Well well well"}),
@@ -865,7 +866,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id):
# Test that we can pull a different tool out by name
result = server.run_tool_from_source(
user_id=user_id,
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
tool_source_type="python",
tool_args=json.dumps({}),