From 971433f9d5e5ec739b4815346e4752b76ee532ac Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 5 Jun 2025 17:02:28 -0700 Subject: [PATCH] feat: Add line metadata and warnings to file blocks (#2663) --- letta/orm/files_agents.py | 2 +- letta/services/agent_manager.py | 16 ++- .../file_processor/chunker/line_chunker.py | 34 ++++++ .../services/file_processor/file_processor.py | 8 +- letta/services/files_agents_manager.py | 37 +++++++ .../tool_executor/files_tool_executor.py | 13 +-- tests/data/lines_1_to_100.txt | 100 +++++++++++++++++ tests/test_sources.py | 104 ++++++++++++------ 8 files changed, 267 insertions(+), 47 deletions(-) create mode 100644 letta/services/file_processor/chunker/line_chunker.py create mode 100644 tests/data/lines_1_to_100.txt diff --git a/letta/orm/files_agents.py b/letta/orm/files_agents.py index a03b8334..02856635 100644 --- a/letta/orm/files_agents.py +++ b/letta/orm/files_agents.py @@ -69,7 +69,7 @@ class FileAgent(SqlalchemyBase, OrganizationMixin): # Truncate content and add warnings here when converting from FileAgent to Block if len(visible_content) > CORE_MEMORY_SOURCE_CHAR_LIMIT: - truncated_warning = f"\n{FILE_IS_TRUNCATED_WARNING}" + truncated_warning = f"...[TRUNCATED]\n{FILE_IS_TRUNCATED_WARNING}" visible_content = visible_content[: CORE_MEMORY_SOURCE_CHAR_LIMIT - len(truncated_warning)] visible_content += truncated_warning diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 6816e178..2f6c5612 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -63,6 +63,7 @@ from letta.server.db import db_registry from letta.services.block_manager import BlockManager from letta.services.context_window_calculator.context_window_calculator import ContextWindowCalculator from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, TiktokenCounter +from letta.services.files_agents_manager import FileAgentManager from letta.services.helpers.agent_manager_helper import ( _apply_filters, _apply_identity_filters, @@ -101,6 +102,7 @@ class AgentManager: self.message_manager = MessageManager() self.passage_manager = PassageManager() self.identity_manager = IdentityManager() + self.file_agent_manager = FileAgentManager() @staticmethod def _resolve_tools(session, names: Set[str], ids: Set[str], org_id: str) -> Tuple[Dict[str, str], Dict[str, str]]: @@ -1659,12 +1661,18 @@ class AgentManager: @trace_method @enforce_types async def refresh_memory_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState: + # TODO: This will NOT work for new blocks/file blocks added intra-step block_ids = [b.id for b in agent_state.memory.blocks] - if not block_ids: - return agent_state + file_block_names = [b.label for b in agent_state.memory.file_blocks] + + if block_ids: + blocks = await self.block_manager.get_all_blocks_by_ids_async(block_ids=[b.id for b in agent_state.memory.blocks], actor=actor) + agent_state.memory.blocks = [b for b in blocks if b is not None] + + if file_block_names: + file_blocks = await self.file_agent_manager.get_all_file_blocks_by_name(file_names=file_block_names, actor=actor) + agent_state.memory.file_blocks = [b for b in file_blocks if b is not None] - blocks = await self.block_manager.get_all_blocks_by_ids_async(block_ids=[b.id for b in agent_state.memory.blocks], actor=actor) - agent_state.memory.blocks = [b for b in blocks if b is not None] return agent_state # ====================================================================================================================== diff --git a/letta/services/file_processor/chunker/line_chunker.py b/letta/services/file_processor/chunker/line_chunker.py new file mode 100644 index 00000000..fcda9516 --- /dev/null +++ b/letta/services/file_processor/chunker/line_chunker.py @@ -0,0 +1,34 @@ +from typing import List, Optional + +from letta.log import get_logger + +logger = get_logger(__name__) + + +class LineChunker: + """Newline chunker""" + + def __init__(self): + pass + + # TODO: Make this more general beyond Mistral + def chunk_text(self, text: str, start: Optional[int] = None, end: Optional[int] = None) -> List[str]: + """Split lines""" + content_lines = [line.strip() for line in text.split("\n") if line.strip()] + total_lines = len(content_lines) + + if start and end: + content_lines = content_lines[start:end] + line_offset = start + else: + line_offset = 0 + + content_lines = [f"Line {i + line_offset}: {line}" for i, line in enumerate(content_lines)] + + # Add metadata about total lines + if start and end: + content_lines.insert(0, f"[Viewing lines {start} to {end} (out of {total_lines} lines)]") + else: + content_lines.insert(0, f"[Viewing file start (out of {total_lines} lines)]") + + return content_lines diff --git a/letta/services/file_processor/file_processor.py b/letta/services/file_processor/file_processor.py index a580ff7a..dde76546 100644 --- a/letta/services/file_processor/file_processor.py +++ b/letta/services/file_processor/file_processor.py @@ -11,6 +11,7 @@ from letta.schemas.job import Job, JobUpdate from letta.schemas.passage import Passage from letta.schemas.user import User from letta.server.server import SyncServer +from letta.services.file_processor.chunker.line_chunker import LineChunker from letta.services.file_processor.chunker.llama_index_chunker import LlamaIndexChunker from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder from letta.services.file_processor.parser.mistral_parser import MistralFileParser @@ -34,6 +35,7 @@ class FileProcessor: ): self.file_parser = file_parser self.text_chunker = text_chunker + self.line_chunker = LineChunker() self.embedder = embedder self.max_file_size = max_file_size self.source_manager = SourceManager() @@ -90,9 +92,13 @@ class FileProcessor: logger.info(f"Successfully processed {filename}: {len(all_passages)} passages") + # TODO: Rethink this line chunking mechanism + content_lines = self.line_chunker.chunk_text(text=raw_markdown_text) + visible_content = "\n".join(content_lines) + await server.insert_file_into_context_windows( source_id=source_id, - text=raw_markdown_text, + text=visible_content, file_id=file_metadata.id, file_name=file_metadata.file_name, actor=self.actor, diff --git a/letta/services/files_agents_manager.py b/letta/services/files_agents_manager.py index 5a03f108..cb174864 100644 --- a/letta/services/files_agents_manager.py +++ b/letta/services/files_agents_manager.py @@ -5,6 +5,7 @@ from sqlalchemy import and_, func, select, update from letta.orm.errors import NoResultFound from letta.orm.files_agents import FileAgent as FileAgentModel +from letta.schemas.block import Block as PydanticBlock from letta.schemas.file import FileAgent as PydanticFileAgent from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry @@ -142,6 +143,42 @@ class FileAgentManager: except NoResultFound: return None + @enforce_types + @trace_method + async def get_all_file_blocks_by_name( + self, + *, + file_names: List[str], + actor: PydanticUser, + ) -> List[PydanticBlock]: + """ + Retrieve multiple FileAgent associations by their IDs in a single query. + + Args: + file_names: List of file names to retrieve + actor: The user making the request + + Returns: + List of PydanticFileAgent objects found (may be fewer than requested if some IDs don't exist) + """ + if not file_names: + return [] + + async with db_registry.async_session() as session: + # Use IN clause for efficient bulk retrieval + query = select(FileAgentModel).where( + and_( + FileAgentModel.file_name.in_(file_names), + FileAgentModel.organization_id == actor.organization_id, + ) + ) + + # Execute query and get all results + rows = (await session.execute(query)).scalars().all() + + # Convert to Pydantic models + return [row.to_pydantic_block() for row in rows] + @enforce_types @trace_method async def get_file_agent_by_file_name(self, *, agent_id: str, file_name: str, actor: PydanticUser) -> Optional[PydanticFileAgent]: diff --git a/letta/services/tool_executor/files_tool_executor.py b/letta/services/tool_executor/files_tool_executor.py index b88a496e..15d8f00a 100644 --- a/letta/services/tool_executor/files_tool_executor.py +++ b/letta/services/tool_executor/files_tool_executor.py @@ -7,6 +7,7 @@ from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.user import User from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager +from letta.services.file_processor.chunker.line_chunker import LineChunker from letta.services.files_agents_manager import FileAgentManager from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager @@ -101,17 +102,15 @@ class LettaFileToolExecutor(ToolExecutor): file = await self.source_manager.get_file_by_id(file_id=file_id, actor=self.actor, include_content=True) # TODO: Inefficient, maybe we can pre-compute this - content_lines = [ - line.strip() for line in file.content.split("\n") if line.strip() # remove leading/trailing whitespace # skip empty lines - ] - - if start and end: - content_lines = content_lines[start:end] - + # TODO: This is also not the best way to split things - would be cool to have "content aware" splitting + # TODO: Split code differently from large text blurbs + content_lines = LineChunker().chunk_text(text=file.content, start=start, end=end) visible_content = "\n".join(content_lines) + await self.files_agents_manager.update_file_agent_by_id( agent_id=agent_state.id, file_id=file_id, actor=self.actor, is_open=True, visible_content=visible_content ) + return "Success" async def close_file(self, agent_state: AgentState, file_name: str) -> str: diff --git a/tests/data/lines_1_to_100.txt b/tests/data/lines_1_to_100.txt new file mode 100644 index 00000000..b9ed43de --- /dev/null +++ b/tests/data/lines_1_to_100.txt @@ -0,0 +1,100 @@ +Line 1 +Line 2 +Line 3 +Line 4 +Line 5 +Line 6 +Line 7 +Line 8 +Line 9 +Line 10 +Line 11 +Line 12 +Line 13 +Line 14 +Line 15 +Line 16 +Line 17 +Line 18 +Line 19 +Line 20 +Line 21 +Line 22 +Line 23 +Line 24 +Line 25 +Line 26 +Line 27 +Line 28 +Line 29 +Line 30 +Line 31 +Line 32 +Line 33 +Line 34 +Line 35 +Line 36 +Line 37 +Line 38 +Line 39 +Line 40 +Line 41 +Line 42 +Line 43 +Line 44 +Line 45 +Line 46 +Line 47 +Line 48 +Line 49 +Line 50 +Line 51 +Line 52 +Line 53 +Line 54 +Line 55 +Line 56 +Line 57 +Line 58 +Line 59 +Line 60 +Line 61 +Line 62 +Line 63 +Line 64 +Line 65 +Line 66 +Line 67 +Line 68 +Line 69 +Line 70 +Line 71 +Line 72 +Line 73 +Line 74 +Line 75 +Line 76 +Line 77 +Line 78 +Line 79 +Line 80 +Line 81 +Line 82 +Line 83 +Line 84 +Line 85 +Line 86 +Line 87 +Line 88 +Line 89 +Line 90 +Line 91 +Line 92 +Line 93 +Line 94 +Line 95 +Line 96 +Line 97 +Line 98 +Line 99 +Line 100 \ No newline at end of file diff --git a/tests/test_sources.py b/tests/test_sources.py index 3cd5e671..5fdd69d1 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -12,7 +12,6 @@ from letta_client.types import AgentState from letta.constants import FILES_TOOLS from letta.orm.enums import ToolType from letta.schemas.message import MessageCreate -from tests.helpers.utils import retry_until_success from tests.utils import wait_for_server # Constants @@ -269,50 +268,36 @@ def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, a assert not any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks) -@retry_until_success(max_attempts=5, sleep_time_seconds=2) def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_state: AgentState): # Create a new source - print("Creating new source...") source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") - print(f"Created source with ID: {source.id}") sources_list = client.sources.list() assert len(sources_list) == 1 - print(f"✓ Verified source creation - found {len(sources_list)} source(s)") # Attach source to agent - print(f"Attaching source {source.id} to agent {agent_state.id}...") client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id) - print("✓ Source attached to agent") # Load files into the source file_path = "tests/data/long_test.txt" - print(f"Uploading file: {file_path}") # Upload the files with open(file_path, "rb") as f: job = client.sources.files.upload(source_id=source.id, file=f) - print(f"File upload job created with ID: {job.id}, initial status: {job.status}") - # Wait for the jobs to complete while job.status != "completed": print(f"Waiting for job {job.id} to complete... Current status: {job.status}") time.sleep(1) job = client.jobs.retrieve(job_id=job.id) - print(f"✓ Job completed successfully with status: {job.status}") - # Get uploaded files - print("Retrieving uploaded files...") files = client.sources.files.list(source_id=source.id, limit=1) assert len(files) == 1 assert files[0].source_id == source.id file = files[0] - print(f"✓ Found uploaded file: {file.file_name} (ID: {file.id})") # Check that file is opened initially - print("Checking initial agent state...") agent_state = client.agents.retrieve(agent_id=agent_state.id) blocks = agent_state.memory.file_blocks print(f"Agent has {len(blocks)} file block(s)") @@ -321,7 +306,6 @@ def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_stat print(f"Initial file content length: {initial_content_length} characters") print(f"First 100 chars of content: {blocks[0].value[:100]}...") assert initial_content_length > 10, f"Expected file content > 10 chars, got {initial_content_length}" - print("✓ File appears to be initially loaded") # Ask agent to close the file print(f"Requesting agent to close file: {file.file_name}") @@ -333,13 +317,11 @@ def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_stat print(close_response.messages) # Check that file is closed - print("Verifying file is closed...") agent_state = client.agents.retrieve(agent_id=agent_state.id) blocks = agent_state.memory.file_blocks closed_content_length = len(blocks[0].value) if blocks else 0 print(f"File content length after close: {closed_content_length} characters") assert closed_content_length == 0, f"Expected empty content after close, got {closed_content_length} chars" - print("✓ File successfully closed") # Ask agent to open the file for a specific range start, end = 0, 5 @@ -356,7 +338,6 @@ def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_stat print(open_response1.messages) # Check that file is opened - print("Verifying file is opened with first range...") agent_state = client.agents.retrieve(agent_id=agent_state.id) blocks = agent_state.memory.file_blocks old_value = blocks[0].value @@ -364,11 +345,9 @@ def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_stat print(f"File content length after first open: {old_content_length} characters") print(f"First range content: '{old_value}'") assert old_content_length > 10, f"Expected content > 10 chars for range [{start}, {end}], got {old_content_length}" - print("✓ File successfully opened with first range") # Ask agent to open the file for a different range start, end = 5, 10 - print(f"Requesting agent to open file for different range [{start}, {end}]") open_response2 = client.agents.messages.create( agent_id=agent_state.id, messages=[ @@ -398,21 +377,15 @@ def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_stat print("✓ File successfully opened with different range - content differs as expected") -@retry_until_success(max_attempts=5, sleep_time_seconds=2) def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state: AgentState): # Create a new source - print("Creating new source...") source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") - print(f"Created source with ID: {source.id}") sources_list = client.sources.list() assert len(sources_list) == 1 - print(f"✓ Verified source creation - found {len(sources_list)} source(s)") # Attach source to agent - print(f"Attaching source {source.id} to agent {agent_state.id}...") client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id) - print("✓ Source attached to agent") # Load files into the source file_path = "tests/data/long_test.txt" @@ -430,18 +403,13 @@ def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state: time.sleep(1) job = client.jobs.retrieve(job_id=job.id) - print(f"✓ Job completed successfully with status: {job.status}") - # Get uploaded files - print("Retrieving uploaded files...") files = client.sources.files.list(source_id=source.id, limit=1) assert len(files) == 1 assert files[0].source_id == source.id - file = files[0] - print(f"✓ Found uploaded file: {file.file_name} (ID: {file.id})") + files[0] # Check that file is opened initially - print("Checking initial agent state...") agent_state = client.agents.retrieve(agent_id=agent_state.id) blocks = agent_state.memory.file_blocks print(f"Agent has {len(blocks)} file block(s)") @@ -453,7 +421,6 @@ def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state: print("✓ File appears to be initially loaded") # Ask agent to use the search_files tool - print(f"Requesting agent to search_files") search_files_response = client.agents.messages.create( agent_id=agent_state.id, messages=[ @@ -472,3 +439,72 @@ def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state: tool_returns = [msg for msg in search_files_response.messages if msg.message_type == "tool_return_message"] assert len(tool_returns) > 0, "No tool returns found" assert all(tr.status == "success" for tr in tool_returns), "Tool call failed" + + +def test_view_ranges_have_metadata(client: LettaSDKClient, agent_state: AgentState): + # Create a new source + source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") + + sources_list = client.sources.list() + assert len(sources_list) == 1 + + # Attach source to agent + client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id) + + # Load files into the source + file_path = "tests/data/lines_1_to_100.txt" + + # Upload the files + with open(file_path, "rb") as f: + job = client.sources.files.upload(source_id=source.id, file=f) + + # Wait for the jobs to complete + while job.status != "completed": + print(f"Waiting for job {job.id} to complete... Current status: {job.status}") + time.sleep(1) + job = client.jobs.retrieve(job_id=job.id) + + # Get uploaded files + files = client.sources.files.list(source_id=source.id, limit=1) + assert len(files) == 1 + assert files[0].source_id == source.id + file = files[0] + + # Check that file is opened initially + agent_state = client.agents.retrieve(agent_id=agent_state.id) + blocks = agent_state.memory.file_blocks + assert len(blocks) == 1 + block = blocks[0] + assert block.value.startswith("[Viewing file start (out of 100 lines)]") + + # Open a specific range + start = 50 + end = 55 + open_response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[ + MessageCreate( + role="user", content=f"Use ONLY the open_file tool to open the file named {file.file_name} for view range [{start}, {end}]" + ) + ], + ) + print(f"Open request sent, got {len(open_response.messages)} message(s) in response") + print(open_response.messages) + + # Check that file is opened correctly + agent_state = client.agents.retrieve(agent_id=agent_state.id) + blocks = agent_state.memory.file_blocks + assert len(blocks) == 1 + block = blocks[0] + print(block.value) + assert ( + block.value + == """ + [Viewing lines 50 to 55 (out of 100 lines)] +Line 50: Line 51 +Line 51: Line 52 +Line 52: Line 53 +Line 53: Line 54 +Line 54: Line 55 + """.strip() + )