feat: another iteration of chat web UI (#648)

* autogenerate openapi file on server startup

* added endpoint for paginated retrieval of in-context agent messages

* missing diff

* added ability to pass system messages via message endpoint

* patched bad depends into queries to fix the param info not showing up in get requests, fixed some bad copy paste
This commit is contained in:
Charles Packer
2023-12-19 02:44:53 -08:00
committed by robingotz
parent 83973ecfc8
commit 93a897b43c
8 changed files with 161 additions and 25 deletions

View File

@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel, Field
from memgpt.server.rest_api.interface import QueuingInterface
@@ -8,8 +8,8 @@ router = APIRouter()
class AgentConfigRequest(BaseModel):
user_id: str = Field(..., description="Unique identifier of the user issuing the command.")
agent_id: str = Field(..., description="Identifier of the agent on which the command will be executed.")
user_id: str = Field(..., description="Unique identifier of the user requesting the config.")
agent_id: str = Field(..., description="Identifier of the agent whose config is requested.")
class AgentConfigResponse(BaseModel):
@@ -18,12 +18,17 @@ class AgentConfigResponse(BaseModel):
def setup_agents_config_router(server: SyncServer, interface: QueuingInterface):
@router.get("/agents/config", tags=["agents"], response_model=AgentConfigResponse)
def get_agent_config(request: AgentConfigRequest = Depends()):
def get_agent_config(
user_id: str = Query(..., description="Unique identifier of the user requesting the config."),
agent_id: str = Query(..., description="Identifier of the agent whose config is requested."),
):
"""
Retrieve the configuration for a specific agent.
This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs.
"""
request = AgentConfigRequest(user_id=user_id, agent_id=agent_id)
interface.clear()
config = server.get_agent_config(user_id=request.user_id, agent_id=request.agent_id)
return AgentConfigResponse(config=config)

View File

@@ -1,6 +1,6 @@
from typing import List
from fastapi import APIRouter, Depends, Body, HTTPException
from fastapi import APIRouter, Depends, Body, Query, HTTPException
from pydantic import BaseModel, Field
from memgpt.server.rest_api.interface import QueuingInterface
@@ -29,12 +29,14 @@ class CreateAgentResponse(BaseModel):
def setup_agents_index_router(server: SyncServer, interface: QueuingInterface):
@router.get("/agents", tags=["agents"], response_model=ListAgentsResponse)
def list_agents(request: ListAgentsRequest = Depends()):
def list_agents(user_id: str = Query(..., description="Unique identifier of the user.")):
"""
List all agents associated with a given user.
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
"""
request = ListAgentsRequest(user_id=user_id)
interface.clear()
agents_data = server.list_agents(user_id=request.user_id)
return ListAgentsResponse(**agents_data)

View File

@@ -1,6 +1,6 @@
from typing import Optional
from fastapi import APIRouter, Depends, Body
from fastapi import APIRouter, Depends, Body, Query
from pydantic import BaseModel, Field
from memgpt.server.rest_api.interface import QueuingInterface
@@ -29,8 +29,8 @@ class GetAgentMemoryResponse(BaseModel):
class UpdateAgentMemoryRequest(BaseModel):
user_id: str = Field(..., description="The unique identifier of the user.")
agent_id: str = Field(..., description="The unique identifier of the agent.")
human: Optional[str] = Field(None, description="Human element of the core memory.")
persona: Optional[str] = Field(None, description="Persona element of the core memory.")
human: str = Field(None, description="Human element of the core memory.")
persona: str = Field(None, description="Persona element of the core memory.")
class UpdateAgentMemoryResponse(BaseModel):
@@ -40,12 +40,18 @@ class UpdateAgentMemoryResponse(BaseModel):
def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface):
@router.get("/agents/memory", tags=["agents"], response_model=GetAgentMemoryResponse)
def get_agent_memory(request: GetAgentMemoryRequest = Depends()):
def get_agent_memory(
user_id: str = Query(..., description="The unique identifier of the user."),
agent_id: str = Query(..., description="The unique identifier of the agent."),
):
"""
Retrieve the memory state of a specific agent.
This endpoint fetches the current memory state of the agent identified by the user ID and agent ID.
"""
# Validate with the Pydantic model (optional)
request = GetAgentMemoryRequest(user_id=user_id, agent_id=agent_id)
interface.clear()
memory = server.get_agent_memory(user_id=request.user_id, agent_id=request.agent_id)
return GetAgentMemoryResponse(**memory)

View File

