diff --git a/letta/groups/sleeptime_multi_agent.py b/letta/groups/sleeptime_multi_agent.py index 4dc27ab7..336591f1 100644 --- a/letta/groups/sleeptime_multi_agent.py +++ b/letta/groups/sleeptime_multi_agent.py @@ -155,7 +155,10 @@ class SleeptimeMultiAgent(Agent): job_update = JobUpdate( status=JobStatus.completed, completed_at=datetime.utcnow(), - metadata={"result": result.model_dump(mode="json")}, # Store the result in metadata + metadata={ + "result": result.model_dump(mode="json"), + "agent_id": participant_agent.agent_state.id, + }, ) self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.user) return result diff --git a/letta/server/rest_api/routers/v1/runs.py b/letta/server/rest_api/routers/v1/runs.py index 91a7dbf4..fd7e5131 100644 --- a/letta/server/rest_api/routers/v1/runs.py +++ b/letta/server/rest_api/routers/v1/runs.py @@ -19,6 +19,7 @@ router = APIRouter(prefix="/runs", tags=["runs"]) @router.get("/", response_model=List[Run], operation_id="list_runs") def list_runs( server: "SyncServer" = Depends(get_letta_server), + agent_ids: Optional[List[str]] = Query(None, description="The unique identifier of the agent associated with the run."), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ @@ -26,12 +27,18 @@ def list_runs( """ actor = server.user_manager.get_user_or_default(user_id=actor_id) - return [Run.from_job(job) for job in server.job_manager.list_jobs(actor=actor, job_type=JobType.RUN)] + runs = [Run.from_job(job) for job in server.job_manager.list_jobs(actor=actor, job_type=JobType.RUN)] + + if not agent_ids: + return runs + + return [run for run in runs if "agent_id" in run.metadata and run.metadata["agent_id"] in agent_ids] @router.get("/active", response_model=List[Run], operation_id="list_active_runs") def list_active_runs( server: "SyncServer" = Depends(get_letta_server), + agent_ids: Optional[List[str]] = Query(None, description="The unique identifier of the agent associated with the run."), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ @@ -41,7 +48,12 @@ def list_active_runs( active_runs = server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.RUN) - return [Run.from_job(job) for job in active_runs] + active_runs = [Run.from_job(job) for job in active_runs] + + if not agent_ids: + return active_runs + + return [run for run in active_runs if "agent_id" in run.metadata and run.metadata["agent_id"] in agent_ids] @router.get("/{run_id}", response_model=Run, operation_id="retrieve_run") diff --git a/tests/test_multi_agent.py b/tests/test_multi_agent.py index 3e8cb276..4180c202 100644 --- a/tests/test_multi_agent.py +++ b/tests/test_multi_agent.py @@ -1,10 +1,9 @@ -import time - import pytest from sqlalchemy import delete from letta.config import LettaConfig from letta.orm import Provider, Step +from letta.orm.enums import JobType from letta.orm.errors import NoResultFound from letta.schemas.agent import CreateAgent from letta.schemas.block import CreateBlock @@ -19,6 +18,7 @@ from letta.schemas.group import ( SupervisorManager, ) from letta.schemas.message import MessageCreate +from letta.schemas.run import Run from letta.server.server import SyncServer @@ -518,7 +518,10 @@ async def test_sleeptime_group_chat(server, actor): assert len(response.usage.run_ids or []) == i % 2 run_ids.extend(response.usage.run_ids or []) - time.sleep(5) + jobs = server.job_manager.list_jobs(actor=actor, job_type=JobType.RUN) + runs = [Run.from_job(job) for job in jobs] + agent_runs = [run for run in runs if "agent_id" in run.metadata and run.metadata["agent_id"] == sleeptime_agent_id] + assert len(agent_runs) == len(run_ids) for run_id in run_ids: job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor)