fix(core): stabilize system prompt refresh and expand git-memory coverage (#9438)

* fix(core): stabilize system prompt refresh and expand git-memory coverage

Only rebuild system prompts on explicit refresh paths so normal turns preserve prefix-cache stability, including git/custom prompt layouts. Add integration coverage for memory filesystem tree structure and recompile/reset system-message updates via message-id retrieval.

👾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* fix(core): recompile system prompt around compaction and stabilize source tests

Force system prompt refresh before/after compaction in LettaAgentV3 so repaired system+memory state is used and persisted across subsequent turns. Update source-system prompt tests to explicitly recompile before raw preview assertions instead of assuming automatic rebuild timing.

👾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

---------

Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
Sarah Wooders
2026-02-11 16:21:39 -08:00
committed by Caren Thomas
parent 5b7dd15905
commit d7793a4474
8 changed files with 105 additions and 107 deletions

View File

@@ -123,35 +123,17 @@ class BaseAgent(ABC):
curr_system_message = in_context_messages[0] curr_system_message = in_context_messages[0]
curr_system_message_text = curr_system_message.content[0].text curr_system_message_text = curr_system_message.content[0].text
# extract the dynamic section that includes memory blocks, tool rules, and directories # generate memory string with current state for comparison
# this avoids timestamp comparison issues
# TODO: This is a separate position-based parser for the same system message format
# parsed by ContextWindowCalculator.extract_system_components(). Consider unifying
# to avoid divergence. See PR #9398 for context.
def extract_dynamic_section(text):
start_marker = "</base_instructions>"
end_marker = "<memory_metadata>"
start_idx = text.find(start_marker)
end_idx = text.find(end_marker)
if start_idx != -1 and end_idx != -1:
return text[start_idx:end_idx]
return text # fallback to full text if markers not found
curr_dynamic_section = extract_dynamic_section(curr_system_message_text)
# generate just the memory string with current state for comparison
curr_memory_str = agent_state.memory.compile( curr_memory_str = agent_state.memory.compile(
tool_usage_rules=tool_constraint_block, tool_usage_rules=tool_constraint_block,
sources=agent_state.sources, sources=agent_state.sources,
max_files_open=agent_state.max_files_open, max_files_open=agent_state.max_files_open,
llm_config=agent_state.llm_config, llm_config=agent_state.llm_config,
) )
new_dynamic_section = extract_dynamic_section(curr_memory_str)
# compare just the dynamic sections (memory blocks, tool rules, directories) system_prompt_changed = agent_state.system not in curr_system_message_text
if curr_dynamic_section == new_dynamic_section: memory_changed = curr_memory_str not in curr_system_message_text
if (not system_prompt_changed) and (not memory_changed):
logger.debug( logger.debug(
f"Memory and sources haven't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" f"Memory and sources haven't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
) )

View File

@@ -704,19 +704,18 @@ class LettaAgentV2(BaseAgentV2):
Returns: Returns:
Refreshed in-context messages. Refreshed in-context messages.
""" """
# Always attempt to rebuild the system prompt if the memory section changed. # Only rebuild when explicitly forced (e.g., after compaction).
# This method is careful to skip rebuilds when the memory section is unchanged. # Normal turns should not trigger system prompt recompilation.
if force_system_prompt_refresh:
try: try:
in_context_messages = await self._rebuild_memory( in_context_messages = await self._rebuild_memory(
in_context_messages, in_context_messages,
num_messages=None, num_messages=None,
num_archival_memories=None, num_archival_memories=None,
force=True,
) )
except Exception as e: except Exception as e:
# If callers requested a forced refresh, surface the error.
if force_system_prompt_refresh:
raise raise
self.logger.warning(f"Failed to refresh system prompt/memory: {e}")
# Always scrub inner thoughts regardless of system prompt refresh # Always scrub inner thoughts regardless of system prompt refresh
in_context_messages = scrub_inner_thoughts_from_messages(in_context_messages, self.agent_state.llm_config) in_context_messages = scrub_inner_thoughts_from_messages(in_context_messages, self.agent_state.llm_config)
@@ -728,6 +727,7 @@ class LettaAgentV2(BaseAgentV2):
in_context_messages: list[Message], in_context_messages: list[Message],
num_messages: int | None, num_messages: int | None,
num_archival_memories: int | None, num_archival_memories: int | None,
force: bool = False,
): ):
agent_state = await self.agent_manager.refresh_memory_async(agent_state=self.agent_state, actor=self.actor) agent_state = await self.agent_manager.refresh_memory_async(agent_state=self.agent_state, actor=self.actor)
@@ -748,51 +748,24 @@ class LettaAgentV2(BaseAgentV2):
else: else:
archive_tags = None archive_tags = None
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
curr_system_message = in_context_messages[0] curr_system_message = in_context_messages[0]
curr_system_message_text = curr_system_message.content[0].text curr_system_message_text = curr_system_message.content[0].text
# Extract the memory section that includes <memory_blocks>, tool rules, and directories.
# This avoids timestamp comparison issues in <memory_metadata>, which is dynamic.
def extract_memory_section(text: str) -> str:
# Primary pattern: everything from <memory_blocks> up to <memory_metadata>
mem_start = text.find("<memory_blocks>")
meta_start = text.find("<memory_metadata>")
if mem_start != -1:
if meta_start != -1 and meta_start > mem_start:
return text[mem_start:meta_start]
return text[mem_start:]
# Fallback pattern used in some legacy prompts: between </base_instructions> and <memory_metadata>
base_end = text.find("</base_instructions>")
if base_end != -1:
if meta_start != -1 and meta_start > base_end:
return text[base_end + len("</base_instructions>") : meta_start]
return text[base_end + len("</base_instructions>") :]
# Last resort: return full text
return text
curr_memory_section = extract_memory_section(curr_system_message_text)
# refresh files # refresh files
agent_state = await self.agent_manager.refresh_file_blocks(agent_state=agent_state, actor=self.actor) agent_state = await self.agent_manager.refresh_file_blocks(agent_state=agent_state, actor=self.actor)
# generate just the memory string with current state for comparison # generate memory string with current state
curr_memory_str = agent_state.memory.compile( curr_memory_str = agent_state.memory.compile(
tool_usage_rules=tool_constraint_block, tool_usage_rules=tool_constraint_block,
sources=agent_state.sources, sources=agent_state.sources,
max_files_open=agent_state.max_files_open, max_files_open=agent_state.max_files_open,
llm_config=agent_state.llm_config, llm_config=agent_state.llm_config,
) )
new_memory_section = extract_memory_section(curr_memory_str)
# Compare just the memory sections (memory blocks, tool rules, directories). # Skip rebuild unless explicitly forced and unless system/memory content actually changed.
# Also ensure the configured system prompt is still present; if the system prompt
# changed (e.g. via UpdateAgent(system=...)), we must rebuild.
system_prompt_changed = agent_state.system not in curr_system_message_text system_prompt_changed = agent_state.system not in curr_system_message_text
memory_changed = curr_memory_str not in curr_system_message_text
if (not system_prompt_changed) and (curr_memory_section.strip() == new_memory_section.strip()): if (not force) and (not system_prompt_changed) and (not memory_changed):
self.logger.debug( self.logger.debug(
f"Memory, sources, and system prompt haven't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" f"Memory, sources, and system prompt haven't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
) )

