feat: Add mistral for cloud document parsing (#2562)

This commit is contained in:
Matthew Zhou
2025-05-30 21:06:28 -07:00
committed by GitHub
parent d2264e73fb
commit aaf06174f8
16 changed files with 419 additions and 36 deletions

View File

@@ -8,6 +8,7 @@ from starlette import status
import letta.constants as constants
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.file import FileMetadata
from letta.schemas.job import Job
from letta.schemas.passage import Passage
@@ -15,6 +16,11 @@ from letta.schemas.source import Source, SourceCreate, SourceUpdate
from letta.schemas.user import User
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
from letta.services.file_processor.chunker.llama_index_chunker import LlamaIndexChunker
from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
from letta.services.file_processor.file_processor import FileProcessor
from letta.services.file_processor.parser.mistral_parser import MistralFileParser
from letta.settings import model_settings, settings
from letta.utils import safe_create_task, sanitize_filename
logger = get_logger(__name__)
@@ -171,12 +177,15 @@ async def upload_file_to_source(
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
if source is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Source with id={source_id} not found.")
bytes = file.file.read()
content = await file.read()
# sanitize filename
file.filename = sanitize_filename(file.filename)
try:
text = bytes.decode("utf-8")
text = content.decode("utf-8")
except Exception:
text = ""
text = "[Currently parsing...]"
# create job
job = Job(
@@ -184,26 +193,28 @@ async def upload_file_to_source(
metadata={"type": "embedding", "filename": file.filename, "source_id": source_id},
completed_at=None,
)
job_id = job.id
await server.job_manager.create_job_async(job, actor=actor)
job = await server.job_manager.create_job_async(job, actor=actor)
# sanitize filename
sanitized_filename = sanitize_filename(file.filename)
# Add blocks (sometimes without content, for UX purposes)
agent_states = await server.insert_document_into_context_windows(source_id=source_id, text=text, filename=file.filename, actor=actor)
# Add blocks
await server.insert_document_into_context_windows(source_id=source_id, text=text, filename=sanitized_filename, actor=actor)
# create background tasks
safe_create_task(
load_file_to_source_async(server, source_id=source.id, filename=sanitized_filename, job_id=job.id, bytes=bytes, actor=actor),
logger=logger,
label="load_file_to_source_async",
)
# NEW: Cloud based file processing
if settings.mistral_api_key and model_settings.openai_api_key:
logger.info("Running experimental cloud based file processing...")
safe_create_task(
load_file_to_source_cloud(server, agent_states, content, file, job, source_id, actor),
logger=logger,
label="file_processor.process",
)
else:
# create background tasks
safe_create_task(
load_file_to_source_async(server, source_id=source.id, filename=file.filename, job_id=job.id, bytes=content, actor=actor),
logger=logger,
label="load_file_to_source_async",
)
safe_create_task(sleeptime_document_ingest_async(server, source_id, actor), logger=logger, label="sleeptime_document_ingest_async")
job = await server.job_manager.get_job_by_id_async(job_id=job_id, actor=actor)
if job is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Job with id={job_id} not found.")
return job
@@ -287,3 +298,13 @@ async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, ac
for agent in agents:
if agent.enable_sleeptime:
await server.sleeptime_document_ingest_async(agent, source, actor, clear_history)
async def load_file_to_source_cloud(
server: SyncServer, agent_states: List[AgentState], content: bytes, file: UploadFile, job: Job, source_id: str, actor: User
):
file_processor = MistralFileParser()
text_chunker = LlamaIndexChunker()
embedder = OpenAIEmbedder()
file_processor = FileProcessor(file_parser=file_processor, text_chunker=text_chunker, embedder=embedder, actor=actor)
await file_processor.process(server=server, agent_states=agent_states, source_id=source_id, content=content, file=file, job=job)

View File

@@ -1340,6 +1340,9 @@ class SyncServer(Server):
return job
async def load_file_to_source_via_mistral(self):
pass
async def sleeptime_document_ingest_async(
self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False
) -> None:
@@ -1410,22 +1413,26 @@ class SyncServer(Server):
except NoResultFound:
logger.info(f"Document block with label {filename} already removed, skipping...")
async def insert_document_into_context_windows(self, source_id: str, text: str, filename: str, actor: User) -> None:
async def insert_document_into_context_windows(
self, source_id: str, text: str, filename: str, actor: User, agent_states: Optional[List[AgentState]] = None
) -> List[AgentState]:
"""
Insert the uploaded document into the context window of all agents
attached to the given source.
"""
agent_states = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
agent_states = agent_states or await self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
# Return early
if not agent_states:
return
return []
logger.info(f"Inserting document into context window for source: {source_id}")
logger.info(f"Attached agents: {[a.id for a in agent_states]}")
await asyncio.gather(*(self._upsert_document_block(agent_state.id, text, filename, actor) for agent_state in agent_states))
return agent_states
async def insert_documents_into_context_window(
self, agent_state: AgentState, texts: List[str], filenames: List[str], actor: User
) -> None:

View File

@@ -0,0 +1,29 @@
from typing import List
from mistralai import OCRPageObject
from letta.log import get_logger
logger = get_logger(__name__)
class LlamaIndexChunker:
"""LlamaIndex-based text chunking"""
def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
from llama_index.core.node_parser import SentenceSplitter
self.parser = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
# TODO: Make this more general beyond Mistral
def chunk_text(self, page: OCRPageObject) -> List[str]:
"""Chunk text using LlamaIndex splitter"""
try:
return self.parser.split_text(page.markdown)
except Exception as e:
logger.error(f"Chunking failed: {str(e)}")
raise

View File

@@ -0,0 +1,84 @@
import asyncio
from typing import List, Optional, Tuple
import openai
from letta.log import get_logger
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.passage import Passage
from letta.schemas.user import User
from letta.settings import model_settings
logger = get_logger(__name__)
class OpenAIEmbedder:
"""OpenAI-based embedding generation"""
def __init__(self, embedding_config: Optional[EmbeddingConfig] = None):
self.embedding_config = embedding_config or EmbeddingConfig.default_config(provider="openai")
# TODO: Unify to global OpenAI client
self.client = openai.AsyncOpenAI(api_key=model_settings.openai_api_key)
self.max_batch = 1024
self.max_concurrent_requests = 20
async def _embed_batch(self, batch: List[str], batch_indices: List[int]) -> List[Tuple[int, List[float]]]:
"""Embed a single batch and return embeddings with their original indices"""
response = await self.client.embeddings.create(model=self.embedding_config.embedding_model, input=batch)
return [(idx, res.embedding) for idx, res in zip(batch_indices, response.data)]
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
"""Generate embeddings for chunks with batching and concurrent processing"""
if not chunks:
return []
logger.info(f"Generating embeddings for {len(chunks)} chunks using {self.embedding_config.embedding_model}")
# Create batches with their original indices
batches = []
batch_indices = []
for i in range(0, len(chunks), self.max_batch):
batch = chunks[i : i + self.max_batch]
indices = list(range(i, min(i + self.max_batch, len(chunks))))
batches.append(batch)
batch_indices.append(indices)
logger.info(f"Processing {len(batches)} batches")
async def process(batch: List[str], indices: List[int]):
try:
return await self._embed_batch(batch, indices)
except Exception as e:
logger.error(f"Failed to embed batch of size {len(batch)}: {str(e)}")
raise
# Execute all batches concurrently with semaphore control
tasks = [process(batch, indices) for batch, indices in zip(batches, batch_indices)]
results = await asyncio.gather(*tasks)
# Flatten results and sort by original index
indexed_embeddings = []
for batch_result in results:
indexed_embeddings.extend(batch_result)
# Sort by index to maintain original order
indexed_embeddings.sort(key=lambda x: x[0])
# Create Passage objects in original order
passages = []
for (idx, embedding), text in zip(indexed_embeddings, chunks):
passage = Passage(
text=text,
file_id=file_id,
source_id=source_id,
embedding=embedding,
embedding_config=self.embedding_config,
organization_id=actor.organization_id,
)
passages.append(passage)
logger.info(f"Successfully generated {len(passages)} embeddings")
return passages

View File

@@ -0,0 +1,123 @@
import mimetypes
from typing import List, Optional
from fastapi import UploadFile
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.enums import JobStatus
from letta.schemas.file import FileMetadata
from letta.schemas.job import Job, JobUpdate
from letta.schemas.passage import Passage
from letta.schemas.user import User
from letta.server.server import SyncServer
from letta.services.file_processor.chunker.llama_index_chunker import LlamaIndexChunker
from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
from letta.services.file_processor.parser.mistral_parser import MistralFileParser
from letta.services.job_manager import JobManager
from letta.services.passage_manager import PassageManager
from letta.services.source_manager import SourceManager
logger = get_logger(__name__)
class FileProcessor:
"""Main PDF processing orchestrator"""
def __init__(
self,
file_parser: MistralFileParser,
text_chunker: LlamaIndexChunker,
embedder: OpenAIEmbedder,
actor: User,
max_file_size: int = 50 * 1024 * 1024, # 50MB default
):
self.file_parser = file_parser
self.text_chunker = text_chunker
self.embedder = embedder
self.max_file_size = max_file_size
self.source_manager = SourceManager()
self.passage_manager = PassageManager()
self.job_manager = JobManager()
self.actor = actor
# TODO: Factor this function out of SyncServer
async def process(
self,
server: SyncServer,
agent_states: List[AgentState],
source_id: str,
content: bytes,
file: UploadFile,
job: Optional[Job] = None,
) -> List[Passage]:
file_metadata = self._extract_upload_file_metadata(file, source_id=source_id)
file_metadata = await self.source_manager.create_file(file_metadata, self.actor)
filename = file_metadata.file_name
try:
# Ensure we're working with bytes
if isinstance(content, str):
content = content.encode("utf-8")
if len(content) > self.max_file_size:
raise ValueError(f"PDF size exceeds maximum allowed size of {self.max_file_size} bytes")
logger.info(f"Starting OCR extraction for {filename}")
ocr_response = await self.file_parser.extract_text(content, mime_type=file_metadata.file_type)
if not ocr_response or len(ocr_response.pages) == 0:
raise ValueError("No text extracted from PDF")
logger.info("Chunking extracted text")
all_passages = []
for page in ocr_response.pages:
chunks = self.text_chunker.chunk_text(page)
if not chunks:
raise ValueError("No chunks created from text")
passages = await self.embedder.generate_embedded_passages(
file_id=file_metadata.id, source_id=source_id, chunks=chunks, actor=self.actor
)
all_passages.extend(passages)
all_passages = await self.passage_manager.create_many_passages_async(all_passages, self.actor)
logger.info(f"Successfully processed {filename}: {len(all_passages)} passages")
await server.insert_document_into_context_windows(
source_id=source_id,
text="".join([ocr_response.pages[i].markdown for i in range(min(3, len(ocr_response.pages)))]),
filename=file.filename,
actor=self.actor,
agent_states=agent_states,
)
# update job status
if job:
job.status = JobStatus.completed
job.metadata["num_passages"] = len(all_passages)
await self.job_manager.update_job_by_id_async(job_id=job.id, job_update=JobUpdate(**job.model_dump()), actor=self.actor)
return all_passages
except Exception as e:
logger.error(f"PDF processing failed for {filename}: {str(e)}")
# update job status
if job:
job.status = JobStatus.failed
job.metadata["error"] = str(e)
await self.job_manager.update_job_by_id_async(job_id=job.id, job_update=JobUpdate(**job.model_dump()), actor=self.actor)
return []
def _extract_upload_file_metadata(self, file: UploadFile, source_id: str) -> FileMetadata:
file_metadata = {
"file_name": file.filename,
"file_path": None,
"file_type": mimetypes.guess_type(file.filename)[0] or file.content_type or "unknown",
"file_size": file.size if file.size is not None else None,
}
return FileMetadata(**file_metadata, source_id=source_id)

View File

@@ -0,0 +1,9 @@
from abc import ABC, abstractmethod
class FileParser(ABC):
"""Abstract base class for file parser"""
@abstractmethod
async def extract_text(self, content: bytes, mime_type: str):
"""Extract text from PDF content"""

View File

@@ -0,0 +1,54 @@
import base64
from mistralai import Mistral, OCRPageObject, OCRResponse, OCRUsageInfo
from letta.log import get_logger
from letta.services.file_processor.parser.base_parser import FileParser
from letta.settings import settings
logger = get_logger(__name__)
class MistralFileParser(FileParser):
"""Mistral-based OCR extraction"""
def __init__(self, model: str = "mistral-ocr-latest"):
self.model = model
# TODO: Make this return something general if we add more file parsers
async def extract_text(self, content: bytes, mime_type: str) -> OCRResponse:
"""Extract text using Mistral OCR or shortcut for plain text."""
try:
logger.info(f"Extracting text using Mistral OCR model: {self.model}")
# TODO: Kind of hacky...we try to exit early here?
# TODO: Create our internal file parser representation we return instead of OCRResponse
if mime_type == "text/plain":
text = content.decode("utf-8", errors="replace")
return OCRResponse(
model=self.model,
pages=[
OCRPageObject(
index=0,
markdown=text,
images=[],
dimensions=None,
)
],
usage_info=OCRUsageInfo(pages_processed=1), # You might need to construct this properly
document_annotation=None,
)
base64_encoded_content = base64.b64encode(content).decode("utf-8")
document_url = f"data:{mime_type};base64,{base64_encoded_content}"
async with Mistral(api_key=settings.mistral_api_key) as mistral:
ocr_response = await mistral.ocr.process_async(
model="mistral-ocr-latest", document={"type": "document_url", "document_url": document_url}, include_image_base64=False
)
return ocr_response
except Exception as e:
logger.error(f"OCR extraction failed: {str(e)}")
raise

View File

View File

@@ -234,6 +234,9 @@ class Settings(BaseSettings):
poll_lock_retry_interval_seconds: int = 5 * 60
batch_job_polling_lookback_weeks: int = 2
# for OCR
mistral_api_key: Optional[str] = None
@property
def letta_pg_uri(self) -> str:
if self.pg_uri:

38
poetry.lock generated
View File

@@ -1432,6 +1432,20 @@ files = [
dnspython = ">=2.0.0"
idna = ">=2.0.0"
[[package]]
name = "eval-type-backport"
version = "0.2.2"
description = "Like `typing._eval_type`, but lets older Python versions use newer typing features."
optional = false
python-versions = ">=3.8"
files = [
{file = "eval_type_backport-0.2.2-py3-none-any.whl", hash = "sha256:cb6ad7c393517f476f96d456d0412ea80f0a8cf96f6892834cd9340149111b0a"},
{file = "eval_type_backport-0.2.2.tar.gz", hash = "sha256:f0576b4cf01ebb5bd358d02314d31846af5e07678387486e2c798af0e7d849c1"},
]
[package.extras]
tests = ["pytest"]
[[package]]
name = "exceptiongroup"
version = "1.2.2"
@@ -3633,6 +3647,28 @@ files = [
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
]
[[package]]
name = "mistralai"
version = "1.8.1"
description = "Python Client SDK for the Mistral AI API."
optional = false
python-versions = ">=3.9"
files = [
{file = "mistralai-1.8.1-py3-none-any.whl", hash = "sha256:badfc7e6832d894b3e9071d92ad621212b7cccd7df622c6cacdb525162ae338f"},
{file = "mistralai-1.8.1.tar.gz", hash = "sha256:b967ca443726b71ec45632cb33825ee2e55239a652e73c2bda11f7cc683bf6e5"},
]
[package.dependencies]
eval-type-backport = ">=0.2.0"
httpx = ">=0.28.1"
pydantic = ">=2.10.3"
python-dateutil = ">=2.8.2"
typing-inspection = ">=0.4.0"
[package.extras]
agents = ["authlib (>=1.5.2,<2.0)", "griffe (>=1.7.3,<2.0)", "mcp (>=1.0,<2.0)"]
gcp = ["google-auth (>=2.27.0)", "requests (>=2.32.3)"]
[[package]]
name = "msgpack"
version = "1.1.0"
@@ -7202,4 +7238,4 @@ tests = ["wikipedia"]
[metadata]
lock-version = "2.0"
python-versions = "<3.14,>=3.10"
content-hash = "186dbd44cfb8a4c2f041ab8261475154db27235c560de155a07e125337da3175"
content-hash = "3b363dada9fd643d7e2a6ba3f1041c365c22249926d9c0959132336f47718992"

View File

@@ -93,6 +93,7 @@ matplotlib = "^3.10.1"
asyncpg = {version = "^0.30.0", optional = true}
tavily-python = "^0.7.2"
async-lru = "^2.0.5"
mistralai = "^1.8.1"
[tool.poetry.extras]

View File

@@ -56,7 +56,26 @@ def agent_state(client: LettaSDKClient):
client.agents.delete(agent_id=agent_state.id)
def test_file_upload_creates_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState):
import re
import time
import pytest
@pytest.mark.parametrize(
"file_path, expected_value, expected_label_regex",
[
("tests/data/test.txt", "test", r"test_[a-z0-9]+\.txt"),
("tests/data/memgpt_paper.pdf", "MemGPT", r"memgpt_paper_[a-z0-9]+\.pdf"),
],
)
def test_file_upload_creates_source_blocks_correctly(
client: LettaSDKClient,
agent_state: AgentState,
file_path: str,
expected_value: str,
expected_label_regex: str,
):
# Clear existing sources
for source in client.sources.list():
client.sources.delete(source_id=source.id)
@@ -72,38 +91,35 @@ def test_file_upload_creates_source_blocks_correctly(client: LettaSDKClient, age
# Attach
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
# Load files into the source
file_path = "tests/data/test.txt"
# Upload the files
# Upload the file
with open(file_path, "rb") as f:
job = client.sources.files.upload(source_id=source.id, file=f)
# Wait for the jobs to complete
# Wait for the job to complete
while job.status != "completed":
time.sleep(1)
job = client.jobs.retrieve(job_id=job.id)
print("Waiting for jobs to complete...", job.status)
# Get the first file with pagination
# Get uploaded files
files = client.sources.files.list(source_id=source.id, limit=1)
assert len(files) == 1
assert files[0].source_id == source.id
# Get the agent state, check blocks exist
# Check that blocks were created
blocks = client.agents.blocks.list(agent_id=agent_state.id)
assert len(blocks) == 2
assert "test" in [b.value for b in blocks]
assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
assert any(expected_value in b.value for b in blocks)
assert any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
# Remove file from source
client.sources.files.delete(source_id=source.id, file_id=files[0].id)
# Get the agent state, check blocks do NOT exist
# Confirm blocks were removed
blocks = client.agents.blocks.list(agent_id=agent_state.id)
assert len(blocks) == 1
assert "test" not in [b.value for b in blocks]
assert not any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
assert not any(expected_value in b.value for b in blocks)
assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState):