From f4084366698cb9e31854294ca3a7b14245fb9c17 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Mon, 14 Oct 2024 13:29:07 -0700 Subject: [PATCH] feat: refactor the `POST` `agent/messages` API to take multiple messages (#1882) --- letta/schemas/letta_request.py | 6 +- letta/schemas/message.py | 8 ++- letta/server/rest_api/routers/v1/agents.py | 26 ++++---- letta/server/server.py | 69 +++++++++++++++++++++- 4 files changed, 87 insertions(+), 22 deletions(-) diff --git a/letta/schemas/letta_request.py b/letta/schemas/letta_request.py index a6e49d8b..3c0a3ef9 100644 --- a/letta/schemas/letta_request.py +++ b/letta/schemas/letta_request.py @@ -1,13 +1,13 @@ -from typing import List +from typing import List, Union from pydantic import BaseModel, Field from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG -from letta.schemas.message import MessageCreate +from letta.schemas.message import Message, MessageCreate class LettaRequest(BaseModel): - messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.") + messages: Union[List[MessageCreate], List[Message]] = Field(..., description="The messages to be sent to the agent.") run_async: bool = Field(default=False, description="Whether to asynchronously send the messages to the agent.") # TODO: implement stream_steps: bool = Field( diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 8ce7d7b5..fa7f0be8 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -2,7 +2,7 @@ import copy import json import warnings from datetime import datetime, timezone -from typing import List, Optional +from typing import List, Literal, Optional from pydantic import Field, field_validator @@ -57,7 +57,11 @@ class BaseMessage(LettaBase): class MessageCreate(BaseMessage): """Request to create a message""" - role: MessageRole = Field(..., description="The role of the participant.") + # In the simplified format, only allow simple roles + role: Literal[ + MessageRole.user, + MessageRole.system, + ] = Field(..., description="The role of the participant.") text: str = Field(..., description="The text of the message.") name: Optional[str] = Field(None, description="The name of the participant.") diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 00e1cce9..6d42c2fb 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -8,7 +8,7 @@ from starlette.responses import StreamingResponse from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState -from letta.schemas.enums import MessageRole, MessageStreamStatus +from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import ( LegacyLettaMessage, LettaMessage, @@ -23,7 +23,7 @@ from letta.schemas.memory import ( Memory, RecallMemorySummary, ) -from letta.schemas.message import Message, UpdateMessage +from letta.schemas.message import Message, MessageCreate, UpdateMessage from letta.schemas.passage import Passage from letta.schemas.source import Source from letta.server.rest_api.interface import StreamingServerInterface @@ -326,14 +326,15 @@ async def send_message( # TODO(charles): support sending multiple messages assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}" - message = request.messages[0] + request.messages[0] return await send_message_to_agent( server=server, agent_id=agent_id, user_id=actor.id, - role=message.role, - message=message.text, + # role=message.role, + # message=message.text, + messages=request.messages, stream_steps=request.stream_steps, stream_tokens=request.stream_tokens, return_message_object=request.return_message_object, @@ -349,8 +350,8 @@ async def send_message_to_agent( server: SyncServer, agent_id: str, user_id: str, - role: MessageRole, - message: str, + # role: MessageRole, + messages: Union[List[Message], List[MessageCreate]], stream_steps: bool, stream_tokens: bool, # related to whether or not we return `LettaMessage`s or `Message`s @@ -367,14 +368,6 @@ async def send_message_to_agent( # TODO: @charles is this the correct way to handle? include_final_message = True - # determine role - if role == MessageRole.user: - message_func = server.user_message - elif role == MessageRole.system: - message_func = server.system_message - else: - raise HTTPException(status_code=500, detail=f"Bad role {role}") - if not stream_steps and stream_tokens: raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'") @@ -413,7 +406,8 @@ async def send_message_to_agent( # Offload the synchronous message_func to a separate thread streaming_interface.stream_start() task = asyncio.create_task( - asyncio.to_thread(message_func, user_id=user_id, agent_id=agent_id, message=message, timestamp=timestamp) + # asyncio.to_thread(message_func, user_id=user_id, agent_id=agent_id, message=message, timestamp=timestamp) + asyncio.to_thread(server.send_messages, user_id=user_id, agent_id=agent_id, messages=messages) ) if stream_steps: diff --git a/letta/server/server.py b/letta/server/server.py index 68809fca..73254edf 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -72,7 +72,7 @@ from letta.schemas.job import Job from letta.schemas.letta_message import LettaMessage from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary -from letta.schemas.message import Message, UpdateMessage +from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.organization import Organization, OrganizationCreate from letta.schemas.passage import Passage @@ -141,6 +141,11 @@ class Server(object): """Process a message from the system, internally calls step""" raise NotImplementedError + @abstractmethod + def send_messages(self, user_id: str, agent_id: str, messages: Union[MessageCreate, List[Message]]) -> None: + """Send a list of messages to the agent""" + raise NotImplementedError + @abstractmethod def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]: """Run a command on the agent, e.g. /memory @@ -725,6 +730,68 @@ class SyncServer(Server): # Run the agent state forward return self._step(user_id=user_id, agent_id=agent_id, input_messages=message) + def send_messages( + self, + user_id: str, + agent_id: str, + messages: Union[List[MessageCreate], List[Message]], + # whether or not to wrap user and system message as MemGPT-style stringified JSON + wrap_user_message: bool = True, + wrap_system_message: bool = True, + ) -> LettaUsageStatistics: + """Send a list of messages to the agent + + If the messages are of type MessageCreate, we need to turn them into + Message objects first before sending them through step. + + Otherwise, we can pass them in directly. + """ + if self.ms.get_user(user_id=user_id) is None: + raise ValueError(f"User user_id={user_id} does not exist") + if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: + raise ValueError(f"Agent agent_id={agent_id} does not exist") + + message_objects: List[Message] = [] + + if all(isinstance(m, MessageCreate) for m in messages): + for message in messages: + assert isinstance(message, MessageCreate) + + # If wrapping is eanbled, wrap with metadata before placing content inside the Message object + if message.role == MessageRole.user and wrap_user_message: + message.text = system.package_user_message(user_message=message.text) + elif message.role == MessageRole.system and wrap_system_message: + message.text = system.package_system_message(system_message=message.text) + else: + raise ValueError(f"Invalid message role: {message.role}") + + # Create the Message object + message_objects.append( + Message( + user_id=user_id, + agent_id=agent_id, + role=message.role, + text=message.text, + name=message.name, + # assigned later? + model=None, + # irrelevant + tool_calls=None, + tool_call_id=None, + ) + ) + + elif all(isinstance(m, Message) for m in messages): + for message in messages: + assert isinstance(message, Message) + message_objects.append(message) + + else: + raise ValueError(f"All messages must be of type Message or MessageCreate, got {type(messages)}") + + # Run the agent state forward + return self._step(user_id=user_id, agent_id=agent_id, input_messages=message_objects) + # @LockingServer.agent_lock_decorator def run_command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics: """Run a command on the agent"""