@@ -1,9 +1,10 @@
import asyncio
from enum import Enum
import json
from typing import List
from typing import List, Optional
from fastapi import APIRouter, Depends, Body, HTTPException
from pydantic import BaseModel, Field
from fastapi import APIRouter, Depends, Body, HTTPException, Query
from pydantic import BaseModel, Field, constr, validator
from starlette.responses import StreamingResponse
from memgpt.server.rest_api.interface import QueuingInterface
@@ -12,26 +13,67 @@ from memgpt.server.server import SyncServer
router = APIRouter()
class MessageRoleType(str, Enum):
user = "user"
system = "system"
class UserMessageRequest(BaseModel):
user_id: str = Field(..., description="The unique identifier of the user.")
agent_id: str = Field(..., description="The unique identifier of the agent.")
message: str = Field(..., description="The message content to be processed by the agent.")
stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.")
role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')")
class UserMessageResponse(BaseModel):
messages: List[str] = Field(..., description="List of messages generated by the agent.")
messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.")
class GetAgentMessagesRequest(BaseModel):
user_id: str = Field(..., description="The unique identifier of the user.")
agent_id: str = Field(..., description="The unique identifier of the agent.")
start: int = Field(..., description="Message index to start on (reverse chronological).")
count: int = Field(..., description="How many messages to retrieve.")
class GetAgentMessagesResponse(BaseModel):
messages: list = Field(..., description="List of message objects.")
def setup_agents_message_router(server: SyncServer, interface: QueuingInterface):
@router.get("/agents/message", tags=["agents"], response_model=GetAgentMessagesResponse)
def get_agent_messages(
user_id: str = Query(..., description="The unique identifier of the user."),
agent_id: str = Query(..., description="The unique identifier of the agent."),
start: int = Query(..., description="Message index to start on (reverse chronological)."),
count: int = Query(..., description="How many messages to retrieve."),
):
"""
Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate.
"""
# Validate with the Pydantic model (optional)
request = GetAgentMessagesRequest(user_id=user_id, agent_id=agent_id, start=start, count=count)
interface.clear()
messages = server.get_agent_messages(user_id=request.user_id, agent_id=request.agent_id, start=request.start, count=request.count)
return GetAgentMessagesResponse(messages=messages)
@router.post("/agents/message", tags=["agents"], response_model=UserMessageResponse)
async def user_message(request: UserMessageRequest = Body(...)):
async def send_message(request: UserMessageRequest = Body(...)):
"""
Process a user message and return the agent's response.
This endpoint accepts a message from a user and processes it through the agent.
It can optionally stream the response if 'stream' is set to True.
"""
if request.role == "user" or request.role is None:
message_func = server.user_message
elif request.role == "system":
message_func = server.system_message
else:
raise HTTPException(status_code=500, detail=f"Bad role {request.role}")
if request.stream:
# For streaming response
try:
@@ -39,15 +81,13 @@ def setup_agents_message_router(server: SyncServer, interface: QueuingInterface)
# This should be a non-blocking call or run in a background task
# Check if server.user_message is an async function
if asyncio.iscoroutinefunction(server.user_message):
if asyncio.iscoroutinefunction(message_func):
# Start the async task
await asyncio.create_task(
server.user_message(user_id=request.user_id, agent_id=request.agent_id, message=request.message)
)
await asyncio.create_task(message_func(user_id=request.user_id, agent_id=request.agent_id, message=request.message))
else:
# Run the synchronous function in a thread pool
loop = asyncio.get_event_loop()
loop.run_in_executor(None, server.user_message, request.user_id, request.agent_id, request.message)
loop.run_in_executor(None, message_func, request.user_id, request.agent_id, request.message)
async def formatted_message_generator():
async for message in interface.message_generator():
@@ -65,7 +105,7 @@ def setup_agents_message_router(server: SyncServer, interface: QueuingInterface)
else:
interface.clear()
try:
server.user_message(user_id=request.user_id, agent_id=request.agent_id, message=request.message)
message_func(user_id=request.user_id, agent_id=request.agent_id, message=request.message)
except HTTPException:
raise
except Exception as e:

View File

