feat: job listing takes source id (#2473)

This commit is contained in:
Andy Li
2025-05-27 15:34:27 -07:00
committed by GitHub
parent f308aa2b80
commit 08c1fb5a60
4 changed files with 37 additions and 14 deletions

View File

@@ -280,15 +280,28 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.id == identity_id)
# Apply filtering logic from kwargs
# 1 part: <column> // 2 parts: <table>.<column> OR <column>.<json_key> // 3 parts: <table>.<column>.<json_key>
# TODO (cliandy): can make this more robust down the line
for key, value in kwargs.items():
if "." in key:
# Handle joined table columns
table_name, column_name = key.split(".")
parts = key.split(".")
if len(parts) == 1:
column = getattr(cls, key)
elif len(parts) == 2:
if locals().get(parts[0]) or globals().get(parts[0]):
# It's a joined table column
joined_table = locals().get(parts[0]) or globals().get(parts[0])
column = getattr(joined_table, parts[1])
else:
# It's a JSON field on the main table
column = getattr(cls, parts[0])
column = column.op("->>")(parts[1])
elif len(parts) == 3:
table_name, column_name, json_key = parts
joined_table = locals().get(table_name) or globals().get(table_name)
column = getattr(joined_table, column_name)
column = column.op("->>")(json_key)
else:
# Handle columns from main table
column = getattr(cls, key)
raise ValueError(f"Unhandled column name {key}")
if isinstance(value, (list, tuple, set)):
query = query.where(column.in_(value))

View File

@@ -23,26 +23,23 @@ async def list_jobs(
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# TODO: add filtering by status
jobs = await server.job_manager.list_jobs_async(actor=actor)
if source_id:
# can't be in the ORM since we have source_id stored in the metadata
# TODO: Probably change this
jobs = [job for job in jobs if job.metadata.get("source_id") == source_id]
return jobs
return await server.job_manager.list_jobs_async(
actor=actor,
source_id=source_id,
)
@router.get("/active", response_model=List[Job], operation_id="list_active_jobs")
async def list_active_jobs(
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."),
):
"""
List all active jobs.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.job_manager.list_jobs_async(actor=actor, statuses=[JobStatus.created, JobStatus.running])
return await server.job_manager.list_jobs_async(actor=actor, statuses=[JobStatus.created, JobStatus.running], source_id=source_id)
@router.get("/{job_id}", response_model=Job, operation_id="retrieve_job")

View File

@@ -169,6 +169,7 @@ class JobManager:
statuses: Optional[List[JobStatus]] = None,
job_type: JobType = JobType.JOB,
ascending: bool = True,
source_id: Optional[str] = None,
) -> List[PydanticJob]:
"""List all jobs with optional pagination and status filter."""
async with db_registry.async_session() as session:
@@ -178,6 +179,9 @@ class JobManager:
if statuses:
filter_kwargs["status"] = statuses
if source_id:
filter_kwargs["metadata_.source_id"] = source_id
jobs = await JobModel.list_async(
db_session=session,
before=before,

View File

@@ -4339,6 +4339,15 @@ async def test_list_jobs(server: SyncServer, default_user, event_loop):
assert all(job.metadata["type"].startswith("test") for job in jobs)
async def test_list_jobs_with_metadata(server: SyncServer, default_user, event_loop):
for i in range(3):
job_data = PydanticJob(status=JobStatus.created, metadata={"source_id": f"source-test-{i}"})
await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user)
jobs = await server.job_manager.list_jobs_async(actor=default_user, source_id="source-test-2")
assert len(jobs) == 1
assert jobs[0].metadata["source_id"] == "source-test-2"
@pytest.mark.asyncio
async def test_update_job_by_id(server: SyncServer, default_user, event_loop):
"""Test updating a job by its ID."""