fix: drop tool_return_message when using assistant_message (#784)

This commit is contained in:
cthomas
2025-01-28 14:31:51 -08:00
committed by GitHub
parent 176617c4cb
commit 5316127be3
2 changed files with 62 additions and 13 deletions

View File

@@ -303,16 +303,12 @@ class JobManager:
request_config = self._get_run_request_config(run_id)
# Convert messages to LettaMessages
messages = [
msg
for m in messages
for msg in m.to_letta_message(
use_assistant_message=request_config["use_assistant_message"],
assistant_message_tool_name=request_config["assistant_message_tool_name"],
assistant_message_tool_kwarg=request_config["assistant_message_tool_kwarg"],
)
]
messages = PydanticMessage.to_letta_messages_from_list(
messages=messages,
use_assistant_message=request_config["use_assistant_message"],
assistant_message_tool_name=request_config["assistant_message_tool_name"],
assistant_message_tool_kwarg=request_config["assistant_message_tool_kwarg"],
)
return messages

View File

@@ -3018,13 +3018,14 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_
PydanticMessage(
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
role=MessageRole.user if i % 2 == 0 else MessageRole.assistant,
text=f"Test message {i}",
role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant,
text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}',
tool_calls=(
[{"type": "function", "id": f"call_{i}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}]
[{"type": "function", "id": f"call_{i//2}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}]
if i % 2 == 1
else None
),
tool_call_id=f"call_{i//2}" if i % 2 == 0 else None,
)
for i in range(4)
]
@@ -3049,6 +3050,58 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_
assert msg.tool_call.name == "custom_tool"
def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_agent):
"""Test getting messages for a run with request config."""
# Create a run with custom request config
run = server.job_manager.create_job(
pydantic_job=PydanticRun(
user_id=default_user.id,
status=JobStatus.created,
request_config=LettaRequestConfig(
use_assistant_message=True, assistant_message_tool_name="custom_tool", assistant_message_tool_kwarg="custom_arg"
),
),
actor=default_user,
)
# Add some messages
messages = [
PydanticMessage(
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant,
text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}',
tool_calls=(
[{"type": "function", "id": f"call_{i//2}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}]
if i % 2 == 1
else None
),
tool_call_id=f"call_{i//2}" if i % 2 == 0 else None,
)
for i in range(4)
]
for msg in messages:
created_msg = server.message_manager.create_message(msg, actor=default_user)
server.job_manager.add_message_to_job(job_id=run.id, message_id=created_msg.id, actor=default_user)
# Get messages and verify they're converted correctly
result = server.job_manager.get_run_messages(run_id=run.id, actor=default_user)
# Verify correct number of messages. Assistant messages should be parsed
assert len(result) == 4
# Verify assistant messages are parsed according to request config
assistant_messages = [msg for msg in result if msg.message_type == "assistant_message"]
reasoning_messages = [msg for msg in result if msg.message_type == "reasoning_message"]
assert len(assistant_messages) == 2
assert len(reasoning_messages) == 2
for msg in assistant_messages:
assert msg.content == "test"
for msg in reasoning_messages:
assert "Test message" in msg.reasoning
# ======================================================================================================================
# JobManager Tests - Usage Statistics
# ======================================================================================================================