feat: allow jobs to be filtered by source_id (#1786)

This commit is contained in:
Sarah Wooders
2024-09-25 15:13:14 -07:00
committed by GitHub
parent 7fa1016793
commit 81737fdeec
5 changed files with 18 additions and 7 deletions

View File

@@ -2082,7 +2082,8 @@ class LocalClient(AbstractClient):
Returns:
job (Job): Data loading job including job status and metadata
"""
job = self.server.create_job(user_id=self.user_id)
metadata_ = {"type": "embedding", "filename": filename, "source_id": source_id}
job = self.server.create_job(user_id=self.user_id, metadata=metadata_)
# TODO: implement blocking vs. non-blocking
self.server.load_file_to_source(source_id=source_id, file_path=filename, job_id=job.id)

View File

@@ -10,7 +10,7 @@ from letta.utils import get_utc_time
class JobBase(LettaBase):
__id_prefix__ = "job"
metadata_: Optional[dict] = Field({}, description="The metadata of the job.")
metadata_: Optional[dict] = Field(None, description="The metadata of the job.")
class Job(JobBase):

View File

@@ -1,6 +1,6 @@
from typing import List
from typing import List, Optional
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Query
from letta.schemas.job import Job
from letta.server.rest_api.utils import get_letta_server
@@ -12,6 +12,7 @@ router = APIRouter(prefix="/jobs", tags=["jobs"])
@router.get("/", response_model=List[Job], operation_id="list_jobs")
def list_jobs(
server: "SyncServer" = Depends(get_letta_server),
source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."),
):
"""
List all jobs.
@@ -19,7 +20,14 @@ def list_jobs(
actor = server.get_current_user()
# TODO: add filtering by status
return server.list_jobs(user_id=actor.id)
jobs = server.list_jobs(user_id=actor.id)
# TODO: eventually use ORM
# results = session.query(JobModel).filter(JobModel.user_id == user_id, JobModel.metadata_["source_id"].astext == sourced_id).all()
if source_id:
# can't be in the ORM since we have source_id stored in the metadata_
jobs = [job for job in jobs if job.metadata_.get("source_id") == source_id]
return jobs
@router.get("/active", response_model=List[Job], operation_id="list_active_jobs")

View File

@@ -6,7 +6,7 @@ import traceback
import warnings
from abc import abstractmethod
from datetime import datetime
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
from fastapi import HTTPException
@@ -1513,11 +1513,12 @@ class SyncServer(Server):
# TODO: delete data from agent passage stores (?)
def create_job(self, user_id: str) -> Job:
def create_job(self, user_id: str, metadata: Optional[Dict] = None) -> Job:
"""Create a new job"""
job = Job(
user_id=user_id,
status=JobStatus.created,
metadata_=metadata,
)
self.ms.create_job(job)
return job

View File

@@ -348,6 +348,7 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
print(jobs)
assert upload_job.id in [j.id for j in jobs]
assert len(active_jobs) == 1
assert active_jobs[0].metadata_["source_id"] == source.id
# wait for job to finish (with timeout)
timeout = 120