From 83b897e91d3702441853bbb2fb0a78e83d76f11f Mon Sep 17 00:00:00 2001 From: Andy Li <55300002+cliandy@users.noreply.github.com> Date: Tue, 27 May 2025 15:34:27 -0700 Subject: [PATCH] feat: job listing takes source id (#2473) --- letta/orm/sqlalchemy_base.py | 23 ++++++++++++++++++----- letta/server/rest_api/routers/v1/jobs.py | 15 ++++++--------- letta/services/job_manager.py | 4 ++++ tests/test_managers.py | 9 +++++++++ 4 files changed, 37 insertions(+), 14 deletions(-) diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index c629d283..f3ec891b 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -280,15 +280,28 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.id == identity_id) # Apply filtering logic from kwargs + # 1 part: // 2 parts: . OR . // 3 parts:
.. + # TODO (cliandy): can make this more robust down the line for key, value in kwargs.items(): - if "." in key: - # Handle joined table columns - table_name, column_name = key.split(".") + parts = key.split(".") + if len(parts) == 1: + column = getattr(cls, key) + elif len(parts) == 2: + if locals().get(parts[0]) or globals().get(parts[0]): + # It's a joined table column + joined_table = locals().get(parts[0]) or globals().get(parts[0]) + column = getattr(joined_table, parts[1]) + else: + # It's a JSON field on the main table + column = getattr(cls, parts[0]) + column = column.op("->>")(parts[1]) + elif len(parts) == 3: + table_name, column_name, json_key = parts joined_table = locals().get(table_name) or globals().get(table_name) column = getattr(joined_table, column_name) + column = column.op("->>")(json_key) else: - # Handle columns from main table - column = getattr(cls, key) + raise ValueError(f"Unhandled column name {key}") if isinstance(value, (list, tuple, set)): query = query.where(column.in_(value)) diff --git a/letta/server/rest_api/routers/v1/jobs.py b/letta/server/rest_api/routers/v1/jobs.py index d0c7a2e6..4c5595fa 100644 --- a/letta/server/rest_api/routers/v1/jobs.py +++ b/letta/server/rest_api/routers/v1/jobs.py @@ -23,26 +23,23 @@ async def list_jobs( actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) # TODO: add filtering by status - jobs = await server.job_manager.list_jobs_async(actor=actor) - - if source_id: - # can't be in the ORM since we have source_id stored in the metadata - # TODO: Probably change this - jobs = [job for job in jobs if job.metadata.get("source_id") == source_id] - return jobs + return await server.job_manager.list_jobs_async( + actor=actor, + source_id=source_id, + ) @router.get("/active", response_model=List[Job], operation_id="list_active_jobs") async def list_active_jobs( server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."), ): """ List all active jobs. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - - return await server.job_manager.list_jobs_async(actor=actor, statuses=[JobStatus.created, JobStatus.running]) + return await server.job_manager.list_jobs_async(actor=actor, statuses=[JobStatus.created, JobStatus.running], source_id=source_id) @router.get("/{job_id}", response_model=Job, operation_id="retrieve_job") diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index 1cdc5c58..11361b35 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -169,6 +169,7 @@ class JobManager: statuses: Optional[List[JobStatus]] = None, job_type: JobType = JobType.JOB, ascending: bool = True, + source_id: Optional[str] = None, ) -> List[PydanticJob]: """List all jobs with optional pagination and status filter.""" async with db_registry.async_session() as session: @@ -178,6 +179,9 @@ class JobManager: if statuses: filter_kwargs["status"] = statuses + if source_id: + filter_kwargs["metadata_.source_id"] = source_id + jobs = await JobModel.list_async( db_session=session, before=before, diff --git a/tests/test_managers.py b/tests/test_managers.py index c9f578e6..f268c76b 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -4339,6 +4339,15 @@ async def test_list_jobs(server: SyncServer, default_user, event_loop): assert all(job.metadata["type"].startswith("test") for job in jobs) +async def test_list_jobs_with_metadata(server: SyncServer, default_user, event_loop): + for i in range(3): + job_data = PydanticJob(status=JobStatus.created, metadata={"source_id": f"source-test-{i}"}) + await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user) + jobs = await server.job_manager.list_jobs_async(actor=default_user, source_id="source-test-2") + assert len(jobs) == 1 + assert jobs[0].metadata["source_id"] == "source-test-2" + + @pytest.mark.asyncio async def test_update_job_by_id(server: SyncServer, default_user, event_loop): """Test updating a job by its ID."""