From 5316127be33f6a09cc34e10b8d9a622b7e90c50d Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 28 Jan 2025 14:31:51 -0800 Subject: [PATCH] fix: drop tool_return_message when using assistant_message (#784) --- letta/services/job_manager.py | 16 ++++------ tests/test_managers.py | 59 +++++++++++++++++++++++++++++++++-- 2 files changed, 62 insertions(+), 13 deletions(-) diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index b774a9a1..474cc640 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -303,16 +303,12 @@ class JobManager: request_config = self._get_run_request_config(run_id) - # Convert messages to LettaMessages - messages = [ - msg - for m in messages - for msg in m.to_letta_message( - use_assistant_message=request_config["use_assistant_message"], - assistant_message_tool_name=request_config["assistant_message_tool_name"], - assistant_message_tool_kwarg=request_config["assistant_message_tool_kwarg"], - ) - ] + messages = PydanticMessage.to_letta_messages_from_list( + messages=messages, + use_assistant_message=request_config["use_assistant_message"], + assistant_message_tool_name=request_config["assistant_message_tool_name"], + assistant_message_tool_kwarg=request_config["assistant_message_tool_kwarg"], + ) return messages diff --git a/tests/test_managers.py b/tests/test_managers.py index 3f5d8a0e..b72e55d8 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -3018,13 +3018,14 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_ PydanticMessage( organization_id=default_user.organization_id, agent_id=sarah_agent.id, - role=MessageRole.user if i % 2 == 0 else MessageRole.assistant, - text=f"Test message {i}", + role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant, + text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}', tool_calls=( - [{"type": "function", "id": f"call_{i}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}] + [{"type": "function", "id": f"call_{i//2}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}] if i % 2 == 1 else None ), + tool_call_id=f"call_{i//2}" if i % 2 == 0 else None, ) for i in range(4) ] @@ -3049,6 +3050,58 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_ assert msg.tool_call.name == "custom_tool" +def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_agent): + """Test getting messages for a run with request config.""" + # Create a run with custom request config + run = server.job_manager.create_job( + pydantic_job=PydanticRun( + user_id=default_user.id, + status=JobStatus.created, + request_config=LettaRequestConfig( + use_assistant_message=True, assistant_message_tool_name="custom_tool", assistant_message_tool_kwarg="custom_arg" + ), + ), + actor=default_user, + ) + + # Add some messages + messages = [ + PydanticMessage( + organization_id=default_user.organization_id, + agent_id=sarah_agent.id, + role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant, + text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}', + tool_calls=( + [{"type": "function", "id": f"call_{i//2}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}] + if i % 2 == 1 + else None + ), + tool_call_id=f"call_{i//2}" if i % 2 == 0 else None, + ) + for i in range(4) + ] + + for msg in messages: + created_msg = server.message_manager.create_message(msg, actor=default_user) + server.job_manager.add_message_to_job(job_id=run.id, message_id=created_msg.id, actor=default_user) + + # Get messages and verify they're converted correctly + result = server.job_manager.get_run_messages(run_id=run.id, actor=default_user) + + # Verify correct number of messages. Assistant messages should be parsed + assert len(result) == 4 + + # Verify assistant messages are parsed according to request config + assistant_messages = [msg for msg in result if msg.message_type == "assistant_message"] + reasoning_messages = [msg for msg in result if msg.message_type == "reasoning_message"] + assert len(assistant_messages) == 2 + assert len(reasoning_messages) == 2 + for msg in assistant_messages: + assert msg.content == "test" + for msg in reasoning_messages: + assert "Test message" in msg.reasoning + + # ====================================================================================================================== # JobManager Tests - Usage Statistics # ======================================================================================================================