feat: Create POST v1/agents/messages/batches (#1722)
This commit is contained in:
@@ -117,11 +117,13 @@ class LettaAgentBatch:
|
||||
self.max_steps = max_steps
|
||||
|
||||
async def step_until_request(
|
||||
self, batch_requests: List[LettaBatchRequest], agent_step_state_mapping: Dict[str, AgentStepState]
|
||||
self, batch_requests: List[LettaBatchRequest], agent_step_state_mapping: Optional[Dict[str, AgentStepState]] = None
|
||||
) -> LettaBatchResponse:
|
||||
# Basic checks
|
||||
if not batch_requests:
|
||||
raise ValueError("Empty list of batch_requests passed in!")
|
||||
if agent_step_state_mapping is None:
|
||||
agent_step_state_mapping = {}
|
||||
|
||||
agent_messages_mapping: Dict[str, List[Message]] = {}
|
||||
agent_tools_mapping: Dict[str, List[dict]] = {}
|
||||
@@ -134,6 +136,13 @@ class LettaAgentBatch:
|
||||
agent_messages_mapping[agent_id] = self._get_in_context_messages_per_agent(
|
||||
agent_state=agent_state, input_messages=batch_request.messages
|
||||
)
|
||||
|
||||
# TODO: Think about a cleaner way to do this?
|
||||
if agent_id not in agent_step_state_mapping:
|
||||
agent_step_state_mapping[agent_id] = AgentStepState(
|
||||
step_number=0, tool_rules_solver=ToolRulesSolver(tool_rules=agent_state.tool_rules)
|
||||
)
|
||||
|
||||
agent_tools_mapping[agent_id] = self._prepare_tools_per_agent(
|
||||
agent_state, agent_step_state_mapping.get(agent_id).tool_rules_solver
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.agents.letta_agent_batch import LettaAgentBatch
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
@@ -20,8 +21,8 @@ from letta.schemas.block import Block, BlockUpdate
|
||||
from letta.schemas.group import Group
|
||||
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion
|
||||
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_request import LettaBatchRequest, LettaRequest, LettaStreamingRequest
|
||||
from letta.schemas.letta_response import LettaBatchResponse, LettaResponse
|
||||
from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.passage import Passage, PassageUpdate
|
||||
@@ -818,3 +819,31 @@ async def list_agent_groups(
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
print("in list agents with manager_type", manager_type)
|
||||
return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor)
|
||||
|
||||
|
||||
# Batch APIs
|
||||
|
||||
|
||||
@router.post("/messages/batches", response_model=LettaBatchResponse, operation_id="create_batch_message_request")
|
||||
async def send_batch_messages(
|
||||
batch_requests: List[LettaBatchRequest] = Body(..., description="Messages and config for all agents"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Submit a batch of agent messages for asynchronous processing.
|
||||
Creates a job that will fan out messages to all listed agents and process them in parallel.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
batch_runner = LettaAgentBatch(
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
batch_manager=server.batch_manager,
|
||||
sandbox_config_manager=server.sandbox_config_manager,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
return await batch_runner.step_until_request(batch_requests=batch_requests)
|
||||
|
||||
Reference in New Issue
Block a user