View File

@@ -977,6 +977,10 @@ class LettaAgentV3(LettaAgentV2):
trigger="context_window_exceeded", trigger="context_window_exceeded",
) )
# Ensure system prompt is recompiled before summarization so compaction
# operates on the latest system+memory state (including recent repairs).
messages = await self._refresh_messages(messages, force_system_prompt_refresh=True)
summary_message, messages, summary_text = await self.compact( summary_message, messages, summary_text = await self.compact(
messages, messages,
trigger_threshold=self.agent_state.llm_config.context_window, trigger_threshold=self.agent_state.llm_config.context_window,
@@ -987,6 +991,15 @@ class LettaAgentV3(LettaAgentV2):
context_tokens_before=context_tokens_before, context_tokens_before=context_tokens_before,
messages_count_before=messages_count_before, messages_count_before=messages_count_before,
) )
# Recompile the persisted system prompt after compaction so subsequent
# turns load the repaired system+memory state from message_ids[0].
await self.agent_manager.rebuild_system_prompt_async(
agent_id=self.agent_state.id,
actor=self.actor,
force=True,
update_timestamp=True,
)
# Force system prompt rebuild after compaction to update memory blocks and timestamps # Force system prompt rebuild after compaction to update memory blocks and timestamps
messages = await self._refresh_messages(messages, force_system_prompt_refresh=True) messages = await self._refresh_messages(messages, force_system_prompt_refresh=True)
self.logger.info("Summarization succeeded, continuing to retry LLM request") self.logger.info("Summarization succeeded, continuing to retry LLM request")
@@ -1041,12 +1054,16 @@ class LettaAgentV3(LettaAgentV2):
# Track turn data for multi-turn RL training (SGLang native mode) # Track turn data for multi-turn RL training (SGLang native mode)
if self.return_token_ids and hasattr(llm_adapter, "output_ids") and llm_adapter.output_ids: if self.return_token_ids and hasattr(llm_adapter, "output_ids") and llm_adapter.output_ids:
self.turns.append(TurnTokenData( self.turns.append(
TurnTokenData(
role="assistant", role="assistant",
output_ids=llm_adapter.output_ids, output_ids=llm_adapter.output_ids,
output_token_logprobs=llm_adapter.output_token_logprobs, output_token_logprobs=llm_adapter.output_token_logprobs,
content=llm_adapter.chat_completions_response.choices[0].message.content if llm_adapter.chat_completions_response else None, content=llm_adapter.chat_completions_response.choices[0].message.content
)) if llm_adapter.chat_completions_response
else None,
)
)
# Handle the AI response with the extracted data (supports multiple tool calls) # Handle the AI response with the extracted data (supports multiple tool calls)
# Gather tool calls - check for multi-call API first, then fall back to single # Gather tool calls - check for multi-call API first, then fall back to single
@@ -1105,7 +1122,7 @@ class LettaAgentV3(LettaAgentV2):
# Aggregate all tool returns into content (func_response is the actual content) # Aggregate all tool returns into content (func_response is the actual content)
parts = [] parts = []
for tr in msg.tool_returns: for tr in msg.tool_returns:
if hasattr(tr, 'func_response') and tr.func_response: if hasattr(tr, "func_response") and tr.func_response:
if isinstance(tr.func_response, str): if isinstance(tr.func_response, str):
parts.append(tr.func_response) parts.append(tr.func_response)
else: else:
@@ -1116,11 +1133,13 @@ class LettaAgentV3(LettaAgentV2):
if hasattr(msg, "name"): if hasattr(msg, "name"):
tool_name = msg.name tool_name = msg.name
if tool_content: if tool_content:
self.turns.append(TurnTokenData( self.turns.append(
TurnTokenData(
role="tool", role="tool",
content=tool_content, content=tool_content,
tool_name=tool_name, tool_name=tool_name,
)) )
)
# step(...) has successfully completed! now we can persist messages and update the in-context messages + save metrics # step(...) has successfully completed! now we can persist messages and update the in-context messages + save metrics
# persistence needs to happen before streaming to minimize chances of agent getting into an inconsistent state # persistence needs to happen before streaming to minimize chances of agent getting into an inconsistent state
@@ -1177,6 +1196,10 @@ class LettaAgentV3(LettaAgentV2):
) )
try: try:
# Ensure system prompt is recompiled before summarization so compaction
# operates on the latest system+memory state (including recent repairs).
messages = await self._refresh_messages(messages, force_system_prompt_refresh=True)
summary_message, messages, summary_text = await self.compact( summary_message, messages, summary_text = await self.compact(
messages, messages,
trigger_threshold=self.agent_state.llm_config.context_window, trigger_threshold=self.agent_state.llm_config.context_window,
@@ -1187,6 +1210,15 @@ class LettaAgentV3(LettaAgentV2):
context_tokens_before=context_tokens_before, context_tokens_before=context_tokens_before,
messages_count_before=messages_count_before, messages_count_before=messages_count_before,
) )
# Recompile the persisted system prompt after compaction so subsequent
# turns load the repaired system+memory state from message_ids[0].
await self.agent_manager.rebuild_system_prompt_async(
agent_id=self.agent_state.id,
actor=self.actor,
force=True,
update_timestamp=True,
)
# Force system prompt rebuild after compaction to update memory blocks and timestamps # Force system prompt rebuild after compaction to update memory blocks and timestamps
messages = await self._refresh_messages(messages, force_system_prompt_refresh=True) messages = await self._refresh_messages(messages, force_system_prompt_refresh=True)
# TODO: persist + return the summary message # TODO: persist + return the summary message