@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, Query, HTTPException
from pydantic import BaseModel, Field
@@ -9,7 +9,7 @@ router = APIRouter()
class ConfigRequest(BaseModel):
user_id: str = Field(..., description="Unique identifier of the user issuing the command.")
user_id: str = Field(..., description="Unique identifier of the user requesting the config.")
class ConfigResponse(BaseModel):
@@ -18,12 +18,14 @@ class ConfigResponse(BaseModel):
def setup_config_index_router(server: SyncServer, interface: QueuingInterface):
@router.get("/config", tags=["config"], response_model=ConfigResponse)
def get_server_config(user_id: ConfigRequest = Depends()):
def get_server_config(user_id: str = Query(..., description="Unique identifier of the user requesting the config.")):
"""
Retrieve the base configuration for the server.
"""
request = ConfigRequest(user_id=user_id)
interface.clear()
response = server.get_server_config(user_id=user_id)
response = server.get_server_config(user_id=request.user_id)
return ConfigResponse(config=response)
return router

View File

@@ -1,3 +1,4 @@
import json
from contextlib import asynccontextmanager
from fastapi import FastAPI
@@ -58,6 +59,22 @@ app.include_router(setup_config_index_router(server, interface), prefix=API_PREF
mount_static_files(app)
@app.on_event("startup")
def on_startup():
# Update the OpenAPI schema
if not app.openapi_schema:
app.openapi_schema = app.openapi()
if app.openapi_schema:
app.openapi_schema["servers"] = [{"url": "http://localhost:8283"}]
app.openapi_schema["info"]["title"] = "MemGPT API"
# Write out the OpenAPI schema to a file
with open("openapi.json", "w") as file:
print(f"Writing out openapi.json file")
json.dump(app.openapi_schema, file, indent=2)
@app.on_event("shutdown")
def on_shutdown():
global server

View File

@@ -36,6 +36,11 @@ class Server(object):
"""List all available agents to a user"""
raise NotImplementedError
@abstractmethod
def get_agent_messages(self, user_id: str, agent_id: str, start: int, count: int) -> list:
"""Paginated query of in-context messages in agent message queue"""
raise NotImplementedError
@abstractmethod
def get_agent_memory(self, user_id: str, agent_id: str) -> dict:
"""Return the memory of an agent (core memory + non-core statistics)"""
@@ -72,6 +77,11 @@ class Server(object):
"""Process a message from the user, internally calls step"""
raise NotImplementedError
@abstractmethod
def system_message(self, user_id: str, agent_id: str, message: str) -> None:
"""Process a message from the system, internally calls step"""
raise NotImplementedError
@abstractmethod
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
"""Run a command on the agent, e.g. /memory
@@ -406,6 +416,26 @@ class SyncServer(LockingServer):
# Run the agent state forward
self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message)
@LockingServer.agent_lock_decorator
def system_message(self, user_id: str, agent_id: str, message: str) -> None:
"""Process an incoming system message and feed it through the MemGPT agent"""
from memgpt.utils import printd
# Basic input sanitization
if not isinstance(message, str) or len(message) == 0:
raise ValueError(f"Invalid input: '{message}'")
# If the input begins with a command prefix, reject
elif message.startswith("/"):
raise ValueError(f"Invalid input: '{message}'")
# Else, process it as a user message to be fed to the agent
else:
# Package the user message first
packaged_system_message = package_system_message(system_message=message)
# Run the agent state forward
self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_system_message)
@LockingServer.agent_lock_decorator
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
"""Run a command on the agent"""
@@ -505,6 +535,29 @@ class SyncServer(LockingServer):
return memory_obj
def get_agent_messages(self, user_id: str, agent_id: str, start: int, count: int) -> list:
"""Paginated query of in-context messages in agent message queue"""
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
if start < 0 or count < 0:
raise ValueError("Start and count values should be non-negative")
# Reverse the list to make it in reverse chronological order
reversed_messages = memgpt_agent.messages[::-1]
# Check if start is within the range of the list
if start >= len(reversed_messages):
raise IndexError("Start index is out of range")
# Calculate the end index, ensuring it does not exceed the list length
end_index = min(start + count, len(reversed_messages))
# Slice the list for pagination
paginated_messages = reversed_messages[start:end_index]
return paginated_messages
def get_agent_config(self, user_id: str, agent_id: str) -> dict:
"""Return the config of an agent"""
# Get the agent object (loaded in memory)

View File

@@ -108,6 +108,17 @@ def package_function_response(was_success, response_string, timestamp=None):
return json.dumps(packaged_message)
def package_system_message(system_message, message_type="system_alert", time=None):
formatted_time = time if time else get_local_time()
packaged_message = {
"type": message_type,
"message": system_message,
"time": formatted_time,
}
return json.dumps(packaged_message)
def package_summarize_message(summary, summary_length, hidden_message_count, total_message_count, timestamp=None):
context_message = (
f"Note: prior messages ({hidden_message_count} of {total_message_count} total messages) have been hidden from view due to conversation memory constraints.\n"