diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index c137e615..c6f83853 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -8,6 +8,7 @@ from starlette import status import letta.constants as constants from letta.log import get_logger +from letta.schemas.agent import AgentState from letta.schemas.file import FileMetadata from letta.schemas.job import Job from letta.schemas.passage import Passage @@ -15,6 +16,11 @@ from letta.schemas.source import Source, SourceCreate, SourceUpdate from letta.schemas.user import User from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer +from letta.services.file_processor.chunker.llama_index_chunker import LlamaIndexChunker +from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder +from letta.services.file_processor.file_processor import FileProcessor +from letta.services.file_processor.parser.mistral_parser import MistralFileParser +from letta.settings import model_settings, settings from letta.utils import safe_create_task, sanitize_filename logger = get_logger(__name__) @@ -171,12 +177,15 @@ async def upload_file_to_source( source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor) if source is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Source with id={source_id} not found.") - bytes = file.file.read() + content = await file.read() + + # sanitize filename + file.filename = sanitize_filename(file.filename) try: - text = bytes.decode("utf-8") + text = content.decode("utf-8") except Exception: - text = "" + text = "[Currently parsing...]" # create job job = Job( @@ -184,26 +193,28 @@ async def upload_file_to_source( metadata={"type": "embedding", "filename": file.filename, "source_id": source_id}, completed_at=None, ) - job_id = job.id - await server.job_manager.create_job_async(job, actor=actor) + job = await server.job_manager.create_job_async(job, actor=actor) - # sanitize filename - sanitized_filename = sanitize_filename(file.filename) + # Add blocks (sometimes without content, for UX purposes) + agent_states = await server.insert_document_into_context_windows(source_id=source_id, text=text, filename=file.filename, actor=actor) - # Add blocks - await server.insert_document_into_context_windows(source_id=source_id, text=text, filename=sanitized_filename, actor=actor) - - # create background tasks - safe_create_task( - load_file_to_source_async(server, source_id=source.id, filename=sanitized_filename, job_id=job.id, bytes=bytes, actor=actor), - logger=logger, - label="load_file_to_source_async", - ) + # NEW: Cloud based file processing + if settings.mistral_api_key and model_settings.openai_api_key: + logger.info("Running experimental cloud based file processing...") + safe_create_task( + load_file_to_source_cloud(server, agent_states, content, file, job, source_id, actor), + logger=logger, + label="file_processor.process", + ) + else: + # create background tasks + safe_create_task( + load_file_to_source_async(server, source_id=source.id, filename=file.filename, job_id=job.id, bytes=content, actor=actor), + logger=logger, + label="load_file_to_source_async", + ) safe_create_task(sleeptime_document_ingest_async(server, source_id, actor), logger=logger, label="sleeptime_document_ingest_async") - job = await server.job_manager.get_job_by_id_async(job_id=job_id, actor=actor) - if job is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Job with id={job_id} not found.") return job @@ -287,3 +298,13 @@ async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, ac for agent in agents: if agent.enable_sleeptime: await server.sleeptime_document_ingest_async(agent, source, actor, clear_history) + + +async def load_file_to_source_cloud( + server: SyncServer, agent_states: List[AgentState], content: bytes, file: UploadFile, job: Job, source_id: str, actor: User +): + file_processor = MistralFileParser() + text_chunker = LlamaIndexChunker() + embedder = OpenAIEmbedder() + file_processor = FileProcessor(file_parser=file_processor, text_chunker=text_chunker, embedder=embedder, actor=actor) + await file_processor.process(server=server, agent_states=agent_states, source_id=source_id, content=content, file=file, job=job) diff --git a/letta/server/server.py b/letta/server/server.py index ce88e7e8..2f07d2ac 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1340,6 +1340,9 @@ class SyncServer(Server): return job + async def load_file_to_source_via_mistral(self): + pass + async def sleeptime_document_ingest_async( self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False ) -> None: @@ -1410,22 +1413,26 @@ class SyncServer(Server): except NoResultFound: logger.info(f"Document block with label {filename} already removed, skipping...") - async def insert_document_into_context_windows(self, source_id: str, text: str, filename: str, actor: User) -> None: + async def insert_document_into_context_windows( + self, source_id: str, text: str, filename: str, actor: User, agent_states: Optional[List[AgentState]] = None + ) -> List[AgentState]: """ Insert the uploaded document into the context window of all agents attached to the given source. """ - agent_states = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor) + agent_states = agent_states or await self.source_manager.list_attached_agents(source_id=source_id, actor=actor) # Return early if not agent_states: - return + return [] logger.info(f"Inserting document into context window for source: {source_id}") logger.info(f"Attached agents: {[a.id for a in agent_states]}") await asyncio.gather(*(self._upsert_document_block(agent_state.id, text, filename, actor) for agent_state in agent_states)) + return agent_states + async def insert_documents_into_context_window( self, agent_state: AgentState, texts: List[str], filenames: List[str], actor: User ) -> None: diff --git a/letta/services/file_processor/__init__.py b/letta/services/file_processor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/letta/services/file_processor/chunker/__init__.py b/letta/services/file_processor/chunker/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/letta/services/file_processor/chunker/llama_index_chunker.py b/letta/services/file_processor/chunker/llama_index_chunker.py new file mode 100644 index 00000000..94f45e0a --- /dev/null +++ b/letta/services/file_processor/chunker/llama_index_chunker.py @@ -0,0 +1,29 @@ +from typing import List + +from mistralai import OCRPageObject + +from letta.log import get_logger + +logger = get_logger(__name__) + + +class LlamaIndexChunker: + """LlamaIndex-based text chunking""" + + def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50): + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + from llama_index.core.node_parser import SentenceSplitter + + self.parser = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) + + # TODO: Make this more general beyond Mistral + def chunk_text(self, page: OCRPageObject) -> List[str]: + """Chunk text using LlamaIndex splitter""" + try: + return self.parser.split_text(page.markdown) + + except Exception as e: + logger.error(f"Chunking failed: {str(e)}") + raise diff --git a/letta/services/file_processor/embedder/__init__.py b/letta/services/file_processor/embedder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/letta/services/file_processor/embedder/openai_embedder.py b/letta/services/file_processor/embedder/openai_embedder.py new file mode 100644 index 00000000..5b1a2e72 --- /dev/null +++ b/letta/services/file_processor/embedder/openai_embedder.py @@ -0,0 +1,84 @@ +import asyncio +from typing import List, Optional, Tuple + +import openai + +from letta.log import get_logger +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.passage import Passage +from letta.schemas.user import User +from letta.settings import model_settings + +logger = get_logger(__name__) + + +class OpenAIEmbedder: + """OpenAI-based embedding generation""" + + def __init__(self, embedding_config: Optional[EmbeddingConfig] = None): + self.embedding_config = embedding_config or EmbeddingConfig.default_config(provider="openai") + + # TODO: Unify to global OpenAI client + self.client = openai.AsyncOpenAI(api_key=model_settings.openai_api_key) + self.max_batch = 1024 + self.max_concurrent_requests = 20 + + async def _embed_batch(self, batch: List[str], batch_indices: List[int]) -> List[Tuple[int, List[float]]]: + """Embed a single batch and return embeddings with their original indices""" + response = await self.client.embeddings.create(model=self.embedding_config.embedding_model, input=batch) + return [(idx, res.embedding) for idx, res in zip(batch_indices, response.data)] + + async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]: + """Generate embeddings for chunks with batching and concurrent processing""" + if not chunks: + return [] + + logger.info(f"Generating embeddings for {len(chunks)} chunks using {self.embedding_config.embedding_model}") + + # Create batches with their original indices + batches = [] + batch_indices = [] + + for i in range(0, len(chunks), self.max_batch): + batch = chunks[i : i + self.max_batch] + indices = list(range(i, min(i + self.max_batch, len(chunks)))) + batches.append(batch) + batch_indices.append(indices) + + logger.info(f"Processing {len(batches)} batches") + + async def process(batch: List[str], indices: List[int]): + try: + return await self._embed_batch(batch, indices) + except Exception as e: + logger.error(f"Failed to embed batch of size {len(batch)}: {str(e)}") + raise + + # Execute all batches concurrently with semaphore control + tasks = [process(batch, indices) for batch, indices in zip(batches, batch_indices)] + + results = await asyncio.gather(*tasks) + + # Flatten results and sort by original index + indexed_embeddings = [] + for batch_result in results: + indexed_embeddings.extend(batch_result) + + # Sort by index to maintain original order + indexed_embeddings.sort(key=lambda x: x[0]) + + # Create Passage objects in original order + passages = [] + for (idx, embedding), text in zip(indexed_embeddings, chunks): + passage = Passage( + text=text, + file_id=file_id, + source_id=source_id, + embedding=embedding, + embedding_config=self.embedding_config, + organization_id=actor.organization_id, + ) + passages.append(passage) + + logger.info(f"Successfully generated {len(passages)} embeddings") + return passages diff --git a/letta/services/file_processor/file_processor.py b/letta/services/file_processor/file_processor.py new file mode 100644 index 00000000..2ab0fad6 --- /dev/null +++ b/letta/services/file_processor/file_processor.py @@ -0,0 +1,123 @@ +import mimetypes +from typing import List, Optional + +from fastapi import UploadFile + +from letta.log import get_logger +from letta.schemas.agent import AgentState +from letta.schemas.enums import JobStatus +from letta.schemas.file import FileMetadata +from letta.schemas.job import Job, JobUpdate +from letta.schemas.passage import Passage +from letta.schemas.user import User +from letta.server.server import SyncServer +from letta.services.file_processor.chunker.llama_index_chunker import LlamaIndexChunker +from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder +from letta.services.file_processor.parser.mistral_parser import MistralFileParser +from letta.services.job_manager import JobManager +from letta.services.passage_manager import PassageManager +from letta.services.source_manager import SourceManager + +logger = get_logger(__name__) + + +class FileProcessor: + """Main PDF processing orchestrator""" + + def __init__( + self, + file_parser: MistralFileParser, + text_chunker: LlamaIndexChunker, + embedder: OpenAIEmbedder, + actor: User, + max_file_size: int = 50 * 1024 * 1024, # 50MB default + ): + self.file_parser = file_parser + self.text_chunker = text_chunker + self.embedder = embedder + self.max_file_size = max_file_size + self.source_manager = SourceManager() + self.passage_manager = PassageManager() + self.job_manager = JobManager() + self.actor = actor + + # TODO: Factor this function out of SyncServer + async def process( + self, + server: SyncServer, + agent_states: List[AgentState], + source_id: str, + content: bytes, + file: UploadFile, + job: Optional[Job] = None, + ) -> List[Passage]: + file_metadata = self._extract_upload_file_metadata(file, source_id=source_id) + file_metadata = await self.source_manager.create_file(file_metadata, self.actor) + filename = file_metadata.file_name + + try: + # Ensure we're working with bytes + if isinstance(content, str): + content = content.encode("utf-8") + + if len(content) > self.max_file_size: + raise ValueError(f"PDF size exceeds maximum allowed size of {self.max_file_size} bytes") + + logger.info(f"Starting OCR extraction for {filename}") + ocr_response = await self.file_parser.extract_text(content, mime_type=file_metadata.file_type) + + if not ocr_response or len(ocr_response.pages) == 0: + raise ValueError("No text extracted from PDF") + + logger.info("Chunking extracted text") + all_passages = [] + for page in ocr_response.pages: + chunks = self.text_chunker.chunk_text(page) + + if not chunks: + raise ValueError("No chunks created from text") + + passages = await self.embedder.generate_embedded_passages( + file_id=file_metadata.id, source_id=source_id, chunks=chunks, actor=self.actor + ) + all_passages.extend(passages) + + all_passages = await self.passage_manager.create_many_passages_async(all_passages, self.actor) + + logger.info(f"Successfully processed {filename}: {len(all_passages)} passages") + + await server.insert_document_into_context_windows( + source_id=source_id, + text="".join([ocr_response.pages[i].markdown for i in range(min(3, len(ocr_response.pages)))]), + filename=file.filename, + actor=self.actor, + agent_states=agent_states, + ) + + # update job status + if job: + job.status = JobStatus.completed + job.metadata["num_passages"] = len(all_passages) + await self.job_manager.update_job_by_id_async(job_id=job.id, job_update=JobUpdate(**job.model_dump()), actor=self.actor) + + return all_passages + + except Exception as e: + logger.error(f"PDF processing failed for {filename}: {str(e)}") + + # update job status + if job: + job.status = JobStatus.failed + job.metadata["error"] = str(e) + await self.job_manager.update_job_by_id_async(job_id=job.id, job_update=JobUpdate(**job.model_dump()), actor=self.actor) + + return [] + + def _extract_upload_file_metadata(self, file: UploadFile, source_id: str) -> FileMetadata: + file_metadata = { + "file_name": file.filename, + "file_path": None, + "file_type": mimetypes.guess_type(file.filename)[0] or file.content_type or "unknown", + "file_size": file.size if file.size is not None else None, + } + return FileMetadata(**file_metadata, source_id=source_id) diff --git a/letta/services/file_processor/parser/__init__.py b/letta/services/file_processor/parser/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/letta/services/file_processor/parser/base_parser.py b/letta/services/file_processor/parser/base_parser.py new file mode 100644 index 00000000..74fc386b --- /dev/null +++ b/letta/services/file_processor/parser/base_parser.py @@ -0,0 +1,9 @@ +from abc import ABC, abstractmethod + + +class FileParser(ABC): + """Abstract base class for file parser""" + + @abstractmethod + async def extract_text(self, content: bytes, mime_type: str): + """Extract text from PDF content""" diff --git a/letta/services/file_processor/parser/mistral_parser.py b/letta/services/file_processor/parser/mistral_parser.py new file mode 100644 index 00000000..59eb7084 --- /dev/null +++ b/letta/services/file_processor/parser/mistral_parser.py @@ -0,0 +1,54 @@ +import base64 + +from mistralai import Mistral, OCRPageObject, OCRResponse, OCRUsageInfo + +from letta.log import get_logger +from letta.services.file_processor.parser.base_parser import FileParser +from letta.settings import settings + +logger = get_logger(__name__) + + +class MistralFileParser(FileParser): + """Mistral-based OCR extraction""" + + def __init__(self, model: str = "mistral-ocr-latest"): + self.model = model + + # TODO: Make this return something general if we add more file parsers + async def extract_text(self, content: bytes, mime_type: str) -> OCRResponse: + """Extract text using Mistral OCR or shortcut for plain text.""" + try: + logger.info(f"Extracting text using Mistral OCR model: {self.model}") + + # TODO: Kind of hacky...we try to exit early here? + # TODO: Create our internal file parser representation we return instead of OCRResponse + if mime_type == "text/plain": + text = content.decode("utf-8", errors="replace") + return OCRResponse( + model=self.model, + pages=[ + OCRPageObject( + index=0, + markdown=text, + images=[], + dimensions=None, + ) + ], + usage_info=OCRUsageInfo(pages_processed=1), # You might need to construct this properly + document_annotation=None, + ) + + base64_encoded_content = base64.b64encode(content).decode("utf-8") + document_url = f"data:{mime_type};base64,{base64_encoded_content}" + + async with Mistral(api_key=settings.mistral_api_key) as mistral: + ocr_response = await mistral.ocr.process_async( + model="mistral-ocr-latest", document={"type": "document_url", "document_url": document_url}, include_image_base64=False + ) + + return ocr_response + + except Exception as e: + logger.error(f"OCR extraction failed: {str(e)}") + raise diff --git a/letta/services/file_processor/types.py b/letta/services/file_processor/types.py new file mode 100644 index 00000000..e69de29b diff --git a/letta/settings.py b/letta/settings.py index 74179920..e337ee70 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -234,6 +234,9 @@ class Settings(BaseSettings): poll_lock_retry_interval_seconds: int = 5 * 60 batch_job_polling_lookback_weeks: int = 2 + # for OCR + mistral_api_key: Optional[str] = None + @property def letta_pg_uri(self) -> str: if self.pg_uri: diff --git a/poetry.lock b/poetry.lock index a5f643de..8f2c8317 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1432,6 +1432,20 @@ files = [ dnspython = ">=2.0.0" idna = ">=2.0.0" +[[package]] +name = "eval-type-backport" +version = "0.2.2" +description = "Like `typing._eval_type`, but lets older Python versions use newer typing features." +optional = false +python-versions = ">=3.8" +files = [ + {file = "eval_type_backport-0.2.2-py3-none-any.whl", hash = "sha256:cb6ad7c393517f476f96d456d0412ea80f0a8cf96f6892834cd9340149111b0a"}, + {file = "eval_type_backport-0.2.2.tar.gz", hash = "sha256:f0576b4cf01ebb5bd358d02314d31846af5e07678387486e2c798af0e7d849c1"}, +] + +[package.extras] +tests = ["pytest"] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -3633,6 +3647,28 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "mistralai" +version = "1.8.1" +description = "Python Client SDK for the Mistral AI API." +optional = false +python-versions = ">=3.9" +files = [ + {file = "mistralai-1.8.1-py3-none-any.whl", hash = "sha256:badfc7e6832d894b3e9071d92ad621212b7cccd7df622c6cacdb525162ae338f"}, + {file = "mistralai-1.8.1.tar.gz", hash = "sha256:b967ca443726b71ec45632cb33825ee2e55239a652e73c2bda11f7cc683bf6e5"}, +] + +[package.dependencies] +eval-type-backport = ">=0.2.0" +httpx = ">=0.28.1" +pydantic = ">=2.10.3" +python-dateutil = ">=2.8.2" +typing-inspection = ">=0.4.0" + +[package.extras] +agents = ["authlib (>=1.5.2,<2.0)", "griffe (>=1.7.3,<2.0)", "mcp (>=1.0,<2.0)"] +gcp = ["google-auth (>=2.27.0)", "requests (>=2.32.3)"] + [[package]] name = "msgpack" version = "1.1.0" @@ -7202,4 +7238,4 @@ tests = ["wikipedia"] [metadata] lock-version = "2.0" python-versions = "<3.14,>=3.10" -content-hash = "186dbd44cfb8a4c2f041ab8261475154db27235c560de155a07e125337da3175" +content-hash = "3b363dada9fd643d7e2a6ba3f1041c365c22249926d9c0959132336f47718992" diff --git a/pyproject.toml b/pyproject.toml index 604c6b01..77e9fad8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ matplotlib = "^3.10.1" asyncpg = {version = "^0.30.0", optional = true} tavily-python = "^0.7.2" async-lru = "^2.0.5" +mistralai = "^1.8.1" [tool.poetry.extras] diff --git a/tests/test_sources.py b/tests/test_sources.py index 0aac53ea..bde523f7 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -56,7 +56,26 @@ def agent_state(client: LettaSDKClient): client.agents.delete(agent_id=agent_state.id) -def test_file_upload_creates_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState): +import re +import time + +import pytest + + +@pytest.mark.parametrize( + "file_path, expected_value, expected_label_regex", + [ + ("tests/data/test.txt", "test", r"test_[a-z0-9]+\.txt"), + ("tests/data/memgpt_paper.pdf", "MemGPT", r"memgpt_paper_[a-z0-9]+\.pdf"), + ], +) +def test_file_upload_creates_source_blocks_correctly( + client: LettaSDKClient, + agent_state: AgentState, + file_path: str, + expected_value: str, + expected_label_regex: str, +): # Clear existing sources for source in client.sources.list(): client.sources.delete(source_id=source.id) @@ -72,38 +91,35 @@ def test_file_upload_creates_source_blocks_correctly(client: LettaSDKClient, age # Attach client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id) - # Load files into the source - file_path = "tests/data/test.txt" - - # Upload the files + # Upload the file with open(file_path, "rb") as f: job = client.sources.files.upload(source_id=source.id, file=f) - # Wait for the jobs to complete + # Wait for the job to complete while job.status != "completed": time.sleep(1) job = client.jobs.retrieve(job_id=job.id) print("Waiting for jobs to complete...", job.status) - # Get the first file with pagination + # Get uploaded files files = client.sources.files.list(source_id=source.id, limit=1) assert len(files) == 1 assert files[0].source_id == source.id - # Get the agent state, check blocks exist + # Check that blocks were created blocks = client.agents.blocks.list(agent_id=agent_state.id) assert len(blocks) == 2 - assert "test" in [b.value for b in blocks] - assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks) + assert any(expected_value in b.value for b in blocks) + assert any(re.fullmatch(expected_label_regex, b.label) for b in blocks) # Remove file from source client.sources.files.delete(source_id=source.id, file_id=files[0].id) - # Get the agent state, check blocks do NOT exist + # Confirm blocks were removed blocks = client.agents.blocks.list(agent_id=agent_state.id) assert len(blocks) == 1 - assert "test" not in [b.value for b in blocks] - assert not any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks) + assert not any(expected_value in b.value for b in blocks) + assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks) def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState):