feat: Insert file blocks for agent on source attach (#2545)

This commit is contained in:
Matthew Zhou
2025-05-30 11:09:59 -07:00
committed by GitHub
parent 7f82ce9adf
commit 1058882ab8
5 changed files with 147 additions and 50 deletions

View File

@@ -12,7 +12,7 @@ from sqlalchemy.exc import IntegrityError, OperationalError
from starlette.responses import Response, StreamingResponse
from letta.agents.letta_agent import LettaAgent
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.constants import CORE_MEMORY_SOURCE_CHAR_LIMIT, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
from letta.log import get_logger
@@ -36,6 +36,7 @@ from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
from letta.services.telemetry_manager import NoopTelemetryManager
from letta.settings import settings
from letta.utils import safe_create_task
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
@@ -301,7 +302,6 @@ async def detach_tool(
async def attach_source(
agent_id: str,
source_id: str,
background_tasks: BackgroundTasks,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
@@ -309,11 +309,30 @@ async def attach_source(
Attach a source to an agent.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
agent = await server.agent_manager.attach_source_async(agent_id=agent_id, source_id=source_id, actor=actor)
if agent.enable_sleeptime:
agent_state = await server.agent_manager.attach_source_async(agent_id=agent_id, source_id=source_id, actor=actor)
files = await server.source_manager.list_files(source_id, actor)
texts = []
filenames = []
for f in files:
passages = await server.passage_manager.list_passages_by_file_id_async(file_id=f.id, actor=actor)
passage_text = ""
for p in passages:
if len(passage_text) <= CORE_MEMORY_SOURCE_CHAR_LIMIT:
passage_text += p.text
texts.append(passage_text)
filenames.append(f.file_name)
await server.insert_documents_into_context_window(agent_state=agent_state, texts=texts, filenames=filenames, actor=actor)
if agent_state.enable_sleeptime:
source = await server.source_manager.get_source_by_id(source_id=source_id)
background_tasks.add_task(server.sleeptime_document_ingest_async, agent, source, actor)
return agent
safe_create_task(
server.sleeptime_document_ingest_async(agent_state, source, actor), logger=logger, label="sleeptime_document_ingest_async"
)
return agent_state
@router.patch("/{agent_id}/sources/detach/{source_id}", response_model=AgentState, operation_id="detach_source_from_agent")

View File

@@ -168,6 +168,11 @@ async def upload_file_to_source(
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Source with id={source_id} not found.")
bytes = file.file.read()
try:
text = bytes.decode("utf-8")
except Exception:
text = ""
# create job
job = Job(
user_id=actor.id,
@@ -180,17 +185,15 @@ async def upload_file_to_source(
# sanitize filename
sanitized_filename = sanitize_filename(file.filename)
# Add blocks
await server.insert_document_into_context_windows(source_id=source_id, text=text, filename=sanitized_filename, actor=actor)
# create background tasks
safe_create_task(
load_file_to_source_async(server, source_id=source.id, filename=sanitized_filename, job_id=job.id, bytes=bytes, actor=actor),
logger=logger,
label="load_file_to_source_async",
)
safe_create_task(
insert_document_into_context_window_async(server, filename=sanitized_filename, source_id=source_id, actor=actor, bytes=bytes),
logger=logger,
label="insert_document_into_context_window_async",
)
safe_create_task(sleeptime_document_ingest_async(server, source_id, actor), logger=logger, label="sleeptime_document_ingest_async")
job = await server.job_manager.get_job_by_id_async(job_id=job_id, actor=actor)
@@ -278,8 +281,3 @@ async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, ac
for agent in agents:
if agent.enable_sleeptime:
await server.sleeptime_document_ingest_async(agent, source, actor, clear_history)
async def insert_document_into_context_window_async(server: SyncServer, filename: str, source_id: str, actor: User, bytes: bytes):
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
await server.insert_document_into_context_window(source, bytes=bytes, filename=filename, actor=actor)

View File

@@ -1363,46 +1363,71 @@ class SyncServer(Server):
)
await self.agent_manager.delete_agent_async(agent_id=sleeptime_agent_state.id, actor=actor)
async def insert_document_into_context_window(self, source: Source, bytes: bytes, filename: str, actor: User) -> None:
async def _upsert_document_block(self, agent_id: str, text: str, filename: str, actor: User) -> None:
"""
Internal method to create or update a document block for an agent.
This is the shared logic between single and multiple document insertion.
"""
truncated_text = text[:CORE_MEMORY_SOURCE_CHAR_LIMIT]
try:
block = await self.agent_manager.get_block_with_label_async(
agent_id=agent_id,
block_label=filename,
actor=actor,
)
await self.block_manager.update_block_async(
block_id=block.id,
block_update=BlockUpdate(value=truncated_text),
actor=actor,
)
except NoResultFound:
block = await self.block_manager.create_or_update_block_async(
block=Block(
value=truncated_text,
label=filename,
description=f"Contains the parsed contents of external file {filename}",
limit=CORE_MEMORY_SOURCE_CHAR_LIMIT,
),
actor=actor,
)
await self.agent_manager.attach_block_async(
agent_id=agent_id,
block_id=block.id,
actor=actor,
)
async def insert_document_into_context_windows(self, source_id: str, text: str, filename: str, actor: User) -> None:
"""
Insert the uploaded document into the context window of all agents
attached to the given source.
"""
agent_states = await self.source_manager.list_attached_agents(source_id=source.id, actor=actor)
logger.info(f"Inserting document into context window for source: {source}")
agent_states = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
# Return early
if not agent_states:
return
logger.info(f"Inserting document into context window for source: {source_id}")
logger.info(f"Attached agents: {[a.id for a in agent_states]}")
passages = bytes.decode("utf-8")[:CORE_MEMORY_SOURCE_CHAR_LIMIT]
await asyncio.gather(*(self._upsert_document_block(agent_state.id, text, filename, actor) for agent_state in agent_states))
async def process_agent(agent_state):
try:
block = await self.agent_manager.get_block_with_label_async(
agent_id=agent_state.id,
block_label=filename,
actor=actor,
)
await self.block_manager.update_block_async(
block_id=block.id,
block_update=BlockUpdate(value=passages),
actor=actor,
)
except NoResultFound:
block = await self.block_manager.create_or_update_block_async(
block=Block(
value=passages,
label=filename,
description="Contains recursive summarizations of the conversation so far",
limit=CORE_MEMORY_SOURCE_CHAR_LIMIT,
),
actor=actor,
)
await self.agent_manager.attach_block_async(
agent_id=agent_state.id,
block_id=block.id,
actor=actor,
)
async def insert_documents_into_context_window(
self, agent_state: AgentState, texts: List[str], filenames: List[str], actor: User
) -> None:
"""
Insert the uploaded documents into the context window of an agent
attached to the given source.
"""
logger.info(f"Inserting documents into context window for agent_state: {agent_state.id}")
await asyncio.gather(*(process_agent(agent) for agent in agent_states))
if len(texts) != len(filenames):
raise ValueError(f"Mismatch between number of texts ({len(texts)}) and filenames ({len(filenames)})")
await asyncio.gather(
*(self._upsert_document_block(agent_state.id, text, filename, actor) for text, filename in zip(texts, filenames))
)
async def create_document_sleeptime_agent_async(
self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False

View File

@@ -5,6 +5,7 @@ from typing import List, Optional
from async_lru import alru_cache
from openai import AsyncOpenAI, OpenAI
from sqlalchemy import select
from letta.constants import MAX_EMBEDDING_DIM
from letta.embeddings import embedding_model, parse_and_chunk_text
@@ -448,3 +449,16 @@ class PassageManager:
BYTES_PER_EMBEDDING_DIM = 4
GB_PER_EMBEDDING = BYTES_PER_EMBEDDING_DIM / BYTES_PER_STORAGE_UNIT[storage_unit] * MAX_EMBEDDING_DIM
return await self.size_async(actor=actor, agent_id=agent_id) * GB_PER_EMBEDDING
@enforce_types
@trace_method
async def list_passages_by_file_id_async(self, file_id: str, actor: PydanticUser) -> List[PydanticPassage]:
"""
List all source passages associated with a given file_id.
"""
async with db_registry.async_session() as session:
result = await session.execute(
select(SourcePassage).where(SourcePassage.file_id == file_id).where(SourcePassage.organization_id == actor.organization_id)
)
passages = result.scalars().all()
return [p.to_pydantic() for p in passages]

View File

@@ -38,7 +38,7 @@ def client() -> LettaSDKClient:
yield client
@pytest.fixture(scope="module")
@pytest.fixture
def agent_state(client: LettaSDKClient):
agent_state = client.agents.create(
memory_blocks=[
@@ -90,7 +90,48 @@ def test_file_upload_creates_source_blocks_correctly(client: LettaSDKClient, age
assert len(files) == 1
assert files[0].source_id == source.id
# Get the agent state
# Get the agent state, check blocks exist
blocks = client.agents.blocks.list(agent_id=agent_state.id)
assert len(blocks) == 2
assert "test" in [b.value for b in blocks]
assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState):
# Clear existing sources
for source in client.sources.list():
client.sources.delete(source_id=source.id)
# Clear existing jobs
for job in client.jobs.list():
client.jobs.delete(job_id=job.id)
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002")
assert len(client.sources.list()) == 1
# Load files into the source
file_path = "tests/data/test.txt"
# Upload the files
with open(file_path, "rb") as f:
job = client.sources.files.upload(source_id=source.id, file=f)
# Wait for the jobs to complete
while job.status != "completed":
time.sleep(1)
job = client.jobs.retrieve(job_id=job.id)
print("Waiting for jobs to complete...", job.status)
# Get the first file with pagination
files = client.sources.files.list(source_id=source.id, limit=1)
assert len(files) == 1
assert files[0].source_id == source.id
# Attach after uploading the file
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
# Get the agent state, check blocks exist
blocks = client.agents.blocks.list(agent_id=agent_state.id)
assert len(blocks) == 2
assert "test" in [b.value for b in blocks]