fix: Fix state transitions for file processing (#3541)
This commit is contained in:
@@ -202,7 +202,7 @@ class OpenAIProvider(Provider):
|
||||
if model_type not in ["text->embedding"]:
|
||||
continue
|
||||
else:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Skipping embedding models for %s by default, as we don't assume embeddings are supported."
|
||||
"Please open an issue on GitHub if support is required.",
|
||||
self.base_url,
|
||||
|
||||
@@ -431,9 +431,11 @@ async def get_file_metadata(
|
||||
else:
|
||||
file_status = FileProcessingStatus.COMPLETED
|
||||
try:
|
||||
print("GETTING PINECONE!!!")
|
||||
file_metadata = await server.file_manager.update_file_status(
|
||||
file_id=file_metadata.id, actor=actor, chunks_embedded=len(ids), processing_status=file_status
|
||||
)
|
||||
print(file_metadata)
|
||||
except ValueError as e:
|
||||
# state transition was blocked - this is a race condition
|
||||
# log it but don't fail the request since we're just reading metadata
|
||||
|
||||
@@ -205,11 +205,17 @@ class FileManager:
|
||||
if processing_status is not None:
|
||||
# enforce specific transitions based on target status
|
||||
if processing_status == FileProcessingStatus.PARSING:
|
||||
where_conditions.append(FileMetadataModel.processing_status == FileProcessingStatus.PENDING)
|
||||
where_conditions.append(
|
||||
FileMetadataModel.processing_status.in_([FileProcessingStatus.PENDING, FileProcessingStatus.PARSING])
|
||||
)
|
||||
elif processing_status == FileProcessingStatus.EMBEDDING:
|
||||
where_conditions.append(FileMetadataModel.processing_status == FileProcessingStatus.PARSING)
|
||||
where_conditions.append(
|
||||
FileMetadataModel.processing_status.in_([FileProcessingStatus.PARSING, FileProcessingStatus.EMBEDDING])
|
||||
)
|
||||
elif processing_status == FileProcessingStatus.COMPLETED:
|
||||
where_conditions.append(FileMetadataModel.processing_status == FileProcessingStatus.EMBEDDING)
|
||||
where_conditions.append(
|
||||
FileMetadataModel.processing_status.in_([FileProcessingStatus.EMBEDDING, FileProcessingStatus.COMPLETED])
|
||||
)
|
||||
# ERROR can be set from any non-terminal state (already handled by terminal check above)
|
||||
|
||||
# fast in-place update with state validation
|
||||
|
||||
@@ -177,9 +177,7 @@ class FileProcessor:
|
||||
"file_processor.ocr_completed",
|
||||
{"filename": filename, "pages_extracted": len(ocr_response.pages), "text_length": len(raw_markdown_text)},
|
||||
)
|
||||
file_metadata = await self.file_manager.update_file_status(
|
||||
file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.EMBEDDING
|
||||
)
|
||||
|
||||
file_metadata = await self.file_manager.upsert_file_content(file_id=file_metadata.id, text=raw_markdown_text, actor=self.actor)
|
||||
|
||||
await self.agent_manager.insert_file_into_context_windows(
|
||||
@@ -207,6 +205,11 @@ class FileProcessor:
|
||||
)
|
||||
|
||||
# Chunk and embed with fallback logic
|
||||
if not self.using_pinecone:
|
||||
file_metadata = await self.file_manager.update_file_status(
|
||||
file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.EMBEDDING
|
||||
)
|
||||
|
||||
all_passages = await self._chunk_and_embed_with_fallback(
|
||||
file_metadata=file_metadata,
|
||||
ocr_response=ocr_response,
|
||||
@@ -243,12 +246,16 @@ class FileProcessor:
|
||||
processing_status=FileProcessingStatus.COMPLETED,
|
||||
)
|
||||
else:
|
||||
await self.file_manager.update_file_status(
|
||||
print("UPDATING HERE!!!!")
|
||||
|
||||
file_metadata = await self.file_manager.update_file_status(
|
||||
file_id=file_metadata.id,
|
||||
actor=self.actor,
|
||||
total_chunks=len(all_passages),
|
||||
chunks_embedded=0,
|
||||
processing_status=FileProcessingStatus.EMBEDDING,
|
||||
)
|
||||
print(file_metadata)
|
||||
|
||||
return all_passages
|
||||
|
||||
|
||||
@@ -217,3 +217,77 @@ class TestOpenAIEmbedder:
|
||||
assert passages[2].embedding[:2] == [0.3, 0.3]
|
||||
assert passages[3].text == "chunk 4"
|
||||
assert passages[3].embedding[:2] == [0.4, 0.4]
|
||||
|
||||
|
||||
class TestFileProcessorWithPinecone:
|
||||
"""Test suite for file processor with Pinecone integration"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_processor_sets_chunks_embedded_zero_with_pinecone(self):
|
||||
"""Test that file processor sets total_chunks and chunks_embedded=0 when using Pinecone"""
|
||||
from letta.schemas.enums import FileProcessingStatus
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.services.file_processor.embedder.pinecone_embedder import PineconeEmbedder
|
||||
from letta.services.file_processor.file_processor import FileProcessor
|
||||
from letta.services.file_processor.parser.markitdown_parser import MarkitdownFileParser
|
||||
|
||||
# Mock dependencies
|
||||
mock_actor = Mock()
|
||||
mock_actor.organization_id = "test_org"
|
||||
|
||||
# Create real parser
|
||||
file_parser = MarkitdownFileParser()
|
||||
|
||||
# Create file metadata with content
|
||||
mock_file = FileMetadata(
|
||||
file_name="test.txt",
|
||||
source_id="source-87654321",
|
||||
processing_status=FileProcessingStatus.PARSING,
|
||||
total_chunks=0,
|
||||
chunks_embedded=0,
|
||||
content="This is test content that will be chunked.",
|
||||
)
|
||||
|
||||
# Mock only the Pinecone-specific functionality
|
||||
with patch("letta.services.file_processor.embedder.pinecone_embedder.PINECONE_AVAILABLE", True):
|
||||
with patch("letta.services.file_processor.embedder.pinecone_embedder.upsert_file_records_to_pinecone_index") as mock_upsert:
|
||||
# Mock successful Pinecone upsert
|
||||
mock_upsert.return_value = None
|
||||
|
||||
# Create real Pinecone embedder
|
||||
embedder = PineconeEmbedder()
|
||||
|
||||
# Create file processor with Pinecone enabled
|
||||
file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=mock_actor, using_pinecone=True)
|
||||
|
||||
# Track file manager update calls
|
||||
update_calls = []
|
||||
|
||||
async def track_update(*args, **kwargs):
|
||||
update_calls.append(kwargs)
|
||||
return mock_file
|
||||
|
||||
# Mock managers to track calls
|
||||
with patch.object(file_processor.file_manager, "update_file_status", new=track_update):
|
||||
with patch.object(file_processor.passage_manager, "create_many_source_passages_async", new=AsyncMock()):
|
||||
# Process the imported file (which has content)
|
||||
await file_processor.process_imported_file(mock_file, mock_file.source_id)
|
||||
|
||||
# Find the call that sets total_chunks and chunks_embedded
|
||||
chunk_update_call = None
|
||||
for call in update_calls:
|
||||
if "total_chunks" in call and "chunks_embedded" in call:
|
||||
chunk_update_call = call
|
||||
break
|
||||
|
||||
# Verify the correct values were set
|
||||
assert chunk_update_call is not None, "No update_file_status call found with total_chunks and chunks_embedded"
|
||||
assert chunk_update_call["total_chunks"] > 0, "total_chunks should be greater than 0"
|
||||
assert chunk_update_call["chunks_embedded"] == 0, "chunks_embedded should be 0 when using Pinecone"
|
||||
|
||||
# Verify Pinecone upsert was called
|
||||
mock_upsert.assert_called_once()
|
||||
call_args = mock_upsert.call_args
|
||||
assert call_args.kwargs["file_id"] == mock_file.id
|
||||
assert call_args.kwargs["source_id"] == mock_file.source_id
|
||||
assert len(call_args.kwargs["chunks"]) > 0
|
||||
|
||||
@@ -6574,14 +6574,14 @@ async def test_file_status_race_condition_prevention(server, default_user, defau
|
||||
processing_status=FileProcessingStatus.PARSING,
|
||||
)
|
||||
|
||||
# Simulate race condition: Try to update from PENDING again (stale read)
|
||||
# This should fail because the file is already in PARSING
|
||||
with pytest.raises(ValueError, match="Invalid state transition.*parsing.*PARSING"):
|
||||
await server.file_manager.update_file_status(
|
||||
file_id=created.id,
|
||||
actor=default_user,
|
||||
processing_status=FileProcessingStatus.PARSING,
|
||||
)
|
||||
# Simulate race condition: Try to update from PARSING to PARSING again (stale read)
|
||||
# This should now be allowed (same-state transition) to prevent race conditions
|
||||
updated_again = await server.file_manager.update_file_status(
|
||||
file_id=created.id,
|
||||
actor=default_user,
|
||||
processing_status=FileProcessingStatus.PARSING,
|
||||
)
|
||||
assert updated_again.processing_status == FileProcessingStatus.PARSING
|
||||
|
||||
# Move to ERROR
|
||||
await server.file_manager.update_file_status(
|
||||
@@ -6685,6 +6685,43 @@ async def test_file_status_update_with_chunks_progress(server, default_user, def
|
||||
assert updated.chunks_embedded == 100 # preserved
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_state_transitions_allowed(server, default_user, default_source):
|
||||
"""Test that same-state transitions are allowed to prevent race conditions."""
|
||||
# Create file
|
||||
created = await server.file_manager.create_file(
|
||||
FileMetadata(
|
||||
file_name="same_state_test.txt",
|
||||
source_id=default_source.id,
|
||||
processing_status=FileProcessingStatus.PENDING,
|
||||
),
|
||||
default_user,
|
||||
)
|
||||
|
||||
# Test PARSING -> PARSING
|
||||
await server.file_manager.update_file_status(file_id=created.id, actor=default_user, processing_status=FileProcessingStatus.PARSING)
|
||||
updated = await server.file_manager.update_file_status(
|
||||
file_id=created.id, actor=default_user, processing_status=FileProcessingStatus.PARSING
|
||||
)
|
||||
assert updated.processing_status == FileProcessingStatus.PARSING
|
||||
|
||||
# Test EMBEDDING -> EMBEDDING
|
||||
await server.file_manager.update_file_status(file_id=created.id, actor=default_user, processing_status=FileProcessingStatus.EMBEDDING)
|
||||
updated = await server.file_manager.update_file_status(
|
||||
file_id=created.id, actor=default_user, processing_status=FileProcessingStatus.EMBEDDING, chunks_embedded=5
|
||||
)
|
||||
assert updated.processing_status == FileProcessingStatus.EMBEDDING
|
||||
assert updated.chunks_embedded == 5
|
||||
|
||||
# Test COMPLETED -> COMPLETED
|
||||
await server.file_manager.update_file_status(file_id=created.id, actor=default_user, processing_status=FileProcessingStatus.COMPLETED)
|
||||
updated = await server.file_manager.update_file_status(
|
||||
file_id=created.id, actor=default_user, processing_status=FileProcessingStatus.COMPLETED, total_chunks=10
|
||||
)
|
||||
assert updated.processing_status == FileProcessingStatus.COMPLETED
|
||||
assert updated.total_chunks == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_file_content_basic(server: SyncServer, default_user, default_source, async_session):
|
||||
"""Test creating and updating file content with upsert_file_content()."""
|
||||
|
||||
Reference in New Issue
Block a user