fix: Fix context window compilation issues for files (#3272)

This commit is contained in:
Matthew Zhou
2025-07-10 11:29:36 -07:00
committed by GitHub
parent 49641ca7bb
commit 4e7750d17f
4 changed files with 125 additions and 10 deletions

View File

@@ -96,7 +96,7 @@ class BaseAgent(ABC):
"""
try:
# [DB Call] loading blocks (modifies: agent_state.memory.blocks)
await self.agent_manager.refresh_memory_async(agent_state=agent_state, actor=self.actor)
agent_state = await self.agent_manager.refresh_memory_async(agent_state=agent_state, actor=self.actor)
tool_constraint_block = None
if tool_rules_solver is not None:
@@ -104,18 +104,37 @@ class BaseAgent(ABC):
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
curr_system_message = in_context_messages[0]
curr_memory_str = agent_state.memory.compile(tool_usage_rules=tool_constraint_block, sources=agent_state.sources)
curr_system_message_text = curr_system_message.content[0].text
if curr_memory_str in curr_system_message_text:
# extract the dynamic section that includes memory blocks, tool rules, and directories
# this avoids timestamp comparison issues
def extract_dynamic_section(text):
start_marker = "</base_instructions>"
end_marker = "<memory_metadata>"
start_idx = text.find(start_marker)
end_idx = text.find(end_marker)
if start_idx != -1 and end_idx != -1:
return text[start_idx:end_idx]
return text # fallback to full text if markers not found
curr_dynamic_section = extract_dynamic_section(curr_system_message_text)
# generate just the memory string with current state for comparison
curr_memory_str = agent_state.memory.compile(tool_usage_rules=tool_constraint_block, sources=agent_state.sources)
new_dynamic_section = extract_dynamic_section(curr_memory_str)
# compare just the dynamic sections (memory blocks, tool rules, directories)
if curr_dynamic_section == new_dynamic_section:
logger.debug(
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
f"Memory and sources haven't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
)
return in_context_messages
memory_edit_timestamp = get_utc_time()
# [DB Call] size of messages and archival memories
# todo: blocking for now
# size of messages and archival memories
if num_messages is None:
num_messages = await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
if num_archival_memories is None:

View File

@@ -169,7 +169,9 @@ class LettaAgent(BaseAgent):
) -> Union[LettaResponse, dict]:
# TODO (cliandy): pass in run_id and use at send_message endpoints for all step functions
agent_state = await self.agent_manager.get_agent_by_id_async(
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
agent_id=self.agent_id,
include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"],
actor=self.actor,
)
result = await self._step(
agent_state=agent_state,
@@ -203,7 +205,9 @@ class LettaAgent(BaseAgent):
include_return_message_types: list[MessageType] | None = None,
):
agent_state = await self.agent_manager.get_agent_by_id_async(
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
agent_id=self.agent_id,
include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"],
actor=self.actor,
)
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_no_persist_async(
input_messages, agent_state, self.message_manager, self.actor
@@ -549,7 +553,9 @@ class LettaAgent(BaseAgent):
4. Processes the response
"""
agent_state = await self.agent_manager.get_agent_by_id_async(
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
agent_id=self.agent_id,
include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"],
actor=self.actor,
)
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_no_persist_async(
input_messages, agent_state, self.message_manager, self.actor

View File

@@ -1208,6 +1208,7 @@ async def preview_raw_payload(
),
)
# TODO: Support step_streaming
return await agent_loop.step(
input_messages=request.messages,
use_assistant_message=request.use_assistant_message,

View File

@@ -7,18 +7,37 @@ import pytest
from dotenv import load_dotenv
from letta_client import CreateBlock
from letta_client import Letta as LettaSDKClient
from letta_client import LettaRequest
from letta_client import MessageCreate as ClientMessageCreate
from letta_client.types import AgentState
from letta.constants import DEFAULT_ORG_ID, FILES_TOOLS
from letta.orm.enums import ToolType
from letta.schemas.message import MessageCreate
from letta.schemas.user import User
from letta.settings import settings
from tests.utils import wait_for_server
# Constants
SERVER_PORT = 8283
def get_raw_system_message(client: LettaSDKClient, agent_id: str) -> str:
"""Helper function to get the raw system message from an agent's preview payload."""
raw_payload = client.agents.messages.preview_raw_payload(
agent_id=agent_id,
request=LettaRequest(
messages=[
ClientMessageCreate(
role="user",
content="Testing",
)
],
),
)
return raw_payload["messages"][0]["content"]
@pytest.fixture(autouse=True)
def clear_sources(client: LettaSDKClient):
# Clear existing sources
@@ -172,6 +191,10 @@ def test_file_upload_creates_source_blocks_correctly(
expected_value: str,
expected_label_regex: str,
):
# skip pdf tests if mistral api key is missing
if file_path.endswith(".pdf") and not settings.mistral_api_key:
pytest.skip("mistral api key required for pdf processing")
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
assert len(client.sources.list()) == 1
@@ -195,6 +218,15 @@ def test_file_upload_creates_source_blocks_correctly(
assert any(b.value.startswith("[Viewing file start") for b in blocks)
assert any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
# verify raw system message contains source information
raw_system_message = get_raw_system_message(client, agent_state.id)
assert "test_source" in raw_system_message
assert "<directories>" in raw_system_message
# verify file-specific details in raw system message
file_name = files[0].file_name
assert f'name="test_source/{file_name}"' in raw_system_message
assert 'status="open"' in raw_system_message
# Remove file from source
client.sources.files.delete(source_id=source.id, file_id=files[0].id)
@@ -205,6 +237,14 @@ def test_file_upload_creates_source_blocks_correctly(
assert not any(expected_value in b.value for b in blocks)
assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
# verify raw system message no longer contains source information
raw_system_message_after_removal = get_raw_system_message(client, agent_state.id)
# this should be in, because we didn't delete the source
assert "test_source" in raw_system_message_after_removal
assert "<directories>" in raw_system_message_after_removal
# verify file-specific details are also removed
assert f'name="test_source/{file_name}"' not in raw_system_message_after_removal
def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
@@ -224,6 +264,25 @@ def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone,
# Attach after uploading the file
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
raw_system_message = get_raw_system_message(client, agent_state.id)
# Assert that the expected chunk is in the raw system message
expected_chunk = """<directories>
<directory name="test_source">
<file status="open" name="test_source/test.txt">
<metadata>
- read_only=true
- chars_current=46
- chars_limit=50000
</metadata>
<value>
[Viewing file start (out of 1 chunks)]
1: test
</value>
</file>
</directory>
</directories>"""
assert expected_chunk in raw_system_message
# Get the agent state, check blocks exist
agent_state = client.agents.retrieve(agent_id=agent_state.id)
@@ -241,20 +300,46 @@ def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone,
assert len(blocks) == 0
assert not any("test" in b.value for b in blocks)
# Verify no traces of the prompt exist in the raw system message after detaching
raw_system_message_after_detach = get_raw_system_message(client, agent_state.id)
assert expected_chunk not in raw_system_message_after_detach
assert "test_source" not in raw_system_message_after_detach
assert "<directories>" not in raw_system_message_after_detach
def test_delete_source_removes_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
assert len(client.sources.list()) == 1
# Attach
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
raw_system_message = get_raw_system_message(client, agent_state.id)
assert "test_source" in raw_system_message
assert "<directories>" in raw_system_message
# Load files into the source
file_path = "tests/data/test.txt"
# Upload the files
upload_file_and_wait(client, source.id, file_path)
raw_system_message = get_raw_system_message(client, agent_state.id)
# Assert that the expected chunk is in the raw system message
expected_chunk = """<directories>
<directory name="test_source">
<file status="open" name="test_source/test.txt">
<metadata>
- read_only=true
- chars_current=46
- chars_limit=50000
</metadata>
<value>
[Viewing file start (out of 1 chunks)]
1: test
</value>
</file>
</directory>
</directories>"""
assert expected_chunk in raw_system_message
# Get the agent state, check blocks exist
agent_state = client.agents.retrieve(agent_id=agent_state.id)
@@ -264,6 +349,10 @@ def test_delete_source_removes_source_blocks_correctly(disable_pinecone, client:
# Remove file from source
client.sources.delete(source_id=source.id)
raw_system_message_after_detach = get_raw_system_message(client, agent_state.id)
assert expected_chunk not in raw_system_message_after_detach
assert "test_source" not in raw_system_message_after_detach
assert "<directories>" not in raw_system_message_after_detach
# Get the agent state, check blocks do NOT exist
agent_state = client.agents.retrieve(agent_id=agent_state.id)