fix: Fix context window compilation issues for files (#3272)
This commit is contained in:
@@ -96,7 +96,7 @@ class BaseAgent(ABC):
|
||||
"""
|
||||
try:
|
||||
# [DB Call] loading blocks (modifies: agent_state.memory.blocks)
|
||||
await self.agent_manager.refresh_memory_async(agent_state=agent_state, actor=self.actor)
|
||||
agent_state = await self.agent_manager.refresh_memory_async(agent_state=agent_state, actor=self.actor)
|
||||
|
||||
tool_constraint_block = None
|
||||
if tool_rules_solver is not None:
|
||||
@@ -104,18 +104,37 @@ class BaseAgent(ABC):
|
||||
|
||||
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
|
||||
curr_system_message = in_context_messages[0]
|
||||
curr_memory_str = agent_state.memory.compile(tool_usage_rules=tool_constraint_block, sources=agent_state.sources)
|
||||
curr_system_message_text = curr_system_message.content[0].text
|
||||
if curr_memory_str in curr_system_message_text:
|
||||
|
||||
# extract the dynamic section that includes memory blocks, tool rules, and directories
|
||||
# this avoids timestamp comparison issues
|
||||
def extract_dynamic_section(text):
|
||||
start_marker = "</base_instructions>"
|
||||
end_marker = "<memory_metadata>"
|
||||
|
||||
start_idx = text.find(start_marker)
|
||||
end_idx = text.find(end_marker)
|
||||
|
||||
if start_idx != -1 and end_idx != -1:
|
||||
return text[start_idx:end_idx]
|
||||
return text # fallback to full text if markers not found
|
||||
|
||||
curr_dynamic_section = extract_dynamic_section(curr_system_message_text)
|
||||
|
||||
# generate just the memory string with current state for comparison
|
||||
curr_memory_str = agent_state.memory.compile(tool_usage_rules=tool_constraint_block, sources=agent_state.sources)
|
||||
new_dynamic_section = extract_dynamic_section(curr_memory_str)
|
||||
|
||||
# compare just the dynamic sections (memory blocks, tool rules, directories)
|
||||
if curr_dynamic_section == new_dynamic_section:
|
||||
logger.debug(
|
||||
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
|
||||
f"Memory and sources haven't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
|
||||
)
|
||||
return in_context_messages
|
||||
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
|
||||
# [DB Call] size of messages and archival memories
|
||||
# todo: blocking for now
|
||||
# size of messages and archival memories
|
||||
if num_messages is None:
|
||||
num_messages = await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
if num_archival_memories is None:
|
||||
|
||||
@@ -169,7 +169,9 @@ class LettaAgent(BaseAgent):
|
||||
) -> Union[LettaResponse, dict]:
|
||||
# TODO (cliandy): pass in run_id and use at send_message endpoints for all step functions
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||||
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
|
||||
agent_id=self.agent_id,
|
||||
include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"],
|
||||
actor=self.actor,
|
||||
)
|
||||
result = await self._step(
|
||||
agent_state=agent_state,
|
||||
@@ -203,7 +205,9 @@ class LettaAgent(BaseAgent):
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
):
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||||
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
|
||||
agent_id=self.agent_id,
|
||||
include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"],
|
||||
actor=self.actor,
|
||||
)
|
||||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_no_persist_async(
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
@@ -549,7 +553,9 @@ class LettaAgent(BaseAgent):
|
||||
4. Processes the response
|
||||
"""
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||||
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
|
||||
agent_id=self.agent_id,
|
||||
include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"],
|
||||
actor=self.actor,
|
||||
)
|
||||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_no_persist_async(
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
|
||||
@@ -1208,6 +1208,7 @@ async def preview_raw_payload(
|
||||
),
|
||||
)
|
||||
|
||||
# TODO: Support step_streaming
|
||||
return await agent_loop.step(
|
||||
input_messages=request.messages,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
|
||||
@@ -7,18 +7,37 @@ import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import CreateBlock
|
||||
from letta_client import Letta as LettaSDKClient
|
||||
from letta_client import LettaRequest
|
||||
from letta_client import MessageCreate as ClientMessageCreate
|
||||
from letta_client.types import AgentState
|
||||
|
||||
from letta.constants import DEFAULT_ORG_ID, FILES_TOOLS
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.settings import settings
|
||||
from tests.utils import wait_for_server
|
||||
|
||||
# Constants
|
||||
SERVER_PORT = 8283
|
||||
|
||||
|
||||
def get_raw_system_message(client: LettaSDKClient, agent_id: str) -> str:
|
||||
"""Helper function to get the raw system message from an agent's preview payload."""
|
||||
raw_payload = client.agents.messages.preview_raw_payload(
|
||||
agent_id=agent_id,
|
||||
request=LettaRequest(
|
||||
messages=[
|
||||
ClientMessageCreate(
|
||||
role="user",
|
||||
content="Testing",
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
return raw_payload["messages"][0]["content"]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_sources(client: LettaSDKClient):
|
||||
# Clear existing sources
|
||||
@@ -172,6 +191,10 @@ def test_file_upload_creates_source_blocks_correctly(
|
||||
expected_value: str,
|
||||
expected_label_regex: str,
|
||||
):
|
||||
# skip pdf tests if mistral api key is missing
|
||||
if file_path.endswith(".pdf") and not settings.mistral_api_key:
|
||||
pytest.skip("mistral api key required for pdf processing")
|
||||
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
assert len(client.sources.list()) == 1
|
||||
@@ -195,6 +218,15 @@ def test_file_upload_creates_source_blocks_correctly(
|
||||
assert any(b.value.startswith("[Viewing file start") for b in blocks)
|
||||
assert any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
|
||||
|
||||
# verify raw system message contains source information
|
||||
raw_system_message = get_raw_system_message(client, agent_state.id)
|
||||
assert "test_source" in raw_system_message
|
||||
assert "<directories>" in raw_system_message
|
||||
# verify file-specific details in raw system message
|
||||
file_name = files[0].file_name
|
||||
assert f'name="test_source/{file_name}"' in raw_system_message
|
||||
assert 'status="open"' in raw_system_message
|
||||
|
||||
# Remove file from source
|
||||
client.sources.files.delete(source_id=source.id, file_id=files[0].id)
|
||||
|
||||
@@ -205,6 +237,14 @@ def test_file_upload_creates_source_blocks_correctly(
|
||||
assert not any(expected_value in b.value for b in blocks)
|
||||
assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
|
||||
|
||||
# verify raw system message no longer contains source information
|
||||
raw_system_message_after_removal = get_raw_system_message(client, agent_state.id)
|
||||
# this should be in, because we didn't delete the source
|
||||
assert "test_source" in raw_system_message_after_removal
|
||||
assert "<directories>" in raw_system_message_after_removal
|
||||
# verify file-specific details are also removed
|
||||
assert f'name="test_source/{file_name}"' not in raw_system_message_after_removal
|
||||
|
||||
|
||||
def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
@@ -224,6 +264,25 @@ def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone,
|
||||
|
||||
# Attach after uploading the file
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
|
||||
raw_system_message = get_raw_system_message(client, agent_state.id)
|
||||
|
||||
# Assert that the expected chunk is in the raw system message
|
||||
expected_chunk = """<directories>
|
||||
<directory name="test_source">
|
||||
<file status="open" name="test_source/test.txt">
|
||||
<metadata>
|
||||
- read_only=true
|
||||
- chars_current=46
|
||||
- chars_limit=50000
|
||||
</metadata>
|
||||
<value>
|
||||
[Viewing file start (out of 1 chunks)]
|
||||
1: test
|
||||
</value>
|
||||
</file>
|
||||
</directory>
|
||||
</directories>"""
|
||||
assert expected_chunk in raw_system_message
|
||||
|
||||
# Get the agent state, check blocks exist
|
||||
agent_state = client.agents.retrieve(agent_id=agent_state.id)
|
||||
@@ -241,20 +300,46 @@ def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone,
|
||||
assert len(blocks) == 0
|
||||
assert not any("test" in b.value for b in blocks)
|
||||
|
||||
# Verify no traces of the prompt exist in the raw system message after detaching
|
||||
raw_system_message_after_detach = get_raw_system_message(client, agent_state.id)
|
||||
assert expected_chunk not in raw_system_message_after_detach
|
||||
assert "test_source" not in raw_system_message_after_detach
|
||||
assert "<directories>" not in raw_system_message_after_detach
|
||||
|
||||
|
||||
def test_delete_source_removes_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
assert len(client.sources.list()) == 1
|
||||
|
||||
# Attach
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
|
||||
raw_system_message = get_raw_system_message(client, agent_state.id)
|
||||
assert "test_source" in raw_system_message
|
||||
assert "<directories>" in raw_system_message
|
||||
|
||||
# Load files into the source
|
||||
file_path = "tests/data/test.txt"
|
||||
|
||||
# Upload the files
|
||||
upload_file_and_wait(client, source.id, file_path)
|
||||
raw_system_message = get_raw_system_message(client, agent_state.id)
|
||||
# Assert that the expected chunk is in the raw system message
|
||||
expected_chunk = """<directories>
|
||||
<directory name="test_source">
|
||||
<file status="open" name="test_source/test.txt">
|
||||
<metadata>
|
||||
- read_only=true
|
||||
- chars_current=46
|
||||
- chars_limit=50000
|
||||
</metadata>
|
||||
<value>
|
||||
[Viewing file start (out of 1 chunks)]
|
||||
1: test
|
||||
</value>
|
||||
</file>
|
||||
</directory>
|
||||
</directories>"""
|
||||
assert expected_chunk in raw_system_message
|
||||
|
||||
# Get the agent state, check blocks exist
|
||||
agent_state = client.agents.retrieve(agent_id=agent_state.id)
|
||||
@@ -264,6 +349,10 @@ def test_delete_source_removes_source_blocks_correctly(disable_pinecone, client:
|
||||
|
||||
# Remove file from source
|
||||
client.sources.delete(source_id=source.id)
|
||||
raw_system_message_after_detach = get_raw_system_message(client, agent_state.id)
|
||||
assert expected_chunk not in raw_system_message_after_detach
|
||||
assert "test_source" not in raw_system_message_after_detach
|
||||
assert "<directories>" not in raw_system_message_after_detach
|
||||
|
||||
# Get the agent state, check blocks do NOT exist
|
||||
agent_state = client.agents.retrieve(agent_id=agent_state.id)
|
||||
|
||||
Reference in New Issue
Block a user