feat: refactor the POST agent/messages API to take multiple messages (#1882)
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
|
||||
|
||||
class LettaRequest(BaseModel):
|
||||
messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.")
|
||||
messages: Union[List[MessageCreate], List[Message]] = Field(..., description="The messages to be sent to the agent.")
|
||||
run_async: bool = Field(default=False, description="Whether to asynchronously send the messages to the agent.") # TODO: implement
|
||||
|
||||
stream_steps: bool = Field(
|
||||
|
||||
@@ -2,7 +2,7 @@ import copy
|
||||
import json
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -57,7 +57,11 @@ class BaseMessage(LettaBase):
|
||||
class MessageCreate(BaseMessage):
|
||||
"""Request to create a message"""
|
||||
|
||||
role: MessageRole = Field(..., description="The role of the participant.")
|
||||
# In the simplified format, only allow simple roles
|
||||
role: Literal[
|
||||
MessageRole.user,
|
||||
MessageRole.system,
|
||||
] = Field(..., description="The role of the participant.")
|
||||
text: str = Field(..., description="The text of the message.")
|
||||
name: Optional[str] = Field(None, description="The name of the participant.")
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from starlette.responses import StreamingResponse
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
|
||||
from letta.schemas.enums import MessageRole, MessageStreamStatus
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import (
|
||||
LegacyLettaMessage,
|
||||
LettaMessage,
|
||||
@@ -23,7 +23,7 @@ from letta.schemas.memory import (
|
||||
Memory,
|
||||
RecallMemorySummary,
|
||||
)
|
||||
from letta.schemas.message import Message, UpdateMessage
|
||||
from letta.schemas.message import Message, MessageCreate, UpdateMessage
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.source import Source
|
||||
from letta.server.rest_api.interface import StreamingServerInterface
|
||||
@@ -326,14 +326,15 @@ async def send_message(
|
||||
|
||||
# TODO(charles): support sending multiple messages
|
||||
assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}"
|
||||
message = request.messages[0]
|
||||
request.messages[0]
|
||||
|
||||
return await send_message_to_agent(
|
||||
server=server,
|
||||
agent_id=agent_id,
|
||||
user_id=actor.id,
|
||||
role=message.role,
|
||||
message=message.text,
|
||||
# role=message.role,
|
||||
# message=message.text,
|
||||
messages=request.messages,
|
||||
stream_steps=request.stream_steps,
|
||||
stream_tokens=request.stream_tokens,
|
||||
return_message_object=request.return_message_object,
|
||||
@@ -349,8 +350,8 @@ async def send_message_to_agent(
|
||||
server: SyncServer,
|
||||
agent_id: str,
|
||||
user_id: str,
|
||||
role: MessageRole,
|
||||
message: str,
|
||||
# role: MessageRole,
|
||||
messages: Union[List[Message], List[MessageCreate]],
|
||||
stream_steps: bool,
|
||||
stream_tokens: bool,
|
||||
# related to whether or not we return `LettaMessage`s or `Message`s
|
||||
@@ -367,14 +368,6 @@ async def send_message_to_agent(
|
||||
# TODO: @charles is this the correct way to handle?
|
||||
include_final_message = True
|
||||
|
||||
# determine role
|
||||
if role == MessageRole.user:
|
||||
message_func = server.user_message
|
||||
elif role == MessageRole.system:
|
||||
message_func = server.system_message
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail=f"Bad role {role}")
|
||||
|
||||
if not stream_steps and stream_tokens:
|
||||
raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'")
|
||||
|
||||
@@ -413,7 +406,8 @@ async def send_message_to_agent(
|
||||
# Offload the synchronous message_func to a separate thread
|
||||
streaming_interface.stream_start()
|
||||
task = asyncio.create_task(
|
||||
asyncio.to_thread(message_func, user_id=user_id, agent_id=agent_id, message=message, timestamp=timestamp)
|
||||
# asyncio.to_thread(message_func, user_id=user_id, agent_id=agent_id, message=message, timestamp=timestamp)
|
||||
asyncio.to_thread(server.send_messages, user_id=user_id, agent_id=agent_id, messages=messages)
|
||||
)
|
||||
|
||||
if stream_steps:
|
||||
|
||||
@@ -72,7 +72,7 @@ from letta.schemas.job import Job
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
|
||||
from letta.schemas.message import Message, UpdateMessage
|
||||
from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.organization import Organization, OrganizationCreate
|
||||
from letta.schemas.passage import Passage
|
||||
@@ -141,6 +141,11 @@ class Server(object):
|
||||
"""Process a message from the system, internally calls step"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def send_messages(self, user_id: str, agent_id: str, messages: Union[MessageCreate, List[Message]]) -> None:
|
||||
"""Send a list of messages to the agent"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
|
||||
"""Run a command on the agent, e.g. /memory
|
||||
@@ -725,6 +730,68 @@ class SyncServer(Server):
|
||||
# Run the agent state forward
|
||||
return self._step(user_id=user_id, agent_id=agent_id, input_messages=message)
|
||||
|
||||
def send_messages(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
messages: Union[List[MessageCreate], List[Message]],
|
||||
# whether or not to wrap user and system message as MemGPT-style stringified JSON
|
||||
wrap_user_message: bool = True,
|
||||
wrap_system_message: bool = True,
|
||||
) -> LettaUsageStatistics:
|
||||
"""Send a list of messages to the agent
|
||||
|
||||
If the messages are of type MessageCreate, we need to turn them into
|
||||
Message objects first before sending them through step.
|
||||
|
||||
Otherwise, we can pass them in directly.
|
||||
"""
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
||||
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
||||
|
||||
message_objects: List[Message] = []
|
||||
|
||||
if all(isinstance(m, MessageCreate) for m in messages):
|
||||
for message in messages:
|
||||
assert isinstance(message, MessageCreate)
|
||||
|
||||
# If wrapping is eanbled, wrap with metadata before placing content inside the Message object
|
||||
if message.role == MessageRole.user and wrap_user_message:
|
||||
message.text = system.package_user_message(user_message=message.text)
|
||||
elif message.role == MessageRole.system and wrap_system_message:
|
||||
message.text = system.package_system_message(system_message=message.text)
|
||||
else:
|
||||
raise ValueError(f"Invalid message role: {message.role}")
|
||||
|
||||
# Create the Message object
|
||||
message_objects.append(
|
||||
Message(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
role=message.role,
|
||||
text=message.text,
|
||||
name=message.name,
|
||||
# assigned later?
|
||||
model=None,
|
||||
# irrelevant
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
)
|
||||
)
|
||||
|
||||
elif all(isinstance(m, Message) for m in messages):
|
||||
for message in messages:
|
||||
assert isinstance(message, Message)
|
||||
message_objects.append(message)
|
||||
|
||||
else:
|
||||
raise ValueError(f"All messages must be of type Message or MessageCreate, got {type(messages)}")
|
||||
|
||||
# Run the agent state forward
|
||||
return self._step(user_id=user_id, agent_id=agent_id, input_messages=message_objects)
|
||||
|
||||
# @LockingServer.agent_lock_decorator
|
||||
def run_command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics:
|
||||
"""Run a command on the agent"""
|
||||
|
||||
Reference in New Issue
Block a user