From 1058882ab87e8952f9c6f4265ca8f9ab03164082 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Fri, 30 May 2025 11:09:59 -0700 Subject: [PATCH] feat: Insert file blocks for agent on source attach (#2545) --- letta/server/rest_api/routers/v1/agents.py | 31 +++++-- letta/server/rest_api/routers/v1/sources.py | 18 ++--- letta/server/server.py | 89 +++++++++++++-------- letta/services/passage_manager.py | 14 ++++ tests/test_sources.py | 45 ++++++++++- 5 files changed, 147 insertions(+), 50 deletions(-) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 1a4be325..38aafef5 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -12,7 +12,7 @@ from sqlalchemy.exc import IntegrityError, OperationalError from starlette.responses import Response, StreamingResponse from letta.agents.letta_agent import LettaAgent -from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG +from letta.constants import CORE_MEMORY_SOURCE_CHAR_LIMIT, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2 from letta.helpers.datetime_helpers import get_utc_timestamp_ns from letta.log import get_logger @@ -36,6 +36,7 @@ from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer from letta.services.telemetry_manager import NoopTelemetryManager from letta.settings import settings +from letta.utils import safe_create_task # These can be forward refs, but because Fastapi needs them at runtime the must be imported normally @@ -301,7 +302,6 @@ async def detach_tool( async def attach_source( agent_id: str, source_id: str, - background_tasks: BackgroundTasks, server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), ): @@ -309,11 +309,30 @@ async def attach_source( Attach a source to an agent. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - agent = await server.agent_manager.attach_source_async(agent_id=agent_id, source_id=source_id, actor=actor) - if agent.enable_sleeptime: + agent_state = await server.agent_manager.attach_source_async(agent_id=agent_id, source_id=source_id, actor=actor) + + files = await server.source_manager.list_files(source_id, actor) + texts = [] + filenames = [] + for f in files: + passages = await server.passage_manager.list_passages_by_file_id_async(file_id=f.id, actor=actor) + passage_text = "" + for p in passages: + if len(passage_text) <= CORE_MEMORY_SOURCE_CHAR_LIMIT: + passage_text += p.text + + texts.append(passage_text) + filenames.append(f.file_name) + + await server.insert_documents_into_context_window(agent_state=agent_state, texts=texts, filenames=filenames, actor=actor) + + if agent_state.enable_sleeptime: source = await server.source_manager.get_source_by_id(source_id=source_id) - background_tasks.add_task(server.sleeptime_document_ingest_async, agent, source, actor) - return agent + safe_create_task( + server.sleeptime_document_ingest_async(agent_state, source, actor), logger=logger, label="sleeptime_document_ingest_async" + ) + + return agent_state @router.patch("/{agent_id}/sources/detach/{source_id}", response_model=AgentState, operation_id="detach_source_from_agent") diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index e74c77d9..9f90388d 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -168,6 +168,11 @@ async def upload_file_to_source( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Source with id={source_id} not found.") bytes = file.file.read() + try: + text = bytes.decode("utf-8") + except Exception: + text = "" + # create job job = Job( user_id=actor.id, @@ -180,17 +185,15 @@ async def upload_file_to_source( # sanitize filename sanitized_filename = sanitize_filename(file.filename) + # Add blocks + await server.insert_document_into_context_windows(source_id=source_id, text=text, filename=sanitized_filename, actor=actor) + # create background tasks safe_create_task( load_file_to_source_async(server, source_id=source.id, filename=sanitized_filename, job_id=job.id, bytes=bytes, actor=actor), logger=logger, label="load_file_to_source_async", ) - safe_create_task( - insert_document_into_context_window_async(server, filename=sanitized_filename, source_id=source_id, actor=actor, bytes=bytes), - logger=logger, - label="insert_document_into_context_window_async", - ) safe_create_task(sleeptime_document_ingest_async(server, source_id, actor), logger=logger, label="sleeptime_document_ingest_async") job = await server.job_manager.get_job_by_id_async(job_id=job_id, actor=actor) @@ -278,8 +281,3 @@ async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, ac for agent in agents: if agent.enable_sleeptime: await server.sleeptime_document_ingest_async(agent, source, actor, clear_history) - - -async def insert_document_into_context_window_async(server: SyncServer, filename: str, source_id: str, actor: User, bytes: bytes): - source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor) - await server.insert_document_into_context_window(source, bytes=bytes, filename=filename, actor=actor) diff --git a/letta/server/server.py b/letta/server/server.py index 56a2ada9..520d2e8f 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1363,46 +1363,71 @@ class SyncServer(Server): ) await self.agent_manager.delete_agent_async(agent_id=sleeptime_agent_state.id, actor=actor) - async def insert_document_into_context_window(self, source: Source, bytes: bytes, filename: str, actor: User) -> None: + async def _upsert_document_block(self, agent_id: str, text: str, filename: str, actor: User) -> None: + """ + Internal method to create or update a document block for an agent. + This is the shared logic between single and multiple document insertion. + """ + truncated_text = text[:CORE_MEMORY_SOURCE_CHAR_LIMIT] + + try: + block = await self.agent_manager.get_block_with_label_async( + agent_id=agent_id, + block_label=filename, + actor=actor, + ) + await self.block_manager.update_block_async( + block_id=block.id, + block_update=BlockUpdate(value=truncated_text), + actor=actor, + ) + except NoResultFound: + block = await self.block_manager.create_or_update_block_async( + block=Block( + value=truncated_text, + label=filename, + description=f"Contains the parsed contents of external file {filename}", + limit=CORE_MEMORY_SOURCE_CHAR_LIMIT, + ), + actor=actor, + ) + await self.agent_manager.attach_block_async( + agent_id=agent_id, + block_id=block.id, + actor=actor, + ) + + async def insert_document_into_context_windows(self, source_id: str, text: str, filename: str, actor: User) -> None: """ Insert the uploaded document into the context window of all agents attached to the given source. """ - agent_states = await self.source_manager.list_attached_agents(source_id=source.id, actor=actor) - logger.info(f"Inserting document into context window for source: {source}") + agent_states = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor) + + # Return early + if not agent_states: + return + + logger.info(f"Inserting document into context window for source: {source_id}") logger.info(f"Attached agents: {[a.id for a in agent_states]}") - passages = bytes.decode("utf-8")[:CORE_MEMORY_SOURCE_CHAR_LIMIT] + await asyncio.gather(*(self._upsert_document_block(agent_state.id, text, filename, actor) for agent_state in agent_states)) - async def process_agent(agent_state): - try: - block = await self.agent_manager.get_block_with_label_async( - agent_id=agent_state.id, - block_label=filename, - actor=actor, - ) - await self.block_manager.update_block_async( - block_id=block.id, - block_update=BlockUpdate(value=passages), - actor=actor, - ) - except NoResultFound: - block = await self.block_manager.create_or_update_block_async( - block=Block( - value=passages, - label=filename, - description="Contains recursive summarizations of the conversation so far", - limit=CORE_MEMORY_SOURCE_CHAR_LIMIT, - ), - actor=actor, - ) - await self.agent_manager.attach_block_async( - agent_id=agent_state.id, - block_id=block.id, - actor=actor, - ) + async def insert_documents_into_context_window( + self, agent_state: AgentState, texts: List[str], filenames: List[str], actor: User + ) -> None: + """ + Insert the uploaded documents into the context window of an agent + attached to the given source. + """ + logger.info(f"Inserting documents into context window for agent_state: {agent_state.id}") - await asyncio.gather(*(process_agent(agent) for agent in agent_states)) + if len(texts) != len(filenames): + raise ValueError(f"Mismatch between number of texts ({len(texts)}) and filenames ({len(filenames)})") + + await asyncio.gather( + *(self._upsert_document_block(agent_state.id, text, filename, actor) for text, filename in zip(texts, filenames)) + ) async def create_document_sleeptime_agent_async( self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index 4b233ddb..99f7d6c4 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -5,6 +5,7 @@ from typing import List, Optional from async_lru import alru_cache from openai import AsyncOpenAI, OpenAI +from sqlalchemy import select from letta.constants import MAX_EMBEDDING_DIM from letta.embeddings import embedding_model, parse_and_chunk_text @@ -448,3 +449,16 @@ class PassageManager: BYTES_PER_EMBEDDING_DIM = 4 GB_PER_EMBEDDING = BYTES_PER_EMBEDDING_DIM / BYTES_PER_STORAGE_UNIT[storage_unit] * MAX_EMBEDDING_DIM return await self.size_async(actor=actor, agent_id=agent_id) * GB_PER_EMBEDDING + + @enforce_types + @trace_method + async def list_passages_by_file_id_async(self, file_id: str, actor: PydanticUser) -> List[PydanticPassage]: + """ + List all source passages associated with a given file_id. + """ + async with db_registry.async_session() as session: + result = await session.execute( + select(SourcePassage).where(SourcePassage.file_id == file_id).where(SourcePassage.organization_id == actor.organization_id) + ) + passages = result.scalars().all() + return [p.to_pydantic() for p in passages] diff --git a/tests/test_sources.py b/tests/test_sources.py index 156b4492..4e9d4e96 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -38,7 +38,7 @@ def client() -> LettaSDKClient: yield client -@pytest.fixture(scope="module") +@pytest.fixture def agent_state(client: LettaSDKClient): agent_state = client.agents.create( memory_blocks=[ @@ -90,7 +90,48 @@ def test_file_upload_creates_source_blocks_correctly(client: LettaSDKClient, age assert len(files) == 1 assert files[0].source_id == source.id - # Get the agent state + # Get the agent state, check blocks exist + blocks = client.agents.blocks.list(agent_id=agent_state.id) + assert len(blocks) == 2 + assert "test" in [b.value for b in blocks] + assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks) + + +def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState): + # Clear existing sources + for source in client.sources.list(): + client.sources.delete(source_id=source.id) + + # Clear existing jobs + for job in client.jobs.list(): + client.jobs.delete(job_id=job.id) + + # Create a new source + source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") + assert len(client.sources.list()) == 1 + + # Load files into the source + file_path = "tests/data/test.txt" + + # Upload the files + with open(file_path, "rb") as f: + job = client.sources.files.upload(source_id=source.id, file=f) + + # Wait for the jobs to complete + while job.status != "completed": + time.sleep(1) + job = client.jobs.retrieve(job_id=job.id) + print("Waiting for jobs to complete...", job.status) + + # Get the first file with pagination + files = client.sources.files.list(source_id=source.id, limit=1) + assert len(files) == 1 + assert files[0].source_id == source.id + + # Attach after uploading the file + client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id) + + # Get the agent state, check blocks exist blocks = client.agents.blocks.list(agent_id=agent_state.id) assert len(blocks) == 2 assert "test" in [b.value for b in blocks]