View File

@@ -243,7 +243,8 @@ class Memory(BaseModel, validate_assignment=True):
"""Render a filesystem tree view of all memory blocks. """Render a filesystem tree view of all memory blocks.
Only rendered for git-memory-enabled agents. Uses box-drawing Only rendered for git-memory-enabled agents. Uses box-drawing
characters (├──, └──, │) like the Unix `tree` command. characters (├──, └──, │) like the Unix `tree` command, while keeping
deterministic ordering (directories first, then files, alphabetically).
""" """
if not self.blocks: if not self.blocks:
return return

View File

@@ -2306,6 +2306,7 @@ async def reset_messages(
actor=actor, actor=actor,
add_default_initial_messages=request.add_default_initial_messages, add_default_initial_messages=request.add_default_initial_messages,
needs_agent_state=not is_1_0_sdk_version(headers), needs_agent_state=not is_1_0_sdk_version(headers),
rebuild_system_prompt=True,
) )

View File

@@ -551,19 +551,16 @@ class BlockManager:
result = await session.execute(query) result = await session.execute(query)
blocks = result.scalars().all() blocks = result.scalars().all()
# Convert to Pydantic models # Convert to Pydantic models and preserve caller-provided ID order
pydantic_blocks = [block.to_pydantic() for block in blocks] pydantic_blocks = [block.to_pydantic() for block in blocks]
blocks_by_id = {b.id: b for b in pydantic_blocks}
ordered_blocks = [blocks_by_id.get(block_id) for block_id in block_ids]
# For backward compatibility, add None for missing blocks # For backward compatibility, include None for missing blocks
if len(pydantic_blocks) < len(block_ids): if len(pydantic_blocks) < len(block_ids):
{block.id for block in pydantic_blocks} return ordered_blocks
result_blocks = []
for block_id in block_ids:
block = next((b for b in pydantic_blocks if b.id == block_id), None)
result_blocks.append(block)
return result_blocks
return pydantic_blocks return ordered_blocks
@enforce_types @enforce_types
@trace_method @trace_method

