fix: drop tool_return_message when using assistant_message (#784)
This commit is contained in:
@@ -303,16 +303,12 @@ class JobManager:
|
||||
|
||||
request_config = self._get_run_request_config(run_id)
|
||||
|
||||
# Convert messages to LettaMessages
|
||||
messages = [
|
||||
msg
|
||||
for m in messages
|
||||
for msg in m.to_letta_message(
|
||||
use_assistant_message=request_config["use_assistant_message"],
|
||||
assistant_message_tool_name=request_config["assistant_message_tool_name"],
|
||||
assistant_message_tool_kwarg=request_config["assistant_message_tool_kwarg"],
|
||||
)
|
||||
]
|
||||
messages = PydanticMessage.to_letta_messages_from_list(
|
||||
messages=messages,
|
||||
use_assistant_message=request_config["use_assistant_message"],
|
||||
assistant_message_tool_name=request_config["assistant_message_tool_name"],
|
||||
assistant_message_tool_kwarg=request_config["assistant_message_tool_kwarg"],
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@@ -3018,13 +3018,14 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_
|
||||
PydanticMessage(
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user if i % 2 == 0 else MessageRole.assistant,
|
||||
text=f"Test message {i}",
|
||||
role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant,
|
||||
text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}',
|
||||
tool_calls=(
|
||||
[{"type": "function", "id": f"call_{i}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}]
|
||||
[{"type": "function", "id": f"call_{i//2}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}]
|
||||
if i % 2 == 1
|
||||
else None
|
||||
),
|
||||
tool_call_id=f"call_{i//2}" if i % 2 == 0 else None,
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
@@ -3049,6 +3050,58 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_
|
||||
assert msg.tool_call.name == "custom_tool"
|
||||
|
||||
|
||||
def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_agent):
|
||||
"""Test getting messages for a run with request config."""
|
||||
# Create a run with custom request config
|
||||
run = server.job_manager.create_job(
|
||||
pydantic_job=PydanticRun(
|
||||
user_id=default_user.id,
|
||||
status=JobStatus.created,
|
||||
request_config=LettaRequestConfig(
|
||||
use_assistant_message=True, assistant_message_tool_name="custom_tool", assistant_message_tool_kwarg="custom_arg"
|
||||
),
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Add some messages
|
||||
messages = [
|
||||
PydanticMessage(
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant,
|
||||
text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}',
|
||||
tool_calls=(
|
||||
[{"type": "function", "id": f"call_{i//2}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}]
|
||||
if i % 2 == 1
|
||||
else None
|
||||
),
|
||||
tool_call_id=f"call_{i//2}" if i % 2 == 0 else None,
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
|
||||
for msg in messages:
|
||||
created_msg = server.message_manager.create_message(msg, actor=default_user)
|
||||
server.job_manager.add_message_to_job(job_id=run.id, message_id=created_msg.id, actor=default_user)
|
||||
|
||||
# Get messages and verify they're converted correctly
|
||||
result = server.job_manager.get_run_messages(run_id=run.id, actor=default_user)
|
||||
|
||||
# Verify correct number of messages. Assistant messages should be parsed
|
||||
assert len(result) == 4
|
||||
|
||||
# Verify assistant messages are parsed according to request config
|
||||
assistant_messages = [msg for msg in result if msg.message_type == "assistant_message"]
|
||||
reasoning_messages = [msg for msg in result if msg.message_type == "reasoning_message"]
|
||||
assert len(assistant_messages) == 2
|
||||
assert len(reasoning_messages) == 2
|
||||
for msg in assistant_messages:
|
||||
assert msg.content == "test"
|
||||
for msg in reasoning_messages:
|
||||
assert "Test message" in msg.reasoning
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# JobManager Tests - Usage Statistics
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user