fix: Fix message_id ordering in agent serialization (#1458)
This commit is contained in:
@@ -110,7 +110,9 @@ class AnthropicClient(LLMClientBase):
|
||||
]
|
||||
|
||||
# Move 'system' to the top level
|
||||
# assert data["messages"][0]["role"] == "system", f"Expected 'system' role in messages[0]:\n{data['messages'][0]}"
|
||||
if data["messages"][0]["role"] != "system":
|
||||
raise RuntimeError(f"First message is not a system message, instead has role {data["messages"][0]["role"]}")
|
||||
|
||||
data["system"] = data["messages"][0]["content"]
|
||||
data["messages"] = data["messages"][1:]
|
||||
|
||||
|
||||
@@ -42,7 +42,6 @@ class LLMClientBase:
|
||||
Otherwise returns a ChatCompletionResponse.
|
||||
"""
|
||||
request_data = self.build_request_data(messages, tools, tool_call)
|
||||
response_data = {}
|
||||
|
||||
try:
|
||||
log_event(name="llm_request_sent", attributes=request_data)
|
||||
|
||||
@@ -59,29 +59,26 @@ class MarshmallowAgentSchema(BaseSchema):
|
||||
"""
|
||||
- Removes `message_ids`
|
||||
- Adds versioning
|
||||
- Marks messages as in-context
|
||||
- Marks messages as in-context, preserving the order of the original `message_ids`
|
||||
- Removes individual message `id` fields
|
||||
"""
|
||||
data = super().sanitize_ids(data, **kwargs)
|
||||
data[self.FIELD_VERSION] = letta.__version__
|
||||
|
||||
message_ids = list(data.pop(self.FIELD_MESSAGE_IDS, []))
|
||||
original_message_ids = data.pop(self.FIELD_MESSAGE_IDS, [])
|
||||
messages = data.get(self.FIELD_MESSAGES, [])
|
||||
|
||||
# NOTE: currently we don't support out-of-context messages since it has a bunch of system message spams
|
||||
# TODO: support out-of-context messages
|
||||
messages = []
|
||||
# Build a mapping from message id to its first occurrence index and remove the id in one pass
|
||||
id_to_index = {}
|
||||
for idx, message in enumerate(messages):
|
||||
msg_id = message.pop(self.FIELD_ID, None)
|
||||
if msg_id is not None and msg_id not in id_to_index:
|
||||
id_to_index[msg_id] = idx
|
||||
|
||||
# loop through message in the *same* order is the in-context message IDs
|
||||
data[self.FIELD_IN_CONTEXT_INDICES] = []
|
||||
for i, message in enumerate(data.get(self.FIELD_MESSAGES, [])):
|
||||
# if id matches in-context message ID, add to `messages`
|
||||
if message[self.FIELD_ID] in message_ids:
|
||||
data[self.FIELD_IN_CONTEXT_INDICES].append(i)
|
||||
messages.append(message)
|
||||
# Build in-context indices in the same order as the original message_ids
|
||||
in_context_indices = [id_to_index[msg_id] for msg_id in original_message_ids if msg_id in id_to_index]
|
||||
|
||||
# remove ids
|
||||
for message in messages:
|
||||
message.pop(self.FIELD_ID, None) # Remove the id field
|
||||
data[self.FIELD_IN_CONTEXT_INDICES] = in_context_indices
|
||||
data[self.FIELD_MESSAGES] = messages
|
||||
|
||||
return data
|
||||
|
||||
@@ -337,9 +337,7 @@ def test_multi_agent_broadcast_client(client: Letta, weather_tool):
|
||||
}
|
||||
],
|
||||
)
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
for message in response.messages:
|
||||
print(message)
|
||||
|
||||
|
||||
@@ -604,3 +604,41 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
|
||||
agent_id=copied_agent_id,
|
||||
messages=[MessageCreate(role=MessageRole.user, content="Hello copied agent!")],
|
||||
)
|
||||
|
||||
|
||||
# TODO: Add this back
|
||||
# @pytest.mark.parametrize("test_af_filename", ["deep_research_agent.af"])
|
||||
# def test_agent_file_upload_flow(fastapi_client, server, default_user, other_user, test_af_filename):
|
||||
# """
|
||||
# Test the full E2E serialization and deserialization flow using FastAPI endpoints.
|
||||
# """
|
||||
# file_path = Path(__file__).parent / "test_agent_files" / test_af_filename
|
||||
# with open(file_path, "r") as f:
|
||||
# data = json.load(f)
|
||||
#
|
||||
# # Ensure response matches expected schema
|
||||
# agent_schema = AgentSchema.model_validate(data) # Validate as Pydantic model
|
||||
# agent_json = agent_schema.model_dump(mode="json") # Convert back to serializable JSON
|
||||
#
|
||||
# import ipdb;ipdb.set_trace()
|
||||
#
|
||||
# # Step 2: Upload the serialized agent as a copy
|
||||
# agent_bytes = BytesIO(json.dumps(agent_json).encode("utf-8"))
|
||||
# files = {"file": ("agent.json", agent_bytes, "application/json")}
|
||||
# upload_response = fastapi_client.post(
|
||||
# "/v1/agents/import",
|
||||
# headers={"user_id": other_user.id},
|
||||
# params={"append_copy_suffix": True, "override_existing_tools": False, "project_id": None},
|
||||
# files=files,
|
||||
# )
|
||||
# assert upload_response.status_code == 200, f"Upload failed: {upload_response.text}"
|
||||
#
|
||||
# copied_agent = upload_response.json()
|
||||
# copied_agent_id = copied_agent["id"]
|
||||
#
|
||||
# # Step 3: Ensure copied agent receives messages correctly
|
||||
# server.send_messages(
|
||||
# actor=other_user,
|
||||
# agent_id=copied_agent_id,
|
||||
# messages=[MessageCreate(role=MessageRole.user, content="Hello copied agent!")],
|
||||
# )
|
||||
|
||||
Reference in New Issue
Block a user