diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index bd77ac3e..efba90c3 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -404,9 +404,15 @@ async def get_file_metadata( file_status = file_metadata.processing_status else: file_status = FileProcessingStatus.COMPLETED - file_metadata = await server.file_manager.update_file_status( - file_id=file_metadata.id, actor=actor, chunks_embedded=len(ids), processing_status=file_status - ) + try: + file_metadata = await server.file_manager.update_file_status( + file_id=file_metadata.id, actor=actor, chunks_embedded=len(ids), processing_status=file_status + ) + except ValueError as e: + # state transition was blocked - this is a race condition + # log it but don't fail the request since we're just reading metadata + logger.warning(f"Race condition detected in get_file_metadata: {str(e)}") + # return the current file state without updating return file_metadata diff --git a/letta/services/agent_serialization_manager.py b/letta/services/agent_serialization_manager.py index a56be24e..632fd77b 100644 --- a/letta/services/agent_serialization_manager.py +++ b/letta/services/agent_serialization_manager.py @@ -18,6 +18,7 @@ from letta.schemas.agent_file import ( ToolSchema, ) from letta.schemas.block import Block +from letta.schemas.enums import FileProcessingStatus from letta.schemas.file import FileMetadata from letta.schemas.message import Message from letta.schemas.source import Source @@ -422,6 +423,11 @@ class AgentSerializationManager: file_data = file_schema.model_dump(exclude={"id", "content"}) # Remap source_id from file ID to database ID file_data["source_id"] = file_to_db_ids[file_schema.source_id] + # Set processing status to PARSING since we have parsed content but need to re-embed + file_data["processing_status"] = FileProcessingStatus.PARSING + file_data["error_message"] = None + file_data["total_chunks"] = None + file_data["chunks_embedded"] = None file_metadata = FileMetadata(**file_data) created_file = await self.file_manager.create_file(file_metadata, actor, text=file_schema.content) file_to_db_ids[file_schema.id] = created_file.id diff --git a/letta/services/file_manager.py b/letta/services/file_manager.py index 29335412..144d2e72 100644 --- a/letta/services/file_manager.py +++ b/letta/services/file_manager.py @@ -143,12 +143,31 @@ class FileManager: error_message: Optional[str] = None, total_chunks: Optional[int] = None, chunks_embedded: Optional[int] = None, - ) -> PydanticFileMetadata: + enforce_state_transitions: bool = True, + ) -> Optional[PydanticFileMetadata]: """ Update processing_status, error_message, total_chunks, and/or chunks_embedded on a FileMetadata row. - * 1st round-trip → UPDATE - * 2nd round-trip → SELECT fresh row (same as read_async) + Enforces state transition rules (when enforce_state_transitions=True): + - PENDING -> PARSING -> EMBEDDING -> COMPLETED (normal flow) + - Any non-terminal state -> ERROR + - ERROR and COMPLETED are terminal (no transitions allowed) + + Args: + file_id: ID of the file to update + actor: User performing the update + processing_status: New processing status to set + error_message: Error message to set (if any) + total_chunks: Total number of chunks in the file + chunks_embedded: Number of chunks already embedded + enforce_state_transitions: Whether to enforce state transition rules (default: True). + Set to False to bypass validation for testing or special cases. + + Returns: + Updated file metadata, or None if the update was blocked + + * 1st round-trip → UPDATE with optional state validation + * 2nd round-trip → SELECT fresh row (same as read_async) if update succeeded """ if processing_status is None and error_message is None and total_chunks is None and chunks_embedded is None: @@ -164,23 +183,79 @@ class FileManager: if chunks_embedded is not None: values["chunks_embedded"] = chunks_embedded + # validate state transitions before making any database calls + if enforce_state_transitions and processing_status == FileProcessingStatus.PENDING: + # PENDING cannot be set after initial creation + raise ValueError(f"Cannot transition to PENDING state for file {file_id} - PENDING is only valid as initial state") + async with db_registry.async_session() as session: - # Fast in-place update – no ORM hydration + # build where conditions + where_conditions = [ + FileMetadataModel.id == file_id, + FileMetadataModel.organization_id == actor.organization_id, + ] + + # only add state transition validation if enforce_state_transitions is True + if enforce_state_transitions: + # prevent updates to terminal states (ERROR, COMPLETED) + where_conditions.append( + FileMetadataModel.processing_status.notin_([FileProcessingStatus.ERROR, FileProcessingStatus.COMPLETED]) + ) + + if processing_status is not None: + # enforce specific transitions based on target status + if processing_status == FileProcessingStatus.PARSING: + where_conditions.append(FileMetadataModel.processing_status == FileProcessingStatus.PENDING) + elif processing_status == FileProcessingStatus.EMBEDDING: + where_conditions.append(FileMetadataModel.processing_status == FileProcessingStatus.PARSING) + elif processing_status == FileProcessingStatus.COMPLETED: + where_conditions.append(FileMetadataModel.processing_status == FileProcessingStatus.EMBEDDING) + # ERROR can be set from any non-terminal state (already handled by terminal check above) + + # fast in-place update with state validation stmt = ( update(FileMetadataModel) - .where( - FileMetadataModel.id == file_id, - FileMetadataModel.organization_id == actor.organization_id, - ) + .where(*where_conditions) .values(**values) + .returning(FileMetadataModel.id) # return id if update succeeded ) - await session.execute(stmt) + result = await session.execute(stmt) + updated_id = result.scalar() + + if not updated_id: + # update was blocked + await session.commit() + + if enforce_state_transitions: + # update was blocked by state transition rules - raise error + # fetch current state to provide informative error + current_file = await FileMetadataModel.read_async( + db_session=session, + identifier=file_id, + actor=actor, + ) + current_status = current_file.processing_status + + # build informative error message + if processing_status is not None: + if current_status in [FileProcessingStatus.ERROR, FileProcessingStatus.COMPLETED]: + raise ValueError( + f"Cannot update file {file_id} status from terminal state {current_status} to {processing_status}" + ) + else: + raise ValueError(f"Invalid state transition for file {file_id}: {current_status} -> {processing_status}") + else: + raise ValueError(f"Cannot update file {file_id} in terminal state {current_status}") + else: + # validation was bypassed but update still failed (e.g., file doesn't exist) + return None + await session.commit() # invalidate cache for this file await self._invalidate_file_caches(file_id, actor) - # Reload via normal accessor so we return a fully-attached object + # reload via normal accessor so we return a fully-attached object file_orm = await FileMetadataModel.read_async( db_session=session, identifier=file_id, diff --git a/letta/services/file_processor/file_processor.py b/letta/services/file_processor/file_processor.py index 14e492be..92f6c086 100644 --- a/letta/services/file_processor/file_processor.py +++ b/letta/services/file_processor/file_processor.py @@ -300,7 +300,7 @@ class FileProcessor: # Create OCR response from existing content ocr_response = self._create_ocr_response_from_content(content) - # Update file status to embedding + # Update file status to embedding (valid transition from PARSING) file_metadata = await self.file_manager.update_file_status( file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.EMBEDDING ) @@ -320,12 +320,14 @@ class FileProcessor: ) log_event("file_processor.import_passages_created", {"filename": filename, "total_passages": len(all_passages)}) - # Update file status to completed + # Update file status to completed (valid transition from EMBEDDING) if not self.using_pinecone: await self.file_manager.update_file_status( file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.COMPLETED ) else: + # For Pinecone, update chunk counts but keep status at EMBEDDING + # The status will be updated to COMPLETED later when chunks are confirmed embedded await self.file_manager.update_file_status( file_id=file_metadata.id, actor=self.actor, total_chunks=len(all_passages), chunks_embedded=0 ) diff --git a/tests/test_agent_serialization_v2.py b/tests/test_agent_serialization_v2.py index 2a548b68..5d0b8570 100644 --- a/tests/test_agent_serialization_v2.py +++ b/tests/test_agent_serialization_v2.py @@ -5,6 +5,7 @@ import pytest from letta.config import LettaConfig from letta.errors import AgentFileExportError, AgentFileImportError +from letta.helpers.pinecone_utils import should_use_pinecone from letta.orm import Base from letta.schemas.agent import CreateAgent from letta.schemas.agent_file import ( @@ -1056,7 +1057,11 @@ class TestAgentFileImportWithProcessing: imported_file_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id.startswith("file-")) imported_file = await server.file_manager.get_file_by_id(imported_file_id, other_user) - assert imported_file.processing_status.value == "completed" + # When using Pinecone, status stays at embedding until chunks are confirmed uploaded + if should_use_pinecone(): + assert imported_file.processing_status.value == "embedding" + else: + assert imported_file.processing_status.value == "completed" async def test_import_passage_creation(self, server, agent_serialization_manager, default_user, other_user): """Test that import creates passages for file content.""" @@ -1073,11 +1078,16 @@ class TestAgentFileImportWithProcessing: imported_file_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id.startswith("file-")) passages = await server.passage_manager.list_passages_by_file_id_async(imported_file_id, other_user) - assert len(passages) > 0 - for passage in passages: - assert passage.embedding is not None - assert len(passage.embedding) > 0 + if should_use_pinecone(): + # With Pinecone, passages are stored in Pinecone, not locally + assert len(passages) == 0 + else: + # Without Pinecone, passages are stored locally + assert len(passages) > 0 + for passage in passages: + assert passage.embedding is not None + assert len(passage.embedding) > 0 async def test_import_file_status_updates(self, server, agent_serialization_manager, default_user, other_user): """Test that file processing status is updated correctly during import.""" @@ -1093,9 +1103,15 @@ class TestAgentFileImportWithProcessing: imported_file_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id.startswith("file-")) imported_file = await server.file_manager.get_file_by_id(imported_file_id, other_user) - assert imported_file.processing_status.value == "completed" - assert imported_file.total_chunks is None - assert imported_file.chunks_embedded is None + # When using Pinecone, status stays at embedding until chunks are confirmed uploaded + if should_use_pinecone(): + assert imported_file.processing_status.value == "embedding" + assert imported_file.total_chunks == 1 # Pinecone tracks chunk counts + assert imported_file.chunks_embedded == 0 + else: + assert imported_file.processing_status.value == "completed" + assert imported_file.total_chunks is None + assert imported_file.chunks_embedded is None async def test_import_multiple_files_processing(self, server, agent_serialization_manager, default_user, other_user): """Test import processes multiple files efficiently.""" @@ -1114,7 +1130,11 @@ class TestAgentFileImportWithProcessing: for file_id in imported_file_ids: imported_file = await server.file_manager.get_file_by_id(file_id, other_user) - assert imported_file.processing_status.value == "completed" + # When using Pinecone, status stays at embedding until chunks are confirmed uploaded + if should_use_pinecone(): + assert imported_file.processing_status.value == "embedding" + else: + assert imported_file.processing_status.value == "completed" class TestAgentFileRoundTrip: diff --git a/tests/test_managers.py b/tests/test_managers.py index acc9d6d2..db6ac730 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -3660,7 +3660,12 @@ async def test_delete_tool_by_id(server: SyncServer, print_tool, default_user, e @pytest.mark.asyncio async def test_upsert_base_tools(server: SyncServer, default_user, event_loop): tools = await server.tool_manager.upsert_base_tools_async(actor=default_user) - expected_tool_names = sorted(LETTA_TOOL_SET) + + # Calculate expected tools accounting for production filtering + if os.getenv("LETTA_ENVIRONMENT") == "PRODUCTION": + expected_tool_names = sorted(LETTA_TOOL_SET - set(LOCAL_ONLY_MULTI_AGENT_TOOLS)) + else: + expected_tool_names = sorted(LETTA_TOOL_SET) assert sorted([t.name for t in tools]) == expected_tool_names @@ -3708,7 +3713,12 @@ async def test_upsert_base_tools(server: SyncServer, default_user, event_loop): async def test_upsert_filtered_base_tools(server: SyncServer, default_user, tool_type, expected_names): tools = await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={tool_type}) tool_names = sorted([t.name for t in tools]) - expected_sorted = sorted(expected_names) + + # Adjust expected names for multi-agent tools in production + if tool_type == ToolType.LETTA_MULTI_AGENT_CORE and os.getenv("LETTA_ENVIRONMENT") == "PRODUCTION": + expected_sorted = sorted(set(expected_names) - set(LOCAL_ONLY_MULTI_AGENT_TOOLS)) + else: + expected_sorted = sorted(expected_names) assert tool_names == expected_sorted assert all(t.tool_type == tool_type for t in tools) @@ -6229,7 +6239,15 @@ async def test_update_file_status_with_chunks(server, default_user, default_sour ) created = await server.file_manager.create_file(file_metadata=meta, actor=default_user) - # Update with chunk progress + # First transition: PENDING -> PARSING + updated = await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + processing_status=FileProcessingStatus.PARSING, + ) + assert updated.processing_status == FileProcessingStatus.PARSING + + # Next transition: PARSING -> EMBEDDING with chunk progress updated = await server.file_manager.update_file_status( file_id=created.id, actor=default_user, @@ -6252,6 +6270,428 @@ async def test_update_file_status_with_chunks(server, default_user, default_sour assert updated.processing_status == FileProcessingStatus.EMBEDDING # unchanged +@pytest.mark.asyncio +async def test_file_status_valid_transitions(server, default_user, default_source): + """Test valid state transitions follow the expected flow.""" + meta = PydanticFileMetadata( + file_name="valid_transitions.txt", + file_path="/tmp/valid_transitions.txt", + file_type="text/plain", + file_size=100, + source_id=default_source.id, + ) + created = await server.file_manager.create_file(file_metadata=meta, actor=default_user) + assert created.processing_status == FileProcessingStatus.PENDING + + # PENDING -> PARSING + updated = await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + processing_status=FileProcessingStatus.PARSING, + ) + assert updated.processing_status == FileProcessingStatus.PARSING + + # PARSING -> EMBEDDING + updated = await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + processing_status=FileProcessingStatus.EMBEDDING, + ) + assert updated.processing_status == FileProcessingStatus.EMBEDDING + + # EMBEDDING -> COMPLETED + updated = await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + processing_status=FileProcessingStatus.COMPLETED, + ) + assert updated.processing_status == FileProcessingStatus.COMPLETED + + +@pytest.mark.asyncio +async def test_file_status_invalid_transitions(server, default_user, default_source): + """Test that invalid state transitions are blocked.""" + # Test PENDING -> COMPLETED (skipping PARSING and EMBEDDING) + meta = PydanticFileMetadata( + file_name="invalid_pending_to_completed.txt", + file_path="/tmp/invalid1.txt", + file_type="text/plain", + file_size=100, + source_id=default_source.id, + ) + created = await server.file_manager.create_file(file_metadata=meta, actor=default_user) + + with pytest.raises(ValueError, match="Invalid state transition.*pending.*COMPLETED"): + await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + processing_status=FileProcessingStatus.COMPLETED, + ) + + # Test PARSING -> COMPLETED (skipping EMBEDDING) + meta2 = PydanticFileMetadata( + file_name="invalid_parsing_to_completed.txt", + file_path="/tmp/invalid2.txt", + file_type="text/plain", + file_size=100, + source_id=default_source.id, + ) + created2 = await server.file_manager.create_file(file_metadata=meta2, actor=default_user) + await server.file_manager.update_file_status( + file_id=created2.id, + actor=default_user, + processing_status=FileProcessingStatus.PARSING, + ) + + with pytest.raises(ValueError, match="Invalid state transition.*parsing.*COMPLETED"): + await server.file_manager.update_file_status( + file_id=created2.id, + actor=default_user, + processing_status=FileProcessingStatus.COMPLETED, + ) + + # Test PENDING -> EMBEDDING (skipping PARSING) + meta3 = PydanticFileMetadata( + file_name="invalid_pending_to_embedding.txt", + file_path="/tmp/invalid3.txt", + file_type="text/plain", + file_size=100, + source_id=default_source.id, + ) + created3 = await server.file_manager.create_file(file_metadata=meta3, actor=default_user) + + with pytest.raises(ValueError, match="Invalid state transition.*pending.*EMBEDDING"): + await server.file_manager.update_file_status( + file_id=created3.id, + actor=default_user, + processing_status=FileProcessingStatus.EMBEDDING, + ) + + +@pytest.mark.asyncio +async def test_file_status_terminal_states(server, default_user, default_source): + """Test that terminal states (COMPLETED and ERROR) cannot be updated.""" + # Test COMPLETED is terminal + meta = PydanticFileMetadata( + file_name="completed_terminal.txt", + file_path="/tmp/completed_terminal.txt", + file_type="text/plain", + file_size=100, + source_id=default_source.id, + ) + created = await server.file_manager.create_file(file_metadata=meta, actor=default_user) + + # Move through valid transitions to COMPLETED + await server.file_manager.update_file_status(file_id=created.id, actor=default_user, processing_status=FileProcessingStatus.PARSING) + await server.file_manager.update_file_status(file_id=created.id, actor=default_user, processing_status=FileProcessingStatus.EMBEDDING) + await server.file_manager.update_file_status(file_id=created.id, actor=default_user, processing_status=FileProcessingStatus.COMPLETED) + + # Cannot transition from COMPLETED to any state + with pytest.raises(ValueError, match="Cannot update.*terminal state completed"): + await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + processing_status=FileProcessingStatus.EMBEDDING, + ) + + with pytest.raises(ValueError, match="Cannot update.*terminal state completed"): + await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + processing_status=FileProcessingStatus.ERROR, + error_message="Should not work", + ) + + # Test ERROR is terminal + meta2 = PydanticFileMetadata( + file_name="error_terminal.txt", + file_path="/tmp/error_terminal.txt", + file_type="text/plain", + file_size=100, + source_id=default_source.id, + ) + created2 = await server.file_manager.create_file(file_metadata=meta2, actor=default_user) + + await server.file_manager.update_file_status( + file_id=created2.id, + actor=default_user, + processing_status=FileProcessingStatus.ERROR, + error_message="Test error", + ) + + # Cannot transition from ERROR to any state + with pytest.raises(ValueError, match="Cannot update.*terminal state error"): + await server.file_manager.update_file_status( + file_id=created2.id, + actor=default_user, + processing_status=FileProcessingStatus.PARSING, + ) + + +@pytest.mark.asyncio +async def test_file_status_error_transitions(server, default_user, default_source): + """Test that any non-terminal state can transition to ERROR.""" + # PENDING -> ERROR + meta1 = PydanticFileMetadata( + file_name="pending_to_error.txt", + file_path="/tmp/pending_error.txt", + file_type="text/plain", + file_size=100, + source_id=default_source.id, + ) + created1 = await server.file_manager.create_file(file_metadata=meta1, actor=default_user) + + updated1 = await server.file_manager.update_file_status( + file_id=created1.id, + actor=default_user, + processing_status=FileProcessingStatus.ERROR, + error_message="Failed at PENDING", + ) + assert updated1.processing_status == FileProcessingStatus.ERROR + assert updated1.error_message == "Failed at PENDING" + + # PARSING -> ERROR + meta2 = PydanticFileMetadata( + file_name="parsing_to_error.txt", + file_path="/tmp/parsing_error.txt", + file_type="text/plain", + file_size=100, + source_id=default_source.id, + ) + created2 = await server.file_manager.create_file(file_metadata=meta2, actor=default_user) + await server.file_manager.update_file_status( + file_id=created2.id, + actor=default_user, + processing_status=FileProcessingStatus.PARSING, + ) + + updated2 = await server.file_manager.update_file_status( + file_id=created2.id, + actor=default_user, + processing_status=FileProcessingStatus.ERROR, + error_message="Failed at PARSING", + ) + assert updated2.processing_status == FileProcessingStatus.ERROR + assert updated2.error_message == "Failed at PARSING" + + # EMBEDDING -> ERROR + meta3 = PydanticFileMetadata( + file_name="embedding_to_error.txt", + file_path="/tmp/embedding_error.txt", + file_type="text/plain", + file_size=100, + source_id=default_source.id, + ) + created3 = await server.file_manager.create_file(file_metadata=meta3, actor=default_user) + await server.file_manager.update_file_status(file_id=created3.id, actor=default_user, processing_status=FileProcessingStatus.PARSING) + await server.file_manager.update_file_status(file_id=created3.id, actor=default_user, processing_status=FileProcessingStatus.EMBEDDING) + + updated3 = await server.file_manager.update_file_status( + file_id=created3.id, + actor=default_user, + processing_status=FileProcessingStatus.ERROR, + error_message="Failed at EMBEDDING", + ) + assert updated3.processing_status == FileProcessingStatus.ERROR + assert updated3.error_message == "Failed at EMBEDDING" + + +@pytest.mark.asyncio +async def test_file_status_terminal_state_non_status_updates(server, default_user, default_source): + """Test that terminal states block ALL updates, not just status changes.""" + # Create file and move to COMPLETED + meta = PydanticFileMetadata( + file_name="terminal_blocks_all.txt", + file_path="/tmp/terminal_all.txt", + file_type="text/plain", + file_size=100, + source_id=default_source.id, + ) + created = await server.file_manager.create_file(file_metadata=meta, actor=default_user) + + await server.file_manager.update_file_status(file_id=created.id, actor=default_user, processing_status=FileProcessingStatus.PARSING) + await server.file_manager.update_file_status(file_id=created.id, actor=default_user, processing_status=FileProcessingStatus.EMBEDDING) + await server.file_manager.update_file_status(file_id=created.id, actor=default_user, processing_status=FileProcessingStatus.COMPLETED) + + # Cannot update chunks_embedded in COMPLETED state + with pytest.raises(ValueError, match="Cannot update.*terminal state completed"): + await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + chunks_embedded=50, + ) + + # Cannot update total_chunks in COMPLETED state + with pytest.raises(ValueError, match="Cannot update.*terminal state completed"): + await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + total_chunks=100, + ) + + # Cannot update error_message in COMPLETED state + with pytest.raises(ValueError, match="Cannot update.*terminal state completed"): + await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + error_message="This should fail", + ) + + # Test same for ERROR state + meta2 = PydanticFileMetadata( + file_name="error_blocks_all.txt", + file_path="/tmp/error_all.txt", + file_type="text/plain", + file_size=100, + source_id=default_source.id, + ) + created2 = await server.file_manager.create_file(file_metadata=meta2, actor=default_user) + await server.file_manager.update_file_status( + file_id=created2.id, + actor=default_user, + processing_status=FileProcessingStatus.ERROR, + error_message="Initial error", + ) + + # Cannot update chunks_embedded in ERROR state + with pytest.raises(ValueError, match="Cannot update.*terminal state error"): + await server.file_manager.update_file_status( + file_id=created2.id, + actor=default_user, + chunks_embedded=25, + ) + + +@pytest.mark.asyncio +async def test_file_status_race_condition_prevention(server, default_user, default_source): + """Test that race conditions are prevented when multiple updates happen.""" + meta = PydanticFileMetadata( + file_name="race_condition_test.txt", + file_path="/tmp/race_test.txt", + file_type="text/plain", + file_size=100, + source_id=default_source.id, + ) + created = await server.file_manager.create_file(file_metadata=meta, actor=default_user) + + # Move to PARSING + await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + processing_status=FileProcessingStatus.PARSING, + ) + + # Simulate race condition: Try to update from PENDING again (stale read) + # This should fail because the file is already in PARSING + with pytest.raises(ValueError, match="Invalid state transition.*parsing.*PARSING"): + await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + processing_status=FileProcessingStatus.PARSING, + ) + + # Move to ERROR + await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + processing_status=FileProcessingStatus.ERROR, + error_message="Simulated error", + ) + + # Try to continue with EMBEDDING as if error didn't happen (race condition) + # This should fail because file is in ERROR state + with pytest.raises(ValueError, match="Cannot update.*terminal state error"): + await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + processing_status=FileProcessingStatus.EMBEDDING, + ) + + +@pytest.mark.asyncio +async def test_file_status_backwards_transitions(server, default_user, default_source): + """Test that backwards transitions are not allowed.""" + meta = PydanticFileMetadata( + file_name="backwards_transitions.txt", + file_path="/tmp/backwards.txt", + file_type="text/plain", + file_size=100, + source_id=default_source.id, + ) + created = await server.file_manager.create_file(file_metadata=meta, actor=default_user) + + # Move to EMBEDDING + await server.file_manager.update_file_status(file_id=created.id, actor=default_user, processing_status=FileProcessingStatus.PARSING) + await server.file_manager.update_file_status(file_id=created.id, actor=default_user, processing_status=FileProcessingStatus.EMBEDDING) + + # Cannot go back to PARSING + with pytest.raises(ValueError, match="Invalid state transition.*embedding.*PARSING"): + await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + processing_status=FileProcessingStatus.PARSING, + ) + + # Cannot go back to PENDING + with pytest.raises(ValueError, match="Cannot transition to PENDING state.*PENDING is only valid as initial state"): + await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + processing_status=FileProcessingStatus.PENDING, + ) + + +@pytest.mark.asyncio +async def test_file_status_update_with_chunks_progress(server, default_user, default_source): + """Test updating chunk progress during EMBEDDING state.""" + meta = PydanticFileMetadata( + file_name="chunk_progress.txt", + file_path="/tmp/chunks.txt", + file_type="text/plain", + file_size=1000, + source_id=default_source.id, + ) + created = await server.file_manager.create_file(file_metadata=meta, actor=default_user) + + # Move to EMBEDDING with initial chunk info + await server.file_manager.update_file_status(file_id=created.id, actor=default_user, processing_status=FileProcessingStatus.PARSING) + updated = await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + processing_status=FileProcessingStatus.EMBEDDING, + total_chunks=100, + chunks_embedded=0, + ) + assert updated.total_chunks == 100 + assert updated.chunks_embedded == 0 + + # Update chunk progress without changing status + updated = await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + chunks_embedded=50, + ) + assert updated.chunks_embedded == 50 + assert updated.processing_status == FileProcessingStatus.EMBEDDING + + # Update to completion + updated = await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + chunks_embedded=100, + ) + assert updated.chunks_embedded == 100 + + # Move to COMPLETED + updated = await server.file_manager.update_file_status( + file_id=created.id, + actor=default_user, + processing_status=FileProcessingStatus.COMPLETED, + ) + assert updated.processing_status == FileProcessingStatus.COMPLETED + assert updated.chunks_embedded == 100 # preserved + + @pytest.mark.asyncio async def test_upsert_file_content_basic(server: SyncServer, default_user, default_source, async_session): """Test creating and updating file content with upsert_file_content()."""