Files
letta-server/memgpt/server/rest_api/agents/message.py
Charles Packer 93a897b43c feat: another iteration of chat web UI (#648)
* autogenerate openapi file on server startup

* added endpoint for paginated retrieval of in-context agent messages

* missing diff

* added ability to pass system messages via message endpoint

* patched bad depends into queries to fix the param info not showing up in get requests, fixed some bad copy paste
2024-01-11 14:49:44 +01:00

116 lines
5.2 KiB
Python

import asyncio
from enum import Enum
import json
from typing import List, Optional
from fastapi import APIRouter, Depends, Body, HTTPException, Query
from pydantic import BaseModel, Field, constr, validator
from starlette.responses import StreamingResponse
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
class MessageRoleType(str, Enum):
user = "user"
system = "system"
class UserMessageRequest(BaseModel):
user_id: str = Field(..., description="The unique identifier of the user.")
agent_id: str = Field(..., description="The unique identifier of the agent.")
message: str = Field(..., description="The message content to be processed by the agent.")
stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.")
role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')")
class UserMessageResponse(BaseModel):
messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.")
class GetAgentMessagesRequest(BaseModel):
user_id: str = Field(..., description="The unique identifier of the user.")
agent_id: str = Field(..., description="The unique identifier of the agent.")
start: int = Field(..., description="Message index to start on (reverse chronological).")
count: int = Field(..., description="How many messages to retrieve.")
class GetAgentMessagesResponse(BaseModel):
messages: list = Field(..., description="List of message objects.")
def setup_agents_message_router(server: SyncServer, interface: QueuingInterface):
@router.get("/agents/message", tags=["agents"], response_model=GetAgentMessagesResponse)
def get_agent_messages(
user_id: str = Query(..., description="The unique identifier of the user."),
agent_id: str = Query(..., description="The unique identifier of the agent."),
start: int = Query(..., description="Message index to start on (reverse chronological)."),
count: int = Query(..., description="How many messages to retrieve."),
):
"""
Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate.
"""
# Validate with the Pydantic model (optional)
request = GetAgentMessagesRequest(user_id=user_id, agent_id=agent_id, start=start, count=count)
interface.clear()
messages = server.get_agent_messages(user_id=request.user_id, agent_id=request.agent_id, start=request.start, count=request.count)
return GetAgentMessagesResponse(messages=messages)
@router.post("/agents/message", tags=["agents"], response_model=UserMessageResponse)
async def send_message(request: UserMessageRequest = Body(...)):
"""
Process a user message and return the agent's response.
This endpoint accepts a message from a user and processes it through the agent.
It can optionally stream the response if 'stream' is set to True.
"""
if request.role == "user" or request.role is None:
message_func = server.user_message
elif request.role == "system":
message_func = server.system_message
else:
raise HTTPException(status_code=500, detail=f"Bad role {request.role}")
if request.stream:
# For streaming response
try:
# Start the generation process (similar to the non-streaming case)
# This should be a non-blocking call or run in a background task
# Check if server.user_message is an async function
if asyncio.iscoroutinefunction(message_func):
# Start the async task
await asyncio.create_task(message_func(user_id=request.user_id, agent_id=request.agent_id, message=request.message))
else:
# Run the synchronous function in a thread pool
loop = asyncio.get_event_loop()
loop.run_in_executor(None, message_func, request.user_id, request.agent_id, request.message)
async def formatted_message_generator():
async for message in interface.message_generator():
formatted_message = f"data: {json.dumps(message)}\n\n"
yield formatted_message
await asyncio.sleep(1)
# Return the streaming response using the generator
return StreamingResponse(formatted_message_generator(), media_type="text/event-stream")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
else:
interface.clear()
try:
message_func(user_id=request.user_id, agent_id=request.agent_id, message=request.message)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return UserMessageResponse(messages=interface.to_list())
return router