View File

@@ -333,7 +333,7 @@ Some directory content
) )
result = ContextWindowCalculator.extract_system_components(system_message) result = ContextWindowCalculator.extract_system_components(system_message)
# memory_filesystem should be the tree view only # memory_filesystem should preserve tree connectors with deterministic ordering
assert result["memory_filesystem"] is not None assert result["memory_filesystem"] is not None
assert "\u251c\u2500\u2500 system/" in result["memory_filesystem"] assert "\u251c\u2500\u2500 system/" in result["memory_filesystem"]

View File

@@ -27,8 +27,20 @@ from tests.utils import wait_for_server
SERVER_PORT = 8283 SERVER_PORT = 8283
def get_raw_system_message(client: LettaSDKClient, agent_id: str) -> str: def recompile_agent_system_prompt(client: LettaSDKClient, agent_id: str) -> None:
"""Force a system prompt recompilation for deterministic raw-preview assertions."""
client.post(
f"/v1/agents/{agent_id}/recompile",
cast_to=str,
body={},
)
def get_raw_system_message(client: LettaSDKClient, agent_id: str, recompile: bool = False) -> str:
"""Helper function to get the raw system message from an agent's preview payload.""" """Helper function to get the raw system message from an agent's preview payload."""
if recompile:
recompile_agent_system_prompt(client, agent_id)
raw_payload = client.post( raw_payload = client.post(
f"/v1/agents/{agent_id}/messages/preview-raw-payload", f"/v1/agents/{agent_id}/messages/preview-raw-payload",
cast_to=dict[str, Any], cast_to=dict[str, Any],
@@ -215,7 +227,7 @@ def test_file_upload_creates_source_blocks_correctly(
assert any(re.fullmatch(expected_label_regex, b.label) for b in blocks) assert any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
# verify raw system message contains source information # verify raw system message contains source information
raw_system_message = get_raw_system_message(client, agent_state.id) raw_system_message = get_raw_system_message(client, agent_state.id, recompile=True)
assert "test_source" in raw_system_message assert "test_source" in raw_system_message
assert "<directories>" in raw_system_message assert "<directories>" in raw_system_message
# verify file-specific details in raw system message # verify file-specific details in raw system message
@@ -234,7 +246,7 @@ def test_file_upload_creates_source_blocks_correctly(
assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks) assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
# verify raw system message no longer contains source information # verify raw system message no longer contains source information
raw_system_message_after_removal = get_raw_system_message(client, agent_state.id) raw_system_message_after_removal = get_raw_system_message(client, agent_state.id, recompile=True)
# this should be in, because we didn't delete the source # this should be in, because we didn't delete the source
assert "test_source" in raw_system_message_after_removal assert "test_source" in raw_system_message_after_removal
assert "<directories>" in raw_system_message_after_removal assert "<directories>" in raw_system_message_after_removal
@@ -266,7 +278,7 @@ def test_attach_existing_files_creates_source_blocks_correctly(
# Attach after uploading the file # Attach after uploading the file
client.agents.folders.attach(folder_id=source.id, agent_id=agent_state.id) client.agents.folders.attach(folder_id=source.id, agent_id=agent_state.id)
raw_system_message = get_raw_system_message(client, agent_state.id) raw_system_message = get_raw_system_message(client, agent_state.id, recompile=True)
# Assert that the expected chunk is in the raw system message # Assert that the expected chunk is in the raw system message
expected_chunk = """<directories> expected_chunk = """<directories>
@@ -307,7 +319,7 @@ def test_attach_existing_files_creates_source_blocks_correctly(
assert not any("test" in b.value for b in blocks) assert not any("test" in b.value for b in blocks)
# Verify no traces of the prompt exist in the raw system message after detaching # Verify no traces of the prompt exist in the raw system message after detaching
raw_system_message_after_detach = get_raw_system_message(client, agent_state.id) raw_system_message_after_detach = get_raw_system_message(client, agent_state.id, recompile=True)
assert expected_chunk not in raw_system_message_after_detach assert expected_chunk not in raw_system_message_after_detach
assert "test_source" not in raw_system_message_after_detach assert "test_source" not in raw_system_message_after_detach
assert "<directories>" not in raw_system_message_after_detach assert "<directories>" not in raw_system_message_after_detach
@@ -321,7 +333,7 @@ def test_delete_source_removes_source_blocks_correctly(
assert len(list(client.folders.list())) == 1 assert len(list(client.folders.list())) == 1
client.agents.folders.attach(folder_id=source.id, agent_id=agent_state.id) client.agents.folders.attach(folder_id=source.id, agent_id=agent_state.id)
raw_system_message = get_raw_system_message(client, agent_state.id) raw_system_message = get_raw_system_message(client, agent_state.id, recompile=True)
assert "test_source" in raw_system_message assert "test_source" in raw_system_message
assert "<directories>" in raw_system_message assert "<directories>" in raw_system_message
@@ -330,7 +342,7 @@ def test_delete_source_removes_source_blocks_correctly(
# Upload the files # Upload the files
upload_file_and_wait(client, source.id, file_path) upload_file_and_wait(client, source.id, file_path)
raw_system_message = get_raw_system_message(client, agent_state.id) raw_system_message = get_raw_system_message(client, agent_state.id, recompile=True)
# Assert that the expected chunk is in the raw system message # Assert that the expected chunk is in the raw system message
expected_chunk = """<directories> expected_chunk = """<directories>
<file_limits> <file_limits>
@@ -361,7 +373,7 @@ def test_delete_source_removes_source_blocks_correctly(
# Remove file from source # Remove file from source
client.folders.delete(folder_id=source.id) client.folders.delete(folder_id=source.id)
raw_system_message_after_detach = get_raw_system_message(client, agent_state.id) raw_system_message_after_detach = get_raw_system_message(client, agent_state.id, recompile=True)
assert expected_chunk not in raw_system_message_after_detach assert expected_chunk not in raw_system_message_after_detach
assert "test_source" not in raw_system_message_after_detach assert "test_source" not in raw_system_message_after_detach
assert "<directories>" not in raw_system_message_after_detach assert "<directories>" not in raw_system_message_after_detach
@@ -1112,7 +1124,7 @@ def test_agent_open_file(disable_pinecone, disable_turbopuffer, client: LettaSDK
closed_files = client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata["id"]) closed_files = client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata["id"])
assert len(closed_files) == 0 assert len(closed_files) == 0
system = get_raw_system_message(client, agent_state.id) system = get_raw_system_message(client, agent_state.id, recompile=True)
assert '<file status="open" name="test_source/test.txt">' in system assert '<file status="open" name="test_source/test.txt">' in system
assert "[Viewing file start (out of 1 lines)]" in system assert "[Viewing file start (out of 1 lines)]" in system
@@ -1137,7 +1149,7 @@ def test_agent_close_file(disable_pinecone, disable_turbopuffer, client: LettaSD
# Test close_file function # Test close_file function
client.agents.files.close(agent_id=agent_state.id, file_id=file_metadata["id"]) client.agents.files.close(agent_id=agent_state.id, file_id=file_metadata["id"])
system = get_raw_system_message(client, agent_state.id) system = get_raw_system_message(client, agent_state.id, recompile=True)
assert '<file status="closed" name="test_source/test.txt">' in system assert '<file status="closed" name="test_source/test.txt">' in system
@@ -1160,7 +1172,7 @@ def test_agent_close_all_open_files(disable_pinecone, disable_turbopuffer, clien
# Open each file # Open each file
client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata["id"]) client.agents.files.open(agent_id=agent_state.id, file_id=file_metadata["id"])
system = get_raw_system_message(client, agent_state.id) system = get_raw_system_message(client, agent_state.id, recompile=True)
assert '<file status="open"' in system assert '<file status="open"' in system
# Test close_all_open_files function # Test close_all_open_files function
@@ -1170,7 +1182,7 @@ def test_agent_close_all_open_files(disable_pinecone, disable_turbopuffer, clien
assert isinstance(result, list), f"Expected list, got {type(result)}" assert isinstance(result, list), f"Expected list, got {type(result)}"
assert all(isinstance(item, str) for item in result), "All items in result should be strings" assert all(isinstance(item, str) for item in result), "All items in result should be strings"
system = get_raw_system_message(client, agent_state.id) system = get_raw_system_message(client, agent_state.id, recompile=True)
assert '<file status="open"' not in system assert '<file status="open"' not in system