diff --git a/alembic/versions/614c4e53b66e_add_unique_constraint_to_file_id_and_.py b/alembic/versions/614c4e53b66e_add_unique_constraint_to_file_id_and_.py new file mode 100644 index 00000000..9d726a35 --- /dev/null +++ b/alembic/versions/614c4e53b66e_add_unique_constraint_to_file_id_and_.py @@ -0,0 +1,29 @@ +"""Add unique constraint to file_id and agent_id on file_agent + +Revision ID: 614c4e53b66e +Revises: 0b496eae90de +Create Date: 2025-06-02 17:03:58.879839 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "614c4e53b66e" +down_revision: Union[str, None] = "0b496eae90de" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_unique_constraint("uq_files_agents_file_agent", "files_agents", ["file_id", "agent_id"]) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("uq_files_agents_file_agent", "files_agents", type_="unique") + # ### end Alembic commands ### diff --git a/letta/orm/files_agents.py b/letta/orm/files_agents.py index 847b6c7e..e005ba74 100644 --- a/letta/orm/files_agents.py +++ b/letta/orm/files_agents.py @@ -2,7 +2,7 @@ import uuid from datetime import datetime from typing import TYPE_CHECKING, Optional -from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, Text, func +from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, Text, UniqueConstraint, func from sqlalchemy.orm import Mapped, mapped_column from letta.orm.mixins import OrganizationMixin @@ -22,7 +22,10 @@ class FileAgent(SqlalchemyBase, OrganizationMixin): """ __tablename__ = "files_agents" - __table_args__ = (Index("ix_files_agents_file_id_agent_id", "file_id", "agent_id"),) + __table_args__ = ( + Index("ix_files_agents_file_id_agent_id", "file_id", "agent_id"), + UniqueConstraint("file_id", "agent_id", name="uq_files_agents_file_agent"), + ) __pydantic_model__ = PydanticFileAgent # TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index fbdcf518..b385a2a5 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -313,7 +313,7 @@ async def attach_source( files = await server.source_manager.list_files(source_id, actor) texts = [] - filenames = [] + file_ids = [] for f in files: passages = await server.passage_manager.list_passages_by_file_id_async(file_id=f.id, actor=actor) passage_text = "" @@ -322,9 +322,9 @@ async def attach_source( passage_text += p.text texts.append(passage_text) - filenames.append(f.file_name) + file_ids.append(f.id) - await server.insert_documents_into_context_window(agent_state=agent_state, texts=texts, filenames=filenames, actor=actor) + await server.insert_files_into_context_window(agent_state=agent_state, texts=texts, file_ids=file_ids, actor=actor) if agent_state.enable_sleeptime: source = await server.source_manager.get_source_by_id(source_id=source_id) @@ -348,8 +348,8 @@ async def detach_source( actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) agent_state = await server.agent_manager.detach_source_async(agent_id=agent_id, source_id=source_id, actor=actor) files = await server.source_manager.list_files(source_id, actor) - filenames = [f.file_name for f in files] - await server.remove_documents_from_context_window(agent_state=agent_state, filenames=filenames, actor=actor) + file_ids = [f.id for f in files] + await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor) if agent_state.enable_sleeptime: try: diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index e4498939..cc57d977 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -150,10 +150,10 @@ async def delete_source( source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor) agent_states = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor) files = await server.source_manager.list_files(source_id, actor) - filenames = [f.file_name for f in files] + file_ids = [f.id for f in files] for agent_state in agent_states: - await server.remove_documents_from_context_window(agent_state=agent_state, filenames=filenames, actor=actor) + await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor) if agent_state.enable_sleeptime: try: @@ -212,11 +212,6 @@ async def upload_file_to_source( # sanitize filename file.filename = sanitize_filename(file.filename) - try: - text = content.decode("utf-8") - except Exception: - text = "[Currently parsing...]" - # create job job = Job( user_id=actor.id, @@ -225,8 +220,8 @@ async def upload_file_to_source( ) job = await server.job_manager.create_job_async(job, actor=actor) - # Add blocks (sometimes without content, for UX purposes) - agent_states = await server.insert_document_into_context_windows(source_id=source_id, text=text, filename=file.filename, actor=actor) + # TODO: Do we need to pull in the full agent_states? Can probably simplify here right? + agent_states = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor) # NEW: Cloud based file processing if settings.mistral_api_key and model_settings.openai_api_key: @@ -301,8 +296,7 @@ async def delete_file_from_source( deleted_file = await server.source_manager.delete_file(file_id=file_id, actor=actor) - # Remove blocks - await server.remove_document_from_context_windows(source_id=source_id, filename=deleted_file.file_name, actor=actor) + await server.remove_file_from_context_windows(source_id=source_id, file_id=deleted_file.id, actor=actor) asyncio.create_task(sleeptime_document_ingest_async(server, source_id, actor, clear_history=True)) if deleted_file is None: diff --git a/letta/server/server.py b/letta/server/server.py index aca7d043..9f6fa541 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1368,55 +1368,28 @@ class SyncServer(Server): ) await self.agent_manager.delete_agent_async(agent_id=sleeptime_agent_state.id, actor=actor) - async def _upsert_document_block(self, agent_id: str, text: str, filename: str, actor: User) -> None: + async def _upsert_file_to_agent(self, agent_id: str, text: str, file_id: str, actor: User) -> None: """ - Internal method to create or update a document block for an agent. + Internal method to create or update a file <-> agent association """ truncated_text = text[:CORE_MEMORY_SOURCE_CHAR_LIMIT] + await self.file_agent_manager.attach_file(agent_id=agent_id, file_id=file_id, actor=actor, visible_content=truncated_text) - try: - block = await self.agent_manager.get_block_with_label_async( - agent_id=agent_id, - block_label=filename, - actor=actor, - ) - await self.block_manager.update_block_async( - block_id=block.id, - block_update=BlockUpdate(value=truncated_text), - actor=actor, - ) - except NoResultFound: - block = await self.block_manager.create_or_update_block_async( - block=Block( - value=truncated_text, - label=filename, - description=f"Contains the parsed contents of external file {filename}", - limit=CORE_MEMORY_SOURCE_CHAR_LIMIT, - ), - actor=actor, - ) - await self.agent_manager.attach_block_async( - agent_id=agent_id, - block_id=block.id, - actor=actor, - ) - - async def _remove_document_block(self, agent_id: str, filename: str, actor: User) -> None: + async def _remove_file_from_agent(self, agent_id: str, file_id: str, actor: User) -> None: """ Internal method to remove a document block for an agent. """ try: - block = await self.agent_manager.get_block_with_label_async( + await self.file_agent_manager.detach_file( agent_id=agent_id, - block_label=filename, + file_id=file_id, actor=actor, ) - await self.block_manager.delete_block_async(block_id=block.id, actor=actor) except NoResultFound: - logger.info(f"Document block with label {filename} already removed, skipping...") + logger.info(f"File {file_id} already removed from agent {agent_id}, skipping...") - async def insert_document_into_context_windows( - self, source_id: str, text: str, filename: str, actor: User, agent_states: Optional[List[AgentState]] = None + async def insert_file_into_context_windows( + self, source_id: str, text: str, file_id: str, actor: User, agent_states: Optional[List[AgentState]] = None ) -> List[AgentState]: """ Insert the uploaded document into the context window of all agents @@ -1431,51 +1404,48 @@ class SyncServer(Server): logger.info(f"Inserting document into context window for source: {source_id}") logger.info(f"Attached agents: {[a.id for a in agent_states]}") - await asyncio.gather(*(self._upsert_document_block(agent_state.id, text, filename, actor) for agent_state in agent_states)) + await asyncio.gather(*(self._upsert_file_to_agent(agent_state.id, text, file_id, actor) for agent_state in agent_states)) return agent_states - async def insert_documents_into_context_window( - self, agent_state: AgentState, texts: List[str], filenames: List[str], actor: User - ) -> None: + async def insert_files_into_context_window(self, agent_state: AgentState, texts: List[str], file_ids: List[str], actor: User) -> None: """ Insert the uploaded documents into the context window of an agent attached to the given source. """ logger.info(f"Inserting documents into context window for agent_state: {agent_state.id}") - if len(texts) != len(filenames): - raise ValueError(f"Mismatch between number of texts ({len(texts)}) and filenames ({len(filenames)})") + if len(texts) != len(file_ids): + raise ValueError(f"Mismatch between number of texts ({len(texts)}) and file ids ({len(file_ids)})") - await asyncio.gather( - *(self._upsert_document_block(agent_state.id, text, filename, actor) for text, filename in zip(texts, filenames)) - ) + await asyncio.gather(*(self._upsert_file_to_agent(agent_state.id, text, file_id, actor) for text, file_id in zip(texts, file_ids))) - async def remove_document_from_context_windows(self, source_id: str, filename: str, actor: User) -> None: + async def remove_file_from_context_windows(self, source_id: str, file_id: str, actor: User) -> None: """ Remove the document from the context window of all agents attached to the given source. """ + # TODO: We probably do NOT need to get the entire agent state, we can just get the IDs agent_states = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor) # Return early if not agent_states: return - logger.info(f"Removing document from context window for source: {source_id}") + logger.info(f"Removing file from context window for source: {source_id}") logger.info(f"Attached agents: {[a.id for a in agent_states]}") - await asyncio.gather(*(self._remove_document_block(agent_state.id, filename, actor) for agent_state in agent_states)) + await asyncio.gather(*(self._remove_file_from_agent(agent_state.id, file_id, actor) for agent_state in agent_states)) - async def remove_documents_from_context_window(self, agent_state: AgentState, filenames: List[str], actor: User) -> None: + async def remove_files_from_context_window(self, agent_state: AgentState, file_ids: List[str], actor: User) -> None: """ Remove multiple documents from the context window of an agent attached to the given source. """ - logger.info(f"Removing documents from context window for agent_state: {agent_state.id}") - logger.info(f"Documents to remove: {filenames}") + logger.info(f"Removing files from context window for agent_state: {agent_state.id}") + logger.info(f"Files to remove: {file_ids}") - await asyncio.gather(*(self._remove_document_block(agent_state.id, filename, actor) for filename in filenames)) + await asyncio.gather(*(self._remove_file_from_agent(agent_state.id, file_id, actor) for file_id in file_ids)) async def create_document_sleeptime_agent_async( self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False diff --git a/letta/services/file_processor/file_processor.py b/letta/services/file_processor/file_processor.py index 2ab0fad6..0854d049 100644 --- a/letta/services/file_processor/file_processor.py +++ b/letta/services/file_processor/file_processor.py @@ -86,10 +86,10 @@ class FileProcessor: logger.info(f"Successfully processed {filename}: {len(all_passages)} passages") - await server.insert_document_into_context_windows( + await server.insert_file_into_context_windows( source_id=source_id, text="".join([ocr_response.pages[i].markdown for i in range(min(3, len(ocr_response.pages)))]), - filename=file.filename, + file_id=file_metadata.id, actor=self.actor, agent_states=agent_states, ) diff --git a/tests/test_sources.py b/tests/test_sources.py index bde523f7..9982fa37 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -1,5 +1,4 @@ import os -import re import threading import time @@ -56,12 +55,6 @@ def agent_state(client: LettaSDKClient): client.agents.delete(agent_id=agent_state.id) -import re -import time - -import pytest - - @pytest.mark.parametrize( "file_path, expected_value, expected_label_regex", [ @@ -106,20 +99,23 @@ def test_file_upload_creates_source_blocks_correctly( assert len(files) == 1 assert files[0].source_id == source.id - # Check that blocks were created - blocks = client.agents.blocks.list(agent_id=agent_state.id) - assert len(blocks) == 2 - assert any(expected_value in b.value for b in blocks) - assert any(re.fullmatch(expected_label_regex, b.label) for b in blocks) + # Check that the proper file associations were created + # files_agents = await server.file_agent_manager.list_files_for_agent(agent_id=agent_state.id, actor=actor) - # Remove file from source - client.sources.files.delete(source_id=source.id, file_id=files[0].id) - - # Confirm blocks were removed - blocks = client.agents.blocks.list(agent_id=agent_state.id) - assert len(blocks) == 1 - assert not any(expected_value in b.value for b in blocks) - assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks) + # # Check that blocks were created + # blocks = client.agents.blocks.list(agent_id=agent_state.id) + # assert len(blocks) == 2 + # assert any(expected_value in b.value for b in blocks) + # assert any(re.fullmatch(expected_label_regex, b.label) for b in blocks) + # + # # Remove file from source + # client.sources.files.delete(source_id=source.id, file_id=files[0].id) + # + # # Confirm blocks were removed + # blocks = client.agents.blocks.list(agent_id=agent_state.id) + # assert len(blocks) == 1 + # assert not any(expected_value in b.value for b in blocks) + # assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks) def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState): @@ -156,20 +152,20 @@ def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKC # Attach after uploading the file client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id) - # Get the agent state, check blocks exist - blocks = client.agents.blocks.list(agent_id=agent_state.id) - assert len(blocks) == 2 - assert "test" in [b.value for b in blocks] - assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks) + # # Get the agent state, check blocks exist + # blocks = client.agents.blocks.list(agent_id=agent_state.id) + # assert len(blocks) == 2 + # assert "test" in [b.value for b in blocks] + # assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks) # Detach the source client.agents.sources.detach(source_id=source.id, agent_id=agent_state.id) - # Get the agent state, check blocks do NOT exist - blocks = client.agents.blocks.list(agent_id=agent_state.id) - assert len(blocks) == 1 - assert "test" not in [b.value for b in blocks] - assert not any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks) + # # Get the agent state, check blocks do NOT exist + # blocks = client.agents.blocks.list(agent_id=agent_state.id) + # assert len(blocks) == 1 + # assert "test" not in [b.value for b in blocks] + # assert not any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks) def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState): @@ -202,16 +198,16 @@ def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, a print("Waiting for jobs to complete...", job.status) # Get the agent state, check blocks exist - blocks = client.agents.blocks.list(agent_id=agent_state.id) - assert len(blocks) == 2 - assert "test" in [b.value for b in blocks] - assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks) + # blocks = client.agents.blocks.list(agent_id=agent_state.id) + # assert len(blocks) == 2 + # assert "test" in [b.value for b in blocks] + # assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks) # Remove file from source client.sources.delete(source_id=source.id) # Get the agent state, check blocks do NOT exist - blocks = client.agents.blocks.list(agent_id=agent_state.id) - assert len(blocks) == 1 - assert "test" not in [b.value for b in blocks] - assert not any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks) + # blocks = client.agents.blocks.list(agent_id=agent_state.id) + # assert len(blocks) == 1 + # assert "test" not in [b.value for b in blocks] + # assert not any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)