From bbd4b087d30f711c729adadbd40741a00d322d69 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 15 Apr 2025 14:36:41 -0700 Subject: [PATCH] feat: Create POST v1/agents/messages/batches (#1722) --- letta/agents/letta_agent_batch.py | 11 +++++++- letta/server/rest_api/routers/v1/agents.py | 33 ++++++++++++++++++++-- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index 7d2f3c30..b792ae3f 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -117,11 +117,13 @@ class LettaAgentBatch: self.max_steps = max_steps async def step_until_request( - self, batch_requests: List[LettaBatchRequest], agent_step_state_mapping: Dict[str, AgentStepState] + self, batch_requests: List[LettaBatchRequest], agent_step_state_mapping: Optional[Dict[str, AgentStepState]] = None ) -> LettaBatchResponse: # Basic checks if not batch_requests: raise ValueError("Empty list of batch_requests passed in!") + if agent_step_state_mapping is None: + agent_step_state_mapping = {} agent_messages_mapping: Dict[str, List[Message]] = {} agent_tools_mapping: Dict[str, List[dict]] = {} @@ -134,6 +136,13 @@ class LettaAgentBatch: agent_messages_mapping[agent_id] = self._get_in_context_messages_per_agent( agent_state=agent_state, input_messages=batch_request.messages ) + + # TODO: Think about a cleaner way to do this? + if agent_id not in agent_step_state_mapping: + agent_step_state_mapping[agent_id] = AgentStepState( + step_number=0, tool_rules_solver=ToolRulesSolver(tool_rules=agent_state.tool_rules) + ) + agent_tools_mapping[agent_id] = self._prepare_tools_per_agent( agent_state, agent_step_state_mapping.get(agent_id).tool_rules_solver ) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 35dcd703..c5ec3533 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -12,6 +12,7 @@ from sqlalchemy.exc import IntegrityError, OperationalError from starlette.responses import Response, StreamingResponse from letta.agents.letta_agent import LettaAgent +from letta.agents.letta_agent_batch import LettaAgentBatch from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.log import get_logger from letta.orm.errors import NoResultFound @@ -20,8 +21,8 @@ from letta.schemas.block import Block, BlockUpdate from letta.schemas.group import Group from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion -from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest -from letta.schemas.letta_response import LettaResponse +from letta.schemas.letta_request import LettaBatchRequest, LettaRequest, LettaStreamingRequest +from letta.schemas.letta_response import LettaBatchResponse, LettaResponse from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory from letta.schemas.message import MessageCreate from letta.schemas.passage import Passage, PassageUpdate @@ -818,3 +819,31 @@ async def list_agent_groups( actor = server.user_manager.get_user_or_default(user_id=actor_id) print("in list agents with manager_type", manager_type) return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor) + + +# Batch APIs + + +@router.post("/messages/batches", response_model=LettaBatchResponse, operation_id="create_batch_message_request") +async def send_batch_messages( + batch_requests: List[LettaBatchRequest] = Body(..., description="Messages and config for all agents"), + server: SyncServer = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), +): + """ + Submit a batch of agent messages for asynchronous processing. + Creates a job that will fan out messages to all listed agents and process them in parallel. + """ + actor = server.user_manager.get_user_or_default(user_id=actor_id) + + batch_runner = LettaAgentBatch( + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + passage_manager=server.passage_manager, + batch_manager=server.batch_manager, + sandbox_config_manager=server.sandbox_config_manager, + actor=actor, + ) + + return await batch_runner.step_until_request(batch_requests=batch_requests)