* wait I forgot to comit locally * cp the entire core directory and then rm the .git subdir
294 lines
13 KiB
Python
294 lines
13 KiB
Python
from unittest.mock import AsyncMock, Mock, patch
|
|
|
|
import openai
|
|
import pytest
|
|
|
|
from letta.errors import ErrorCode, LLMBadRequestError
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
|
|
|
|
|
|
class TestOpenAIEmbedder:
|
|
"""Test suite for OpenAI embedder functionality"""
|
|
|
|
@pytest.fixture
|
|
def mock_user(self):
|
|
"""Create a mock user for testing"""
|
|
user = Mock()
|
|
user.organization_id = "test_org_id"
|
|
return user
|
|
|
|
@pytest.fixture
|
|
def embedding_config(self):
|
|
"""Create a test embedding config"""
|
|
return EmbeddingConfig(
|
|
embedding_model="text-embedding-3-small",
|
|
embedding_endpoint_type="openai",
|
|
embedding_endpoint="https://api.openai.com/v1",
|
|
embedding_dim=3, # small dimension for testing
|
|
embedding_chunk_size=300,
|
|
batch_size=2, # small batch size for testing
|
|
)
|
|
|
|
@pytest.fixture
|
|
def embedder(self, embedding_config):
|
|
"""Create OpenAI embedder with test config"""
|
|
with patch("letta.services.file_processor.embedder.openai_embedder.LLMClient.create") as mock_create:
|
|
mock_client = Mock()
|
|
mock_client.handle_llm_error = Mock()
|
|
mock_create.return_value = mock_client
|
|
|
|
embedder = OpenAIEmbedder(embedding_config)
|
|
embedder.client = mock_client
|
|
return embedder
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_successful_embedding_generation(self, embedder, mock_user):
|
|
"""Test successful embedding generation for normal cases"""
|
|
# mock successful embedding response
|
|
mock_embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
|
embedder.client.request_embeddings = AsyncMock(return_value=mock_embeddings)
|
|
|
|
chunks = ["chunk 1", "chunk 2"]
|
|
file_id = "test_file"
|
|
source_id = "test_source"
|
|
|
|
passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
|
|
|
|
assert len(passages) == 2
|
|
assert passages[0].text == "chunk 1"
|
|
assert passages[1].text == "chunk 2"
|
|
# embeddings are padded to MAX_EMBEDDING_DIM, so check first 3 values
|
|
assert passages[0].embedding[:3] == [0.1, 0.2, 0.3]
|
|
assert passages[1].embedding[:3] == [0.4, 0.5, 0.6]
|
|
assert passages[0].file_id == file_id
|
|
assert passages[0].source_id == source_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_token_limit_retry_splits_batch(self, embedder, mock_user):
|
|
"""Test that token limit errors trigger batch splitting and retry"""
|
|
# create a mock token limit error
|
|
mock_error_body = {"error": {"code": "max_tokens_per_request", "message": "Requested 319270 tokens, max 300000 tokens per request"}}
|
|
token_limit_error = openai.BadRequestError(message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body)
|
|
|
|
# first call fails with token limit, subsequent calls succeed
|
|
call_count = 0
|
|
|
|
async def mock_request_embeddings(inputs, embedding_config):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1 and len(inputs) == 4: # first call with full batch
|
|
raise token_limit_error
|
|
elif len(inputs) == 2: # split batches succeed
|
|
return [[0.1, 0.2], [0.3, 0.4]] if call_count == 2 else [[0.5, 0.6], [0.7, 0.8]]
|
|
else:
|
|
return [[0.1, 0.2]] * len(inputs)
|
|
|
|
embedder.client.request_embeddings = AsyncMock(side_effect=mock_request_embeddings)
|
|
|
|
chunks = ["chunk 1", "chunk 2", "chunk 3", "chunk 4"]
|
|
file_id = "test_file"
|
|
source_id = "test_source"
|
|
|
|
passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
|
|
|
|
# should still get all 4 passages despite the retry
|
|
assert len(passages) == 4
|
|
assert all(len(p.embedding) == 4096 for p in passages) # padded to MAX_EMBEDDING_DIM
|
|
# verify multiple calls were made (original + retries)
|
|
assert call_count >= 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_token_limit_error_detection(self, embedder):
|
|
"""Test various token limit error detection patterns"""
|
|
# test openai BadRequestError with proper structure
|
|
mock_error_body = {"error": {"code": "max_tokens_per_request", "message": "Requested 319270 tokens, max 300000 tokens per request"}}
|
|
openai_error = openai.BadRequestError(message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body)
|
|
assert embedder._is_token_limit_error(openai_error) is True
|
|
|
|
# test error with message but no code
|
|
mock_error_body_no_code = {"error": {"message": "max_tokens_per_request exceeded"}}
|
|
openai_error_no_code = openai.BadRequestError(
|
|
message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body_no_code
|
|
)
|
|
assert embedder._is_token_limit_error(openai_error_no_code) is True
|
|
|
|
# test fallback string detection
|
|
generic_error = Exception("Requested 100000 tokens, max 50000 tokens per request")
|
|
assert embedder._is_token_limit_error(generic_error) is True
|
|
|
|
# test non-token errors
|
|
other_error = Exception("Some other error")
|
|
assert embedder._is_token_limit_error(other_error) is False
|
|
|
|
auth_error = openai.AuthenticationError(
|
|
message="Invalid API key", response=Mock(status_code=401), body={"error": {"code": "invalid_api_key"}}
|
|
)
|
|
assert embedder._is_token_limit_error(auth_error) is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_non_token_error_handling(self, embedder, mock_user):
|
|
"""Test that non-token errors are properly handled and re-raised"""
|
|
# create a non-token error
|
|
auth_error = openai.AuthenticationError(
|
|
message="Invalid API key", response=Mock(status_code=401), body={"error": {"code": "invalid_api_key"}}
|
|
)
|
|
|
|
# mock handle_llm_error to return a standardized error
|
|
handled_error = LLMBadRequestError(message="Handled error", code=ErrorCode.UNAUTHENTICATED)
|
|
embedder.client.handle_llm_error.return_value = handled_error
|
|
embedder.client.request_embeddings = AsyncMock(side_effect=auth_error)
|
|
|
|
chunks = ["chunk 1"]
|
|
file_id = "test_file"
|
|
source_id = "test_source"
|
|
|
|
with pytest.raises(LLMBadRequestError) as exc_info:
|
|
await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
|
|
|
|
assert exc_info.value == handled_error
|
|
embedder.client.handle_llm_error.assert_called_once_with(auth_error)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_single_item_batch_no_retry(self, embedder, mock_user):
|
|
"""Test that single-item batches don't retry on token limit errors"""
|
|
# create a token limit error
|
|
mock_error_body = {"error": {"code": "max_tokens_per_request", "message": "Requested 319270 tokens, max 300000 tokens per request"}}
|
|
token_limit_error = openai.BadRequestError(message="Token limit exceeded", response=Mock(status_code=400), body=mock_error_body)
|
|
|
|
handled_error = LLMBadRequestError(message="Handled token limit error", code=ErrorCode.INVALID_ARGUMENT)
|
|
embedder.client.handle_llm_error.return_value = handled_error
|
|
embedder.client.request_embeddings = AsyncMock(side_effect=token_limit_error)
|
|
|
|
chunks = ["very long chunk that exceeds token limit"]
|
|
file_id = "test_file"
|
|
source_id = "test_source"
|
|
|
|
with pytest.raises(LLMBadRequestError) as exc_info:
|
|
await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
|
|
|
|
assert exc_info.value == handled_error
|
|
embedder.client.handle_llm_error.assert_called_once_with(token_limit_error)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_chunks_handling(self, embedder, mock_user):
|
|
"""Test handling of empty chunks list"""
|
|
chunks = []
|
|
file_id = "test_file"
|
|
source_id = "test_source"
|
|
|
|
passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
|
|
|
|
assert passages == []
|
|
# should not call request_embeddings for empty input
|
|
embedder.client.request_embeddings.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_embedding_order_preservation(self, embedder, mock_user):
|
|
"""Test that embedding order is preserved even with retries"""
|
|
# set up embedder to split batches (batch_size=2)
|
|
embedder.embedding_config.batch_size = 2
|
|
|
|
# mock responses for each batch
|
|
async def mock_request_embeddings(inputs, embedding_config):
|
|
# return embeddings that correspond to input order
|
|
if inputs == ["chunk 1", "chunk 2"]:
|
|
return [[0.1, 0.1], [0.2, 0.2]]
|
|
elif inputs == ["chunk 3", "chunk 4"]:
|
|
return [[0.3, 0.3], [0.4, 0.4]]
|
|
else:
|
|
return [[0.1, 0.1]] * len(inputs)
|
|
|
|
embedder.client.request_embeddings = AsyncMock(side_effect=mock_request_embeddings)
|
|
|
|
chunks = ["chunk 1", "chunk 2", "chunk 3", "chunk 4"]
|
|
file_id = "test_file"
|
|
source_id = "test_source"
|
|
|
|
passages = await embedder.generate_embedded_passages(file_id, source_id, chunks, mock_user)
|
|
|
|
# verify order is preserved
|
|
assert len(passages) == 4
|
|
assert passages[0].text == "chunk 1"
|
|
assert passages[0].embedding[:2] == [0.1, 0.1] # check first 2 values before padding
|
|
assert passages[1].text == "chunk 2"
|
|
assert passages[1].embedding[:2] == [0.2, 0.2]
|
|
assert passages[2].text == "chunk 3"
|
|
assert passages[2].embedding[:2] == [0.3, 0.3]
|
|
assert passages[3].text == "chunk 4"
|
|
assert passages[3].embedding[:2] == [0.4, 0.4]
|
|
|
|
|
|
class TestFileProcessorWithPinecone:
|
|
"""Test suite for file processor with Pinecone integration"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_file_processor_sets_chunks_embedded_zero_with_pinecone(self):
|
|
"""Test that file processor sets total_chunks and chunks_embedded=0 when using Pinecone"""
|
|
from letta.schemas.enums import FileProcessingStatus
|
|
from letta.schemas.file import FileMetadata
|
|
from letta.services.file_processor.embedder.pinecone_embedder import PineconeEmbedder
|
|
from letta.services.file_processor.file_processor import FileProcessor
|
|
from letta.services.file_processor.parser.markitdown_parser import MarkitdownFileParser
|
|
|
|
# Mock dependencies
|
|
mock_actor = Mock()
|
|
mock_actor.organization_id = "test_org"
|
|
|
|
# Create real parser
|
|
file_parser = MarkitdownFileParser()
|
|
|
|
# Create file metadata with content
|
|
mock_file = FileMetadata(
|
|
file_name="test.txt",
|
|
source_id="source-87654321",
|
|
processing_status=FileProcessingStatus.PARSING,
|
|
total_chunks=0,
|
|
chunks_embedded=0,
|
|
content="This is test content that will be chunked.",
|
|
)
|
|
|
|
# Mock only the Pinecone-specific functionality
|
|
with patch("letta.services.file_processor.embedder.pinecone_embedder.PINECONE_AVAILABLE", True):
|
|
with patch("letta.services.file_processor.embedder.pinecone_embedder.upsert_file_records_to_pinecone_index") as mock_upsert:
|
|
# Mock successful Pinecone upsert
|
|
mock_upsert.return_value = None
|
|
|
|
# Create real Pinecone embedder
|
|
embedder = PineconeEmbedder()
|
|
|
|
# Create file processor with Pinecone enabled
|
|
file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=mock_actor)
|
|
|
|
# Track file manager update calls
|
|
update_calls = []
|
|
|
|
async def track_update(*args, **kwargs):
|
|
update_calls.append(kwargs)
|
|
return mock_file
|
|
|
|
# Mock managers to track calls
|
|
with patch.object(file_processor.file_manager, "update_file_status", new=track_update):
|
|
with patch.object(file_processor.passage_manager, "create_many_source_passages_async", new=AsyncMock()):
|
|
# Process the imported file (which has content)
|
|
await file_processor.process_imported_file(mock_file, mock_file.source_id)
|
|
|
|
# Find the call that sets total_chunks and chunks_embedded
|
|
chunk_update_call = None
|
|
for call in update_calls:
|
|
if "total_chunks" in call and "chunks_embedded" in call:
|
|
chunk_update_call = call
|
|
break
|
|
|
|
# Verify the correct values were set
|
|
assert chunk_update_call is not None, "No update_file_status call found with total_chunks and chunks_embedded"
|
|
assert chunk_update_call["total_chunks"] > 0, "total_chunks should be greater than 0"
|
|
assert chunk_update_call["chunks_embedded"] == 0, "chunks_embedded should be 0 when using Pinecone"
|
|
|
|
# Verify Pinecone upsert was called
|
|
mock_upsert.assert_called_once()
|
|
call_args = mock_upsert.call_args
|
|
assert call_args.kwargs["file_id"] == mock_file.id
|
|
assert call_args.kwargs["source_id"] == mock_file.source_id
|
|
assert len(call_args.kwargs["chunks"]) > 0
|