feat: Various file fixes and improvements (#3125)

This commit is contained in:
Matthew Zhou
2025-07-01 15:21:52 -07:00
committed by GitHub
parent 3654fa8c26
commit 2263ffd07c
11 changed files with 99 additions and 60 deletions

View File

@@ -98,9 +98,13 @@ class BaseAgent(ABC):
# [DB Call] loading blocks (modifies: agent_state.memory.blocks)
await self.agent_manager.refresh_memory_async(agent_state=agent_state, actor=self.actor)
tool_constraint_block = None
if tool_rules_solver is not None:
tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts()
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
curr_system_message = in_context_messages[0]
curr_memory_str = agent_state.memory.compile()
curr_memory_str = agent_state.memory.compile(tool_usage_rules=tool_constraint_block, sources=agent_state.sources)
curr_system_message_text = curr_system_message.content[0].text
if curr_memory_str in curr_system_message_text:
logger.debug(

View File

@@ -10,15 +10,20 @@ if TYPE_CHECKING:
async def open_files(agent_state: "AgentState", file_requests: List[FileOpenRequest], close_all_others: bool = False) -> str:
"""Open one or more files and load their contents into files section in core memory. Maximum of 5 files can be opened simultaneously.
Use this when you want to:
- Inspect or reference file contents during reasoning
- View specific portions of large files (e.g. functions or definitions)
- Replace currently open files with a new set for focused context (via `close_all_others=True`)
Examples:
Open single file (entire content):
file_requests = [FileOpenRequest(file_name="config.py")]
Open single file belonging to a directory named `project_utils` (entire content):
file_requests = [FileOpenRequest(file_name="project_utils/config.py")]
Open multiple files with different view ranges:
file_requests = [
FileOpenRequest(file_name="config.py", offset=1, length=50), # Lines 1-50
FileOpenRequest(file_name="main.py", offset=100, length=100), # Lines 100-199
FileOpenRequest(file_name="utils.py") # Entire file
FileOpenRequest(file_name="project_utils/config.py", offset=1, length=50), # Lines 1-50
FileOpenRequest(file_name="project_utils/main.py", offset=100, length=100), # Lines 100-199
FileOpenRequest(file_name="project_utils/utils.py") # Entire file
]
Close all other files and open new ones:
@@ -43,6 +48,11 @@ async def grep_files(
"""
Grep tool to search files across data sources using a keyword or regex pattern.
Use this when you want to:
- Quickly find occurrences of a variable, function, or keyword
- Locate log messages, error codes, or TODOs across files
- Understand surrounding code by including `context_lines`
Args:
pattern (str): Keyword or regex pattern to search within file contents.
include (Optional[str]): Optional keyword or regex pattern to filter filenames to include in the search.
@@ -57,7 +67,12 @@ async def grep_files(
async def search_files(agent_state: "AgentState", query: str) -> List["FileMetadata"]:
"""
Get list of most relevant files across all data sources using embedding search.
Get list of most relevant chunks from any file using embedding search.
Use this when you want to:
- Find related content that may not match exact keywords (e.g., conceptually similar sections)
- Look up high-level descriptions, documentation, or config patterns
- Perform fuzzy search when grep isn't sufficient
Args:
query (str): The search query.

View File

@@ -14,5 +14,5 @@ class FileOpenRequest(BaseModel):
default=None, description="Optional starting line number (1-indexed). If not specified, starts from beginning of file."
)
length: Optional[int] = Field(
default=None, description="Optional number of lines to view from offset. If not specified, views to end of file."
default=None, description="Optional number of lines to view from offset (inclusive). If not specified, views to end of file."
)

View File

@@ -82,7 +82,7 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs):
cascade="all, delete-orphan",
)
async def to_pydantic_async(self, include_content: bool = False) -> PydanticFileMetadata:
async def to_pydantic_async(self, include_content: bool = False, strip_directory_prefix: bool = False) -> PydanticFileMetadata:
"""
Async version of `to_pydantic` that supports optional relationship loading
without requiring `expire_on_commit=False`.
@@ -95,11 +95,15 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs):
else:
content_text = None
file_name = self.file_name
if strip_directory_prefix:
file_name = "/".join(file_name.split("/")[1:])
return PydanticFileMetadata(
id=self.id,
organization_id=self.organization_id,
source_id=self.source_id,
file_name=self.file_name,
file_name=file_name,
original_file_name=self.original_file_name,
file_path=self.file_path,
file_type=self.file_type,

View File

@@ -326,7 +326,7 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
"{% endif %}"
"{% if file_blocks %}"
"{% for block in file_blocks %}"
"{% if block.metadata['source_id'] == source.id %}"
"{% if block.metadata and block.metadata.get('source_id') == source.id %}"
f"<file status=\"{{{{ '{FileStatus.open.value}' if block.value else '{FileStatus.closed.value}' }}}}\">\n"
"<{{ block.label }}>\n"
"<description>\n"
@@ -393,7 +393,7 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
"{% endif %}"
"{% if file_blocks %}"
"{% for block in file_blocks %}"
"{% if block.metadata['source_id'] == source.id %}"
"{% if block.metadata and block.metadata.get('source_id') == source.id %}"
f"<file status=\"{{{{ '{FileStatus.open.value}' if block.value else '{FileStatus.closed.value}' }}}}\" name=\"{{{{ block.label }}}}\">\n"
"{% if block.description %}"
"<description>\n"
@@ -459,7 +459,7 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
"{% endif %}"
"{% if file_blocks %}"
"{% for block in file_blocks %}"
"{% if block.metadata['source_id'] == source.id %}"
"{% if block.metadata and block.metadata.get('source_id') == source.id %}"
f"<file status=\"{{{{ '{FileStatus.open.value}' if block.value else '{FileStatus.closed.value}' }}}}\" name=\"{{{{ block.label }}}}\">\n"
"{% if block.description %}"
"<description>\n"

View File

@@ -135,8 +135,13 @@ class Memory(BaseModel, validate_assignment=True):
def compile(self, tool_usage_rules=None, sources=None) -> str:
"""Generate a string representation of the memory in-context using the Jinja2 template"""
template = Template(self.prompt_template)
return template.render(blocks=self.blocks, file_blocks=self.file_blocks, tool_usage_rules=tool_usage_rules, sources=sources)
try:
template = Template(self.prompt_template)
return template.render(blocks=self.blocks, file_blocks=self.file_blocks, tool_usage_rules=tool_usage_rules, sources=sources)
except TemplateSyntaxError as e:
raise ValueError(f"Invalid Jinja2 template syntax: {str(e)}")
except Exception as e:
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
def list_block_labels(self) -> List[str]:
"""Return a list of the block names held inside the memory object"""

View File

@@ -237,7 +237,7 @@ async def upload_file_to_source(
# Store original filename and generate unique filename
original_filename = sanitize_filename(file.filename) # Basic sanitization only
unique_filename = await server.file_manager.generate_unique_filename(
original_filename=original_filename, source_id=source_id, organization_id=actor.organization_id
original_filename=original_filename, source=source, organization_id=actor.organization_id
)
# create file metadata
@@ -308,6 +308,7 @@ async def list_source_files(
after=after,
actor=actor,
include_content=include_content,
strip_directory_prefix=True, # TODO: Reconsider this. This is purely for aesthetics.
)
@@ -330,7 +331,9 @@ async def get_file_metadata(
raise HTTPException(status_code=404, detail=f"Source with id={source_id} not found.")
# Get file metadata using the file manager
file_metadata = await server.file_manager.get_file_by_id(file_id=file_id, actor=actor, include_content=include_content)
file_metadata = await server.file_manager.get_file_by_id(
file_id=file_id, actor=actor, include_content=include_content, strip_directory_prefix=True
)
if not file_metadata:
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")

View File

@@ -1435,7 +1435,7 @@ class AgentManager:
# note: we only update the system prompt if the core memory is changed
# this means that the archival/recall memory statistics may be someout out of date
curr_memory_str = agent_state.memory.compile()
curr_memory_str = agent_state.memory.compile(sources=agent_state.sources)
if curr_memory_str in curr_system_message_openai["content"] and not force:
# NOTE: could this cause issues if a block is removed? (substring match would still work)
logger.debug(
@@ -1512,7 +1512,9 @@ class AgentManager:
# note: we only update the system prompt if the core memory is changed
# this means that the archival/recall memory statistics may be someout out of date
curr_memory_str = agent_state.memory.compile()
curr_memory_str = agent_state.memory.compile(
sources=agent_state.sources, tool_usage_rules=tool_rules_solver.compile_tool_rule_prompts()
)
if curr_memory_str in curr_system_message_openai["content"] and not force:
# NOTE: could this cause issues if a block is removed? (substring match would still work)
logger.debug(
@@ -1693,9 +1695,13 @@ class AgentManager:
Returns:
modified (bool): whether the memory was updated
"""
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor, include_relationships=["memory", "sources"])
system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
if new_memory.compile() not in system_message.content[0].text:
temp_tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
if (
new_memory.compile(sources=agent_state.sources, tool_usage_rules=temp_tool_rules_solver.compile_tool_rule_prompts())
not in system_message.content[0].text
):
# update the blocks (LRW) in the DB
for label in agent_state.memory.list_block_labels():
updated_value = new_memory.get_block(label).value

View File

@@ -15,6 +15,7 @@ from letta.orm.sqlalchemy_base import AccessType
from letta.otel.tracing import trace_method
from letta.schemas.enums import FileProcessingStatus
from letta.schemas.file import FileMetadata as PydanticFileMetadata
from letta.schemas.source import Source as PydanticSource
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types
@@ -60,11 +61,7 @@ class FileManager:
@enforce_types
@trace_method
async def get_file_by_id(
self,
file_id: str,
actor: Optional[PydanticUser] = None,
*,
include_content: bool = False,
self, file_id: str, actor: Optional[PydanticUser] = None, *, include_content: bool = False, strip_directory_prefix: bool = False
) -> Optional[PydanticFileMetadata]:
"""Retrieve a file by its ID.
@@ -98,7 +95,7 @@ class FileManager:
actor=actor,
)
return await file_orm.to_pydantic_async(include_content=include_content)
return await file_orm.to_pydantic_async(include_content=include_content, strip_directory_prefix=strip_directory_prefix)
except NoResultFound:
return None
@@ -195,7 +192,13 @@ class FileManager:
@enforce_types
@trace_method
async def list_files(
self, source_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, include_content: bool = False
self,
source_id: str,
actor: PydanticUser,
after: Optional[str] = None,
limit: Optional[int] = 50,
include_content: bool = False,
strip_directory_prefix: bool = False,
) -> List[PydanticFileMetadata]:
"""List all files with optional pagination."""
async with db_registry.async_session() as session:
@@ -209,7 +212,10 @@ class FileManager:
source_id=source_id,
query_options=options,
)
return [await file.to_pydantic_async(include_content=include_content) for file in files]
return [
await file.to_pydantic_async(include_content=include_content, strip_directory_prefix=strip_directory_prefix)
for file in files
]
@enforce_types
@trace_method
@@ -222,7 +228,7 @@ class FileManager:
@enforce_types
@trace_method
async def generate_unique_filename(self, original_filename: str, source_id: str, organization_id: str) -> str:
async def generate_unique_filename(self, original_filename: str, source: PydanticSource, organization_id: str) -> str:
"""
Generate a unique filename by checking for duplicates and adding a numeric suffix if needed.
Similar to how filesystems handle duplicates (e.g., file.txt, file (1).txt, file (2).txt).
@@ -247,7 +253,7 @@ class FileManager:
# Count existing files with the same original_file_name in this source
query = select(func.count(FileMetadataModel.id)).where(
FileMetadataModel.original_file_name == original_filename,
FileMetadataModel.source_id == source_id,
FileMetadataModel.source_id == source.id,
FileMetadataModel.organization_id == organization_id,
FileMetadataModel.is_deleted == False,
)
@@ -255,8 +261,8 @@ class FileManager:
count = result.scalar() or 0
if count == 0:
# No duplicates, return original filename
return original_filename
# No duplicates, return original filename with source.name
return f"{source.name}/{original_filename}"
else:
# Add numeric suffix
return f"{base} ({count}){ext}"
return f"{source.name}/{base}_({count}){ext}"

View File

@@ -4835,8 +4835,8 @@ async def test_create_source(server: SyncServer, default_user, event_loop):
@pytest.mark.asyncio
async def test_create_sources_with_same_name_does_not_error(server: SyncServer, default_user):
"""Test creating a new source."""
async def test_create_sources_with_same_name_raises_error(server: SyncServer, default_user):
"""Test that creating sources with the same name raises an IntegrityError due to unique constraint."""
name = "Test Source"
source_pydantic = PydanticSource(
name=name,
@@ -4845,16 +4845,16 @@ async def test_create_sources_with_same_name_does_not_error(server: SyncServer,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
# Attempting to create another source with the same name should raise an IntegrityError
source_pydantic = PydanticSource(
name=name,
description="This is a different test source.",
metadata={"type": "legal"},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
)
same_source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
assert source.name == same_source.name
assert source.id != same_source.id
with pytest.raises(UniqueConstraintViolationError):
await server.source_manager.create_source(source=source_pydantic, actor=default_user)
@pytest.mark.asyncio

View File

@@ -150,17 +150,17 @@ def test_auto_attach_detach_files_tools(client: LettaSDKClient):
@pytest.mark.parametrize(
"file_path, expected_value, expected_label_regex",
[
("tests/data/test.txt", "test", r"test\.txt"),
("tests/data/memgpt_paper.pdf", "MemGPT", r"memgpt_paper\.pdf"),
("tests/data/toy_chat_fine_tuning.jsonl", '{"messages"', r"toy_chat_fine_tuning\.jsonl"),
("tests/data/test.md", "h2 Heading", r"test\.md"),
("tests/data/test.json", "glossary", r"test\.json"),
("tests/data/react_component.jsx", "UserProfile", r"react_component\.jsx"),
("tests/data/task_manager.java", "TaskManager", r"task_manager\.java"),
("tests/data/data_structures.cpp", "BinarySearchTree", r"data_structures\.cpp"),
("tests/data/api_server.go", "UserService", r"api_server\.go"),
("tests/data/data_analysis.py", "StatisticalAnalyzer", r"data_analysis\.py"),
("tests/data/test.csv", "Smart Fridge Plus", r"test\.csv"),
("tests/data/test.txt", "test", r"test_source/test\.txt"),
("tests/data/memgpt_paper.pdf", "MemGPT", r"test_source/memgpt_paper\.pdf"),
("tests/data/toy_chat_fine_tuning.jsonl", '{"messages"', r"test_source/toy_chat_fine_tuning\.jsonl"),
("tests/data/test.md", "h2 Heading", r"test_source/test\.md"),
("tests/data/test.json", "glossary", r"test_source/test\.json"),
("tests/data/react_component.jsx", "UserProfile", r"test_source/react_component\.jsx"),
("tests/data/task_manager.java", "TaskManager", r"test_source/task_manager\.java"),
("tests/data/data_structures.cpp", "BinarySearchTree", r"test_source/data_structures\.cpp"),
("tests/data/api_server.go", "UserService", r"test_source/api_server\.go"),
("tests/data/data_analysis.py", "StatisticalAnalyzer", r"test_source/data_analysis\.py"),
("tests/data/test.csv", "Smart Fridge Plus", r"test_source/test\.csv"),
],
)
def test_file_upload_creates_source_blocks_correctly(
@@ -229,7 +229,6 @@ def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKC
assert len(blocks) == 1
assert any("test" in b.value for b in blocks)
assert any(b.value.startswith("[Viewing file start") for b in blocks)
assert any(re.fullmatch(r"test\.txt", b.label) for b in blocks)
# Detach the source
client.agents.sources.detach(source_id=source.id, agent_id=agent_state.id)
@@ -239,7 +238,6 @@ def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKC
blocks = agent_state.memory.file_blocks
assert len(blocks) == 0
assert not any("test" in b.value for b in blocks)
assert not any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState):
@@ -261,7 +259,6 @@ def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, a
blocks = agent_state.memory.file_blocks
assert len(blocks) == 1
assert any("test" in b.value for b in blocks)
assert any(re.fullmatch(r"test\.txt", b.label) for b in blocks)
# Remove file from source
client.sources.delete(source_id=source.id)
@@ -271,7 +268,6 @@ def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, a
blocks = agent_state.memory.file_blocks
assert len(blocks) == 0
assert not any("test" in b.value for b in blocks)
assert not any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_state: AgentState):
@@ -314,7 +310,7 @@ def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_stat
messages=[
MessageCreate(
role="user",
content=f"Use ONLY the open_files tool to open the file named {file.file_name} with offset {offset} and length {length}",
content=f"Use ONLY the open_files tool to open the file named test_source/{file.file_name} with offset {offset} and length {length}",
)
],
)
@@ -556,7 +552,6 @@ def test_create_agent_with_source_ids_creates_source_blocks_correctly(client: Le
blocks = temp_agent_state.memory.file_blocks
assert len(blocks) == 1
assert any(b.value.startswith("[Viewing file start (out of 554 chunks)]") for b in blocks)
assert any(re.fullmatch(r"long_test\.txt", b.label) for b in blocks)
# Verify file tools were automatically attached
file_tools = {tool.name for tool in temp_agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE}
@@ -600,7 +595,7 @@ def test_view_ranges_have_metadata(client: LettaSDKClient, agent_state: AgentSta
messages=[
MessageCreate(
role="user",
content=f"Use ONLY the open_files tool to open the file named {file.file_name} with offset {offset} and length {length}",
content=f"Use ONLY the open_files tool to open the file named test_source/{file.file_name} with offset {offset} and length {length}",
)
],
)
@@ -651,7 +646,7 @@ def test_duplicate_file_renaming(client: LettaSDKClient):
files.sort(key=lambda f: f.created_at)
# Verify filenames follow the count-based pattern
expected_filenames = ["test.txt", "test (1).txt", "test (2).txt"]
expected_filenames = ["test.txt", "test_(1).txt", "test_(2).txt"]
actual_filenames = [f.file_name for f in files]
assert actual_filenames == expected_filenames, f"Expected {expected_filenames}, got {actual_filenames}"
@@ -692,6 +687,7 @@ def test_open_files_schema_descriptions(client: LettaSDKClient):
assert "# Lines 100-199" in description
assert "# Entire file" in description
assert "close_all_others=True" in description
assert "View specific portions of large files (e.g. functions or definitions)" in description
# Check parameters structure
assert "parameters" in schema
@@ -742,6 +738,6 @@ def test_open_files_schema_descriptions(client: LettaSDKClient):
# Check length field
assert "length" in file_request_properties
length_prop = file_request_properties["length"]
expected_length_desc = "Optional number of lines to view from offset. If not specified, views to end of file."
expected_length_desc = "Optional number of lines to view from offset (inclusive). If not specified, views to end of file."
assert length_prop["description"] == expected_length_desc
assert length_prop["type"] == "integer"