From 68792caec28aec5408560c272270cd2d21a9055c Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 10 Dec 2024 13:24:05 -0800 Subject: [PATCH] feat: add logs to response for tool run (#2205) Co-authored-by: Caren Thomas --- letta/schemas/letta_message.py | 6 +++++- letta/schemas/sandbox_config.py | 1 + letta/server/server.py | 4 ++++ letta/services/tool_execution_sandbox.py | 23 ++++++++++++++++------- tests/test_server.py | 12 ++++++++++++ 5 files changed, 38 insertions(+), 8 deletions(-) diff --git a/letta/schemas/letta_message.py b/letta/schemas/letta_message.py index b3f7bf90..3b2dc734 100644 --- a/letta/schemas/letta_message.py +++ b/letta/schemas/letta_message.py @@ -1,6 +1,6 @@ import json from datetime import datetime, timezone -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, List, Literal, Optional, Union from pydantic import BaseModel, Field, field_serializer, field_validator @@ -150,12 +150,16 @@ class FunctionReturn(LettaMessage): id (str): The ID of the message date (datetime): The date the message was created in ISO format function_call_id (str): A unique identifier for the function call that generated this message + stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the function invocation + stderr (Optional[List(str)]): Captured stderr from the function invocation """ message_type: Literal["function_return"] = "function_return" function_return: str status: Literal["success", "error"] function_call_id: str + stdout: Optional[List[str]] = None + stderr: Optional[List[str]] = None # Legacy Letta API had an additional type "assistant_message" and the "function_call" was a formatted string diff --git a/letta/schemas/sandbox_config.py b/letta/schemas/sandbox_config.py index 97e5a8ef..246ba8a3 100644 --- a/letta/schemas/sandbox_config.py +++ b/letta/schemas/sandbox_config.py @@ -19,6 +19,7 @@ class SandboxRunResult(BaseModel): func_return: Optional[Any] = Field(None, description="The function return object") agent_state: Optional[AgentState] = Field(None, description="The agent state") stdout: Optional[List[str]] = Field(None, description="Captured stdout (e.g. prints, logs) from the function invocation") + stderr: Optional[List[str]] = Field(None, description="Captured stderr from the function invocation") sandbox_config_fingerprint: str = Field(None, description="The fingerprint of the config for the sandbox") diff --git a/letta/server/server.py b/letta/server/server.py index 08701587..609e8eab 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1829,6 +1829,8 @@ class SyncServer(Server): date=get_utc_time(), status="success", function_return=function_response, + stdout=sandbox_run_result.stdout, + stderr=sandbox_run_result.stderr, ) except Exception as e: # same as agent.py @@ -1844,6 +1846,8 @@ class SyncServer(Server): date=get_utc_time(), status="error", function_return=error_msg, + stdout=[''], + stderr=[traceback.format_exc()], ) # Composio wrappers diff --git a/letta/services/tool_execution_sandbox.py b/letta/services/tool_execution_sandbox.py index b45a0b00..e1698dfe 100644 --- a/letta/services/tool_execution_sandbox.py +++ b/letta/services/tool_execution_sandbox.py @@ -127,11 +127,12 @@ class ToolExecutionSandbox: # Save the old stdout old_stdout = sys.stdout + old_stderr = sys.stderr try: if local_configs.use_venv: return self.run_local_dir_sandbox_venv(sbx_config, env, temp_file_path) else: - return self.run_local_dir_sandbox_runpy(sbx_config, env_vars, temp_file_path, old_stdout) + return self.run_local_dir_sandbox_runpy(sbx_config, env_vars, temp_file_path, old_stdout, old_stderr) except Exception as e: logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}") logger.error(f"Logging out tool {self.tool_name} auto-generated code for debugging: \n\n{code}") @@ -139,6 +140,7 @@ class ToolExecutionSandbox: finally: # Clean up the temp file and restore stdout sys.stdout = old_stdout + sys.stderr = old_stderr os.remove(temp_file_path) def run_local_dir_sandbox_venv(self, sbx_config: SandboxConfig, env: Dict[str, str], temp_file_path: str) -> SandboxRunResult: @@ -202,8 +204,9 @@ class ToolExecutionSandbox: func_return, agent_state = self.parse_best_effort(func_result) return SandboxRunResult( func_return=func_return, - agent_state=agent_state, - stdout=[stdout], + agent_state=agent_state, + stdout=[stdout], + stderr=[result.stderr], sandbox_config_fingerprint=sbx_config.fingerprint(), ) except subprocess.TimeoutExpired: @@ -216,11 +219,13 @@ class ToolExecutionSandbox: raise e def run_local_dir_sandbox_runpy( - self, sbx_config: SandboxConfig, env_vars: Dict[str, str], temp_file_path: str, old_stdout: TextIO + self, sbx_config: SandboxConfig, env_vars: Dict[str, str], temp_file_path: str, old_stdout: TextIO, old_stderr: TextIO ) -> SandboxRunResult: - # Redirect stdout to capture script output + # Redirect stdout and stderr to capture script output captured_stdout = io.StringIO() + captured_stderr = io.StringIO() sys.stdout = captured_stdout + sys.stderr = captured_stderr # Execute the temp file with self.temporary_env_vars(env_vars): @@ -230,14 +235,17 @@ class ToolExecutionSandbox: func_result = result.get(self.LOCAL_SANDBOX_RESULT_VAR_NAME) func_return, agent_state = self.parse_best_effort(func_result) - # Restore stdout and collect captured output + # Restore stdout and stderr and collect captured output sys.stdout = old_stdout + sys.stderr = old_stderr stdout_output = captured_stdout.getvalue() + stderr_output = captured_stderr.getvalue() return SandboxRunResult( func_return=func_return, agent_state=agent_state, stdout=[stdout_output], + stderr=[stderr_output], sandbox_config_fingerprint=sbx_config.fingerprint(), ) @@ -297,7 +305,8 @@ class ToolExecutionSandbox: return SandboxRunResult( func_return=func_return, agent_state=agent_state, - stdout=execution.logs.stdout + execution.logs.stderr, + stdout=execution.logs.stdout, + stderr=execution.logs.stderr, sandbox_config_fingerprint=sbx_config.fingerprint(), ) diff --git a/tests/test_server.py b/tests/test_server.py index 8d85cc1c..32e5da69 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -541,6 +541,8 @@ def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id): print(result) assert result.status == "success" assert result.function_return == "Ingested message Hello, world!", result.function_return + assert result.stdout == [''] + assert result.stderr == [''] result = server.run_tool_from_source( user_id=user_id, @@ -552,6 +554,8 @@ def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id): print(result) assert result.status == "success" assert result.function_return == "Ingested message Well well well", result.function_return + assert result.stdout == [''] + assert result.stderr == [''] result = server.run_tool_from_source( user_id=user_id, @@ -564,6 +568,8 @@ def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id): assert result.status == "error" assert "Error" in result.function_return, result.function_return assert "missing 1 required positional argument" in result.function_return, result.function_return + assert result.stdout == [''] + assert result.stderr != [''], "missing 1 required positional argument" in result.stderr[0] # Test that we can still pull the tool out by default (pulls that last tool in the source) result = server.run_tool_from_source( @@ -576,6 +582,8 @@ def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id): print(result) assert result.status == "success" assert result.function_return == "Ingested message Well well well", result.function_return + assert result.stdout != [''], "I'm a distractor" in result.stdout[0] + assert result.stderr == [''] # Test that we can pull the tool out by name result = server.run_tool_from_source( @@ -588,6 +596,8 @@ def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id): print(result) assert result.status == "success" assert result.function_return == "Ingested message Well well well", result.function_return + assert result.stdout != [''], "I'm a distractor" in result.stdout[0] + assert result.stderr == [''] # Test that we can pull a different tool out by name result = server.run_tool_from_source( @@ -600,6 +610,8 @@ def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id): print(result) assert result.status == "success" assert result.function_return == str(None), result.function_return + assert result.stdout != [''], "I'm a distractor" in result.stdout[0] + assert result.stderr == [''] def test_composio_client_simple(server):