feat: Adapt lifecycle management of file <-> agent association (#2591)

This commit is contained in:
Matthew Zhou
2025-06-02 17:34:12 -07:00
committed by GitHub
parent 6821765687
commit 1c14d25fe6
7 changed files with 102 additions and 110 deletions

View File

@@ -0,0 +1,29 @@
"""Add unique constraint to file_id and agent_id on file_agent
Revision ID: 614c4e53b66e
Revises: 0b496eae90de
Create Date: 2025-06-02 17:03:58.879839
"""
from typing import Sequence, Union
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "614c4e53b66e"
down_revision: Union[str, None] = "0b496eae90de"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_unique_constraint("uq_files_agents_file_agent", "files_agents", ["file_id", "agent_id"])
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("uq_files_agents_file_agent", "files_agents", type_="unique")
# ### end Alembic commands ###

View File

@@ -2,7 +2,7 @@ import uuid
from datetime import datetime
from typing import TYPE_CHECKING, Optional
from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, Text, func
from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, Text, UniqueConstraint, func
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm.mixins import OrganizationMixin
@@ -22,7 +22,10 @@ class FileAgent(SqlalchemyBase, OrganizationMixin):
"""
__tablename__ = "files_agents"
__table_args__ = (Index("ix_files_agents_file_id_agent_id", "file_id", "agent_id"),)
__table_args__ = (
Index("ix_files_agents_file_id_agent_id", "file_id", "agent_id"),
UniqueConstraint("file_id", "agent_id", name="uq_files_agents_file_agent"),
)
__pydantic_model__ = PydanticFileAgent
# TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase

View File

@@ -313,7 +313,7 @@ async def attach_source(
files = await server.source_manager.list_files(source_id, actor)
texts = []
filenames = []
file_ids = []
for f in files:
passages = await server.passage_manager.list_passages_by_file_id_async(file_id=f.id, actor=actor)
passage_text = ""
@@ -322,9 +322,9 @@ async def attach_source(
passage_text += p.text
texts.append(passage_text)
filenames.append(f.file_name)
file_ids.append(f.id)
await server.insert_documents_into_context_window(agent_state=agent_state, texts=texts, filenames=filenames, actor=actor)
await server.insert_files_into_context_window(agent_state=agent_state, texts=texts, file_ids=file_ids, actor=actor)
if agent_state.enable_sleeptime:
source = await server.source_manager.get_source_by_id(source_id=source_id)
@@ -348,8 +348,8 @@ async def detach_source(
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
agent_state = await server.agent_manager.detach_source_async(agent_id=agent_id, source_id=source_id, actor=actor)
files = await server.source_manager.list_files(source_id, actor)
filenames = [f.file_name for f in files]
await server.remove_documents_from_context_window(agent_state=agent_state, filenames=filenames, actor=actor)
file_ids = [f.id for f in files]
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
if agent_state.enable_sleeptime:
try:

View File

@@ -150,10 +150,10 @@ async def delete_source(
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
agent_states = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
files = await server.source_manager.list_files(source_id, actor)
filenames = [f.file_name for f in files]
file_ids = [f.id for f in files]
for agent_state in agent_states:
await server.remove_documents_from_context_window(agent_state=agent_state, filenames=filenames, actor=actor)
await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
if agent_state.enable_sleeptime:
try:
@@ -212,11 +212,6 @@ async def upload_file_to_source(
# sanitize filename
file.filename = sanitize_filename(file.filename)
try:
text = content.decode("utf-8")
except Exception:
text = "[Currently parsing...]"
# create job
job = Job(
user_id=actor.id,
@@ -225,8 +220,8 @@ async def upload_file_to_source(
)
job = await server.job_manager.create_job_async(job, actor=actor)
# Add blocks (sometimes without content, for UX purposes)
agent_states = await server.insert_document_into_context_windows(source_id=source_id, text=text, filename=file.filename, actor=actor)
# TODO: Do we need to pull in the full agent_states? Can probably simplify here right?
agent_states = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
# NEW: Cloud based file processing
if settings.mistral_api_key and model_settings.openai_api_key:
@@ -301,8 +296,7 @@ async def delete_file_from_source(
deleted_file = await server.source_manager.delete_file(file_id=file_id, actor=actor)
# Remove blocks
await server.remove_document_from_context_windows(source_id=source_id, filename=deleted_file.file_name, actor=actor)
await server.remove_file_from_context_windows(source_id=source_id, file_id=deleted_file.id, actor=actor)
asyncio.create_task(sleeptime_document_ingest_async(server, source_id, actor, clear_history=True))
if deleted_file is None:

View File

@@ -1368,55 +1368,28 @@ class SyncServer(Server):
)
await self.agent_manager.delete_agent_async(agent_id=sleeptime_agent_state.id, actor=actor)
async def _upsert_document_block(self, agent_id: str, text: str, filename: str, actor: User) -> None:
async def _upsert_file_to_agent(self, agent_id: str, text: str, file_id: str, actor: User) -> None:
"""
Internal method to create or update a document block for an agent.
Internal method to create or update a file <-> agent association
"""
truncated_text = text[:CORE_MEMORY_SOURCE_CHAR_LIMIT]
await self.file_agent_manager.attach_file(agent_id=agent_id, file_id=file_id, actor=actor, visible_content=truncated_text)
try:
block = await self.agent_manager.get_block_with_label_async(
agent_id=agent_id,
block_label=filename,
actor=actor,
)
await self.block_manager.update_block_async(
block_id=block.id,
block_update=BlockUpdate(value=truncated_text),
actor=actor,
)
except NoResultFound:
block = await self.block_manager.create_or_update_block_async(
block=Block(
value=truncated_text,
label=filename,
description=f"Contains the parsed contents of external file {filename}",
limit=CORE_MEMORY_SOURCE_CHAR_LIMIT,
),
actor=actor,
)
await self.agent_manager.attach_block_async(
agent_id=agent_id,
block_id=block.id,
actor=actor,
)
async def _remove_document_block(self, agent_id: str, filename: str, actor: User) -> None:
async def _remove_file_from_agent(self, agent_id: str, file_id: str, actor: User) -> None:
"""
Internal method to remove a document block for an agent.
"""
try:
block = await self.agent_manager.get_block_with_label_async(
await self.file_agent_manager.detach_file(
agent_id=agent_id,
block_label=filename,
file_id=file_id,
actor=actor,
)
await self.block_manager.delete_block_async(block_id=block.id, actor=actor)
except NoResultFound:
logger.info(f"Document block with label {filename} already removed, skipping...")
logger.info(f"File {file_id} already removed from agent {agent_id}, skipping...")
async def insert_document_into_context_windows(
self, source_id: str, text: str, filename: str, actor: User, agent_states: Optional[List[AgentState]] = None
async def insert_file_into_context_windows(
self, source_id: str, text: str, file_id: str, actor: User, agent_states: Optional[List[AgentState]] = None
) -> List[AgentState]:
"""
Insert the uploaded document into the context window of all agents
@@ -1431,51 +1404,48 @@ class SyncServer(Server):
logger.info(f"Inserting document into context window for source: {source_id}")
logger.info(f"Attached agents: {[a.id for a in agent_states]}")
await asyncio.gather(*(self._upsert_document_block(agent_state.id, text, filename, actor) for agent_state in agent_states))
await asyncio.gather(*(self._upsert_file_to_agent(agent_state.id, text, file_id, actor) for agent_state in agent_states))
return agent_states
async def insert_documents_into_context_window(
self, agent_state: AgentState, texts: List[str], filenames: List[str], actor: User
) -> None:
async def insert_files_into_context_window(self, agent_state: AgentState, texts: List[str], file_ids: List[str], actor: User) -> None:
"""
Insert the uploaded documents into the context window of an agent
attached to the given source.
"""
logger.info(f"Inserting documents into context window for agent_state: {agent_state.id}")
if len(texts) != len(filenames):
raise ValueError(f"Mismatch between number of texts ({len(texts)}) and filenames ({len(filenames)})")
if len(texts) != len(file_ids):
raise ValueError(f"Mismatch between number of texts ({len(texts)}) and file ids ({len(file_ids)})")
await asyncio.gather(
*(self._upsert_document_block(agent_state.id, text, filename, actor) for text, filename in zip(texts, filenames))
)
await asyncio.gather(*(self._upsert_file_to_agent(agent_state.id, text, file_id, actor) for text, file_id in zip(texts, file_ids)))
async def remove_document_from_context_windows(self, source_id: str, filename: str, actor: User) -> None:
async def remove_file_from_context_windows(self, source_id: str, file_id: str, actor: User) -> None:
"""
Remove the document from the context window of all agents
attached to the given source.
"""
# TODO: We probably do NOT need to get the entire agent state, we can just get the IDs
agent_states = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
# Return early
if not agent_states:
return
logger.info(f"Removing document from context window for source: {source_id}")
logger.info(f"Removing file from context window for source: {source_id}")
logger.info(f"Attached agents: {[a.id for a in agent_states]}")
await asyncio.gather(*(self._remove_document_block(agent_state.id, filename, actor) for agent_state in agent_states))
await asyncio.gather(*(self._remove_file_from_agent(agent_state.id, file_id, actor) for agent_state in agent_states))
async def remove_documents_from_context_window(self, agent_state: AgentState, filenames: List[str], actor: User) -> None:
async def remove_files_from_context_window(self, agent_state: AgentState, file_ids: List[str], actor: User) -> None:
"""
Remove multiple documents from the context window of an agent
attached to the given source.
"""
logger.info(f"Removing documents from context window for agent_state: {agent_state.id}")
logger.info(f"Documents to remove: {filenames}")
logger.info(f"Removing files from context window for agent_state: {agent_state.id}")
logger.info(f"Files to remove: {file_ids}")
await asyncio.gather(*(self._remove_document_block(agent_state.id, filename, actor) for filename in filenames))
await asyncio.gather(*(self._remove_file_from_agent(agent_state.id, file_id, actor) for file_id in file_ids))
async def create_document_sleeptime_agent_async(
self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False

View File

@@ -86,10 +86,10 @@ class FileProcessor:
logger.info(f"Successfully processed {filename}: {len(all_passages)} passages")
await server.insert_document_into_context_windows(
await server.insert_file_into_context_windows(
source_id=source_id,
text="".join([ocr_response.pages[i].markdown for i in range(min(3, len(ocr_response.pages)))]),
filename=file.filename,
file_id=file_metadata.id,
actor=self.actor,
agent_states=agent_states,
)

View File

@@ -1,5 +1,4 @@
import os
import re
import threading
import time
@@ -56,12 +55,6 @@ def agent_state(client: LettaSDKClient):
client.agents.delete(agent_id=agent_state.id)
import re
import time
import pytest
@pytest.mark.parametrize(
"file_path, expected_value, expected_label_regex",
[
@@ -106,20 +99,23 @@ def test_file_upload_creates_source_blocks_correctly(
assert len(files) == 1
assert files[0].source_id == source.id
# Check that blocks were created
blocks = client.agents.blocks.list(agent_id=agent_state.id)
assert len(blocks) == 2
assert any(expected_value in b.value for b in blocks)
assert any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
# Check that the proper file associations were created
# files_agents = await server.file_agent_manager.list_files_for_agent(agent_id=agent_state.id, actor=actor)
# Remove file from source
client.sources.files.delete(source_id=source.id, file_id=files[0].id)
# Confirm blocks were removed
blocks = client.agents.blocks.list(agent_id=agent_state.id)
assert len(blocks) == 1
assert not any(expected_value in b.value for b in blocks)
assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
# # Check that blocks were created
# blocks = client.agents.blocks.list(agent_id=agent_state.id)
# assert len(blocks) == 2
# assert any(expected_value in b.value for b in blocks)
# assert any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
#
# # Remove file from source
# client.sources.files.delete(source_id=source.id, file_id=files[0].id)
#
# # Confirm blocks were removed
# blocks = client.agents.blocks.list(agent_id=agent_state.id)
# assert len(blocks) == 1
# assert not any(expected_value in b.value for b in blocks)
# assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState):
@@ -156,20 +152,20 @@ def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKC
# Attach after uploading the file
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
# Get the agent state, check blocks exist
blocks = client.agents.blocks.list(agent_id=agent_state.id)
assert len(blocks) == 2
assert "test" in [b.value for b in blocks]
assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
# # Get the agent state, check blocks exist
# blocks = client.agents.blocks.list(agent_id=agent_state.id)
# assert len(blocks) == 2
# assert "test" in [b.value for b in blocks]
# assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
# Detach the source
client.agents.sources.detach(source_id=source.id, agent_id=agent_state.id)
# Get the agent state, check blocks do NOT exist
blocks = client.agents.blocks.list(agent_id=agent_state.id)
assert len(blocks) == 1
assert "test" not in [b.value for b in blocks]
assert not any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
# # Get the agent state, check blocks do NOT exist
# blocks = client.agents.blocks.list(agent_id=agent_state.id)
# assert len(blocks) == 1
# assert "test" not in [b.value for b in blocks]
# assert not any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState):
@@ -202,16 +198,16 @@ def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, a
print("Waiting for jobs to complete...", job.status)
# Get the agent state, check blocks exist
blocks = client.agents.blocks.list(agent_id=agent_state.id)
assert len(blocks) == 2
assert "test" in [b.value for b in blocks]
assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
# blocks = client.agents.blocks.list(agent_id=agent_state.id)
# assert len(blocks) == 2
# assert "test" in [b.value for b in blocks]
# assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
# Remove file from source
client.sources.delete(source_id=source.id)
# Get the agent state, check blocks do NOT exist
blocks = client.agents.blocks.list(agent_id=agent_state.id)
assert len(blocks) == 1
assert "test" not in [b.value for b in blocks]
assert not any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
# blocks = client.agents.blocks.list(agent_id=agent_state.id)
# assert len(blocks) == 1
# assert "test" not in [b.value for b in blocks]
# assert not any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)