248 lines
9.7 KiB
Python
248 lines
9.7 KiB
Python
import os
|
|
import tempfile
|
|
from typing import List, Optional
|
|
|
|
from fastapi import (
|
|
APIRouter,
|
|
BackgroundTasks,
|
|
Depends,
|
|
Header,
|
|
HTTPException,
|
|
Query,
|
|
UploadFile,
|
|
)
|
|
from letta.schemas.file import FileMetadata
|
|
from letta.schemas.job import Job
|
|
from letta.schemas.passage import Passage
|
|
from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
|
from letta.schemas.user import User
|
|
from letta.server.rest_api.utils import get_letta_server
|
|
from letta.server.server import SyncServer
|
|
from letta.utils import sanitize_filename
|
|
|
|
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
|
|
|
|
|
|
router = APIRouter(prefix="/sources", tags=["sources"])
|
|
|
|
|
|
@router.get("/{source_id}", response_model=Source, operation_id="get_source")
|
|
def get_source(
|
|
source_id: str,
|
|
server: "SyncServer" = Depends(get_letta_server),
|
|
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
):
|
|
"""
|
|
Get all sources
|
|
"""
|
|
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
|
|
|
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
|
if not source:
|
|
raise HTTPException(status_code=404, detail=f"Source with id={source_id} not found.")
|
|
return source
|
|
|
|
|
|
@router.get("/name/{source_name}", response_model=str, operation_id="get_source_id_by_name")
|
|
def get_source_id_by_name(
|
|
source_name: str,
|
|
server: "SyncServer" = Depends(get_letta_server),
|
|
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
):
|
|
"""
|
|
Get a source by name
|
|
"""
|
|
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
|
|
|
source = server.source_manager.get_source_by_name(source_name=source_name, actor=actor)
|
|
if not source:
|
|
raise HTTPException(status_code=404, detail=f"Source with name={source_name} not found.")
|
|
return source.id
|
|
|
|
|
|
@router.get("/", response_model=List[Source], operation_id="list_sources")
|
|
def list_sources(
|
|
server: "SyncServer" = Depends(get_letta_server),
|
|
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
):
|
|
"""
|
|
List all data sources created by a user.
|
|
"""
|
|
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
|
|
|
return server.list_all_sources(actor=actor)
|
|
|
|
|
|
@router.post("/", response_model=Source, operation_id="create_source")
|
|
def create_source(
|
|
source_create: SourceCreate,
|
|
server: "SyncServer" = Depends(get_letta_server),
|
|
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
):
|
|
"""
|
|
Create a new data source.
|
|
"""
|
|
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
|
source = Source(**source_create.model_dump())
|
|
|
|
return server.source_manager.create_source(source=source, actor=actor)
|
|
|
|
|
|
@router.patch("/{source_id}", response_model=Source, operation_id="update_source")
|
|
def update_source(
|
|
source_id: str,
|
|
source: SourceUpdate,
|
|
server: "SyncServer" = Depends(get_letta_server),
|
|
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
):
|
|
"""
|
|
Update the name or documentation of an existing data source.
|
|
"""
|
|
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
|
if not server.source_manager.get_source_by_id(source_id=source_id, actor=actor):
|
|
raise HTTPException(status_code=404, detail=f"Source with id={source_id} does not exist.")
|
|
return server.source_manager.update_source(source_id=source_id, source_update=source, actor=actor)
|
|
|
|
|
|
@router.delete("/{source_id}", response_model=None, operation_id="delete_source")
|
|
def delete_source(
|
|
source_id: str,
|
|
server: "SyncServer" = Depends(get_letta_server),
|
|
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
):
|
|
"""
|
|
Delete a data source.
|
|
"""
|
|
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
|
|
|
server.delete_source(source_id=source_id, actor=actor)
|
|
|
|
|
|
@router.post("/{source_id}/attach", response_model=Source, operation_id="attach_agent_to_source")
|
|
def attach_source_to_agent(
|
|
source_id: str,
|
|
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
|
|
server: "SyncServer" = Depends(get_letta_server),
|
|
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
):
|
|
"""
|
|
Attach a data source to an existing agent.
|
|
"""
|
|
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
|
server.agent_manager.attach_source(source_id=source_id, agent_id=agent_id, actor=actor)
|
|
return server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
|
|
|
|
|
@router.post("/{source_id}/detach", response_model=Source, operation_id="detach_agent_from_source")
|
|
def detach_source_from_agent(
|
|
source_id: str,
|
|
agent_id: str = Query(..., description="The unique identifier of the agent to detach the source from."),
|
|
server: "SyncServer" = Depends(get_letta_server),
|
|
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
) -> None:
|
|
"""
|
|
Detach a data source from an existing agent.
|
|
"""
|
|
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
|
server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
|
return server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
|
|
|
|
|
@router.post("/{source_id}/upload", response_model=Job, operation_id="upload_file_to_source")
|
|
def upload_file_to_source(
|
|
file: UploadFile,
|
|
source_id: str,
|
|
background_tasks: BackgroundTasks,
|
|
server: "SyncServer" = Depends(get_letta_server),
|
|
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
):
|
|
"""
|
|
Upload a file to a data source.
|
|
"""
|
|
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
|
|
|
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
|
assert source is not None, f"Source with id={source_id} not found."
|
|
bytes = file.file.read()
|
|
|
|
# create job
|
|
job = Job(
|
|
user_id=actor.id,
|
|
metadata_={"type": "embedding", "filename": file.filename, "source_id": source_id},
|
|
completed_at=None,
|
|
)
|
|
job_id = job.id
|
|
server.job_manager.create_job(job, actor=actor)
|
|
|
|
# create background task
|
|
background_tasks.add_task(load_file_to_source_async, server, source_id=source.id, file=file, job_id=job.id, bytes=bytes, actor=actor)
|
|
|
|
# return job information
|
|
# Is this necessary? Can we just return the job from create_job?
|
|
job = server.job_manager.get_job_by_id(job_id=job_id, actor=actor)
|
|
assert job is not None, "Job not found"
|
|
return job
|
|
|
|
|
|
@router.get("/{source_id}/passages", response_model=List[Passage], operation_id="list_source_passages")
|
|
def list_passages(
|
|
source_id: str,
|
|
server: SyncServer = Depends(get_letta_server),
|
|
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
):
|
|
"""
|
|
List all passages associated with a data source.
|
|
"""
|
|
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
|
passages = server.list_data_source_passages(user_id=actor.id, source_id=source_id)
|
|
return passages
|
|
|
|
|
|
@router.get("/{source_id}/files", response_model=List[FileMetadata], operation_id="list_files_from_source")
|
|
def list_files_from_source(
|
|
source_id: str,
|
|
limit: int = Query(1000, description="Number of files to return"),
|
|
cursor: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
|
|
server: "SyncServer" = Depends(get_letta_server),
|
|
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
):
|
|
"""
|
|
List paginated files associated with a data source.
|
|
"""
|
|
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
|
return server.source_manager.list_files(source_id=source_id, limit=limit, cursor=cursor, actor=actor)
|
|
|
|
|
|
# it's redundant to include /delete in the URL path. The HTTP verb DELETE already implies that action.
|
|
# it's still good practice to return a status indicating the success or failure of the deletion
|
|
@router.delete("/{source_id}/{file_id}", status_code=204, operation_id="delete_file_from_source")
|
|
def delete_file_from_source(
|
|
source_id: str,
|
|
file_id: str,
|
|
server: "SyncServer" = Depends(get_letta_server),
|
|
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
):
|
|
"""
|
|
Delete a data source.
|
|
"""
|
|
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
|
|
|
deleted_file = server.source_manager.delete_file(file_id=file_id, actor=actor)
|
|
if deleted_file is None:
|
|
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")
|
|
|
|
|
|
def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes, actor: User):
|
|
# Create a temporary directory (deleted after the context manager exits)
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
# Sanitize the filename
|
|
sanitized_filename = sanitize_filename(file.filename)
|
|
file_path = os.path.join(tmpdirname, sanitized_filename)
|
|
|
|
# Write the file to the sanitized path
|
|
with open(file_path, "wb") as buffer:
|
|
buffer.write(bytes)
|
|
|
|
# Pass the file to load_file_to_source
|
|
server.load_file_to_source(source_id, file_path, job_id, actor)
|