From e495fb2ef2dcf1562aaacde5ae8c423204999f21 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 1 Jul 2025 15:21:52 -0700 Subject: [PATCH] feat: Various file fixes and improvements (#3125) --- letta/agents/base_agent.py | 6 +++- letta/functions/function_sets/files.py | 27 ++++++++++++---- letta/functions/types.py | 2 +- letta/orm/file.py | 8 +++-- letta/schemas/agent.py | 6 ++-- letta/schemas/memory.py | 9 ++++-- letta/server/rest_api/routers/v1/sources.py | 7 ++-- letta/services/agent_manager.py | 14 +++++--- letta/services/file_manager.py | 32 ++++++++++-------- tests/test_managers.py | 12 +++---- tests/test_sources.py | 36 +++++++++------------ 11 files changed, 99 insertions(+), 60 deletions(-) diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index a752474a..b8613201 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -98,9 +98,13 @@ class BaseAgent(ABC): # [DB Call] loading blocks (modifies: agent_state.memory.blocks) await self.agent_manager.refresh_memory_async(agent_state=agent_state, actor=self.actor) + tool_constraint_block = None + if tool_rules_solver is not None: + tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts() + # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this curr_system_message = in_context_messages[0] - curr_memory_str = agent_state.memory.compile() + curr_memory_str = agent_state.memory.compile(tool_usage_rules=tool_constraint_block, sources=agent_state.sources) curr_system_message_text = curr_system_message.content[0].text if curr_memory_str in curr_system_message_text: logger.debug( diff --git a/letta/functions/function_sets/files.py b/letta/functions/function_sets/files.py index 50b88eed..9e62807e 100644 --- a/letta/functions/function_sets/files.py +++ b/letta/functions/function_sets/files.py @@ -10,15 +10,20 @@ if TYPE_CHECKING: async def open_files(agent_state: "AgentState", file_requests: List[FileOpenRequest], close_all_others: bool = False) -> str: """Open one or more files and load their contents into files section in core memory. Maximum of 5 files can be opened simultaneously. + Use this when you want to: + - Inspect or reference file contents during reasoning + - View specific portions of large files (e.g. functions or definitions) + - Replace currently open files with a new set for focused context (via `close_all_others=True`) + Examples: - Open single file (entire content): - file_requests = [FileOpenRequest(file_name="config.py")] + Open single file belonging to a directory named `project_utils` (entire content): + file_requests = [FileOpenRequest(file_name="project_utils/config.py")] Open multiple files with different view ranges: file_requests = [ - FileOpenRequest(file_name="config.py", offset=1, length=50), # Lines 1-50 - FileOpenRequest(file_name="main.py", offset=100, length=100), # Lines 100-199 - FileOpenRequest(file_name="utils.py") # Entire file + FileOpenRequest(file_name="project_utils/config.py", offset=1, length=50), # Lines 1-50 + FileOpenRequest(file_name="project_utils/main.py", offset=100, length=100), # Lines 100-199 + FileOpenRequest(file_name="project_utils/utils.py") # Entire file ] Close all other files and open new ones: @@ -43,6 +48,11 @@ async def grep_files( """ Grep tool to search files across data sources using a keyword or regex pattern. + Use this when you want to: + - Quickly find occurrences of a variable, function, or keyword + - Locate log messages, error codes, or TODOs across files + - Understand surrounding code by including `context_lines` + Args: pattern (str): Keyword or regex pattern to search within file contents. include (Optional[str]): Optional keyword or regex pattern to filter filenames to include in the search. @@ -57,7 +67,12 @@ async def grep_files( async def search_files(agent_state: "AgentState", query: str) -> List["FileMetadata"]: """ - Get list of most relevant files across all data sources using embedding search. + Get list of most relevant chunks from any file using embedding search. + + Use this when you want to: + - Find related content that may not match exact keywords (e.g., conceptually similar sections) + - Look up high-level descriptions, documentation, or config patterns + - Perform fuzzy search when grep isn't sufficient Args: query (str): The search query. diff --git a/letta/functions/types.py b/letta/functions/types.py index 85978f71..af464657 100644 --- a/letta/functions/types.py +++ b/letta/functions/types.py @@ -14,5 +14,5 @@ class FileOpenRequest(BaseModel): default=None, description="Optional starting line number (1-indexed). If not specified, starts from beginning of file." ) length: Optional[int] = Field( - default=None, description="Optional number of lines to view from offset. If not specified, views to end of file." + default=None, description="Optional number of lines to view from offset (inclusive). If not specified, views to end of file." ) diff --git a/letta/orm/file.py b/letta/orm/file.py index f2c83a6f..0552763d 100644 --- a/letta/orm/file.py +++ b/letta/orm/file.py @@ -82,7 +82,7 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs): cascade="all, delete-orphan", ) - async def to_pydantic_async(self, include_content: bool = False) -> PydanticFileMetadata: + async def to_pydantic_async(self, include_content: bool = False, strip_directory_prefix: bool = False) -> PydanticFileMetadata: """ Async version of `to_pydantic` that supports optional relationship loading without requiring `expire_on_commit=False`. @@ -95,11 +95,15 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs): else: content_text = None + file_name = self.file_name + if strip_directory_prefix: + file_name = "/".join(file_name.split("/")[1:]) + return PydanticFileMetadata( id=self.id, organization_id=self.organization_id, source_id=self.source_id, - file_name=self.file_name, + file_name=file_name, original_file_name=self.original_file_name, file_path=self.file_path, file_type=self.file_type, diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 83885249..8bceac98 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -326,7 +326,7 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None): "{% endif %}" "{% if file_blocks %}" "{% for block in file_blocks %}" - "{% if block.metadata['source_id'] == source.id %}" + "{% if block.metadata and block.metadata.get('source_id') == source.id %}" f"\n" "<{{ block.label }}>\n" "\n" @@ -393,7 +393,7 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None): "{% endif %}" "{% if file_blocks %}" "{% for block in file_blocks %}" - "{% if block.metadata['source_id'] == source.id %}" + "{% if block.metadata and block.metadata.get('source_id') == source.id %}" f"\n" "{% if block.description %}" "\n" @@ -459,7 +459,7 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None): "{% endif %}" "{% if file_blocks %}" "{% for block in file_blocks %}" - "{% if block.metadata['source_id'] == source.id %}" + "{% if block.metadata and block.metadata.get('source_id') == source.id %}" f"\n" "{% if block.description %}" "\n" diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index 18105990..97658393 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -135,8 +135,13 @@ class Memory(BaseModel, validate_assignment=True): def compile(self, tool_usage_rules=None, sources=None) -> str: """Generate a string representation of the memory in-context using the Jinja2 template""" - template = Template(self.prompt_template) - return template.render(blocks=self.blocks, file_blocks=self.file_blocks, tool_usage_rules=tool_usage_rules, sources=sources) + try: + template = Template(self.prompt_template) + return template.render(blocks=self.blocks, file_blocks=self.file_blocks, tool_usage_rules=tool_usage_rules, sources=sources) + except TemplateSyntaxError as e: + raise ValueError(f"Invalid Jinja2 template syntax: {str(e)}") + except Exception as e: + raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}") def list_block_labels(self) -> List[str]: """Return a list of the block names held inside the memory object""" diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index 68270735..b6e54b6b 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -237,7 +237,7 @@ async def upload_file_to_source( # Store original filename and generate unique filename original_filename = sanitize_filename(file.filename) # Basic sanitization only unique_filename = await server.file_manager.generate_unique_filename( - original_filename=original_filename, source_id=source_id, organization_id=actor.organization_id + original_filename=original_filename, source=source, organization_id=actor.organization_id ) # create file metadata @@ -308,6 +308,7 @@ async def list_source_files( after=after, actor=actor, include_content=include_content, + strip_directory_prefix=True, # TODO: Reconsider this. This is purely for aesthetics. ) @@ -330,7 +331,9 @@ async def get_file_metadata( raise HTTPException(status_code=404, detail=f"Source with id={source_id} not found.") # Get file metadata using the file manager - file_metadata = await server.file_manager.get_file_by_id(file_id=file_id, actor=actor, include_content=include_content) + file_metadata = await server.file_manager.get_file_by_id( + file_id=file_id, actor=actor, include_content=include_content, strip_directory_prefix=True + ) if not file_metadata: raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.") diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 3102428e..ac41f37d 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1435,7 +1435,7 @@ class AgentManager: # note: we only update the system prompt if the core memory is changed # this means that the archival/recall memory statistics may be someout out of date - curr_memory_str = agent_state.memory.compile() + curr_memory_str = agent_state.memory.compile(sources=agent_state.sources) if curr_memory_str in curr_system_message_openai["content"] and not force: # NOTE: could this cause issues if a block is removed? (substring match would still work) logger.debug( @@ -1512,7 +1512,9 @@ class AgentManager: # note: we only update the system prompt if the core memory is changed # this means that the archival/recall memory statistics may be someout out of date - curr_memory_str = agent_state.memory.compile() + curr_memory_str = agent_state.memory.compile( + sources=agent_state.sources, tool_usage_rules=tool_rules_solver.compile_tool_rule_prompts() + ) if curr_memory_str in curr_system_message_openai["content"] and not force: # NOTE: could this cause issues if a block is removed? (substring match would still work) logger.debug( @@ -1693,9 +1695,13 @@ class AgentManager: Returns: modified (bool): whether the memory was updated """ - agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) + agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor, include_relationships=["memory", "sources"]) system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor) - if new_memory.compile() not in system_message.content[0].text: + temp_tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) + if ( + new_memory.compile(sources=agent_state.sources, tool_usage_rules=temp_tool_rules_solver.compile_tool_rule_prompts()) + not in system_message.content[0].text + ): # update the blocks (LRW) in the DB for label in agent_state.memory.list_block_labels(): updated_value = new_memory.get_block(label).value diff --git a/letta/services/file_manager.py b/letta/services/file_manager.py index 17c6d3bd..fbfa44e8 100644 --- a/letta/services/file_manager.py +++ b/letta/services/file_manager.py @@ -15,6 +15,7 @@ from letta.orm.sqlalchemy_base import AccessType from letta.otel.tracing import trace_method from letta.schemas.enums import FileProcessingStatus from letta.schemas.file import FileMetadata as PydanticFileMetadata +from letta.schemas.source import Source as PydanticSource from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.utils import enforce_types @@ -60,11 +61,7 @@ class FileManager: @enforce_types @trace_method async def get_file_by_id( - self, - file_id: str, - actor: Optional[PydanticUser] = None, - *, - include_content: bool = False, + self, file_id: str, actor: Optional[PydanticUser] = None, *, include_content: bool = False, strip_directory_prefix: bool = False ) -> Optional[PydanticFileMetadata]: """Retrieve a file by its ID. @@ -98,7 +95,7 @@ class FileManager: actor=actor, ) - return await file_orm.to_pydantic_async(include_content=include_content) + return await file_orm.to_pydantic_async(include_content=include_content, strip_directory_prefix=strip_directory_prefix) except NoResultFound: return None @@ -195,7 +192,13 @@ class FileManager: @enforce_types @trace_method async def list_files( - self, source_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, include_content: bool = False + self, + source_id: str, + actor: PydanticUser, + after: Optional[str] = None, + limit: Optional[int] = 50, + include_content: bool = False, + strip_directory_prefix: bool = False, ) -> List[PydanticFileMetadata]: """List all files with optional pagination.""" async with db_registry.async_session() as session: @@ -209,7 +212,10 @@ class FileManager: source_id=source_id, query_options=options, ) - return [await file.to_pydantic_async(include_content=include_content) for file in files] + return [ + await file.to_pydantic_async(include_content=include_content, strip_directory_prefix=strip_directory_prefix) + for file in files + ] @enforce_types @trace_method @@ -222,7 +228,7 @@ class FileManager: @enforce_types @trace_method - async def generate_unique_filename(self, original_filename: str, source_id: str, organization_id: str) -> str: + async def generate_unique_filename(self, original_filename: str, source: PydanticSource, organization_id: str) -> str: """ Generate a unique filename by checking for duplicates and adding a numeric suffix if needed. Similar to how filesystems handle duplicates (e.g., file.txt, file (1).txt, file (2).txt). @@ -247,7 +253,7 @@ class FileManager: # Count existing files with the same original_file_name in this source query = select(func.count(FileMetadataModel.id)).where( FileMetadataModel.original_file_name == original_filename, - FileMetadataModel.source_id == source_id, + FileMetadataModel.source_id == source.id, FileMetadataModel.organization_id == organization_id, FileMetadataModel.is_deleted == False, ) @@ -255,8 +261,8 @@ class FileManager: count = result.scalar() or 0 if count == 0: - # No duplicates, return original filename - return original_filename + # No duplicates, return original filename with source.name + return f"{source.name}/{original_filename}" else: # Add numeric suffix - return f"{base} ({count}){ext}" + return f"{source.name}/{base}_({count}){ext}" diff --git a/tests/test_managers.py b/tests/test_managers.py index d574bf1b..cdd9c0a4 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -4835,8 +4835,8 @@ async def test_create_source(server: SyncServer, default_user, event_loop): @pytest.mark.asyncio -async def test_create_sources_with_same_name_does_not_error(server: SyncServer, default_user): - """Test creating a new source.""" +async def test_create_sources_with_same_name_raises_error(server: SyncServer, default_user): + """Test that creating sources with the same name raises an IntegrityError due to unique constraint.""" name = "Test Source" source_pydantic = PydanticSource( name=name, @@ -4845,16 +4845,16 @@ async def test_create_sources_with_same_name_does_not_error(server: SyncServer, embedding_config=DEFAULT_EMBEDDING_CONFIG, ) source = await server.source_manager.create_source(source=source_pydantic, actor=default_user) + + # Attempting to create another source with the same name should raise an IntegrityError source_pydantic = PydanticSource( name=name, description="This is a different test source.", metadata={"type": "legal"}, embedding_config=DEFAULT_EMBEDDING_CONFIG, ) - same_source = await server.source_manager.create_source(source=source_pydantic, actor=default_user) - - assert source.name == same_source.name - assert source.id != same_source.id + with pytest.raises(UniqueConstraintViolationError): + await server.source_manager.create_source(source=source_pydantic, actor=default_user) @pytest.mark.asyncio diff --git a/tests/test_sources.py b/tests/test_sources.py index 45a7f2d8..c1aa9972 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -150,17 +150,17 @@ def test_auto_attach_detach_files_tools(client: LettaSDKClient): @pytest.mark.parametrize( "file_path, expected_value, expected_label_regex", [ - ("tests/data/test.txt", "test", r"test\.txt"), - ("tests/data/memgpt_paper.pdf", "MemGPT", r"memgpt_paper\.pdf"), - ("tests/data/toy_chat_fine_tuning.jsonl", '{"messages"', r"toy_chat_fine_tuning\.jsonl"), - ("tests/data/test.md", "h2 Heading", r"test\.md"), - ("tests/data/test.json", "glossary", r"test\.json"), - ("tests/data/react_component.jsx", "UserProfile", r"react_component\.jsx"), - ("tests/data/task_manager.java", "TaskManager", r"task_manager\.java"), - ("tests/data/data_structures.cpp", "BinarySearchTree", r"data_structures\.cpp"), - ("tests/data/api_server.go", "UserService", r"api_server\.go"), - ("tests/data/data_analysis.py", "StatisticalAnalyzer", r"data_analysis\.py"), - ("tests/data/test.csv", "Smart Fridge Plus", r"test\.csv"), + ("tests/data/test.txt", "test", r"test_source/test\.txt"), + ("tests/data/memgpt_paper.pdf", "MemGPT", r"test_source/memgpt_paper\.pdf"), + ("tests/data/toy_chat_fine_tuning.jsonl", '{"messages"', r"test_source/toy_chat_fine_tuning\.jsonl"), + ("tests/data/test.md", "h2 Heading", r"test_source/test\.md"), + ("tests/data/test.json", "glossary", r"test_source/test\.json"), + ("tests/data/react_component.jsx", "UserProfile", r"test_source/react_component\.jsx"), + ("tests/data/task_manager.java", "TaskManager", r"test_source/task_manager\.java"), + ("tests/data/data_structures.cpp", "BinarySearchTree", r"test_source/data_structures\.cpp"), + ("tests/data/api_server.go", "UserService", r"test_source/api_server\.go"), + ("tests/data/data_analysis.py", "StatisticalAnalyzer", r"test_source/data_analysis\.py"), + ("tests/data/test.csv", "Smart Fridge Plus", r"test_source/test\.csv"), ], ) def test_file_upload_creates_source_blocks_correctly( @@ -229,7 +229,6 @@ def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKC assert len(blocks) == 1 assert any("test" in b.value for b in blocks) assert any(b.value.startswith("[Viewing file start") for b in blocks) - assert any(re.fullmatch(r"test\.txt", b.label) for b in blocks) # Detach the source client.agents.sources.detach(source_id=source.id, agent_id=agent_state.id) @@ -239,7 +238,6 @@ def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKC blocks = agent_state.memory.file_blocks assert len(blocks) == 0 assert not any("test" in b.value for b in blocks) - assert not any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks) def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState): @@ -261,7 +259,6 @@ def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, a blocks = agent_state.memory.file_blocks assert len(blocks) == 1 assert any("test" in b.value for b in blocks) - assert any(re.fullmatch(r"test\.txt", b.label) for b in blocks) # Remove file from source client.sources.delete(source_id=source.id) @@ -271,7 +268,6 @@ def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, a blocks = agent_state.memory.file_blocks assert len(blocks) == 0 assert not any("test" in b.value for b in blocks) - assert not any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks) def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_state: AgentState): @@ -314,7 +310,7 @@ def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_stat messages=[ MessageCreate( role="user", - content=f"Use ONLY the open_files tool to open the file named {file.file_name} with offset {offset} and length {length}", + content=f"Use ONLY the open_files tool to open the file named test_source/{file.file_name} with offset {offset} and length {length}", ) ], ) @@ -556,7 +552,6 @@ def test_create_agent_with_source_ids_creates_source_blocks_correctly(client: Le blocks = temp_agent_state.memory.file_blocks assert len(blocks) == 1 assert any(b.value.startswith("[Viewing file start (out of 554 chunks)]") for b in blocks) - assert any(re.fullmatch(r"long_test\.txt", b.label) for b in blocks) # Verify file tools were automatically attached file_tools = {tool.name for tool in temp_agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE} @@ -600,7 +595,7 @@ def test_view_ranges_have_metadata(client: LettaSDKClient, agent_state: AgentSta messages=[ MessageCreate( role="user", - content=f"Use ONLY the open_files tool to open the file named {file.file_name} with offset {offset} and length {length}", + content=f"Use ONLY the open_files tool to open the file named test_source/{file.file_name} with offset {offset} and length {length}", ) ], ) @@ -651,7 +646,7 @@ def test_duplicate_file_renaming(client: LettaSDKClient): files.sort(key=lambda f: f.created_at) # Verify filenames follow the count-based pattern - expected_filenames = ["test.txt", "test (1).txt", "test (2).txt"] + expected_filenames = ["test.txt", "test_(1).txt", "test_(2).txt"] actual_filenames = [f.file_name for f in files] assert actual_filenames == expected_filenames, f"Expected {expected_filenames}, got {actual_filenames}" @@ -692,6 +687,7 @@ def test_open_files_schema_descriptions(client: LettaSDKClient): assert "# Lines 100-199" in description assert "# Entire file" in description assert "close_all_others=True" in description + assert "View specific portions of large files (e.g. functions or definitions)" in description # Check parameters structure assert "parameters" in schema @@ -742,6 +738,6 @@ def test_open_files_schema_descriptions(client: LettaSDKClient): # Check length field assert "length" in file_request_properties length_prop = file_request_properties["length"] - expected_length_desc = "Optional number of lines to view from offset. If not specified, views to end of file." + expected_length_desc = "Optional number of lines to view from offset (inclusive). If not specified, views to end of file." assert length_prop["description"] == expected_length_desc assert length_prop["type"] == "integer"