fix: Fix message_id ordering in agent serialization (#1458)

This commit is contained in:
Matthew Zhou
2025-03-28 15:13:33 -07:00
committed by GitHub
parent a1ad3be919
commit 60ffc9e8ec
5 changed files with 53 additions and 19 deletions

View File

@@ -110,7 +110,9 @@ class AnthropicClient(LLMClientBase):
]
# Move 'system' to the top level
# assert data["messages"][0]["role"] == "system", f"Expected 'system' role in messages[0]:\n{data['messages'][0]}"
if data["messages"][0]["role"] != "system":
raise RuntimeError(f"First message is not a system message, instead has role {data["messages"][0]["role"]}")
data["system"] = data["messages"][0]["content"]
data["messages"] = data["messages"][1:]

View File

@@ -42,7 +42,6 @@ class LLMClientBase:
Otherwise returns a ChatCompletionResponse.
"""
request_data = self.build_request_data(messages, tools, tool_call)
response_data = {}
try:
log_event(name="llm_request_sent", attributes=request_data)

View File

@@ -59,29 +59,26 @@ class MarshmallowAgentSchema(BaseSchema):
"""
- Removes `message_ids`
- Adds versioning
- Marks messages as in-context
- Marks messages as in-context, preserving the order of the original `message_ids`
- Removes individual message `id` fields
"""
data = super().sanitize_ids(data, **kwargs)
data[self.FIELD_VERSION] = letta.__version__
message_ids = list(data.pop(self.FIELD_MESSAGE_IDS, []))
original_message_ids = data.pop(self.FIELD_MESSAGE_IDS, [])
messages = data.get(self.FIELD_MESSAGES, [])
# NOTE: currently we don't support out-of-context messages since it has a bunch of system message spams
# TODO: support out-of-context messages
messages = []
# Build a mapping from message id to its first occurrence index and remove the id in one pass
id_to_index = {}
for idx, message in enumerate(messages):
msg_id = message.pop(self.FIELD_ID, None)
if msg_id is not None and msg_id not in id_to_index:
id_to_index[msg_id] = idx
# loop through message in the *same* order is the in-context message IDs
data[self.FIELD_IN_CONTEXT_INDICES] = []
for i, message in enumerate(data.get(self.FIELD_MESSAGES, [])):
# if id matches in-context message ID, add to `messages`
if message[self.FIELD_ID] in message_ids:
data[self.FIELD_IN_CONTEXT_INDICES].append(i)
messages.append(message)
# Build in-context indices in the same order as the original message_ids
in_context_indices = [id_to_index[msg_id] for msg_id in original_message_ids if msg_id in id_to_index]
# remove ids
for message in messages:
message.pop(self.FIELD_ID, None) # Remove the id field
data[self.FIELD_IN_CONTEXT_INDICES] = in_context_indices
data[self.FIELD_MESSAGES] = messages
return data

View File

@@ -337,9 +337,7 @@ def test_multi_agent_broadcast_client(client: Letta, weather_tool):
}
],
)
import ipdb
ipdb.set_trace()
for message in response.messages:
print(message)

View File

@@ -604,3 +604,41 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
agent_id=copied_agent_id,
messages=[MessageCreate(role=MessageRole.user, content="Hello copied agent!")],
)
# TODO: Add this back
# @pytest.mark.parametrize("test_af_filename", ["deep_research_agent.af"])
# def test_agent_file_upload_flow(fastapi_client, server, default_user, other_user, test_af_filename):
# """
# Test the full E2E serialization and deserialization flow using FastAPI endpoints.
# """
# file_path = Path(__file__).parent / "test_agent_files" / test_af_filename
# with open(file_path, "r") as f:
# data = json.load(f)
#
# # Ensure response matches expected schema
# agent_schema = AgentSchema.model_validate(data) # Validate as Pydantic model
# agent_json = agent_schema.model_dump(mode="json") # Convert back to serializable JSON
#
# import ipdb;ipdb.set_trace()
#
# # Step 2: Upload the serialized agent as a copy
# agent_bytes = BytesIO(json.dumps(agent_json).encode("utf-8"))
# files = {"file": ("agent.json", agent_bytes, "application/json")}
# upload_response = fastapi_client.post(
# "/v1/agents/import",
# headers={"user_id": other_user.id},
# params={"append_copy_suffix": True, "override_existing_tools": False, "project_id": None},
# files=files,
# )
# assert upload_response.status_code == 200, f"Upload failed: {upload_response.text}"
#
# copied_agent = upload_response.json()
# copied_agent_id = copied_agent["id"]
#
# # Step 3: Ensure copied agent receives messages correctly
# server.send_messages(
# actor=other_user,
# agent_id=copied_agent_id,
# messages=[MessageCreate(role=MessageRole.user, content="Hello copied agent!")],
# )