feat: Add mistral for cloud document parsing (#2562)
This commit is contained in:
@@ -8,6 +8,7 @@ from starlette import status
|
||||
|
||||
import letta.constants as constants
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.job import Job
|
||||
from letta.schemas.passage import Passage
|
||||
@@ -15,6 +16,11 @@ from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.file_processor.chunker.llama_index_chunker import LlamaIndexChunker
|
||||
from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
|
||||
from letta.services.file_processor.file_processor import FileProcessor
|
||||
from letta.services.file_processor.parser.mistral_parser import MistralFileParser
|
||||
from letta.settings import model_settings, settings
|
||||
from letta.utils import safe_create_task, sanitize_filename
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -171,12 +177,15 @@ async def upload_file_to_source(
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
if source is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Source with id={source_id} not found.")
|
||||
bytes = file.file.read()
|
||||
content = await file.read()
|
||||
|
||||
# sanitize filename
|
||||
file.filename = sanitize_filename(file.filename)
|
||||
|
||||
try:
|
||||
text = bytes.decode("utf-8")
|
||||
text = content.decode("utf-8")
|
||||
except Exception:
|
||||
text = ""
|
||||
text = "[Currently parsing...]"
|
||||
|
||||
# create job
|
||||
job = Job(
|
||||
@@ -184,26 +193,28 @@ async def upload_file_to_source(
|
||||
metadata={"type": "embedding", "filename": file.filename, "source_id": source_id},
|
||||
completed_at=None,
|
||||
)
|
||||
job_id = job.id
|
||||
await server.job_manager.create_job_async(job, actor=actor)
|
||||
job = await server.job_manager.create_job_async(job, actor=actor)
|
||||
|
||||
# sanitize filename
|
||||
sanitized_filename = sanitize_filename(file.filename)
|
||||
# Add blocks (sometimes without content, for UX purposes)
|
||||
agent_states = await server.insert_document_into_context_windows(source_id=source_id, text=text, filename=file.filename, actor=actor)
|
||||
|
||||
# Add blocks
|
||||
await server.insert_document_into_context_windows(source_id=source_id, text=text, filename=sanitized_filename, actor=actor)
|
||||
|
||||
# create background tasks
|
||||
safe_create_task(
|
||||
load_file_to_source_async(server, source_id=source.id, filename=sanitized_filename, job_id=job.id, bytes=bytes, actor=actor),
|
||||
logger=logger,
|
||||
label="load_file_to_source_async",
|
||||
)
|
||||
# NEW: Cloud based file processing
|
||||
if settings.mistral_api_key and model_settings.openai_api_key:
|
||||
logger.info("Running experimental cloud based file processing...")
|
||||
safe_create_task(
|
||||
load_file_to_source_cloud(server, agent_states, content, file, job, source_id, actor),
|
||||
logger=logger,
|
||||
label="file_processor.process",
|
||||
)
|
||||
else:
|
||||
# create background tasks
|
||||
safe_create_task(
|
||||
load_file_to_source_async(server, source_id=source.id, filename=file.filename, job_id=job.id, bytes=content, actor=actor),
|
||||
logger=logger,
|
||||
label="load_file_to_source_async",
|
||||
)
|
||||
safe_create_task(sleeptime_document_ingest_async(server, source_id, actor), logger=logger, label="sleeptime_document_ingest_async")
|
||||
|
||||
job = await server.job_manager.get_job_by_id_async(job_id=job_id, actor=actor)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Job with id={job_id} not found.")
|
||||
return job
|
||||
|
||||
|
||||
@@ -287,3 +298,13 @@ async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, ac
|
||||
for agent in agents:
|
||||
if agent.enable_sleeptime:
|
||||
await server.sleeptime_document_ingest_async(agent, source, actor, clear_history)
|
||||
|
||||
|
||||
async def load_file_to_source_cloud(
|
||||
server: SyncServer, agent_states: List[AgentState], content: bytes, file: UploadFile, job: Job, source_id: str, actor: User
|
||||
):
|
||||
file_processor = MistralFileParser()
|
||||
text_chunker = LlamaIndexChunker()
|
||||
embedder = OpenAIEmbedder()
|
||||
file_processor = FileProcessor(file_parser=file_processor, text_chunker=text_chunker, embedder=embedder, actor=actor)
|
||||
await file_processor.process(server=server, agent_states=agent_states, source_id=source_id, content=content, file=file, job=job)
|
||||
|
||||
@@ -1340,6 +1340,9 @@ class SyncServer(Server):
|
||||
|
||||
return job
|
||||
|
||||
async def load_file_to_source_via_mistral(self):
|
||||
pass
|
||||
|
||||
async def sleeptime_document_ingest_async(
|
||||
self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False
|
||||
) -> None:
|
||||
@@ -1410,22 +1413,26 @@ class SyncServer(Server):
|
||||
except NoResultFound:
|
||||
logger.info(f"Document block with label {filename} already removed, skipping...")
|
||||
|
||||
async def insert_document_into_context_windows(self, source_id: str, text: str, filename: str, actor: User) -> None:
|
||||
async def insert_document_into_context_windows(
|
||||
self, source_id: str, text: str, filename: str, actor: User, agent_states: Optional[List[AgentState]] = None
|
||||
) -> List[AgentState]:
|
||||
"""
|
||||
Insert the uploaded document into the context window of all agents
|
||||
attached to the given source.
|
||||
"""
|
||||
agent_states = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
|
||||
agent_states = agent_states or await self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
|
||||
|
||||
# Return early
|
||||
if not agent_states:
|
||||
return
|
||||
return []
|
||||
|
||||
logger.info(f"Inserting document into context window for source: {source_id}")
|
||||
logger.info(f"Attached agents: {[a.id for a in agent_states]}")
|
||||
|
||||
await asyncio.gather(*(self._upsert_document_block(agent_state.id, text, filename, actor) for agent_state in agent_states))
|
||||
|
||||
return agent_states
|
||||
|
||||
async def insert_documents_into_context_window(
|
||||
self, agent_state: AgentState, texts: List[str], filenames: List[str], actor: User
|
||||
) -> None:
|
||||
|
||||
0
letta/services/file_processor/__init__.py
Normal file
0
letta/services/file_processor/__init__.py
Normal file
0
letta/services/file_processor/chunker/__init__.py
Normal file
0
letta/services/file_processor/chunker/__init__.py
Normal file
29
letta/services/file_processor/chunker/llama_index_chunker.py
Normal file
29
letta/services/file_processor/chunker/llama_index_chunker.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from typing import List
|
||||
|
||||
from mistralai import OCRPageObject
|
||||
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LlamaIndexChunker:
|
||||
"""LlamaIndex-based text chunking"""
|
||||
|
||||
def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50):
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
self.parser = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
|
||||
# TODO: Make this more general beyond Mistral
|
||||
def chunk_text(self, page: OCRPageObject) -> List[str]:
|
||||
"""Chunk text using LlamaIndex splitter"""
|
||||
try:
|
||||
return self.parser.split_text(page.markdown)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Chunking failed: {str(e)}")
|
||||
raise
|
||||
0
letta/services/file_processor/embedder/__init__.py
Normal file
0
letta/services/file_processor/embedder/__init__.py
Normal file
84
letta/services/file_processor/embedder/openai_embedder.py
Normal file
84
letta/services/file_processor/embedder/openai_embedder.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import asyncio
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import openai
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.user import User
|
||||
from letta.settings import model_settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIEmbedder:
|
||||
"""OpenAI-based embedding generation"""
|
||||
|
||||
def __init__(self, embedding_config: Optional[EmbeddingConfig] = None):
|
||||
self.embedding_config = embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
|
||||
# TODO: Unify to global OpenAI client
|
||||
self.client = openai.AsyncOpenAI(api_key=model_settings.openai_api_key)
|
||||
self.max_batch = 1024
|
||||
self.max_concurrent_requests = 20
|
||||
|
||||
async def _embed_batch(self, batch: List[str], batch_indices: List[int]) -> List[Tuple[int, List[float]]]:
|
||||
"""Embed a single batch and return embeddings with their original indices"""
|
||||
response = await self.client.embeddings.create(model=self.embedding_config.embedding_model, input=batch)
|
||||
return [(idx, res.embedding) for idx, res in zip(batch_indices, response.data)]
|
||||
|
||||
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
|
||||
"""Generate embeddings for chunks with batching and concurrent processing"""
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
logger.info(f"Generating embeddings for {len(chunks)} chunks using {self.embedding_config.embedding_model}")
|
||||
|
||||
# Create batches with their original indices
|
||||
batches = []
|
||||
batch_indices = []
|
||||
|
||||
for i in range(0, len(chunks), self.max_batch):
|
||||
batch = chunks[i : i + self.max_batch]
|
||||
indices = list(range(i, min(i + self.max_batch, len(chunks))))
|
||||
batches.append(batch)
|
||||
batch_indices.append(indices)
|
||||
|
||||
logger.info(f"Processing {len(batches)} batches")
|
||||
|
||||
async def process(batch: List[str], indices: List[int]):
|
||||
try:
|
||||
return await self._embed_batch(batch, indices)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to embed batch of size {len(batch)}: {str(e)}")
|
||||
raise
|
||||
|
||||
# Execute all batches concurrently with semaphore control
|
||||
tasks = [process(batch, indices) for batch, indices in zip(batches, batch_indices)]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Flatten results and sort by original index
|
||||
indexed_embeddings = []
|
||||
for batch_result in results:
|
||||
indexed_embeddings.extend(batch_result)
|
||||
|
||||
# Sort by index to maintain original order
|
||||
indexed_embeddings.sort(key=lambda x: x[0])
|
||||
|
||||
# Create Passage objects in original order
|
||||
passages = []
|
||||
for (idx, embedding), text in zip(indexed_embeddings, chunks):
|
||||
passage = Passage(
|
||||
text=text,
|
||||
file_id=file_id,
|
||||
source_id=source_id,
|
||||
embedding=embedding,
|
||||
embedding_config=self.embedding_config,
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
passages.append(passage)
|
||||
|
||||
logger.info(f"Successfully generated {len(passages)} embeddings")
|
||||
return passages
|
||||
123
letta/services/file_processor/file_processor.py
Normal file
123
letta/services/file_processor/file_processor.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import mimetypes
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import UploadFile
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.job import Job, JobUpdate
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.user import User
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.file_processor.chunker.llama_index_chunker import LlamaIndexChunker
|
||||
from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
|
||||
from letta.services.file_processor.parser.mistral_parser import MistralFileParser
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FileProcessor:
|
||||
"""Main PDF processing orchestrator"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_parser: MistralFileParser,
|
||||
text_chunker: LlamaIndexChunker,
|
||||
embedder: OpenAIEmbedder,
|
||||
actor: User,
|
||||
max_file_size: int = 50 * 1024 * 1024, # 50MB default
|
||||
):
|
||||
self.file_parser = file_parser
|
||||
self.text_chunker = text_chunker
|
||||
self.embedder = embedder
|
||||
self.max_file_size = max_file_size
|
||||
self.source_manager = SourceManager()
|
||||
self.passage_manager = PassageManager()
|
||||
self.job_manager = JobManager()
|
||||
self.actor = actor
|
||||
|
||||
# TODO: Factor this function out of SyncServer
|
||||
async def process(
|
||||
self,
|
||||
server: SyncServer,
|
||||
agent_states: List[AgentState],
|
||||
source_id: str,
|
||||
content: bytes,
|
||||
file: UploadFile,
|
||||
job: Optional[Job] = None,
|
||||
) -> List[Passage]:
|
||||
file_metadata = self._extract_upload_file_metadata(file, source_id=source_id)
|
||||
file_metadata = await self.source_manager.create_file(file_metadata, self.actor)
|
||||
filename = file_metadata.file_name
|
||||
|
||||
try:
|
||||
# Ensure we're working with bytes
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
|
||||
if len(content) > self.max_file_size:
|
||||
raise ValueError(f"PDF size exceeds maximum allowed size of {self.max_file_size} bytes")
|
||||
|
||||
logger.info(f"Starting OCR extraction for {filename}")
|
||||
ocr_response = await self.file_parser.extract_text(content, mime_type=file_metadata.file_type)
|
||||
|
||||
if not ocr_response or len(ocr_response.pages) == 0:
|
||||
raise ValueError("No text extracted from PDF")
|
||||
|
||||
logger.info("Chunking extracted text")
|
||||
all_passages = []
|
||||
for page in ocr_response.pages:
|
||||
chunks = self.text_chunker.chunk_text(page)
|
||||
|
||||
if not chunks:
|
||||
raise ValueError("No chunks created from text")
|
||||
|
||||
passages = await self.embedder.generate_embedded_passages(
|
||||
file_id=file_metadata.id, source_id=source_id, chunks=chunks, actor=self.actor
|
||||
)
|
||||
all_passages.extend(passages)
|
||||
|
||||
all_passages = await self.passage_manager.create_many_passages_async(all_passages, self.actor)
|
||||
|
||||
logger.info(f"Successfully processed {filename}: {len(all_passages)} passages")
|
||||
|
||||
await server.insert_document_into_context_windows(
|
||||
source_id=source_id,
|
||||
text="".join([ocr_response.pages[i].markdown for i in range(min(3, len(ocr_response.pages)))]),
|
||||
filename=file.filename,
|
||||
actor=self.actor,
|
||||
agent_states=agent_states,
|
||||
)
|
||||
|
||||
# update job status
|
||||
if job:
|
||||
job.status = JobStatus.completed
|
||||
job.metadata["num_passages"] = len(all_passages)
|
||||
await self.job_manager.update_job_by_id_async(job_id=job.id, job_update=JobUpdate(**job.model_dump()), actor=self.actor)
|
||||
|
||||
return all_passages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"PDF processing failed for {filename}: {str(e)}")
|
||||
|
||||
# update job status
|
||||
if job:
|
||||
job.status = JobStatus.failed
|
||||
job.metadata["error"] = str(e)
|
||||
await self.job_manager.update_job_by_id_async(job_id=job.id, job_update=JobUpdate(**job.model_dump()), actor=self.actor)
|
||||
|
||||
return []
|
||||
|
||||
def _extract_upload_file_metadata(self, file: UploadFile, source_id: str) -> FileMetadata:
|
||||
file_metadata = {
|
||||
"file_name": file.filename,
|
||||
"file_path": None,
|
||||
"file_type": mimetypes.guess_type(file.filename)[0] or file.content_type or "unknown",
|
||||
"file_size": file.size if file.size is not None else None,
|
||||
}
|
||||
return FileMetadata(**file_metadata, source_id=source_id)
|
||||
0
letta/services/file_processor/parser/__init__.py
Normal file
0
letta/services/file_processor/parser/__init__.py
Normal file
9
letta/services/file_processor/parser/base_parser.py
Normal file
9
letta/services/file_processor/parser/base_parser.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class FileParser(ABC):
|
||||
"""Abstract base class for file parser"""
|
||||
|
||||
@abstractmethod
|
||||
async def extract_text(self, content: bytes, mime_type: str):
|
||||
"""Extract text from PDF content"""
|
||||
54
letta/services/file_processor/parser/mistral_parser.py
Normal file
54
letta/services/file_processor/parser/mistral_parser.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import base64
|
||||
|
||||
from mistralai import Mistral, OCRPageObject, OCRResponse, OCRUsageInfo
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.services.file_processor.parser.base_parser import FileParser
|
||||
from letta.settings import settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MistralFileParser(FileParser):
|
||||
"""Mistral-based OCR extraction"""
|
||||
|
||||
def __init__(self, model: str = "mistral-ocr-latest"):
|
||||
self.model = model
|
||||
|
||||
# TODO: Make this return something general if we add more file parsers
|
||||
async def extract_text(self, content: bytes, mime_type: str) -> OCRResponse:
|
||||
"""Extract text using Mistral OCR or shortcut for plain text."""
|
||||
try:
|
||||
logger.info(f"Extracting text using Mistral OCR model: {self.model}")
|
||||
|
||||
# TODO: Kind of hacky...we try to exit early here?
|
||||
# TODO: Create our internal file parser representation we return instead of OCRResponse
|
||||
if mime_type == "text/plain":
|
||||
text = content.decode("utf-8", errors="replace")
|
||||
return OCRResponse(
|
||||
model=self.model,
|
||||
pages=[
|
||||
OCRPageObject(
|
||||
index=0,
|
||||
markdown=text,
|
||||
images=[],
|
||||
dimensions=None,
|
||||
)
|
||||
],
|
||||
usage_info=OCRUsageInfo(pages_processed=1), # You might need to construct this properly
|
||||
document_annotation=None,
|
||||
)
|
||||
|
||||
base64_encoded_content = base64.b64encode(content).decode("utf-8")
|
||||
document_url = f"data:{mime_type};base64,{base64_encoded_content}"
|
||||
|
||||
async with Mistral(api_key=settings.mistral_api_key) as mistral:
|
||||
ocr_response = await mistral.ocr.process_async(
|
||||
model="mistral-ocr-latest", document={"type": "document_url", "document_url": document_url}, include_image_base64=False
|
||||
)
|
||||
|
||||
return ocr_response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OCR extraction failed: {str(e)}")
|
||||
raise
|
||||
0
letta/services/file_processor/types.py
Normal file
0
letta/services/file_processor/types.py
Normal file
@@ -234,6 +234,9 @@ class Settings(BaseSettings):
|
||||
poll_lock_retry_interval_seconds: int = 5 * 60
|
||||
batch_job_polling_lookback_weeks: int = 2
|
||||
|
||||
# for OCR
|
||||
mistral_api_key: Optional[str] = None
|
||||
|
||||
@property
|
||||
def letta_pg_uri(self) -> str:
|
||||
if self.pg_uri:
|
||||
|
||||
38
poetry.lock
generated
38
poetry.lock
generated
@@ -1432,6 +1432,20 @@ files = [
|
||||
dnspython = ">=2.0.0"
|
||||
idna = ">=2.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "eval-type-backport"
|
||||
version = "0.2.2"
|
||||
description = "Like `typing._eval_type`, but lets older Python versions use newer typing features."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "eval_type_backport-0.2.2-py3-none-any.whl", hash = "sha256:cb6ad7c393517f476f96d456d0412ea80f0a8cf96f6892834cd9340149111b0a"},
|
||||
{file = "eval_type_backport-0.2.2.tar.gz", hash = "sha256:f0576b4cf01ebb5bd358d02314d31846af5e07678387486e2c798af0e7d849c1"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
tests = ["pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.2.2"
|
||||
@@ -3633,6 +3647,28 @@ files = [
|
||||
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mistralai"
|
||||
version = "1.8.1"
|
||||
description = "Python Client SDK for the Mistral AI API."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "mistralai-1.8.1-py3-none-any.whl", hash = "sha256:badfc7e6832d894b3e9071d92ad621212b7cccd7df622c6cacdb525162ae338f"},
|
||||
{file = "mistralai-1.8.1.tar.gz", hash = "sha256:b967ca443726b71ec45632cb33825ee2e55239a652e73c2bda11f7cc683bf6e5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
eval-type-backport = ">=0.2.0"
|
||||
httpx = ">=0.28.1"
|
||||
pydantic = ">=2.10.3"
|
||||
python-dateutil = ">=2.8.2"
|
||||
typing-inspection = ">=0.4.0"
|
||||
|
||||
[package.extras]
|
||||
agents = ["authlib (>=1.5.2,<2.0)", "griffe (>=1.7.3,<2.0)", "mcp (>=1.0,<2.0)"]
|
||||
gcp = ["google-auth (>=2.27.0)", "requests (>=2.32.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "msgpack"
|
||||
version = "1.1.0"
|
||||
@@ -7202,4 +7238,4 @@ tests = ["wikipedia"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.14,>=3.10"
|
||||
content-hash = "186dbd44cfb8a4c2f041ab8261475154db27235c560de155a07e125337da3175"
|
||||
content-hash = "3b363dada9fd643d7e2a6ba3f1041c365c22249926d9c0959132336f47718992"
|
||||
|
||||
@@ -93,6 +93,7 @@ matplotlib = "^3.10.1"
|
||||
asyncpg = {version = "^0.30.0", optional = true}
|
||||
tavily-python = "^0.7.2"
|
||||
async-lru = "^2.0.5"
|
||||
mistralai = "^1.8.1"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
||||
@@ -56,7 +56,26 @@ def agent_state(client: LettaSDKClient):
|
||||
client.agents.delete(agent_id=agent_state.id)
|
||||
|
||||
|
||||
def test_file_upload_creates_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState):
|
||||
import re
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"file_path, expected_value, expected_label_regex",
|
||||
[
|
||||
("tests/data/test.txt", "test", r"test_[a-z0-9]+\.txt"),
|
||||
("tests/data/memgpt_paper.pdf", "MemGPT", r"memgpt_paper_[a-z0-9]+\.pdf"),
|
||||
],
|
||||
)
|
||||
def test_file_upload_creates_source_blocks_correctly(
|
||||
client: LettaSDKClient,
|
||||
agent_state: AgentState,
|
||||
file_path: str,
|
||||
expected_value: str,
|
||||
expected_label_regex: str,
|
||||
):
|
||||
# Clear existing sources
|
||||
for source in client.sources.list():
|
||||
client.sources.delete(source_id=source.id)
|
||||
@@ -72,38 +91,35 @@ def test_file_upload_creates_source_blocks_correctly(client: LettaSDKClient, age
|
||||
# Attach
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
|
||||
|
||||
# Load files into the source
|
||||
file_path = "tests/data/test.txt"
|
||||
|
||||
# Upload the files
|
||||
# Upload the file
|
||||
with open(file_path, "rb") as f:
|
||||
job = client.sources.files.upload(source_id=source.id, file=f)
|
||||
|
||||
# Wait for the jobs to complete
|
||||
# Wait for the job to complete
|
||||
while job.status != "completed":
|
||||
time.sleep(1)
|
||||
job = client.jobs.retrieve(job_id=job.id)
|
||||
print("Waiting for jobs to complete...", job.status)
|
||||
|
||||
# Get the first file with pagination
|
||||
# Get uploaded files
|
||||
files = client.sources.files.list(source_id=source.id, limit=1)
|
||||
assert len(files) == 1
|
||||
assert files[0].source_id == source.id
|
||||
|
||||
# Get the agent state, check blocks exist
|
||||
# Check that blocks were created
|
||||
blocks = client.agents.blocks.list(agent_id=agent_state.id)
|
||||
assert len(blocks) == 2
|
||||
assert "test" in [b.value for b in blocks]
|
||||
assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
|
||||
assert any(expected_value in b.value for b in blocks)
|
||||
assert any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
|
||||
|
||||
# Remove file from source
|
||||
client.sources.files.delete(source_id=source.id, file_id=files[0].id)
|
||||
|
||||
# Get the agent state, check blocks do NOT exist
|
||||
# Confirm blocks were removed
|
||||
blocks = client.agents.blocks.list(agent_id=agent_state.id)
|
||||
assert len(blocks) == 1
|
||||
assert "test" not in [b.value for b in blocks]
|
||||
assert not any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
|
||||
assert not any(expected_value in b.value for b in blocks)
|
||||
assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
|
||||
|
||||
|
||||
def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState):
|
||||
|
||||
Reference in New Issue
Block a user