diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index bab45f9c..c0f7a39b 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -110,7 +110,9 @@ class AnthropicClient(LLMClientBase): ] # Move 'system' to the top level - # assert data["messages"][0]["role"] == "system", f"Expected 'system' role in messages[0]:\n{data['messages'][0]}" + if data["messages"][0]["role"] != "system": + raise RuntimeError(f"First message is not a system message, instead has role {data["messages"][0]["role"]}") + data["system"] = data["messages"][0]["content"] data["messages"] = data["messages"][1:] diff --git a/letta/llm_api/llm_client_base.py b/letta/llm_api/llm_client_base.py index cab01dce..66cbf244 100644 --- a/letta/llm_api/llm_client_base.py +++ b/letta/llm_api/llm_client_base.py @@ -42,7 +42,6 @@ class LLMClientBase: Otherwise returns a ChatCompletionResponse. """ request_data = self.build_request_data(messages, tools, tool_call) - response_data = {} try: log_event(name="llm_request_sent", attributes=request_data) diff --git a/letta/serialize_schemas/marshmallow_agent.py b/letta/serialize_schemas/marshmallow_agent.py index f38f1921..a6fb330f 100644 --- a/letta/serialize_schemas/marshmallow_agent.py +++ b/letta/serialize_schemas/marshmallow_agent.py @@ -59,29 +59,26 @@ class MarshmallowAgentSchema(BaseSchema): """ - Removes `message_ids` - Adds versioning - - Marks messages as in-context + - Marks messages as in-context, preserving the order of the original `message_ids` - Removes individual message `id` fields """ data = super().sanitize_ids(data, **kwargs) data[self.FIELD_VERSION] = letta.__version__ - message_ids = list(data.pop(self.FIELD_MESSAGE_IDS, [])) + original_message_ids = data.pop(self.FIELD_MESSAGE_IDS, []) + messages = data.get(self.FIELD_MESSAGES, []) - # NOTE: currently we don't support out-of-context messages since it has a bunch of system message spams - # TODO: support out-of-context messages - messages = [] + # Build a mapping from message id to its first occurrence index and remove the id in one pass + id_to_index = {} + for idx, message in enumerate(messages): + msg_id = message.pop(self.FIELD_ID, None) + if msg_id is not None and msg_id not in id_to_index: + id_to_index[msg_id] = idx - # loop through message in the *same* order is the in-context message IDs - data[self.FIELD_IN_CONTEXT_INDICES] = [] - for i, message in enumerate(data.get(self.FIELD_MESSAGES, [])): - # if id matches in-context message ID, add to `messages` - if message[self.FIELD_ID] in message_ids: - data[self.FIELD_IN_CONTEXT_INDICES].append(i) - messages.append(message) + # Build in-context indices in the same order as the original message_ids + in_context_indices = [id_to_index[msg_id] for msg_id in original_message_ids if msg_id in id_to_index] - # remove ids - for message in messages: - message.pop(self.FIELD_ID, None) # Remove the id field + data[self.FIELD_IN_CONTEXT_INDICES] = in_context_indices data[self.FIELD_MESSAGES] = messages return data diff --git a/tests/integration_test_experimental.py b/tests/integration_test_experimental.py index cbc5ab74..23489895 100644 --- a/tests/integration_test_experimental.py +++ b/tests/integration_test_experimental.py @@ -337,9 +337,7 @@ def test_multi_agent_broadcast_client(client: Letta, weather_tool): } ], ) - import ipdb - ipdb.set_trace() for message in response.messages: print(message) diff --git a/tests/test_agent_serialization.py b/tests/test_agent_serialization.py index d73bffb1..ef455ee9 100644 --- a/tests/test_agent_serialization.py +++ b/tests/test_agent_serialization.py @@ -604,3 +604,41 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent agent_id=copied_agent_id, messages=[MessageCreate(role=MessageRole.user, content="Hello copied agent!")], ) + + +# TODO: Add this back +# @pytest.mark.parametrize("test_af_filename", ["deep_research_agent.af"]) +# def test_agent_file_upload_flow(fastapi_client, server, default_user, other_user, test_af_filename): +# """ +# Test the full E2E serialization and deserialization flow using FastAPI endpoints. +# """ +# file_path = Path(__file__).parent / "test_agent_files" / test_af_filename +# with open(file_path, "r") as f: +# data = json.load(f) +# +# # Ensure response matches expected schema +# agent_schema = AgentSchema.model_validate(data) # Validate as Pydantic model +# agent_json = agent_schema.model_dump(mode="json") # Convert back to serializable JSON +# +# import ipdb;ipdb.set_trace() +# +# # Step 2: Upload the serialized agent as a copy +# agent_bytes = BytesIO(json.dumps(agent_json).encode("utf-8")) +# files = {"file": ("agent.json", agent_bytes, "application/json")} +# upload_response = fastapi_client.post( +# "/v1/agents/import", +# headers={"user_id": other_user.id}, +# params={"append_copy_suffix": True, "override_existing_tools": False, "project_id": None}, +# files=files, +# ) +# assert upload_response.status_code == 200, f"Upload failed: {upload_response.text}" +# +# copied_agent = upload_response.json() +# copied_agent_id = copied_agent["id"] +# +# # Step 3: Ensure copied agent receives messages correctly +# server.send_messages( +# actor=other_user, +# agent_id=copied_agent_id, +# messages=[MessageCreate(role=MessageRole.user, content="Hello copied agent!")], +# )