diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index 94d18232..ddc5309e 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -123,35 +123,17 @@ class BaseAgent(ABC): curr_system_message = in_context_messages[0] curr_system_message_text = curr_system_message.content[0].text - # extract the dynamic section that includes memory blocks, tool rules, and directories - # this avoids timestamp comparison issues - # TODO: This is a separate position-based parser for the same system message format - # parsed by ContextWindowCalculator.extract_system_components(). Consider unifying - # to avoid divergence. See PR #9398 for context. - def extract_dynamic_section(text): - start_marker = "" - end_marker = "" - - start_idx = text.find(start_marker) - end_idx = text.find(end_marker) - - if start_idx != -1 and end_idx != -1: - return text[start_idx:end_idx] - return text # fallback to full text if markers not found - - curr_dynamic_section = extract_dynamic_section(curr_system_message_text) - - # generate just the memory string with current state for comparison + # generate memory string with current state for comparison curr_memory_str = agent_state.memory.compile( tool_usage_rules=tool_constraint_block, sources=agent_state.sources, max_files_open=agent_state.max_files_open, llm_config=agent_state.llm_config, ) - new_dynamic_section = extract_dynamic_section(curr_memory_str) - # compare just the dynamic sections (memory blocks, tool rules, directories) - if curr_dynamic_section == new_dynamic_section: + system_prompt_changed = agent_state.system not in curr_system_message_text + memory_changed = curr_memory_str not in curr_system_message_text + if (not system_prompt_changed) and (not memory_changed): logger.debug( f"Memory and sources haven't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" ) diff --git a/letta/agents/letta_agent_v2.py b/letta/agents/letta_agent_v2.py index 3a01ab7a..2c0bf6ca 100644 --- a/letta/agents/letta_agent_v2.py +++ b/letta/agents/letta_agent_v2.py @@ -704,19 +704,18 @@ class LettaAgentV2(BaseAgentV2): Returns: Refreshed in-context messages. """ - # Always attempt to rebuild the system prompt if the memory section changed. - # This method is careful to skip rebuilds when the memory section is unchanged. - try: - in_context_messages = await self._rebuild_memory( - in_context_messages, - num_messages=None, - num_archival_memories=None, - ) - except Exception as e: - # If callers requested a forced refresh, surface the error. - if force_system_prompt_refresh: + # Only rebuild when explicitly forced (e.g., after compaction). + # Normal turns should not trigger system prompt recompilation. + if force_system_prompt_refresh: + try: + in_context_messages = await self._rebuild_memory( + in_context_messages, + num_messages=None, + num_archival_memories=None, + force=True, + ) + except Exception as e: raise - self.logger.warning(f"Failed to refresh system prompt/memory: {e}") # Always scrub inner thoughts regardless of system prompt refresh in_context_messages = scrub_inner_thoughts_from_messages(in_context_messages, self.agent_state.llm_config) @@ -728,6 +727,7 @@ class LettaAgentV2(BaseAgentV2): in_context_messages: list[Message], num_messages: int | None, num_archival_memories: int | None, + force: bool = False, ): agent_state = await self.agent_manager.refresh_memory_async(agent_state=self.agent_state, actor=self.actor) @@ -748,51 +748,24 @@ class LettaAgentV2(BaseAgentV2): else: archive_tags = None - # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this curr_system_message = in_context_messages[0] curr_system_message_text = curr_system_message.content[0].text - # Extract the memory section that includes , tool rules, and directories. - # This avoids timestamp comparison issues in , which is dynamic. - def extract_memory_section(text: str) -> str: - # Primary pattern: everything from up to - mem_start = text.find("") - meta_start = text.find("") - if mem_start != -1: - if meta_start != -1 and meta_start > mem_start: - return text[mem_start:meta_start] - return text[mem_start:] - - # Fallback pattern used in some legacy prompts: between and - base_end = text.find("") - if base_end != -1: - if meta_start != -1 and meta_start > base_end: - return text[base_end + len("") : meta_start] - return text[base_end + len("") :] - - # Last resort: return full text - return text - - curr_memory_section = extract_memory_section(curr_system_message_text) - # refresh files agent_state = await self.agent_manager.refresh_file_blocks(agent_state=agent_state, actor=self.actor) - # generate just the memory string with current state for comparison + # generate memory string with current state curr_memory_str = agent_state.memory.compile( tool_usage_rules=tool_constraint_block, sources=agent_state.sources, max_files_open=agent_state.max_files_open, llm_config=agent_state.llm_config, ) - new_memory_section = extract_memory_section(curr_memory_str) - # Compare just the memory sections (memory blocks, tool rules, directories). - # Also ensure the configured system prompt is still present; if the system prompt - # changed (e.g. via UpdateAgent(system=...)), we must rebuild. + # Skip rebuild unless explicitly forced and unless system/memory content actually changed. system_prompt_changed = agent_state.system not in curr_system_message_text - - if (not system_prompt_changed) and (curr_memory_section.strip() == new_memory_section.strip()): + memory_changed = curr_memory_str not in curr_system_message_text + if (not force) and (not system_prompt_changed) and (not memory_changed): self.logger.debug( f"Memory, sources, and system prompt haven't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" ) diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 2edc5f7a..db1a7df5 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -202,7 +202,7 @@ class LettaAgentV3(LettaAgentV2): input_messages_to_persist = [input_messages_to_persist[0]] self.in_context_messages = curr_in_context_messages - + # Check if we should use SGLang native adapter for multi-turn RL training use_sglang_native = ( self.agent_state.llm_config.return_token_ids @@ -210,7 +210,7 @@ class LettaAgentV3(LettaAgentV2): and self.agent_state.llm_config.handle.startswith("sglang/") ) self.return_token_ids = use_sglang_native - + if use_sglang_native: # Use SGLang native adapter for multi-turn RL training llm_adapter = SGLangNativeAdapter( @@ -399,7 +399,7 @@ class LettaAgentV3(LettaAgentV2): and self.agent_state.llm_config.handle.startswith("sglang/") ) self.return_token_ids = use_sglang_native - + if stream_tokens: llm_adapter = SimpleLLMStreamAdapter( llm_client=self.llm_client, @@ -977,6 +977,10 @@ class LettaAgentV3(LettaAgentV2): trigger="context_window_exceeded", ) + # Ensure system prompt is recompiled before summarization so compaction + # operates on the latest system+memory state (including recent repairs). + messages = await self._refresh_messages(messages, force_system_prompt_refresh=True) + summary_message, messages, summary_text = await self.compact( messages, trigger_threshold=self.agent_state.llm_config.context_window, @@ -987,6 +991,15 @@ class LettaAgentV3(LettaAgentV2): context_tokens_before=context_tokens_before, messages_count_before=messages_count_before, ) + + # Recompile the persisted system prompt after compaction so subsequent + # turns load the repaired system+memory state from message_ids[0]. + await self.agent_manager.rebuild_system_prompt_async( + agent_id=self.agent_state.id, + actor=self.actor, + force=True, + update_timestamp=True, + ) # Force system prompt rebuild after compaction to update memory blocks and timestamps messages = await self._refresh_messages(messages, force_system_prompt_refresh=True) self.logger.info("Summarization succeeded, continuing to retry LLM request") @@ -1038,15 +1051,19 @@ class LettaAgentV3(LettaAgentV2): # Extract logprobs if present (for RL training) if llm_adapter.logprobs is not None: self.logprobs = llm_adapter.logprobs - + # Track turn data for multi-turn RL training (SGLang native mode) if self.return_token_ids and hasattr(llm_adapter, "output_ids") and llm_adapter.output_ids: - self.turns.append(TurnTokenData( - role="assistant", - output_ids=llm_adapter.output_ids, - output_token_logprobs=llm_adapter.output_token_logprobs, - content=llm_adapter.chat_completions_response.choices[0].message.content if llm_adapter.chat_completions_response else None, - )) + self.turns.append( + TurnTokenData( + role="assistant", + output_ids=llm_adapter.output_ids, + output_token_logprobs=llm_adapter.output_token_logprobs, + content=llm_adapter.chat_completions_response.choices[0].message.content + if llm_adapter.chat_completions_response + else None, + ) + ) # Handle the AI response with the extracted data (supports multiple tool calls) # Gather tool calls - check for multi-call API first, then fall back to single @@ -1093,7 +1110,7 @@ class LettaAgentV3(LettaAgentV2): # extend trackers with new messages self.response_messages.extend(new_messages) messages.extend(new_messages) - + # Track tool return turns for multi-turn RL training if self.return_token_ids: for msg in new_messages: @@ -1105,7 +1122,7 @@ class LettaAgentV3(LettaAgentV2): # Aggregate all tool returns into content (func_response is the actual content) parts = [] for tr in msg.tool_returns: - if hasattr(tr, 'func_response') and tr.func_response: + if hasattr(tr, "func_response") and tr.func_response: if isinstance(tr.func_response, str): parts.append(tr.func_response) else: @@ -1116,11 +1133,13 @@ class LettaAgentV3(LettaAgentV2): if hasattr(msg, "name"): tool_name = msg.name if tool_content: - self.turns.append(TurnTokenData( - role="tool", - content=tool_content, - tool_name=tool_name, - )) + self.turns.append( + TurnTokenData( + role="tool", + content=tool_content, + tool_name=tool_name, + ) + ) # step(...) has successfully completed! now we can persist messages and update the in-context messages + save metrics # persistence needs to happen before streaming to minimize chances of agent getting into an inconsistent state @@ -1177,6 +1196,10 @@ class LettaAgentV3(LettaAgentV2): ) try: + # Ensure system prompt is recompiled before summarization so compaction + # operates on the latest system+memory state (including recent repairs). + messages = await self._refresh_messages(messages, force_system_prompt_refresh=True) + summary_message, messages, summary_text = await self.compact( messages, trigger_threshold=self.agent_state.llm_config.context_window, @@ -1187,6 +1210,15 @@ class LettaAgentV3(LettaAgentV2): context_tokens_before=context_tokens_before, messages_count_before=messages_count_before, ) + + # Recompile the persisted system prompt after compaction so subsequent + # turns load the repaired system+memory state from message_ids[0]. + await self.agent_manager.rebuild_system_prompt_async( + agent_id=self.agent_state.id, + actor=self.actor, + force=True, + update_timestamp=True, + ) # Force system prompt rebuild after compaction to update memory blocks and timestamps messages = await self._refresh_messages(messages, force_system_prompt_refresh=True) # TODO: persist + return the summary message diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index b4ca07cf..c96c2f67 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -243,7 +243,8 @@ class Memory(BaseModel, validate_assignment=True): """Render a filesystem tree view of all memory blocks. Only rendered for git-memory-enabled agents. Uses box-drawing - characters (├──, └──, │) like the Unix `tree` command. + characters (├──, └──, │) like the Unix `tree` command, while keeping + deterministic ordering (directories first, then files, alphabetically). """ if not self.blocks: return diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 5436954e..0fb1b695 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -313,8 +313,8 @@ async def export_agent( raise HTTPException(status_code=400, detail="Legacy format is not supported") actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) agent_file_schema = await server.agent_serialization_manager.export( - agent_ids=[agent_id], - actor=actor, + agent_ids=[agent_id], + actor=actor, conversation_id=conversation_id, scrub_messages=scrub_messages, ) @@ -2306,6 +2306,7 @@ async def reset_messages( actor=actor, add_default_initial_messages=request.add_default_initial_messages, needs_agent_state=not is_1_0_sdk_version(headers), + rebuild_system_prompt=True, ) diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index abc390ae..561657c0 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -551,19 +551,16 @@ class BlockManager: result = await session.execute(query) blocks = result.scalars().all() - # Convert to Pydantic models + # Convert to Pydantic models and preserve caller-provided ID order pydantic_blocks = [block.to_pydantic() for block in blocks] + blocks_by_id = {b.id: b for b in pydantic_blocks} + ordered_blocks = [blocks_by_id.get(block_id) for block_id in block_ids] - # For backward compatibility, add None for missing blocks + # For backward compatibility, include None for missing blocks if len(pydantic_blocks) < len(block_ids): - {block.id for block in pydantic_blocks} - result_blocks = [] - for block_id in block_ids: - block = next((b for b in pydantic_blocks if b.id == block_id), None) - result_blocks.append(block) - return result_blocks + return ordered_blocks - return pydantic_blocks + return ordered_blocks @enforce_types @trace_method diff --git a/tests/test_context_window_calculator.py b/tests/test_context_window_calculator.py index 0f3f3035..2f44a963 100644 --- a/tests/test_context_window_calculator.py +++ b/tests/test_context_window_calculator.py @@ -333,7 +333,7 @@ Some directory content ) result = ContextWindowCalculator.extract_system_components(system_message) - # memory_filesystem should be the tree view only + # memory_filesystem should preserve tree connectors with deterministic ordering assert result["memory_filesystem"] is not None assert "\u251c\u2500\u2500 system/" in result["memory_filesystem"] diff --git a/tests/test_sources.py b/tests/test_sources.py index 79555784..cff833fe 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -27,8 +27,20 @@ from tests.utils import wait_for_server SERVER_PORT = 8283 -def get_raw_system_message(client: LettaSDKClient, agent_id: str) -> str: +def recompile_agent_system_prompt(client: LettaSDKClient, agent_id: str) -> None: + """Force a system prompt recompilation for deterministic raw-preview assertions.""" + client.post( + f"/v1/agents/{agent_id}/recompile", + cast_to=str, + body={}, + ) + + +def get_raw_system_message(client: LettaSDKClient, agent_id: str, recompile: bool = False) -> str: """Helper function to get the raw system message from an agent's preview payload.""" + if recompile: + recompile_agent_system_prompt(client, agent_id) + raw_payload = client.post( f"/v1/agents/{agent_id}/messages/preview-raw-payload", cast_to=dict[str, Any], @@ -215,7 +227,7 @@ def test_file_upload_creates_source_blocks_correctly( assert any(re.fullmatch(expected_label_regex, b.label) for b in blocks) # verify raw system message contains source information - raw_system_message = get_raw_system_message(client, agent_state.id) + raw_system_message = get_raw_system_message(client, agent_state.id, recompile=True) assert "test_source" in raw_system_message assert "" in raw_system_message # verify file-specific details in raw system message @@ -234,7 +246,7 @@ def test_file_upload_creates_source_blocks_correctly( assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks) # verify raw system message no longer contains source information - raw_system_message_after_removal = get_raw_system_message(client, agent_state.id) + raw_system_message_after_removal = get_raw_system_message(client, agent_state.id, recompile=True) # this should be in, because we didn't delete the source assert "test_source" in raw_system_message_after_removal assert "" in raw_system_message_after_removal @@ -266,7 +278,7 @@ def test_attach_existing_files_creates_source_blocks_correctly( # Attach after uploading the file client.agents.folders.attach(folder_id=source.id, agent_id=agent_state.id) - raw_system_message = get_raw_system_message(client, agent_state.id) + raw_system_message = get_raw_system_message(client, agent_state.id, recompile=True) # Assert that the expected chunk is in the raw system message expected_chunk = """ @@ -307,7 +319,7 @@ def test_attach_existing_files_creates_source_blocks_correctly( assert not any("test" in b.value for b in blocks) # Verify no traces of the prompt exist in the raw system message after detaching - raw_system_message_after_detach = get_raw_system_message(client, agent_state.id) + raw_system_message_after_detach = get_raw_system_message(client, agent_state.id, recompile=True) assert expected_chunk not in raw_system_message_after_detach assert "test_source" not in raw_system_message_after_detach assert "" not in raw_system_message_after_detach @@ -321,7 +333,7 @@ def test_delete_source_removes_source_blocks_correctly( assert len(list(client.folders.list())) == 1 client.agents.folders.attach(folder_id=source.id, agent_id=agent_state.id) - raw_system_message = get_raw_system_message(client, agent_state.id) + raw_system_message = get_raw_system_message(client, agent_state.id, recompile=True) assert "test_source" in raw_system_message assert "" in raw_system_message @@ -330,7 +342,7 @@ def test_delete_source_removes_source_blocks_correctly( # Upload the files upload_file_and_wait(client, source.id, file_path) - raw_system_message = get_raw_system_message(client, agent_state.id) + raw_system_message = get_raw_system_message(client, agent_state.id, recompile=True) # Assert that the expected chunk is in the raw system message expected_chunk = """ @@ -361,7 +373,7 @@ def test_delete_source_removes_source_blocks_correctly( # Remove file from source client.folders.delete(folder_id=source.id) - raw_system_message_after_detach = get_raw_system_message(client, agent_state.id) + raw_system_message_after_detach = get_raw_system_message(client, agent_state.id, recompile=True) assert expected_chunk not in raw_system_message_after_detach assert "test_source" not in raw_system_message_after_detach assert "" not in raw_system_message_after_detach @@ -1112,7 +1124,7 @@ def test_agent_open_file(disable_pinecone, disable_turbopuffer, client: LettaSDK closed_files = client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata["id"]) assert len(closed_files) == 0 - system = get_raw_system_message(client, agent_state.id) + system = get_raw_system_message(client, agent_state.id, recompile=True) assert '' in system assert "[Viewing file start (out of 1 lines)]" in system @@ -1137,7 +1149,7 @@ def test_agent_close_file(disable_pinecone, disable_turbopuffer, client: LettaSD # Test close_file function client.agents.files.close(agent_id=agent_state.id, file_id=file_metadata["id"]) - system = get_raw_system_message(client, agent_state.id) + system = get_raw_system_message(client, agent_state.id, recompile=True) assert '' in system @@ -1160,7 +1172,7 @@ def test_agent_close_all_open_files(disable_pinecone, disable_turbopuffer, clien # Open each file client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata["id"]) - system = get_raw_system_message(client, agent_state.id) + system = get_raw_system_message(client, agent_state.id, recompile=True) assert '