341 lines
14 KiB
Python
341 lines
14 KiB
Python
import asyncio
|
||
import mimetypes
|
||
import os
|
||
import tempfile
|
||
from pathlib import Path
|
||
from typing import List, Optional
|
||
|
||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, UploadFile
|
||
from starlette import status
|
||
|
||
import letta.constants as constants
|
||
from letta.log import get_logger
|
||
from letta.schemas.agent import AgentState
|
||
from letta.schemas.file import FileMetadata
|
||
from letta.schemas.job import Job
|
||
from letta.schemas.passage import Passage
|
||
from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
||
from letta.schemas.user import User
|
||
from letta.server.rest_api.utils import get_letta_server
|
||
from letta.server.server import SyncServer
|
||
from letta.services.file_processor.chunker.llama_index_chunker import LlamaIndexChunker
|
||
from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
|
||
from letta.services.file_processor.file_processor import FileProcessor
|
||
from letta.services.file_processor.parser.mistral_parser import MistralFileParser
|
||
from letta.settings import model_settings, settings
|
||
from letta.utils import safe_create_task, sanitize_filename
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
router = APIRouter(prefix="/sources", tags=["sources"])
|
||
|
||
|
||
@router.get("/count", response_model=int, operation_id="count_sources")
|
||
async def count_sources(
|
||
server: "SyncServer" = Depends(get_letta_server),
|
||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||
):
|
||
"""
|
||
Count all data sources created by a user.
|
||
"""
|
||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||
return await server.source_manager.size_async(actor=actor)
|
||
|
||
|
||
@router.get("/{source_id}", response_model=Source, operation_id="retrieve_source")
|
||
async def retrieve_source(
|
||
source_id: str,
|
||
server: "SyncServer" = Depends(get_letta_server),
|
||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||
):
|
||
"""
|
||
Get all sources
|
||
"""
|
||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||
|
||
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||
if not source:
|
||
raise HTTPException(status_code=404, detail=f"Source with id={source_id} not found.")
|
||
return source
|
||
|
||
|
||
@router.get("/name/{source_name}", response_model=str, operation_id="get_source_id_by_name")
|
||
async def get_source_id_by_name(
|
||
source_name: str,
|
||
server: "SyncServer" = Depends(get_letta_server),
|
||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||
):
|
||
"""
|
||
Get a source by name
|
||
"""
|
||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||
|
||
source = await server.source_manager.get_source_by_name(source_name=source_name, actor=actor)
|
||
if not source:
|
||
raise HTTPException(status_code=404, detail=f"Source with name={source_name} not found.")
|
||
return source.id
|
||
|
||
|
||
@router.get("/", response_model=List[Source], operation_id="list_sources")
|
||
async def list_sources(
|
||
server: "SyncServer" = Depends(get_letta_server),
|
||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||
):
|
||
"""
|
||
List all data sources created by a user.
|
||
"""
|
||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||
return await server.source_manager.list_sources(actor=actor)
|
||
|
||
|
||
@router.post("/", response_model=Source, operation_id="create_source")
|
||
async def create_source(
|
||
source_create: SourceCreate,
|
||
server: "SyncServer" = Depends(get_letta_server),
|
||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||
):
|
||
"""
|
||
Create a new data source.
|
||
"""
|
||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||
|
||
# TODO: need to asyncify this
|
||
if not source_create.embedding_config:
|
||
if not source_create.embedding:
|
||
# TODO: modify error type
|
||
raise ValueError("Must specify either embedding or embedding_config in request")
|
||
source_create.embedding_config = await server.get_embedding_config_from_handle_async(
|
||
handle=source_create.embedding,
|
||
embedding_chunk_size=source_create.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||
actor=actor,
|
||
)
|
||
source = Source(
|
||
name=source_create.name,
|
||
embedding_config=source_create.embedding_config,
|
||
description=source_create.description,
|
||
instructions=source_create.instructions,
|
||
metadata=source_create.metadata,
|
||
)
|
||
return await server.source_manager.create_source(source=source, actor=actor)
|
||
|
||
|
||
@router.patch("/{source_id}", response_model=Source, operation_id="modify_source")
|
||
async def modify_source(
|
||
source_id: str,
|
||
source: SourceUpdate,
|
||
server: "SyncServer" = Depends(get_letta_server),
|
||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||
):
|
||
"""
|
||
Update the name or documentation of an existing data source.
|
||
"""
|
||
# TODO: allow updating the handle/embedding config
|
||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||
if not await server.source_manager.get_source_by_id(source_id=source_id, actor=actor):
|
||
raise HTTPException(status_code=404, detail=f"Source with id={source_id} does not exist.")
|
||
return await server.source_manager.update_source(source_id=source_id, source_update=source, actor=actor)
|
||
|
||
|
||
@router.delete("/{source_id}", response_model=None, operation_id="delete_source")
|
||
async def delete_source(
|
||
source_id: str,
|
||
server: "SyncServer" = Depends(get_letta_server),
|
||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||
):
|
||
"""
|
||
Delete a data source.
|
||
"""
|
||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||
agent_states = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
|
||
files = await server.source_manager.list_files(source_id, actor)
|
||
filenames = [f.file_name for f in files]
|
||
|
||
for agent_state in agent_states:
|
||
await server.remove_documents_from_context_window(agent_state=agent_state, filenames=filenames, actor=actor)
|
||
|
||
if agent_state.enable_sleeptime:
|
||
try:
|
||
block = await server.agent_manager.get_block_with_label_async(agent_id=agent_state.id, block_label=source.name, actor=actor)
|
||
await server.block_manager.delete_block_async(block.id, actor)
|
||
except:
|
||
pass
|
||
await server.delete_source(source_id=source_id, actor=actor)
|
||
|
||
|
||
@router.post("/{source_id}/upload", response_model=Job, operation_id="upload_file_to_source")
|
||
async def upload_file_to_source(
|
||
file: UploadFile,
|
||
source_id: str,
|
||
server: "SyncServer" = Depends(get_letta_server),
|
||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||
):
|
||
"""
|
||
Upload a file to a data source.
|
||
"""
|
||
allowed_media_types = {"application/pdf", "text/plain", "application/json"}
|
||
|
||
# Normalize incoming Content-Type header (strip charset or any parameters).
|
||
raw_ct = file.content_type or ""
|
||
media_type = raw_ct.split(";", 1)[0].strip().lower()
|
||
|
||
# If client didn’t supply a Content-Type or it’s not one of the allowed types,
|
||
# attempt to infer from filename extension.
|
||
if media_type not in allowed_media_types and file.filename:
|
||
guessed, _ = mimetypes.guess_type(file.filename)
|
||
media_type = (guessed or "").lower()
|
||
|
||
if media_type not in allowed_media_types:
|
||
ext = Path(file.filename).suffix.lower()
|
||
ext_map = {
|
||
".pdf": "application/pdf",
|
||
".txt": "text/plain",
|
||
".json": "application/json",
|
||
}
|
||
media_type = ext_map.get(ext, media_type)
|
||
|
||
# If still not allowed, reject with 415.
|
||
if media_type not in allowed_media_types:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
||
detail=(f"Unsupported file type: {media_type or 'unknown'} " f"(filename: {file.filename}). Only PDF, .txt, or .json allowed."),
|
||
)
|
||
|
||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||
|
||
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||
if source is None:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Source with id={source_id} not found.")
|
||
content = await file.read()
|
||
|
||
# sanitize filename
|
||
file.filename = sanitize_filename(file.filename)
|
||
|
||
try:
|
||
text = content.decode("utf-8")
|
||
except Exception:
|
||
text = "[Currently parsing...]"
|
||
|
||
# create job
|
||
job = Job(
|
||
user_id=actor.id,
|
||
metadata={"type": "embedding", "filename": file.filename, "source_id": source_id},
|
||
completed_at=None,
|
||
)
|
||
job = await server.job_manager.create_job_async(job, actor=actor)
|
||
|
||
# Add blocks (sometimes without content, for UX purposes)
|
||
agent_states = await server.insert_document_into_context_windows(source_id=source_id, text=text, filename=file.filename, actor=actor)
|
||
|
||
# NEW: Cloud based file processing
|
||
if settings.mistral_api_key and model_settings.openai_api_key:
|
||
logger.info("Running experimental cloud based file processing...")
|
||
safe_create_task(
|
||
load_file_to_source_cloud(server, agent_states, content, file, job, source_id, actor),
|
||
logger=logger,
|
||
label="file_processor.process",
|
||
)
|
||
else:
|
||
# create background tasks
|
||
safe_create_task(
|
||
load_file_to_source_async(server, source_id=source.id, filename=file.filename, job_id=job.id, bytes=content, actor=actor),
|
||
logger=logger,
|
||
label="load_file_to_source_async",
|
||
)
|
||
safe_create_task(sleeptime_document_ingest_async(server, source_id, actor), logger=logger, label="sleeptime_document_ingest_async")
|
||
|
||
return job
|
||
|
||
|
||
@router.get("/{source_id}/passages", response_model=List[Passage], operation_id="list_source_passages")
|
||
async def list_source_passages(
|
||
source_id: str,
|
||
after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."),
|
||
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
|
||
limit: int = Query(100, description="Maximum number of messages to retrieve."),
|
||
server: SyncServer = Depends(get_letta_server),
|
||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||
):
|
||
"""
|
||
List all passages associated with a data source.
|
||
"""
|
||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||
return await server.agent_manager.list_passages_async(
|
||
actor=actor,
|
||
source_id=source_id,
|
||
after=after,
|
||
before=before,
|
||
limit=limit,
|
||
)
|
||
|
||
|
||
@router.get("/{source_id}/files", response_model=List[FileMetadata], operation_id="list_source_files")
|
||
async def list_source_files(
|
||
source_id: str,
|
||
limit: int = Query(1000, description="Number of files to return"),
|
||
after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
|
||
server: "SyncServer" = Depends(get_letta_server),
|
||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||
):
|
||
"""
|
||
List paginated files associated with a data source.
|
||
"""
|
||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||
return await server.source_manager.list_files(source_id=source_id, limit=limit, after=after, actor=actor)
|
||
|
||
|
||
# it's redundant to include /delete in the URL path. The HTTP verb DELETE already implies that action.
|
||
# it's still good practice to return a status indicating the success or failure of the deletion
|
||
@router.delete("/{source_id}/{file_id}", status_code=204, operation_id="delete_file_from_source")
|
||
async def delete_file_from_source(
|
||
source_id: str,
|
||
file_id: str,
|
||
server: "SyncServer" = Depends(get_letta_server),
|
||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||
):
|
||
"""
|
||
Delete a data source.
|
||
"""
|
||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||
|
||
deleted_file = await server.source_manager.delete_file(file_id=file_id, actor=actor)
|
||
|
||
# Remove blocks
|
||
await server.remove_document_from_context_windows(source_id=source_id, filename=deleted_file.file_name, actor=actor)
|
||
|
||
asyncio.create_task(sleeptime_document_ingest_async(server, source_id, actor, clear_history=True))
|
||
if deleted_file is None:
|
||
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")
|
||
|
||
|
||
async def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, filename: str, bytes: bytes, actor: User):
|
||
# Create a temporary directory (deleted after the context manager exits)
|
||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||
file_path = os.path.join(tmpdirname, filename)
|
||
|
||
# Write the file to the sanitized path
|
||
with open(file_path, "wb") as buffer:
|
||
buffer.write(bytes)
|
||
|
||
# Pass the file to load_file_to_source
|
||
await server.load_file_to_source(source_id, file_path, job_id, actor)
|
||
|
||
|
||
async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, actor: User, clear_history: bool = False):
|
||
source = await server.source_manager.get_source_by_id(source_id=source_id)
|
||
agents = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
|
||
for agent in agents:
|
||
if agent.enable_sleeptime:
|
||
await server.sleeptime_document_ingest_async(agent, source, actor, clear_history)
|
||
|
||
|
||
async def load_file_to_source_cloud(
|
||
server: SyncServer, agent_states: List[AgentState], content: bytes, file: UploadFile, job: Job, source_id: str, actor: User
|
||
):
|
||
file_processor = MistralFileParser()
|
||
text_chunker = LlamaIndexChunker()
|
||
embedder = OpenAIEmbedder()
|
||
file_processor = FileProcessor(file_parser=file_processor, text_chunker=text_chunker, embedder=embedder, actor=actor)
|
||
await file_processor.process(server=server, agent_states=agent_states, source_id=source_id, content=content, file=file, job=job)
|