diff --git a/alembic/versions/9792f94e961d_add_file_processing_status_to_.py b/alembic/versions/9792f94e961d_add_file_processing_status_to_.py new file mode 100644 index 00000000..f259da66 --- /dev/null +++ b/alembic/versions/9792f94e961d_add_file_processing_status_to_.py @@ -0,0 +1,41 @@ +"""Add file processing status to FileMetadata and related indices + +Revision ID: 9792f94e961d +Revises: cdd4a1c11aee +Create Date: 2025-06-05 18:51:57.022594 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "9792f94e961d" +down_revision: Union[str, None] = "cdd4a1c11aee" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_unique_constraint("uq_file_contents_file_id", "file_contents", ["file_id"]) + op.add_column("files", sa.Column("processing_status", sa.String(), nullable=False)) + op.add_column("files", sa.Column("error_message", sa.Text(), nullable=True)) + op.create_index("ix_files_org_created", "files", ["organization_id", sa.literal_column("created_at DESC")], unique=False) + op.create_index("ix_files_processing_status", "files", ["processing_status"], unique=False) + op.create_index("ix_files_source_created", "files", ["source_id", sa.literal_column("created_at DESC")], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_files_source_created", table_name="files") + op.drop_index("ix_files_processing_status", table_name="files") + op.drop_index("ix_files_org_created", table_name="files") + op.drop_column("files", "error_message") + op.drop_column("files", "processing_status") + op.drop_constraint("uq_file_contents_file_id", "file_contents", type_="unique") + # ### end Alembic commands ### diff --git a/letta/orm/file.py b/letta/orm/file.py index 017acc02..2e8e5088 100644 --- a/letta/orm/file.py +++ b/letta/orm/file.py @@ -1,12 +1,13 @@ import uuid from typing import TYPE_CHECKING, List, Optional -from sqlalchemy import ForeignKey, Integer, String, Text +from sqlalchemy import ForeignKey, Index, Integer, String, Text, UniqueConstraint, desc from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.mixins import OrganizationMixin, SourceMixin from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.enums import FileProcessingStatus from letta.schemas.file import FileMetadata as PydanticFileMetadata if TYPE_CHECKING: @@ -23,13 +24,13 @@ class FileContent(SqlalchemyBase): """Holds the full text content of a file (potentially large).""" __tablename__ = "file_contents" + __table_args__ = (UniqueConstraint("file_id", name="uq_file_contents_file_id"),) # TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase # TODO: Some still rely on the Pydantic object to do this id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"file_content-{uuid.uuid4()}") - file_id: Mapped[str] = mapped_column( - ForeignKey("files.id", ondelete="CASCADE"), primary_key=True, doc="Foreign key to files table; also serves as primary key." - ) + file_id: Mapped[str] = mapped_column(ForeignKey("files.id", ondelete="CASCADE"), nullable=False, doc="Foreign key to files table.") + text: Mapped[str] = mapped_column(Text, nullable=False, doc="Full plain-text content of the file (e.g., extracted from a PDF).") # back-reference to FileMetadata @@ -41,6 +42,11 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs): __tablename__ = "files" __pydantic_model__ = PydanticFileMetadata + __table_args__ = ( + Index("ix_files_org_created", "organization_id", desc("created_at")), + Index("ix_files_source_created", "source_id", desc("created_at")), + Index("ix_files_processing_status", "processing_status"), + ) file_name: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The name of the file.") file_path: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The file path on the system.") @@ -48,6 +54,11 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs): file_size: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, doc="The size of the file in bytes.") file_creation_date: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The creation date of the file.") file_last_modified_date: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The last modified date of the file.") + processing_status: Mapped[FileProcessingStatus] = mapped_column( + String, default=FileProcessingStatus.PENDING, nullable=False, doc="The current processing status of the file." + ) + + error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Any error message encountered during processing.") # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin") @@ -93,6 +104,8 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs): file_size=self.file_size, file_creation_date=self.file_creation_date, file_last_modified_date=self.file_last_modified_date, + processing_status=self.processing_status, + error_message=self.error_message, created_at=self.created_at, updated_at=self.updated_at, is_deleted=self.is_deleted, diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index 555ffadd..c51bcbca 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -87,3 +87,11 @@ class ToolRuleType(str, Enum): constrain_child_tools = "constrain_child_tools" max_count_per_step = "max_count_per_step" parent_last_tool = "parent_last_tool" + + +class FileProcessingStatus(str, Enum): + PENDING = "pending" + PARSING = "parsing" + EMBEDDING = "embedding" + COMPLETED = "completed" + ERROR = "error" diff --git a/letta/schemas/file.py b/letta/schemas/file.py index e7db7621..a4170b0a 100644 --- a/letta/schemas/file.py +++ b/letta/schemas/file.py @@ -4,6 +4,7 @@ from typing import Optional from pydantic import Field +from letta.schemas.enums import FileProcessingStatus from letta.schemas.letta_base import LettaBase @@ -34,6 +35,11 @@ class FileMetadata(FileMetadataBase): file_size: Optional[int] = Field(None, description="The size of the file in bytes.") file_creation_date: Optional[str] = Field(None, description="The creation date of the file.") file_last_modified_date: Optional[str] = Field(None, description="The last modified date of the file.") + processing_status: FileProcessingStatus = Field( + default=FileProcessingStatus.PENDING, + description="The current processing status of the file (e.g. pending, parsing, embedding, completed, error).", + ) + error_message: Optional[str] = Field(default=None, description="Optional error message if the file failed processing.") # orm metadata, optional fields created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The creation date of the file.") diff --git a/letta/services/file_processor/file_processor.py b/letta/services/file_processor/file_processor.py index dde76546..533723f1 100644 --- a/letta/services/file_processor/file_processor.py +++ b/letta/services/file_processor/file_processor.py @@ -5,7 +5,7 @@ from fastapi import UploadFile from letta.log import get_logger from letta.schemas.agent import AgentState -from letta.schemas.enums import JobStatus +from letta.schemas.enums import FileProcessingStatus, JobStatus from letta.schemas.file import FileMetadata from letta.schemas.job import Job, JobUpdate from letta.schemas.passage import Passage @@ -56,6 +56,10 @@ class FileProcessor: file_metadata = self._extract_upload_file_metadata(file, source_id=source_id) filename = file_metadata.file_name + # Create file as early as possible with no content + file_metadata.processing_status = FileProcessingStatus.PARSING # Parsing now + file_metadata = await self.source_manager.create_file(file_metadata, self.actor) + try: # Ensure we're working with bytes if isinstance(content, str): @@ -67,9 +71,28 @@ class FileProcessor: logger.info(f"Starting OCR extraction for {filename}") ocr_response = await self.file_parser.extract_text(content, mime_type=file_metadata.file_type) - # persist file with raw text + # update file with raw text raw_markdown_text = "".join([page.markdown for page in ocr_response.pages]) - await self.source_manager.create_file(file_metadata, self.actor, text=raw_markdown_text) + file_metadata = await self.source_manager.upsert_file_content( + file_id=file_metadata.id, text=raw_markdown_text, actor=self.actor + ) + file_metadata = await self.source_manager.update_file_status( + file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.EMBEDDING + ) + + # Insert to agent context window + # TODO: Rethink this line chunking mechanism + content_lines = self.line_chunker.chunk_text(text=raw_markdown_text) + visible_content = "\n".join(content_lines) + + await server.insert_file_into_context_windows( + source_id=source_id, + text=visible_content, + file_id=file_metadata.id, + file_name=file_metadata.file_name, + actor=self.actor, + agent_states=agent_states, + ) if not ocr_response or len(ocr_response.pages) == 0: raise ValueError("No text extracted from PDF") @@ -92,29 +115,20 @@ class FileProcessor: logger.info(f"Successfully processed {filename}: {len(all_passages)} passages") - # TODO: Rethink this line chunking mechanism - content_lines = self.line_chunker.chunk_text(text=raw_markdown_text) - visible_content = "\n".join(content_lines) - - await server.insert_file_into_context_windows( - source_id=source_id, - text=visible_content, - file_id=file_metadata.id, - file_name=file_metadata.file_name, - actor=self.actor, - agent_states=agent_states, - ) - # update job status if job: job.status = JobStatus.completed job.metadata["num_passages"] = len(all_passages) await self.job_manager.update_job_by_id_async(job_id=job.id, job_update=JobUpdate(**job.model_dump()), actor=self.actor) + await self.source_manager.update_file_status( + file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.COMPLETED + ) + return all_passages except Exception as e: - logger.error(f"PDF processing failed for {filename}: {str(e)}") + logger.error(f"File processing failed for {filename}: {str(e)}") # update job status if job: @@ -122,6 +136,10 @@ class FileProcessor: job.metadata["error"] = str(e) await self.job_manager.update_job_by_id_async(job_id=job.id, job_update=JobUpdate(**job.model_dump()), actor=self.actor) + await self.source_manager.update_file_status( + file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.ERROR, error_message=str(e) + ) + return [] def _extract_upload_file_metadata(self, file: UploadFile, source_id: str) -> FileMetadata: diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index 83131383..614c0c31 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -1,7 +1,9 @@ import asyncio +from datetime import datetime from typing import List, Optional -from sqlalchemy import select +from sqlalchemy import select, update +from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import selectinload @@ -12,6 +14,7 @@ from letta.orm.source import Source as SourceModel from letta.orm.sqlalchemy_base import AccessType from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState as PydanticAgentState +from letta.schemas.enums import FileProcessingStatus from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.source import Source as PydanticSource from letta.schemas.source import SourceUpdate @@ -226,6 +229,95 @@ class SourceManager: except NoResultFound: return None + @enforce_types + @trace_method + async def update_file_status( + self, + *, + file_id: str, + actor: PydanticUser, + processing_status: Optional[FileProcessingStatus] = None, + error_message: Optional[str] = None, + ) -> PydanticFileMetadata: + """ + Update processing_status and/or error_message on a FileMetadata row. + + * 1st round-trip → UPDATE + * 2nd round-trip → SELECT fresh row (same as read_async) + """ + + if processing_status is None and error_message is None: + raise ValueError("Nothing to update") + + values: dict[str, object] = {"updated_at": datetime.utcnow()} + if processing_status is not None: + values["processing_status"] = processing_status + if error_message is not None: + values["error_message"] = error_message + + async with db_registry.async_session() as session: + # Fast in-place update – no ORM hydration + stmt = ( + update(FileMetadataModel) + .where( + FileMetadataModel.id == file_id, + FileMetadataModel.organization_id == actor.organization_id, + ) + .values(**values) + ) + await session.execute(stmt) + await session.commit() + + # Reload via normal accessor so we return a fully-attached object + file_orm = await FileMetadataModel.read_async( + db_session=session, + identifier=file_id, + actor=actor, + ) + return await file_orm.to_pydantic_async() + + @enforce_types + @trace_method + async def upsert_file_content( + self, + *, + file_id: str, + text: str, + actor: PydanticUser, + ) -> PydanticFileMetadata: + async with db_registry.async_session() as session: + await FileMetadataModel.read_async(session, file_id, actor) + + dialect_name = session.bind.dialect.name + + if dialect_name == "postgresql": + stmt = ( + pg_insert(FileContentModel) + .values(file_id=file_id, text=text) + .on_conflict_do_update( + index_elements=[FileContentModel.file_id], + set_={"text": text}, + ) + ) + await session.execute(stmt) + else: + # Emulate upsert for SQLite and others + stmt = select(FileContentModel).where(FileContentModel.file_id == file_id) + result = await session.execute(stmt) + existing = result.scalar_one_or_none() + + if existing: + await session.execute(update(FileContentModel).where(FileContentModel.file_id == file_id).values(text=text)) + else: + session.add(FileContentModel(file_id=file_id, text=text)) + + await session.commit() + + # Reload with content + query = select(FileMetadataModel).options(selectinload(FileMetadataModel.content)).where(FileMetadataModel.id == file_id) + result = await session.execute(query) + return await result.scalar_one().to_pydantic_async(include_content=True) + @enforce_types @trace_method async def list_files( diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index f2ebe000..eb223145 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -15,7 +15,7 @@ from letta.helpers.datetime_helpers import get_utc_time from letta.orm import FileMetadata, Source from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import MessageRole, MessageStreamStatus +from letta.schemas.enums import MessageRole from letta.schemas.letta_message import ( AssistantMessage, LettaMessage, @@ -25,10 +25,8 @@ from letta.schemas.letta_message import ( ToolReturnMessage, UserMessage, ) -from letta.schemas.letta_response import LettaStreamingResponse from letta.schemas.llm_config import LLMConfig from letta.schemas.message import MessageCreate -from letta.schemas.usage import LettaUsageStatistics from letta.services.helpers.agent_manager_helper import initialize_message_sequence from letta.services.organization_manager import OrganizationManager from letta.services.user_manager import UserManager @@ -214,74 +212,6 @@ def test_core_memory(disable_e2b_api_key, client: RESTClient, agent: AgentState) assert "Timber" in memory.get_block("human").value, f"Updating core memory failed: {memory.get_block('human').value}" -@pytest.mark.parametrize( - "stream_tokens,model", - [ - (True, "gpt-4o-mini"), - (True, "claude-3-sonnet-20240229"), - (False, "gpt-4o-mini"), - (False, "claude-3-sonnet-20240229"), - ], -) -def test_streaming_send_message( - disable_e2b_api_key, - client: RESTClient, - agent: AgentState, - stream_tokens: bool, - model: str, -): - # Update agent's model - agent.llm_config.model = model - - # First, try streaming just steps - - # Next, try streaming both steps and tokens - response = client.send_message( - agent_id=agent.id, - message="This is a test. Repeat after me: 'banana'", - role="user", - stream_steps=True, - stream_tokens=stream_tokens, - ) - - # Some manual checks to run - # 1. Check that there were inner thoughts - inner_thoughts_exist = False - inner_thoughts_count = 0 - # 2. Check that the agent runs `send_message` - send_message_ran = False - # 3. Check that we get all the start/stop/end tokens we want - # This includes all of the MessageStreamStatus enums - done = False - - assert response, "Sending message failed" - for chunk in response: - assert isinstance(chunk, LettaStreamingResponse) - if isinstance(chunk, ReasoningMessage) and chunk.reasoning and chunk.reasoning != "": - inner_thoughts_exist = True - inner_thoughts_count += 1 - if isinstance(chunk, ToolCallMessage) and chunk.tool_call and chunk.tool_call.name == "send_message": - send_message_ran = True - if isinstance(chunk, AssistantMessage): - send_message_ran = True - if isinstance(chunk, MessageStreamStatus): - if chunk == MessageStreamStatus.done: - assert not done, "Message stream already done" - done = True - if isinstance(chunk, LettaUsageStatistics): - # Some rough metrics for a reasonable usage pattern - assert chunk.step_count == 1 - assert chunk.completion_tokens > 10 - assert chunk.prompt_tokens > 1000 - assert chunk.total_tokens > 1000 - - # If stream tokens, we expect at least one inner thought - assert inner_thoughts_count >= 1, "Expected more than one inner thought" - assert inner_thoughts_exist, "No inner thoughts found" - assert send_message_ran, "send_message function call not found" - assert done, "Message stream not done" - - def test_humans_personas(client: RESTClient, agent: AgentState): # _reset_config() diff --git a/tests/test_managers.py b/tests/test_managers.py index 0cbcfcc6..f467bdad 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -49,7 +49,7 @@ from letta.schemas.agent import AgentStepState, CreateAgent, UpdateAgent from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import BlockUpdate, CreateBlock from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole, ProviderType +from letta.schemas.enums import AgentStepStatus, FileProcessingStatus, JobStatus, MessageRole, ProviderType from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate, IdentityUpsert @@ -4321,6 +4321,105 @@ async def test_delete_file(server: SyncServer, default_user, default_source): assert len(files) == 0 +@pytest.mark.asyncio +async def test_update_file_status_basic(server, default_user, default_source): + """Update processing status and error message for a file.""" + meta = PydanticFileMetadata( + file_name="status_test.txt", + file_path="/tmp/status_test.txt", + file_type="text/plain", + file_size=100, + source_id=default_source.id, + ) + created = await server.source_manager.create_file(file_metadata=meta, actor=default_user) + + # Update status only + updated = await server.source_manager.update_file_status( + file_id=created.id, + actor=default_user, + processing_status=FileProcessingStatus.PARSING, + ) + assert updated.processing_status == FileProcessingStatus.PARSING + assert updated.error_message is None + + # Update both status and error message + updated = await server.source_manager.update_file_status( + file_id=created.id, + actor=default_user, + processing_status=FileProcessingStatus.ERROR, + error_message="Parse failed", + ) + assert updated.processing_status == FileProcessingStatus.ERROR + assert updated.error_message == "Parse failed" + + +@pytest.mark.asyncio +async def test_update_file_status_error_only(server, default_user, default_source): + """Update just the error message, leave status unchanged.""" + meta = PydanticFileMetadata( + file_name="error_only.txt", + file_path="/tmp/error_only.txt", + file_type="text/plain", + file_size=123, + source_id=default_source.id, + ) + created = await server.source_manager.create_file(file_metadata=meta, actor=default_user) + + updated = await server.source_manager.update_file_status( + file_id=created.id, + actor=default_user, + error_message="Timeout while embedding", + ) + assert updated.error_message == "Timeout while embedding" + assert updated.processing_status == FileProcessingStatus.PENDING # default from creation + + +@pytest.mark.asyncio +async def test_upsert_file_content_basic(server: SyncServer, default_user, default_source, async_session): + """Test creating and updating file content with upsert_file_content().""" + initial_text = "Initial content" + updated_text = "Updated content" + + # Step 1: Create file with no content + meta = PydanticFileMetadata( + file_name="upsert_body.txt", + file_path="/tmp/upsert_body.txt", + file_type="text/plain", + file_size=len(initial_text), + source_id=default_source.id, + ) + created = await server.source_manager.create_file(file_metadata=meta, actor=default_user) + assert created.content is None + + # Step 2: Insert new content + file_with_content = await server.source_manager.upsert_file_content( + file_id=created.id, + text=initial_text, + actor=default_user, + ) + assert file_with_content.content == initial_text + + # Verify body row exists + count = await _count_file_content_rows(async_session, created.id) + assert count == 1 + + # Step 3: Update existing content + file_with_updated_content = await server.source_manager.upsert_file_content( + file_id=created.id, + text=updated_text, + actor=default_user, + ) + assert file_with_updated_content.content == updated_text + + # Ensure still only 1 row in content table + count = await _count_file_content_rows(async_session, created.id) + assert count == 1 + + # Ensure `updated_at` is bumped + orm_file = await async_session.get(FileMetadataModel, created.id) + assert orm_file.updated_at > orm_file.created_at + + # ====================================================================================================================== # SandboxConfigManager Tests - Sandbox Configs # ====================================================================================================================== diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index f8c4185e..befab5df 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -679,72 +679,3 @@ def test_many_blocks(client: LettaSDKClient): client.agents.delete(agent1.id) client.agents.delete(agent2.id) - - -def test_sources_crud(client: LettaSDKClient, agent: AgentState): - - # Clear existing sources - for source in client.sources.list(): - client.sources.delete(source_id=source.id) - - # Clear existing jobs - for job in client.jobs.list(): - client.jobs.delete(job_id=job.id) - - # Create a new source - source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") - assert len(client.sources.list()) == 1 - - # delete the source - client.sources.delete(source_id=source.id) - assert len(client.sources.list()) == 0 - source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") - - # Load files into the source - file_a_path = "tests/data/memgpt_paper.pdf" - file_b_path = "tests/data/test.txt" - - # Upload the files - with open(file_a_path, "rb") as f: - job_a = client.sources.files.upload(source_id=source.id, file=f) - - with open(file_b_path, "rb") as f: - job_b = client.sources.files.upload(source_id=source.id, file=f) - - # Wait for the jobs to complete - while job_a.status != "completed" or job_b.status != "completed": - time.sleep(1) - job_a = client.jobs.retrieve(job_id=job_a.id) - job_b = client.jobs.retrieve(job_id=job_b.id) - print("Waiting for jobs to complete...", job_a.status, job_b.status) - - # Get the first file with pagination - files_a = client.sources.files.list(source_id=source.id, limit=1) - assert len(files_a) == 1 - assert files_a[0].source_id == source.id - - # Use the cursor from files_a to get the remaining file - files_b = client.sources.files.list(source_id=source.id, limit=1, after=files_a[-1].id) - assert len(files_b) == 1 - assert files_b[0].source_id == source.id - - # Check files are different to ensure the cursor works - assert files_a[0].file_name != files_b[0].file_name - - # Use the cursor from files_b to list files, should be empty - files = client.sources.files.list(source_id=source.id, limit=1, after=files_b[-1].id) - assert len(files) == 0 # Should be empty - - # list passages - passages = client.sources.passages.list(source_id=source.id) - assert len(passages) > 0 - - # attach to an agent - assert len(client.agents.passages.list(agent_id=agent.id)) == 0 - client.agents.sources.attach(source_id=source.id, agent_id=agent.id) - assert len(client.agents.passages.list(agent_id=agent.id)) > 0 - assert len(client.agents.sources.list(agent_id=agent.id)) == 1 - - # detach from agent - client.agents.sources.detach(source_id=source.id, agent_id=agent.id) - assert len(client.agents.passages.list(agent_id=agent.id)) == 0 diff --git a/tests/test_sources.py b/tests/test_sources.py index 5fdd69d1..c8246e23 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -215,7 +215,7 @@ def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKC agent_state = client.agents.retrieve(agent_id=agent_state.id) blocks = agent_state.memory.file_blocks assert len(blocks) == 1 - assert "test" in [b.value for b in blocks] + assert any("test" in b.value for b in blocks) assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks) # Detach the source @@ -225,7 +225,7 @@ def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKC agent_state = client.agents.retrieve(agent_id=agent_state.id) blocks = agent_state.memory.file_blocks assert len(blocks) == 0 - assert "test" not in [b.value for b in blocks] + assert not any("test" in b.value for b in blocks) assert not any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks) @@ -254,7 +254,7 @@ def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, a agent_state = client.agents.retrieve(agent_id=agent_state.id) blocks = agent_state.memory.file_blocks assert len(blocks) == 1 - assert "test" in [b.value for b in blocks] + assert any("test" in b.value for b in blocks) assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks) # Remove file from source @@ -264,7 +264,7 @@ def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, a agent_state = client.agents.retrieve(agent_id=agent_state.id) blocks = agent_state.memory.file_blocks assert len(blocks) == 0 - assert "test" not in [b.value for b in blocks] + assert not any("test" in b.value for b in blocks) assert not any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)