feat: another iteration of chat web UI (#648)
* autogenerate openapi file on server startup * added endpoint for paginated retrieval of in-context agent messages * missing diff * added ability to pass system messages via message endpoint * patched bad depends into queries to fix the param info not showing up in get requests, fixed some bad copy paste
This commit is contained in:
committed by
robingotz
parent
83973ecfc8
commit
93a897b43c
@@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
@@ -8,8 +8,8 @@ router = APIRouter()
|
||||
|
||||
|
||||
class AgentConfigRequest(BaseModel):
|
||||
user_id: str = Field(..., description="Unique identifier of the user issuing the command.")
|
||||
agent_id: str = Field(..., description="Identifier of the agent on which the command will be executed.")
|
||||
user_id: str = Field(..., description="Unique identifier of the user requesting the config.")
|
||||
agent_id: str = Field(..., description="Identifier of the agent whose config is requested.")
|
||||
|
||||
|
||||
class AgentConfigResponse(BaseModel):
|
||||
@@ -18,12 +18,17 @@ class AgentConfigResponse(BaseModel):
|
||||
|
||||
def setup_agents_config_router(server: SyncServer, interface: QueuingInterface):
|
||||
@router.get("/agents/config", tags=["agents"], response_model=AgentConfigResponse)
|
||||
def get_agent_config(request: AgentConfigRequest = Depends()):
|
||||
def get_agent_config(
|
||||
user_id: str = Query(..., description="Unique identifier of the user requesting the config."),
|
||||
agent_id: str = Query(..., description="Identifier of the agent whose config is requested."),
|
||||
):
|
||||
"""
|
||||
Retrieve the configuration for a specific agent.
|
||||
|
||||
This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs.
|
||||
"""
|
||||
request = AgentConfigRequest(user_id=user_id, agent_id=agent_id)
|
||||
|
||||
interface.clear()
|
||||
config = server.get_agent_config(user_id=request.user_id, agent_id=request.agent_id)
|
||||
return AgentConfigResponse(config=config)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, Body, HTTPException
|
||||
from fastapi import APIRouter, Depends, Body, Query, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
@@ -29,12 +29,14 @@ class CreateAgentResponse(BaseModel):
|
||||
|
||||
def setup_agents_index_router(server: SyncServer, interface: QueuingInterface):
|
||||
@router.get("/agents", tags=["agents"], response_model=ListAgentsResponse)
|
||||
def list_agents(request: ListAgentsRequest = Depends()):
|
||||
def list_agents(user_id: str = Query(..., description="Unique identifier of the user.")):
|
||||
"""
|
||||
List all agents associated with a given user.
|
||||
|
||||
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
|
||||
"""
|
||||
request = ListAgentsRequest(user_id=user_id)
|
||||
|
||||
interface.clear()
|
||||
agents_data = server.list_agents(user_id=request.user_id)
|
||||
return ListAgentsResponse(**agents_data)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Body
|
||||
from fastapi import APIRouter, Depends, Body, Query
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
@@ -29,8 +29,8 @@ class GetAgentMemoryResponse(BaseModel):
|
||||
class UpdateAgentMemoryRequest(BaseModel):
|
||||
user_id: str = Field(..., description="The unique identifier of the user.")
|
||||
agent_id: str = Field(..., description="The unique identifier of the agent.")
|
||||
human: Optional[str] = Field(None, description="Human element of the core memory.")
|
||||
persona: Optional[str] = Field(None, description="Persona element of the core memory.")
|
||||
human: str = Field(None, description="Human element of the core memory.")
|
||||
persona: str = Field(None, description="Persona element of the core memory.")
|
||||
|
||||
|
||||
class UpdateAgentMemoryResponse(BaseModel):
|
||||
@@ -40,12 +40,18 @@ class UpdateAgentMemoryResponse(BaseModel):
|
||||
|
||||
def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface):
|
||||
@router.get("/agents/memory", tags=["agents"], response_model=GetAgentMemoryResponse)
|
||||
def get_agent_memory(request: GetAgentMemoryRequest = Depends()):
|
||||
def get_agent_memory(
|
||||
user_id: str = Query(..., description="The unique identifier of the user."),
|
||||
agent_id: str = Query(..., description="The unique identifier of the agent."),
|
||||
):
|
||||
"""
|
||||
Retrieve the memory state of a specific agent.
|
||||
|
||||
This endpoint fetches the current memory state of the agent identified by the user ID and agent ID.
|
||||
"""
|
||||
# Validate with the Pydantic model (optional)
|
||||
request = GetAgentMemoryRequest(user_id=user_id, agent_id=agent_id)
|
||||
|
||||
interface.clear()
|
||||
memory = server.get_agent_memory(user_id=request.user_id, agent_id=request.agent_id)
|
||||
return GetAgentMemoryResponse(**memory)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
import json
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Body, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi import APIRouter, Depends, Body, HTTPException, Query
|
||||
from pydantic import BaseModel, Field, constr, validator
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
@@ -12,26 +13,67 @@ from memgpt.server.server import SyncServer
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class MessageRoleType(str, Enum):
|
||||
user = "user"
|
||||
system = "system"
|
||||
|
||||
|
||||
class UserMessageRequest(BaseModel):
|
||||
user_id: str = Field(..., description="The unique identifier of the user.")
|
||||
agent_id: str = Field(..., description="The unique identifier of the agent.")
|
||||
message: str = Field(..., description="The message content to be processed by the agent.")
|
||||
stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.")
|
||||
role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')")
|
||||
|
||||
|
||||
class UserMessageResponse(BaseModel):
|
||||
messages: List[str] = Field(..., description="List of messages generated by the agent.")
|
||||
messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.")
|
||||
|
||||
|
||||
class GetAgentMessagesRequest(BaseModel):
|
||||
user_id: str = Field(..., description="The unique identifier of the user.")
|
||||
agent_id: str = Field(..., description="The unique identifier of the agent.")
|
||||
start: int = Field(..., description="Message index to start on (reverse chronological).")
|
||||
count: int = Field(..., description="How many messages to retrieve.")
|
||||
|
||||
|
||||
class GetAgentMessagesResponse(BaseModel):
|
||||
messages: list = Field(..., description="List of message objects.")
|
||||
|
||||
|
||||
def setup_agents_message_router(server: SyncServer, interface: QueuingInterface):
|
||||
@router.get("/agents/message", tags=["agents"], response_model=GetAgentMessagesResponse)
|
||||
def get_agent_messages(
|
||||
user_id: str = Query(..., description="The unique identifier of the user."),
|
||||
agent_id: str = Query(..., description="The unique identifier of the agent."),
|
||||
start: int = Query(..., description="Message index to start on (reverse chronological)."),
|
||||
count: int = Query(..., description="How many messages to retrieve."),
|
||||
):
|
||||
"""
|
||||
Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate.
|
||||
"""
|
||||
# Validate with the Pydantic model (optional)
|
||||
request = GetAgentMessagesRequest(user_id=user_id, agent_id=agent_id, start=start, count=count)
|
||||
|
||||
interface.clear()
|
||||
messages = server.get_agent_messages(user_id=request.user_id, agent_id=request.agent_id, start=request.start, count=request.count)
|
||||
return GetAgentMessagesResponse(messages=messages)
|
||||
|
||||
@router.post("/agents/message", tags=["agents"], response_model=UserMessageResponse)
|
||||
async def user_message(request: UserMessageRequest = Body(...)):
|
||||
async def send_message(request: UserMessageRequest = Body(...)):
|
||||
"""
|
||||
Process a user message and return the agent's response.
|
||||
|
||||
This endpoint accepts a message from a user and processes it through the agent.
|
||||
It can optionally stream the response if 'stream' is set to True.
|
||||
"""
|
||||
if request.role == "user" or request.role is None:
|
||||
message_func = server.user_message
|
||||
elif request.role == "system":
|
||||
message_func = server.system_message
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail=f"Bad role {request.role}")
|
||||
|
||||
if request.stream:
|
||||
# For streaming response
|
||||
try:
|
||||
@@ -39,15 +81,13 @@ def setup_agents_message_router(server: SyncServer, interface: QueuingInterface)
|
||||
# This should be a non-blocking call or run in a background task
|
||||
|
||||
# Check if server.user_message is an async function
|
||||
if asyncio.iscoroutinefunction(server.user_message):
|
||||
if asyncio.iscoroutinefunction(message_func):
|
||||
# Start the async task
|
||||
await asyncio.create_task(
|
||||
server.user_message(user_id=request.user_id, agent_id=request.agent_id, message=request.message)
|
||||
)
|
||||
await asyncio.create_task(message_func(user_id=request.user_id, agent_id=request.agent_id, message=request.message))
|
||||
else:
|
||||
# Run the synchronous function in a thread pool
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_in_executor(None, server.user_message, request.user_id, request.agent_id, request.message)
|
||||
loop.run_in_executor(None, message_func, request.user_id, request.agent_id, request.message)
|
||||
|
||||
async def formatted_message_generator():
|
||||
async for message in interface.message_generator():
|
||||
@@ -65,7 +105,7 @@ def setup_agents_message_router(server: SyncServer, interface: QueuingInterface)
|
||||
else:
|
||||
interface.clear()
|
||||
try:
|
||||
server.user_message(user_id=request.user_id, agent_id=request.agent_id, message=request.message)
|
||||
message_func(user_id=request.user_id, agent_id=request.agent_id, message=request.message)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -9,7 +9,7 @@ router = APIRouter()
|
||||
|
||||
|
||||
class ConfigRequest(BaseModel):
|
||||
user_id: str = Field(..., description="Unique identifier of the user issuing the command.")
|
||||
user_id: str = Field(..., description="Unique identifier of the user requesting the config.")
|
||||
|
||||
|
||||
class ConfigResponse(BaseModel):
|
||||
@@ -18,12 +18,14 @@ class ConfigResponse(BaseModel):
|
||||
|
||||
def setup_config_index_router(server: SyncServer, interface: QueuingInterface):
|
||||
@router.get("/config", tags=["config"], response_model=ConfigResponse)
|
||||
def get_server_config(user_id: ConfigRequest = Depends()):
|
||||
def get_server_config(user_id: str = Query(..., description="Unique identifier of the user requesting the config.")):
|
||||
"""
|
||||
Retrieve the base configuration for the server.
|
||||
"""
|
||||
request = ConfigRequest(user_id=user_id)
|
||||
|
||||
interface.clear()
|
||||
response = server.get_server_config(user_id=user_id)
|
||||
response = server.get_server_config(user_id=request.user_id)
|
||||
return ConfigResponse(config=response)
|
||||
|
||||
return router
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
@@ -58,6 +59,22 @@ app.include_router(setup_config_index_router(server, interface), prefix=API_PREF
|
||||
mount_static_files(app)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def on_startup():
|
||||
# Update the OpenAPI schema
|
||||
if not app.openapi_schema:
|
||||
app.openapi_schema = app.openapi()
|
||||
|
||||
if app.openapi_schema:
|
||||
app.openapi_schema["servers"] = [{"url": "http://localhost:8283"}]
|
||||
app.openapi_schema["info"]["title"] = "MemGPT API"
|
||||
|
||||
# Write out the OpenAPI schema to a file
|
||||
with open("openapi.json", "w") as file:
|
||||
print(f"Writing out openapi.json file")
|
||||
json.dump(app.openapi_schema, file, indent=2)
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
def on_shutdown():
|
||||
global server
|
||||
|
||||
@@ -36,6 +36,11 @@ class Server(object):
|
||||
"""List all available agents to a user"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_agent_messages(self, user_id: str, agent_id: str, start: int, count: int) -> list:
|
||||
"""Paginated query of in-context messages in agent message queue"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_agent_memory(self, user_id: str, agent_id: str) -> dict:
|
||||
"""Return the memory of an agent (core memory + non-core statistics)"""
|
||||
@@ -72,6 +77,11 @@ class Server(object):
|
||||
"""Process a message from the user, internally calls step"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def system_message(self, user_id: str, agent_id: str, message: str) -> None:
|
||||
"""Process a message from the system, internally calls step"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
|
||||
"""Run a command on the agent, e.g. /memory
|
||||
@@ -406,6 +416,26 @@ class SyncServer(LockingServer):
|
||||
# Run the agent state forward
|
||||
self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message)
|
||||
|
||||
@LockingServer.agent_lock_decorator
|
||||
def system_message(self, user_id: str, agent_id: str, message: str) -> None:
|
||||
"""Process an incoming system message and feed it through the MemGPT agent"""
|
||||
from memgpt.utils import printd
|
||||
|
||||
# Basic input sanitization
|
||||
if not isinstance(message, str) or len(message) == 0:
|
||||
raise ValueError(f"Invalid input: '{message}'")
|
||||
|
||||
# If the input begins with a command prefix, reject
|
||||
elif message.startswith("/"):
|
||||
raise ValueError(f"Invalid input: '{message}'")
|
||||
|
||||
# Else, process it as a user message to be fed to the agent
|
||||
else:
|
||||
# Package the user message first
|
||||
packaged_system_message = package_system_message(system_message=message)
|
||||
# Run the agent state forward
|
||||
self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_system_message)
|
||||
|
||||
@LockingServer.agent_lock_decorator
|
||||
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
|
||||
"""Run a command on the agent"""
|
||||
@@ -505,6 +535,29 @@ class SyncServer(LockingServer):
|
||||
|
||||
return memory_obj
|
||||
|
||||
def get_agent_messages(self, user_id: str, agent_id: str, start: int, count: int) -> list:
|
||||
"""Paginated query of in-context messages in agent message queue"""
|
||||
# Get the agent object (loaded in memory)
|
||||
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
|
||||
if start < 0 or count < 0:
|
||||
raise ValueError("Start and count values should be non-negative")
|
||||
|
||||
# Reverse the list to make it in reverse chronological order
|
||||
reversed_messages = memgpt_agent.messages[::-1]
|
||||
|
||||
# Check if start is within the range of the list
|
||||
if start >= len(reversed_messages):
|
||||
raise IndexError("Start index is out of range")
|
||||
|
||||
# Calculate the end index, ensuring it does not exceed the list length
|
||||
end_index = min(start + count, len(reversed_messages))
|
||||
|
||||
# Slice the list for pagination
|
||||
paginated_messages = reversed_messages[start:end_index]
|
||||
|
||||
return paginated_messages
|
||||
|
||||
def get_agent_config(self, user_id: str, agent_id: str) -> dict:
|
||||
"""Return the config of an agent"""
|
||||
# Get the agent object (loaded in memory)
|
||||
|
||||
@@ -108,6 +108,17 @@ def package_function_response(was_success, response_string, timestamp=None):
|
||||
return json.dumps(packaged_message)
|
||||
|
||||
|
||||
def package_system_message(system_message, message_type="system_alert", time=None):
|
||||
formatted_time = time if time else get_local_time()
|
||||
packaged_message = {
|
||||
"type": message_type,
|
||||
"message": system_message,
|
||||
"time": formatted_time,
|
||||
}
|
||||
|
||||
return json.dumps(packaged_message)
|
||||
|
||||
|
||||
def package_summarize_message(summary, summary_length, hidden_message_count, total_message_count, timestamp=None):
|
||||
context_message = (
|
||||
f"Note: prior messages ({hidden_message_count} of {total_message_count} total messages) have been hidden from view due to conversation memory constraints.\n"
|
||||
|
||||
Reference in New Issue
Block a user