feat: refactor the POST agent/messages API to take multiple messages (#1882)

This commit is contained in:
Charles Packer
2024-10-14 13:29:07 -07:00
committed by GitHub
parent 93aacc087e
commit f408436669
4 changed files with 87 additions and 22 deletions

View File

@@ -1,13 +1,13 @@
from typing import List
from typing import List, Union
from pydantic import BaseModel, Field
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.schemas.message import MessageCreate
from letta.schemas.message import Message, MessageCreate
class LettaRequest(BaseModel):
messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.")
messages: Union[List[MessageCreate], List[Message]] = Field(..., description="The messages to be sent to the agent.")
run_async: bool = Field(default=False, description="Whether to asynchronously send the messages to the agent.") # TODO: implement
stream_steps: bool = Field(

View File

@@ -2,7 +2,7 @@ import copy
import json
import warnings
from datetime import datetime, timezone
from typing import List, Optional
from typing import List, Literal, Optional
from pydantic import Field, field_validator
@@ -57,7 +57,11 @@ class BaseMessage(LettaBase):
class MessageCreate(BaseMessage):
"""Request to create a message"""
role: MessageRole = Field(..., description="The role of the participant.")
# In the simplified format, only allow simple roles
role: Literal[
MessageRole.user,
MessageRole.system,
] = Field(..., description="The role of the participant.")
text: str = Field(..., description="The text of the message.")
name: Optional[str] = Field(None, description="The name of the participant.")

View File

@@ -8,7 +8,7 @@ from starlette.responses import StreamingResponse
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
from letta.schemas.enums import MessageRole, MessageStreamStatus
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import (
LegacyLettaMessage,
LettaMessage,
@@ -23,7 +23,7 @@ from letta.schemas.memory import (
Memory,
RecallMemorySummary,
)
from letta.schemas.message import Message, UpdateMessage
from letta.schemas.message import Message, MessageCreate, UpdateMessage
from letta.schemas.passage import Passage
from letta.schemas.source import Source
from letta.server.rest_api.interface import StreamingServerInterface
@@ -326,14 +326,15 @@ async def send_message(
# TODO(charles): support sending multiple messages
assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}"
message = request.messages[0]
request.messages[0]
return await send_message_to_agent(
server=server,
agent_id=agent_id,
user_id=actor.id,
role=message.role,
message=message.text,
# role=message.role,
# message=message.text,
messages=request.messages,
stream_steps=request.stream_steps,
stream_tokens=request.stream_tokens,
return_message_object=request.return_message_object,
@@ -349,8 +350,8 @@ async def send_message_to_agent(
server: SyncServer,
agent_id: str,
user_id: str,
role: MessageRole,
message: str,
# role: MessageRole,
messages: Union[List[Message], List[MessageCreate]],
stream_steps: bool,
stream_tokens: bool,
# related to whether or not we return `LettaMessage`s or `Message`s
@@ -367,14 +368,6 @@ async def send_message_to_agent(
# TODO: @charles is this the correct way to handle?
include_final_message = True
# determine role
if role == MessageRole.user:
message_func = server.user_message
elif role == MessageRole.system:
message_func = server.system_message
else:
raise HTTPException(status_code=500, detail=f"Bad role {role}")
if not stream_steps and stream_tokens:
raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'")
@@ -413,7 +406,8 @@ async def send_message_to_agent(
# Offload the synchronous message_func to a separate thread
streaming_interface.stream_start()
task = asyncio.create_task(
asyncio.to_thread(message_func, user_id=user_id, agent_id=agent_id, message=message, timestamp=timestamp)
# asyncio.to_thread(message_func, user_id=user_id, agent_id=agent_id, message=message, timestamp=timestamp)
asyncio.to_thread(server.send_messages, user_id=user_id, agent_id=agent_id, messages=messages)
)
if stream_steps:

View File

@@ -72,7 +72,7 @@ from letta.schemas.job import Job
from letta.schemas.letta_message import LettaMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
from letta.schemas.message import Message, UpdateMessage
from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.organization import Organization, OrganizationCreate
from letta.schemas.passage import Passage
@@ -141,6 +141,11 @@ class Server(object):
"""Process a message from the system, internally calls step"""
raise NotImplementedError
@abstractmethod
def send_messages(self, user_id: str, agent_id: str, messages: Union[MessageCreate, List[Message]]) -> None:
"""Send a list of messages to the agent"""
raise NotImplementedError
@abstractmethod
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
"""Run a command on the agent, e.g. /memory
@@ -725,6 +730,68 @@ class SyncServer(Server):
# Run the agent state forward
return self._step(user_id=user_id, agent_id=agent_id, input_messages=message)
def send_messages(
self,
user_id: str,
agent_id: str,
messages: Union[List[MessageCreate], List[Message]],
# whether or not to wrap user and system message as MemGPT-style stringified JSON
wrap_user_message: bool = True,
wrap_system_message: bool = True,
) -> LettaUsageStatistics:
"""Send a list of messages to the agent
If the messages are of type MessageCreate, we need to turn them into
Message objects first before sending them through step.
Otherwise, we can pass them in directly.
"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")
message_objects: List[Message] = []
if all(isinstance(m, MessageCreate) for m in messages):
for message in messages:
assert isinstance(message, MessageCreate)
# If wrapping is eanbled, wrap with metadata before placing content inside the Message object
if message.role == MessageRole.user and wrap_user_message:
message.text = system.package_user_message(user_message=message.text)
elif message.role == MessageRole.system and wrap_system_message:
message.text = system.package_system_message(system_message=message.text)
else:
raise ValueError(f"Invalid message role: {message.role}")
# Create the Message object
message_objects.append(
Message(
user_id=user_id,
agent_id=agent_id,
role=message.role,
text=message.text,
name=message.name,
# assigned later?
model=None,
# irrelevant
tool_calls=None,
tool_call_id=None,
)
)
elif all(isinstance(m, Message) for m in messages):
for message in messages:
assert isinstance(message, Message)
message_objects.append(message)
else:
raise ValueError(f"All messages must be of type Message or MessageCreate, got {type(messages)}")
# Run the agent state forward
return self._step(user_id=user_id, agent_id=agent_id, input_messages=message_objects)
# @LockingServer.agent_lock_decorator
def run_command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics:
"""Run a command on the agent"""