feat: Create POST v1/agents/messages/batches (#1722)

This commit is contained in:
Matthew Zhou
2025-04-15 14:36:41 -07:00
committed by GitHub
parent 3593abe677
commit bbd4b087d3
2 changed files with 41 additions and 3 deletions

View File

@@ -117,11 +117,13 @@ class LettaAgentBatch:
self.max_steps = max_steps
async def step_until_request(
self, batch_requests: List[LettaBatchRequest], agent_step_state_mapping: Dict[str, AgentStepState]
self, batch_requests: List[LettaBatchRequest], agent_step_state_mapping: Optional[Dict[str, AgentStepState]] = None
) -> LettaBatchResponse:
# Basic checks
if not batch_requests:
raise ValueError("Empty list of batch_requests passed in!")
if agent_step_state_mapping is None:
agent_step_state_mapping = {}
agent_messages_mapping: Dict[str, List[Message]] = {}
agent_tools_mapping: Dict[str, List[dict]] = {}
@@ -134,6 +136,13 @@ class LettaAgentBatch:
agent_messages_mapping[agent_id] = self._get_in_context_messages_per_agent(
agent_state=agent_state, input_messages=batch_request.messages
)
# TODO: Think about a cleaner way to do this?
if agent_id not in agent_step_state_mapping:
agent_step_state_mapping[agent_id] = AgentStepState(
step_number=0, tool_rules_solver=ToolRulesSolver(tool_rules=agent_state.tool_rules)
)
agent_tools_mapping[agent_id] = self._prepare_tools_per_agent(
agent_state, agent_step_state_mapping.get(agent_id).tool_rules_solver
)

View File

@@ -12,6 +12,7 @@ from sqlalchemy.exc import IntegrityError, OperationalError
from starlette.responses import Response, StreamingResponse
from letta.agents.letta_agent import LettaAgent
from letta.agents.letta_agent_batch import LettaAgentBatch
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.log import get_logger
from letta.orm.errors import NoResultFound
@@ -20,8 +21,8 @@ from letta.schemas.block import Block, BlockUpdate
from letta.schemas.group import Group
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_request import LettaBatchRequest, LettaRequest, LettaStreamingRequest
from letta.schemas.letta_response import LettaBatchResponse, LettaResponse
from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory
from letta.schemas.message import MessageCreate
from letta.schemas.passage import Passage, PassageUpdate
@@ -818,3 +819,31 @@ async def list_agent_groups(
actor = server.user_manager.get_user_or_default(user_id=actor_id)
print("in list agents with manager_type", manager_type)
return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor)
# Batch APIs
@router.post("/messages/batches", response_model=LettaBatchResponse, operation_id="create_batch_message_request")
async def send_batch_messages(
batch_requests: List[LettaBatchRequest] = Body(..., description="Messages and config for all agents"),
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Submit a batch of agent messages for asynchronous processing.
Creates a job that will fan out messages to all listed agents and process them in parallel.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
batch_runner = LettaAgentBatch(
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
batch_manager=server.batch_manager,
sandbox_config_manager=server.sandbox_config_manager,
actor=actor,
)
return await batch_runner.step_until_request(batch_requests=batch_requests)