feat: list active runs by agent id (#1572)
This commit is contained in:
@@ -155,7 +155,10 @@ class SleeptimeMultiAgent(Agent):
|
||||
job_update = JobUpdate(
|
||||
status=JobStatus.completed,
|
||||
completed_at=datetime.utcnow(),
|
||||
metadata={"result": result.model_dump(mode="json")}, # Store the result in metadata
|
||||
metadata={
|
||||
"result": result.model_dump(mode="json"),
|
||||
"agent_id": participant_agent.agent_state.id,
|
||||
},
|
||||
)
|
||||
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.user)
|
||||
return result
|
||||
|
||||
@@ -19,6 +19,7 @@ router = APIRouter(prefix="/runs", tags=["runs"])
|
||||
@router.get("/", response_model=List[Run], operation_id="list_runs")
|
||||
def list_runs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
agent_ids: Optional[List[str]] = Query(None, description="The unique identifier of the agent associated with the run."),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
@@ -26,12 +27,18 @@ def list_runs(
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return [Run.from_job(job) for job in server.job_manager.list_jobs(actor=actor, job_type=JobType.RUN)]
|
||||
runs = [Run.from_job(job) for job in server.job_manager.list_jobs(actor=actor, job_type=JobType.RUN)]
|
||||
|
||||
if not agent_ids:
|
||||
return runs
|
||||
|
||||
return [run for run in runs if "agent_id" in run.metadata and run.metadata["agent_id"] in agent_ids]
|
||||
|
||||
|
||||
@router.get("/active", response_model=List[Run], operation_id="list_active_runs")
|
||||
def list_active_runs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
agent_ids: Optional[List[str]] = Query(None, description="The unique identifier of the agent associated with the run."),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
@@ -41,7 +48,12 @@ def list_active_runs(
|
||||
|
||||
active_runs = server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.RUN)
|
||||
|
||||
return [Run.from_job(job) for job in active_runs]
|
||||
active_runs = [Run.from_job(job) for job in active_runs]
|
||||
|
||||
if not agent_ids:
|
||||
return active_runs
|
||||
|
||||
return [run for run in active_runs if "agent_id" in run.metadata and run.metadata["agent_id"] in agent_ids]
|
||||
|
||||
|
||||
@router.get("/{run_id}", response_model=Run, operation_id="retrieve_run")
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.orm import Provider, Step
|
||||
from letta.orm.enums import JobType
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.block import CreateBlock
|
||||
@@ -19,6 +18,7 @@ from letta.schemas.group import (
|
||||
SupervisorManager,
|
||||
)
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.run import Run
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
|
||||
@@ -518,7 +518,10 @@ async def test_sleeptime_group_chat(server, actor):
|
||||
assert len(response.usage.run_ids or []) == i % 2
|
||||
run_ids.extend(response.usage.run_ids or [])
|
||||
|
||||
time.sleep(5)
|
||||
jobs = server.job_manager.list_jobs(actor=actor, job_type=JobType.RUN)
|
||||
runs = [Run.from_job(job) for job in jobs]
|
||||
agent_runs = [run for run in runs if "agent_id" in run.metadata and run.metadata["agent_id"] == sleeptime_agent_id]
|
||||
assert len(agent_runs) == len(run_ids)
|
||||
|
||||
for run_id in run_ids:
|
||||
job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor)
|
||||
|
||||
Reference in New Issue
Block a user