From 81737fdeec88f27fd3955145459d0633c3bc23c3 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Wed, 25 Sep 2024 15:13:14 -0700 Subject: [PATCH] feat: allow jobs to be filtered by `source_id` (#1786) --- letta/client/client.py | 3 ++- letta/schemas/job.py | 2 +- letta/server/rest_api/routers/v1/jobs.py | 14 +++++++++++--- letta/server/server.py | 5 +++-- tests/test_client.py | 1 + 5 files changed, 18 insertions(+), 7 deletions(-) diff --git a/letta/client/client.py b/letta/client/client.py index 96dc4ee2..fa41cbf7 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -2082,7 +2082,8 @@ class LocalClient(AbstractClient): Returns: job (Job): Data loading job including job status and metadata """ - job = self.server.create_job(user_id=self.user_id) + metadata_ = {"type": "embedding", "filename": filename, "source_id": source_id} + job = self.server.create_job(user_id=self.user_id, metadata=metadata_) # TODO: implement blocking vs. non-blocking self.server.load_file_to_source(source_id=source_id, file_path=filename, job_id=job.id) diff --git a/letta/schemas/job.py b/letta/schemas/job.py index 4c1de730..da83d4be 100644 --- a/letta/schemas/job.py +++ b/letta/schemas/job.py @@ -10,7 +10,7 @@ from letta.utils import get_utc_time class JobBase(LettaBase): __id_prefix__ = "job" - metadata_: Optional[dict] = Field({}, description="The metadata of the job.") + metadata_: Optional[dict] = Field(None, description="The metadata of the job.") class Job(JobBase): diff --git a/letta/server/rest_api/routers/v1/jobs.py b/letta/server/rest_api/routers/v1/jobs.py index 8658d480..9064052f 100644 --- a/letta/server/rest_api/routers/v1/jobs.py +++ b/letta/server/rest_api/routers/v1/jobs.py @@ -1,6 +1,6 @@ -from typing import List +from typing import List, Optional -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Query from letta.schemas.job import Job from letta.server.rest_api.utils import get_letta_server @@ -12,6 +12,7 @@ router = APIRouter(prefix="/jobs", tags=["jobs"]) @router.get("/", response_model=List[Job], operation_id="list_jobs") def list_jobs( server: "SyncServer" = Depends(get_letta_server), + source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."), ): """ List all jobs. @@ -19,7 +20,14 @@ def list_jobs( actor = server.get_current_user() # TODO: add filtering by status - return server.list_jobs(user_id=actor.id) + jobs = server.list_jobs(user_id=actor.id) + + # TODO: eventually use ORM + # results = session.query(JobModel).filter(JobModel.user_id == user_id, JobModel.metadata_["source_id"].astext == sourced_id).all() + if source_id: + # can't be in the ORM since we have source_id stored in the metadata_ + jobs = [job for job in jobs if job.metadata_.get("source_id") == source_id] + return jobs @router.get("/active", response_model=List[Job], operation_id="list_active_jobs") diff --git a/letta/server/server.py b/letta/server/server.py index 873fc6d1..1e098ef1 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -6,7 +6,7 @@ import traceback import warnings from abc import abstractmethod from datetime import datetime -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union from fastapi import HTTPException @@ -1513,11 +1513,12 @@ class SyncServer(Server): # TODO: delete data from agent passage stores (?) - def create_job(self, user_id: str) -> Job: + def create_job(self, user_id: str, metadata: Optional[Dict] = None) -> Job: """Create a new job""" job = Job( user_id=user_id, status=JobStatus.created, + metadata_=metadata, ) self.ms.create_job(job) return job diff --git a/tests/test_client.py b/tests/test_client.py index ba20d984..adc1c875 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -348,6 +348,7 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): print(jobs) assert upload_job.id in [j.id for j in jobs] assert len(active_jobs) == 1 + assert active_jobs[0].metadata_["source_id"] == source.id # wait for job to finish (with timeout) timeout = 120