feat: job listing takes source id (#2473)
This commit is contained in:
@@ -280,15 +280,28 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.id == identity_id)
|
||||
|
||||
# Apply filtering logic from kwargs
|
||||
# 1 part: <column> // 2 parts: <table>.<column> OR <column>.<json_key> // 3 parts: <table>.<column>.<json_key>
|
||||
# TODO (cliandy): can make this more robust down the line
|
||||
for key, value in kwargs.items():
|
||||
if "." in key:
|
||||
# Handle joined table columns
|
||||
table_name, column_name = key.split(".")
|
||||
parts = key.split(".")
|
||||
if len(parts) == 1:
|
||||
column = getattr(cls, key)
|
||||
elif len(parts) == 2:
|
||||
if locals().get(parts[0]) or globals().get(parts[0]):
|
||||
# It's a joined table column
|
||||
joined_table = locals().get(parts[0]) or globals().get(parts[0])
|
||||
column = getattr(joined_table, parts[1])
|
||||
else:
|
||||
# It's a JSON field on the main table
|
||||
column = getattr(cls, parts[0])
|
||||
column = column.op("->>")(parts[1])
|
||||
elif len(parts) == 3:
|
||||
table_name, column_name, json_key = parts
|
||||
joined_table = locals().get(table_name) or globals().get(table_name)
|
||||
column = getattr(joined_table, column_name)
|
||||
column = column.op("->>")(json_key)
|
||||
else:
|
||||
# Handle columns from main table
|
||||
column = getattr(cls, key)
|
||||
raise ValueError(f"Unhandled column name {key}")
|
||||
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
query = query.where(column.in_(value))
|
||||
|
||||
@@ -23,26 +23,23 @@ async def list_jobs(
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
# TODO: add filtering by status
|
||||
jobs = await server.job_manager.list_jobs_async(actor=actor)
|
||||
|
||||
if source_id:
|
||||
# can't be in the ORM since we have source_id stored in the metadata
|
||||
# TODO: Probably change this
|
||||
jobs = [job for job in jobs if job.metadata.get("source_id") == source_id]
|
||||
return jobs
|
||||
return await server.job_manager.list_jobs_async(
|
||||
actor=actor,
|
||||
source_id=source_id,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/active", response_model=List[Job], operation_id="list_active_jobs")
|
||||
async def list_active_jobs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."),
|
||||
):
|
||||
"""
|
||||
List all active jobs.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
return await server.job_manager.list_jobs_async(actor=actor, statuses=[JobStatus.created, JobStatus.running])
|
||||
return await server.job_manager.list_jobs_async(actor=actor, statuses=[JobStatus.created, JobStatus.running], source_id=source_id)
|
||||
|
||||
|
||||
@router.get("/{job_id}", response_model=Job, operation_id="retrieve_job")
|
||||
|
||||
@@ -169,6 +169,7 @@ class JobManager:
|
||||
statuses: Optional[List[JobStatus]] = None,
|
||||
job_type: JobType = JobType.JOB,
|
||||
ascending: bool = True,
|
||||
source_id: Optional[str] = None,
|
||||
) -> List[PydanticJob]:
|
||||
"""List all jobs with optional pagination and status filter."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -178,6 +179,9 @@ class JobManager:
|
||||
if statuses:
|
||||
filter_kwargs["status"] = statuses
|
||||
|
||||
if source_id:
|
||||
filter_kwargs["metadata_.source_id"] = source_id
|
||||
|
||||
jobs = await JobModel.list_async(
|
||||
db_session=session,
|
||||
before=before,
|
||||
|
||||
@@ -4339,6 +4339,15 @@ async def test_list_jobs(server: SyncServer, default_user, event_loop):
|
||||
assert all(job.metadata["type"].startswith("test") for job in jobs)
|
||||
|
||||
|
||||
async def test_list_jobs_with_metadata(server: SyncServer, default_user, event_loop):
|
||||
for i in range(3):
|
||||
job_data = PydanticJob(status=JobStatus.created, metadata={"source_id": f"source-test-{i}"})
|
||||
await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user)
|
||||
jobs = await server.job_manager.list_jobs_async(actor=default_user, source_id="source-test-2")
|
||||
assert len(jobs) == 1
|
||||
assert jobs[0].metadata["source_id"] == "source-test-2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_job_by_id(server: SyncServer, default_user, event_loop):
|
||||
"""Test updating a job by its ID."""
|
||||
|
||||
Reference in New Issue
Block a user