feat: allow jobs to be filtered by source_id (#1786)
This commit is contained in:
@@ -2082,7 +2082,8 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
job (Job): Data loading job including job status and metadata
|
||||
"""
|
||||
job = self.server.create_job(user_id=self.user_id)
|
||||
metadata_ = {"type": "embedding", "filename": filename, "source_id": source_id}
|
||||
job = self.server.create_job(user_id=self.user_id, metadata=metadata_)
|
||||
|
||||
# TODO: implement blocking vs. non-blocking
|
||||
self.server.load_file_to_source(source_id=source_id, file_path=filename, job_id=job.id)
|
||||
|
||||
@@ -10,7 +10,7 @@ from letta.utils import get_utc_time
|
||||
|
||||
class JobBase(LettaBase):
|
||||
__id_prefix__ = "job"
|
||||
metadata_: Optional[dict] = Field({}, description="The metadata of the job.")
|
||||
metadata_: Optional[dict] = Field(None, description="The metadata of the job.")
|
||||
|
||||
|
||||
class Job(JobBase):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from letta.schemas.job import Job
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
@@ -12,6 +12,7 @@ router = APIRouter(prefix="/jobs", tags=["jobs"])
|
||||
@router.get("/", response_model=List[Job], operation_id="list_jobs")
|
||||
def list_jobs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."),
|
||||
):
|
||||
"""
|
||||
List all jobs.
|
||||
@@ -19,7 +20,14 @@ def list_jobs(
|
||||
actor = server.get_current_user()
|
||||
|
||||
# TODO: add filtering by status
|
||||
return server.list_jobs(user_id=actor.id)
|
||||
jobs = server.list_jobs(user_id=actor.id)
|
||||
|
||||
# TODO: eventually use ORM
|
||||
# results = session.query(JobModel).filter(JobModel.user_id == user_id, JobModel.metadata_["source_id"].astext == sourced_id).all()
|
||||
if source_id:
|
||||
# can't be in the ORM since we have source_id stored in the metadata_
|
||||
jobs = [job for job in jobs if job.metadata_.get("source_id") == source_id]
|
||||
return jobs
|
||||
|
||||
|
||||
@router.get("/active", response_model=List[Job], operation_id="list_active_jobs")
|
||||
|
||||
@@ -6,7 +6,7 @@ import traceback
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
@@ -1513,11 +1513,12 @@ class SyncServer(Server):
|
||||
|
||||
# TODO: delete data from agent passage stores (?)
|
||||
|
||||
def create_job(self, user_id: str) -> Job:
|
||||
def create_job(self, user_id: str, metadata: Optional[Dict] = None) -> Job:
|
||||
"""Create a new job"""
|
||||
job = Job(
|
||||
user_id=user_id,
|
||||
status=JobStatus.created,
|
||||
metadata_=metadata,
|
||||
)
|
||||
self.ms.create_job(job)
|
||||
return job
|
||||
|
||||
@@ -348,6 +348,7 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
print(jobs)
|
||||
assert upload_job.id in [j.id for j in jobs]
|
||||
assert len(active_jobs) == 1
|
||||
assert active_jobs[0].metadata_["source_id"] == source.id
|
||||
|
||||
# wait for job to finish (with timeout)
|
||||
timeout = 120
|
||||
|
||||
Reference in New Issue
Block a user