From 2845c86f5f5a2f386f3e17fbae2f41c51da3da33 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 10 Jul 2025 11:29:36 -0700 Subject: [PATCH] fix: Fix context window compilation issues for files (#3272) --- letta/agents/base_agent.py | 31 ++++++-- letta/agents/letta_agent.py | 12 ++- letta/server/rest_api/routers/v1/agents.py | 1 + tests/test_sources.py | 91 +++++++++++++++++++++- 4 files changed, 125 insertions(+), 10 deletions(-) diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index b8613201..a556ab73 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -96,7 +96,7 @@ class BaseAgent(ABC): """ try: # [DB Call] loading blocks (modifies: agent_state.memory.blocks) - await self.agent_manager.refresh_memory_async(agent_state=agent_state, actor=self.actor) + agent_state = await self.agent_manager.refresh_memory_async(agent_state=agent_state, actor=self.actor) tool_constraint_block = None if tool_rules_solver is not None: @@ -104,18 +104,37 @@ class BaseAgent(ABC): # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this curr_system_message = in_context_messages[0] - curr_memory_str = agent_state.memory.compile(tool_usage_rules=tool_constraint_block, sources=agent_state.sources) curr_system_message_text = curr_system_message.content[0].text - if curr_memory_str in curr_system_message_text: + + # extract the dynamic section that includes memory blocks, tool rules, and directories + # this avoids timestamp comparison issues + def extract_dynamic_section(text): + start_marker = "" + end_marker = "" + + start_idx = text.find(start_marker) + end_idx = text.find(end_marker) + + if start_idx != -1 and end_idx != -1: + return text[start_idx:end_idx] + return text # fallback to full text if markers not found + + curr_dynamic_section = extract_dynamic_section(curr_system_message_text) + + # generate just the memory string with current state for comparison + curr_memory_str = agent_state.memory.compile(tool_usage_rules=tool_constraint_block, sources=agent_state.sources) + new_dynamic_section = extract_dynamic_section(curr_memory_str) + + # compare just the dynamic sections (memory blocks, tool rules, directories) + if curr_dynamic_section == new_dynamic_section: logger.debug( - f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" + f"Memory and sources haven't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" ) return in_context_messages memory_edit_timestamp = get_utc_time() - # [DB Call] size of messages and archival memories - # todo: blocking for now + # size of messages and archival memories if num_messages is None: num_messages = await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id) if num_archival_memories is None: diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 40c1deb1..d4d60d9a 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -169,7 +169,9 @@ class LettaAgent(BaseAgent): ) -> Union[LettaResponse, dict]: # TODO (cliandy): pass in run_id and use at send_message endpoints for all step functions agent_state = await self.agent_manager.get_agent_by_id_async( - agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor + agent_id=self.agent_id, + include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"], + actor=self.actor, ) result = await self._step( agent_state=agent_state, @@ -203,7 +205,9 @@ class LettaAgent(BaseAgent): include_return_message_types: list[MessageType] | None = None, ): agent_state = await self.agent_manager.get_agent_by_id_async( - agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor + agent_id=self.agent_id, + include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"], + actor=self.actor, ) current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_no_persist_async( input_messages, agent_state, self.message_manager, self.actor @@ -549,7 +553,9 @@ class LettaAgent(BaseAgent): 4. Processes the response """ agent_state = await self.agent_manager.get_agent_by_id_async( - agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor + agent_id=self.agent_id, + include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"], + actor=self.actor, ) current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_no_persist_async( input_messages, agent_state, self.message_manager, self.actor diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index adf0cd66..fee2de03 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -1208,6 +1208,7 @@ async def preview_raw_payload( ), ) + # TODO: Support step_streaming return await agent_loop.step( input_messages=request.messages, use_assistant_message=request.use_assistant_message, diff --git a/tests/test_sources.py b/tests/test_sources.py index 6a1a4af2..758fc368 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -7,18 +7,37 @@ import pytest from dotenv import load_dotenv from letta_client import CreateBlock from letta_client import Letta as LettaSDKClient +from letta_client import LettaRequest +from letta_client import MessageCreate as ClientMessageCreate from letta_client.types import AgentState from letta.constants import DEFAULT_ORG_ID, FILES_TOOLS from letta.orm.enums import ToolType from letta.schemas.message import MessageCreate from letta.schemas.user import User +from letta.settings import settings from tests.utils import wait_for_server # Constants SERVER_PORT = 8283 +def get_raw_system_message(client: LettaSDKClient, agent_id: str) -> str: + """Helper function to get the raw system message from an agent's preview payload.""" + raw_payload = client.agents.messages.preview_raw_payload( + agent_id=agent_id, + request=LettaRequest( + messages=[ + ClientMessageCreate( + role="user", + content="Testing", + ) + ], + ), + ) + return raw_payload["messages"][0]["content"] + + @pytest.fixture(autouse=True) def clear_sources(client: LettaSDKClient): # Clear existing sources @@ -172,6 +191,10 @@ def test_file_upload_creates_source_blocks_correctly( expected_value: str, expected_label_regex: str, ): + # skip pdf tests if mistral api key is missing + if file_path.endswith(".pdf") and not settings.mistral_api_key: + pytest.skip("mistral api key required for pdf processing") + # Create a new source source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") assert len(client.sources.list()) == 1 @@ -195,6 +218,15 @@ def test_file_upload_creates_source_blocks_correctly( assert any(b.value.startswith("[Viewing file start") for b in blocks) assert any(re.fullmatch(expected_label_regex, b.label) for b in blocks) + # verify raw system message contains source information + raw_system_message = get_raw_system_message(client, agent_state.id) + assert "test_source" in raw_system_message + assert "" in raw_system_message + # verify file-specific details in raw system message + file_name = files[0].file_name + assert f'name="test_source/{file_name}"' in raw_system_message + assert 'status="open"' in raw_system_message + # Remove file from source client.sources.files.delete(source_id=source.id, file_id=files[0].id) @@ -205,6 +237,14 @@ def test_file_upload_creates_source_blocks_correctly( assert not any(expected_value in b.value for b in blocks) assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks) + # verify raw system message no longer contains source information + raw_system_message_after_removal = get_raw_system_message(client, agent_state.id) + # this should be in, because we didn't delete the source + assert "test_source" in raw_system_message_after_removal + assert "" in raw_system_message_after_removal + # verify file-specific details are also removed + assert f'name="test_source/{file_name}"' not in raw_system_message_after_removal + def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState): # Create a new source @@ -224,6 +264,25 @@ def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone, # Attach after uploading the file client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id) + raw_system_message = get_raw_system_message(client, agent_state.id) + + # Assert that the expected chunk is in the raw system message + expected_chunk = """ + + + +- read_only=true +- chars_current=46 +- chars_limit=50000 + + +[Viewing file start (out of 1 chunks)] +1: test + + + +""" + assert expected_chunk in raw_system_message # Get the agent state, check blocks exist agent_state = client.agents.retrieve(agent_id=agent_state.id) @@ -241,20 +300,46 @@ def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone, assert len(blocks) == 0 assert not any("test" in b.value for b in blocks) + # Verify no traces of the prompt exist in the raw system message after detaching + raw_system_message_after_detach = get_raw_system_message(client, agent_state.id) + assert expected_chunk not in raw_system_message_after_detach + assert "test_source" not in raw_system_message_after_detach + assert "" not in raw_system_message_after_detach + def test_delete_source_removes_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState): # Create a new source source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") assert len(client.sources.list()) == 1 - # Attach client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id) + raw_system_message = get_raw_system_message(client, agent_state.id) + assert "test_source" in raw_system_message + assert "" in raw_system_message # Load files into the source file_path = "tests/data/test.txt" # Upload the files upload_file_and_wait(client, source.id, file_path) + raw_system_message = get_raw_system_message(client, agent_state.id) + # Assert that the expected chunk is in the raw system message + expected_chunk = """ + + + +- read_only=true +- chars_current=46 +- chars_limit=50000 + + +[Viewing file start (out of 1 chunks)] +1: test + + + +""" + assert expected_chunk in raw_system_message # Get the agent state, check blocks exist agent_state = client.agents.retrieve(agent_id=agent_state.id) @@ -264,6 +349,10 @@ def test_delete_source_removes_source_blocks_correctly(disable_pinecone, client: # Remove file from source client.sources.delete(source_id=source.id) + raw_system_message_after_detach = get_raw_system_message(client, agent_state.id) + assert expected_chunk not in raw_system_message_after_detach + assert "test_source" not in raw_system_message_after_detach + assert "" not in raw_system_message_after_detach # Get the agent state, check blocks do NOT exist agent_state = client.agents.retrieve(agent_id=agent_state.id)