feat: support markitdown instead of mistral (#3451)
Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
@@ -30,12 +30,8 @@ from letta.server.server import SyncServer
|
||||
from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
|
||||
from letta.services.file_processor.embedder.pinecone_embedder import PineconeEmbedder
|
||||
from letta.services.file_processor.file_processor import FileProcessor
|
||||
from letta.services.file_processor.file_types import (
|
||||
get_allowed_media_types,
|
||||
get_extension_to_mime_type_map,
|
||||
is_simple_text_mime_type,
|
||||
register_mime_types,
|
||||
)
|
||||
from letta.services.file_processor.file_types import get_allowed_media_types, get_extension_to_mime_type_map, register_mime_types
|
||||
from letta.services.file_processor.parser.markitdown_parser import MarkitdownFileParser
|
||||
from letta.services.file_processor.parser.mistral_parser import MistralFileParser
|
||||
from letta.settings import settings
|
||||
from letta.utils import safe_create_task, sanitize_filename
|
||||
@@ -220,17 +216,7 @@ async def upload_file_to_source(
|
||||
"""
|
||||
# NEW: Cloud based file processing
|
||||
# Determine file's MIME type
|
||||
file_mime_type = mimetypes.guess_type(file.filename)[0] or "application/octet-stream"
|
||||
|
||||
# Check if it's a simple text file
|
||||
is_simple_file = is_simple_text_mime_type(file_mime_type)
|
||||
|
||||
# For complex files, require Mistral API key
|
||||
if not is_simple_file and not settings.mistral_api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Mistral API key is required to process this file type {file_mime_type}. Please configure your Mistral API key to upload complex file formats.",
|
||||
)
|
||||
mimetypes.guess_type(file.filename)[0] or "application/octet-stream"
|
||||
|
||||
allowed_media_types = get_allowed_media_types()
|
||||
|
||||
@@ -483,13 +469,18 @@ async def load_file_to_source_cloud(
|
||||
embedding_config: EmbeddingConfig,
|
||||
file_metadata: FileMetadata,
|
||||
):
|
||||
file_processor = MistralFileParser()
|
||||
# Choose parser based on mistral API key availability
|
||||
if settings.mistral_api_key:
|
||||
file_parser = MistralFileParser()
|
||||
else:
|
||||
file_parser = MarkitdownFileParser()
|
||||
|
||||
using_pinecone = should_use_pinecone()
|
||||
if using_pinecone:
|
||||
embedder = PineconeEmbedder()
|
||||
else:
|
||||
embedder = OpenAIEmbedder(embedding_config=embedding_config)
|
||||
file_processor = FileProcessor(file_parser=file_processor, embedder=embedder, actor=actor, using_pinecone=using_pinecone)
|
||||
file_processor = FileProcessor(file_parser=file_parser, embedder=embedder, actor=actor, using_pinecone=using_pinecone)
|
||||
await file_processor.process(
|
||||
server=server, agent_states=agent_states, source_id=source_id, content=content, file_metadata=file_metadata
|
||||
)
|
||||
|
||||
@@ -27,7 +27,7 @@ from letta.services.block_manager import BlockManager
|
||||
from letta.services.file_manager import FileManager
|
||||
from letta.services.file_processor.embedder.base_embedder import BaseEmbedder
|
||||
from letta.services.file_processor.file_processor import FileProcessor
|
||||
from letta.services.file_processor.parser.mistral_parser import MistralFileParser
|
||||
from letta.services.file_processor.parser.base_parser import FileParser
|
||||
from letta.services.files_agents_manager import FileAgentManager
|
||||
from letta.services.group_manager import GroupManager
|
||||
from letta.services.mcp_manager import MCPManager
|
||||
@@ -62,7 +62,7 @@ class AgentFileManager:
|
||||
file_agent_manager: FileAgentManager,
|
||||
message_manager: MessageManager,
|
||||
embedder: BaseEmbedder,
|
||||
file_parser: MistralFileParser,
|
||||
file_parser: FileParser,
|
||||
using_pinecone: bool = False,
|
||||
):
|
||||
self.agent_manager = agent_manager
|
||||
|
||||
@@ -15,7 +15,7 @@ from letta.services.file_manager import FileManager
|
||||
from letta.services.file_processor.chunker.line_chunker import LineChunker
|
||||
from letta.services.file_processor.chunker.llama_index_chunker import LlamaIndexChunker
|
||||
from letta.services.file_processor.embedder.base_embedder import BaseEmbedder
|
||||
from letta.services.file_processor.parser.mistral_parser import MistralFileParser
|
||||
from letta.services.file_processor.parser.base_parser import FileParser
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
@@ -28,7 +28,7 @@ class FileProcessor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_parser: MistralFileParser,
|
||||
file_parser: FileParser,
|
||||
embedder: BaseEmbedder,
|
||||
actor: User,
|
||||
using_pinecone: bool,
|
||||
|
||||
95
letta/services/file_processor/parser/markitdown_parser.py
Normal file
95
letta/services/file_processor/parser/markitdown_parser.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from markitdown import MarkItDown
|
||||
from mistralai import OCRPageObject, OCRResponse, OCRUsageInfo
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.services.file_processor.file_types import is_simple_text_mime_type
|
||||
from letta.services.file_processor.parser.base_parser import FileParser
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Suppress pdfminer warnings that occur during PDF processing
|
||||
logging.getLogger("pdfminer.pdffont").setLevel(logging.ERROR)
|
||||
logging.getLogger("pdfminer.pdfinterp").setLevel(logging.ERROR)
|
||||
logging.getLogger("pdfminer.pdfpage").setLevel(logging.ERROR)
|
||||
logging.getLogger("pdfminer.converter").setLevel(logging.ERROR)
|
||||
|
||||
|
||||
class MarkitdownFileParser(FileParser):
|
||||
"""Markitdown-based file parsing for documents"""
|
||||
|
||||
def __init__(self, model: str = "markitdown"):
|
||||
self.model = model
|
||||
|
||||
@trace_method
|
||||
async def extract_text(self, content: bytes, mime_type: str) -> OCRResponse:
|
||||
"""Extract text using markitdown."""
|
||||
try:
|
||||
# Handle simple text files directly
|
||||
if is_simple_text_mime_type(mime_type):
|
||||
logger.info(f"Extracting text directly (no processing needed): {self.model}")
|
||||
text = content.decode("utf-8", errors="replace")
|
||||
return OCRResponse(
|
||||
model=self.model,
|
||||
pages=[
|
||||
OCRPageObject(
|
||||
index=0,
|
||||
markdown=text,
|
||||
images=[],
|
||||
dimensions=None,
|
||||
)
|
||||
],
|
||||
usage_info=OCRUsageInfo(pages_processed=1),
|
||||
document_annotation=None,
|
||||
)
|
||||
|
||||
logger.info(f"Extracting text using markitdown: {self.model}")
|
||||
|
||||
# Create temporary file to pass to markitdown
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=self._get_file_extension(mime_type)) as temp_file:
|
||||
temp_file.write(content)
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
md = MarkItDown(enable_plugins=False)
|
||||
result = md.convert(temp_file_path)
|
||||
|
||||
return OCRResponse(
|
||||
model=self.model,
|
||||
pages=[
|
||||
OCRPageObject(
|
||||
index=0,
|
||||
markdown=result.text_content,
|
||||
images=[],
|
||||
dimensions=None,
|
||||
)
|
||||
],
|
||||
usage_info=OCRUsageInfo(pages_processed=1),
|
||||
document_annotation=None,
|
||||
)
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Markitdown text extraction failed: {str(e)}")
|
||||
raise
|
||||
|
||||
def _get_file_extension(self, mime_type: str) -> str:
|
||||
"""Get file extension based on MIME type for markitdown processing."""
|
||||
mime_to_ext = {
|
||||
"application/pdf": ".pdf",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||
"text/csv": ".csv",
|
||||
"application/json": ".json",
|
||||
"text/xml": ".xml",
|
||||
"application/xml": ".xml",
|
||||
}
|
||||
return mime_to_ext.get(mime_type, ".txt")
|
||||
5057
poetry.lock
generated
5057
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -28,7 +28,7 @@ pre-commit = {version = "^3.5.0", optional = true }
|
||||
pg8000 = {version = "^1.30.3", optional = true}
|
||||
docstring-parser = ">=0.16,<0.17"
|
||||
httpx = "^0.28.0"
|
||||
numpy = "^1.26.2"
|
||||
numpy = "^2.1.0"
|
||||
demjson3 = "^3.0.6"
|
||||
pyyaml = "^6.0.1"
|
||||
sqlalchemy-json = "^0.7.0"
|
||||
@@ -100,6 +100,7 @@ structlog = "^25.4.0"
|
||||
certifi = "^2025.6.15"
|
||||
aioboto3 = {version = "^14.3.0", optional = true}
|
||||
pinecone = {extras = ["asyncio"], version = "^7.3.0", optional = true}
|
||||
markitdown = {extras = ["docx", "pdf", "pptx"], version = "^0.1.2"}
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
||||
@@ -28,7 +28,9 @@ from letta.schemas.user import User
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.agent_file_manager import AgentFileManager
|
||||
from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
|
||||
from letta.services.file_processor.parser.markitdown_parser import MarkitdownFileParser
|
||||
from letta.services.file_processor.parser.mistral_parser import MistralFileParser
|
||||
from letta.settings import settings
|
||||
from tests.utils import create_tool_from_func
|
||||
|
||||
# ------------------------------
|
||||
@@ -169,7 +171,7 @@ def agent_file_manager(server, default_user):
|
||||
file_agent_manager=server.file_agent_manager,
|
||||
message_manager=server.message_manager,
|
||||
embedder=OpenAIEmbedder(),
|
||||
file_parser=MistralFileParser(),
|
||||
file_parser=MistralFileParser() if settings.mistral_api_key else MarkitdownFileParser(),
|
||||
using_pinecone=False,
|
||||
)
|
||||
yield manager
|
||||
|
||||
@@ -167,6 +167,7 @@ def test_auto_attach_detach_files_tools(disable_pinecone, client: LettaSDKClient
|
||||
assert_no_file_tools(agent)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_mistral_parser", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"file_path, expected_value, expected_label_regex",
|
||||
[
|
||||
@@ -190,60 +191,68 @@ def test_file_upload_creates_source_blocks_correctly(
|
||||
file_path: str,
|
||||
expected_value: str,
|
||||
expected_label_regex: str,
|
||||
use_mistral_parser: bool,
|
||||
):
|
||||
# skip pdf tests if mistral api key is missing
|
||||
if file_path.endswith(".pdf") and not settings.mistral_api_key:
|
||||
pytest.skip("mistral api key required for pdf processing")
|
||||
# Override mistral API key setting to force parser selection for testing
|
||||
original_mistral_key = settings.mistral_api_key
|
||||
try:
|
||||
if not use_mistral_parser:
|
||||
# Set to None to force markitdown parser selection
|
||||
settings.mistral_api_key = None
|
||||
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
assert len(client.sources.list()) == 1
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
assert len(client.sources.list()) == 1
|
||||
|
||||
# Attach
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
|
||||
# Attach
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
|
||||
|
||||
# Upload the file
|
||||
upload_file_and_wait(client, source.id, file_path)
|
||||
# Upload the file
|
||||
upload_file_and_wait(client, source.id, file_path)
|
||||
|
||||
# Get uploaded files
|
||||
files = client.sources.files.list(source_id=source.id, limit=1)
|
||||
assert len(files) == 1
|
||||
assert files[0].source_id == source.id
|
||||
# Get uploaded files
|
||||
files = client.sources.files.list(source_id=source.id, limit=1)
|
||||
assert len(files) == 1
|
||||
assert files[0].source_id == source.id
|
||||
|
||||
# Check that blocks were created
|
||||
agent_state = client.agents.retrieve(agent_id=agent_state.id)
|
||||
blocks = agent_state.memory.file_blocks
|
||||
assert len(blocks) == 1
|
||||
assert any(expected_value in b.value for b in blocks)
|
||||
assert any(b.value.startswith("[Viewing file start") for b in blocks)
|
||||
assert any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
|
||||
# Check that blocks were created
|
||||
agent_state = client.agents.retrieve(agent_id=agent_state.id)
|
||||
blocks = agent_state.memory.file_blocks
|
||||
assert len(blocks) == 1
|
||||
assert any(expected_value in b.value for b in blocks)
|
||||
assert any(b.value.startswith("[Viewing file start") for b in blocks)
|
||||
assert any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
|
||||
|
||||
# verify raw system message contains source information
|
||||
raw_system_message = get_raw_system_message(client, agent_state.id)
|
||||
assert "test_source" in raw_system_message
|
||||
assert "<directories>" in raw_system_message
|
||||
# verify file-specific details in raw system message
|
||||
file_name = files[0].file_name
|
||||
assert f'name="test_source/{file_name}"' in raw_system_message
|
||||
assert 'status="open"' in raw_system_message
|
||||
# verify raw system message contains source information
|
||||
raw_system_message = get_raw_system_message(client, agent_state.id)
|
||||
assert "test_source" in raw_system_message
|
||||
assert "<directories>" in raw_system_message
|
||||
# verify file-specific details in raw system message
|
||||
file_name = files[0].file_name
|
||||
assert f'name="test_source/{file_name}"' in raw_system_message
|
||||
assert 'status="open"' in raw_system_message
|
||||
|
||||
# Remove file from source
|
||||
client.sources.files.delete(source_id=source.id, file_id=files[0].id)
|
||||
# Remove file from source
|
||||
client.sources.files.delete(source_id=source.id, file_id=files[0].id)
|
||||
|
||||
# Confirm blocks were removed
|
||||
agent_state = client.agents.retrieve(agent_id=agent_state.id)
|
||||
blocks = agent_state.memory.file_blocks
|
||||
assert len(blocks) == 0
|
||||
assert not any(expected_value in b.value for b in blocks)
|
||||
assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
|
||||
# Confirm blocks were removed
|
||||
agent_state = client.agents.retrieve(agent_id=agent_state.id)
|
||||
blocks = agent_state.memory.file_blocks
|
||||
assert len(blocks) == 0
|
||||
assert not any(expected_value in b.value for b in blocks)
|
||||
assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
|
||||
|
||||
# verify raw system message no longer contains source information
|
||||
raw_system_message_after_removal = get_raw_system_message(client, agent_state.id)
|
||||
# this should be in, because we didn't delete the source
|
||||
assert "test_source" in raw_system_message_after_removal
|
||||
assert "<directories>" in raw_system_message_after_removal
|
||||
# verify file-specific details are also removed
|
||||
assert f'name="test_source/{file_name}"' not in raw_system_message_after_removal
|
||||
# verify raw system message no longer contains source information
|
||||
raw_system_message_after_removal = get_raw_system_message(client, agent_state.id)
|
||||
# this should be in, because we didn't delete the source
|
||||
assert "test_source" in raw_system_message_after_removal
|
||||
assert "<directories>" in raw_system_message_after_removal
|
||||
# verify file-specific details are also removed
|
||||
assert f'name="test_source/{file_name}"' not in raw_system_message_after_removal
|
||||
|
||||
finally:
|
||||
# Restore original mistral API key setting
|
||||
settings.mistral_api_key = original_mistral_key
|
||||
|
||||
|
||||
def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
|
||||
Reference in New Issue
Block a user