added memgpt server command (#611)

* added memgpt server command

* added the option to specify a port (rest default 8283, ws default 8282)

* fixed import in test

* added agent saving on shutdown

* added basic locking mechanism (assumes only one server.py is running at the same time)

* remove 'STOP' from buffer when converting to list for the non-streaming POST resposne

* removed duplicate on_event (redundant to lifespan)

* added GET agents/memory route

* added GET agent config

* added GET server config

* added PUT route for modifying agent core memory

* refactored to put server loop in separate function called via main
This commit is contained in:
Charles Packer
2023-12-13 00:41:40 -08:00
committed by GitHub
parent b3f1f50a6c
commit 2048ba179b
10 changed files with 372 additions and 23 deletions

View File

@@ -4,6 +4,10 @@ import sys
import io
import logging
import questionary
from pathlib import Path
import os
import subprocess
from enum import Enum
from llama_index import set_global_service_context
from llama_index import ServiceContext
@@ -18,6 +22,76 @@ from memgpt.config import MemGPTConfig, AgentConfig
from memgpt.constants import MEMGPT_DIR, CLI_WARNING_PREFIX
from memgpt.agent import Agent
from memgpt.embeddings import embedding_model
from memgpt.server.constants import WS_DEFAULT_PORT, REST_DEFAULT_PORT
class ServerChoice(Enum):
rest_api = "rest"
ws_api = "websocket"
def server(
type: ServerChoice = typer.Option("rest", help="Server to run"), port: int = typer.Option(None, help="Port to run the server on")
):
"""Launch a MemGPT server process"""
if type == ServerChoice.rest_api:
if port is None:
port = REST_DEFAULT_PORT
# Change to the desired directory
script_path = Path(__file__).resolve()
script_dir = script_path.parent
server_directory = os.path.join(script_dir.parent, "server", "rest_api")
command = f"uvicorn server:app --reload --port {port}"
# Run the command
print(f"Running REST server: {command} (inside {server_directory})")
try:
# Start the subprocess in a new session
process = subprocess.Popen(command, shell=True, start_new_session=True, cwd=server_directory)
process.wait()
except KeyboardInterrupt:
# Handle CTRL-C
print("Terminating the server...")
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
print("Server terminated with kill()")
sys.exit(0)
elif type == ServerChoice.ws_api:
if port is None:
port = WS_DEFAULT_PORT
# Change to the desired directory
script_path = Path(__file__).resolve()
script_dir = script_path.parent
server_directory = os.path.join(script_dir.parent, "server", "ws_api")
command = f"python server.py {port}"
# Run the command
print(f"Running WS (websockets) server: {command} (inside {server_directory})")
try:
# Start the subprocess in a new session
process = subprocess.Popen(command, shell=True, start_new_session=True, cwd=server_directory)
process.wait()
except KeyboardInterrupt:
# Handle CTRL-C
print("Terminating the server...")
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
print("Server terminated with kill()")
sys.exit(0)
def run(

View File

@@ -21,7 +21,7 @@ from memgpt.interface import CLIInterface as interface # for printing to termin
import memgpt.agent as agent
import memgpt.system as system
import memgpt.constants as constants
from memgpt.cli.cli import run, attach, version
from memgpt.cli.cli import run, attach, version, server
from memgpt.cli.cli_config import configure, list, add
from memgpt.cli.cli_load import app as load_app
from memgpt.connectors.storage import StorageConnector
@@ -33,6 +33,7 @@ app.command(name="attach")(attach)
app.command(name="configure")(configure)
app.command(name="list")(list)
app.command(name="add")(add)
app.command(name="server")(server)
# load data commands
app.add_typer(load_app, name="load")

View File

@@ -1,3 +1,6 @@
DEFAULT_PORT = 8282
# WebSockets
WS_DEFAULT_PORT = 8282
WS_CLIENT_TIMEOUT = 30
CLIENT_TIMEOUT = 30
# REST
REST_DEFAULT_PORT = 8283

View File

@@ -18,6 +18,8 @@ class QueuingInterface(AgentInterface):
items.append(self.buffer.get_nowait())
except queue.Empty:
break
if len(items) > 1 and items[-1] == "STOP":
items.pop()
return items
def clear(self):

View File

@@ -1,4 +1,5 @@
import asyncio
from contextlib import asynccontextmanager
import json
from typing import Union
@@ -38,17 +39,64 @@ class Command(BaseModel):
command: str
app = FastAPI()
interface = QueuingInterface()
server = SyncServer(default_interface=interface)
class CoreMemory(BaseModel):
user_id: str
agent_id: str
human: str | None = None
persona: str | None = None
server = None
interface = None
@asynccontextmanager
async def lifespan(application: FastAPI):
global server
global interface
interface = QueuingInterface()
server = SyncServer(default_interface=interface)
yield
server.save_agents()
server = None
app = FastAPI(lifespan=lifespan)
# app = FastAPI()
# server = SyncServer(default_interface=interface)
# server.list_agents
@app.get("/agents")
def list_agents(user_id: str):
interface.clear()
agents_list = utils.list_agent_config_files()
return {"num_agents": len(agents_list), "agent_names": agents_list}
return server.list_agents(user_id=user_id)
@app.get("/agents/memory")
def get_agent_memory(user_id: str, agent_id: str):
interface.clear()
return server.get_agent_memory(user_id=user_id, agent_id=agent_id)
@app.put("/agents/memory")
def get_agent_memory(body: CoreMemory):
interface.clear()
new_memory_contents = {"persona": body.persona, "human": body.human}
return server.update_agent_core_memory(user_id=body.user_id, agent_id=body.agent_id, new_memory_contents=new_memory_contents)
@app.get("/agents/config")
def get_agent_config(user_id: str, agent_id: str):
interface.clear()
return server.get_agent_config(user_id=user_id, agent_id=agent_id)
@app.get("/config")
def get_server_config(user_id: str):
interface.clear()
return server.get_server_config(user_id=user_id)
# server.create_agent
@@ -88,6 +136,8 @@ async def user_message(body: UserMessage):
# Return the streaming response using the generator
return StreamingResponse(formatted_message_generator(), media_type="text/event-stream")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@@ -95,6 +145,8 @@ async def user_message(body: UserMessage):
interface.clear()
try:
server.user_message(user_id=body.user_id, agent_id=body.agent_id, message=body.message)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return {"messages": interface.to_list()}
@@ -106,6 +158,8 @@ def run_command(body: Command):
interface.clear()
try:
response = server.run_command(user_id=body.user_id, agent_id=body.agent_id, command=body.command)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
response = server.run_command(user_id=body.user_id, agent_id=body.agent_id, command=body.command)

View File

@@ -1,9 +1,12 @@
from abc import abstractmethod
from typing import Union
from typing import Union, Callable
import json
from threading import Lock
from functools import wraps
from fastapi import HTTPException
from memgpt.system import package_user_message
from memgpt.config import AgentConfig
from memgpt.config import AgentConfig, MemGPTConfig
from memgpt.agent import Agent
import memgpt.system as system
import memgpt.constants as constants
@@ -11,6 +14,7 @@ from memgpt.cli.cli import attach
from memgpt.connectors.storage import StorageConnector
import memgpt.presets.presets as presets
import memgpt.utils as utils
import memgpt.server.utils as server_utils
from memgpt.persistence_manager import PersistenceManager, LocalStateManager
# TODO use custom interface
@@ -22,10 +26,30 @@ class Server(object):
"""Abstract server class that supports multi-agent multi-user"""
@abstractmethod
def list_agents(self, user_id: str, agent_id: str) -> str:
def list_agents(self, user_id: str) -> dict:
"""List all available agents to a user"""
raise NotImplementedError
@abstractmethod
def get_agent_memory(self, user_id: str, agent_id: str) -> dict:
"""Return the memory of an agent (core memory + non-core statistics)"""
raise NotImplementedError
@abstractmethod
def get_agent_config(self, user_id: str, agent_id: str) -> dict:
"""Return the config of an agent"""
raise NotImplementedError
@abstractmethod
def get_server_config(self, user_id: str) -> dict:
"""Return the base config"""
raise NotImplementedError
@abstractmethod
def update_agent_core_memory(self, user_id: str, agent_id: str, new_memory_contents: dict) -> dict:
"""Update the agents core memory block, return the new state"""
raise NotImplementedError
@abstractmethod
def create_agent(
self,
@@ -51,8 +75,50 @@ class Server(object):
raise NotImplementedError
class LockingServer(Server):
"""Basic support for concurrency protections (all requests that modify an agent lock the agent until the operation is complete)"""
# Locks for each agent
_agent_locks = {}
@staticmethod
def agent_lock_decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(self, user_id: str, agent_id: str, *args, **kwargs):
# print("Locking check")
# Initialize the lock for the agent_id if it doesn't exist
if agent_id not in self._agent_locks:
# print(f"Creating lock for agent_id = {agent_id}")
self._agent_locks[agent_id] = Lock()
# Check if the agent is currently locked
if not self._agent_locks[agent_id].acquire(blocking=False):
# print(f"agent_id = {agent_id} is busy")
raise HTTPException(status_code=423, detail=f"Agent '{agent_id}' is currently busy.")
try:
# Execute the function
# print(f"running function on agent_id = {agent_id}")
return func(self, user_id, agent_id, *args, **kwargs)
finally:
# Release the lock
# print(f"releasing lock on agent_id = {agent_id}")
self._agent_locks[agent_id].release()
return wrapper
@agent_lock_decorator
def user_message(self, user_id: str, agent_id: str, message: str) -> None:
raise NotImplementedError
@agent_lock_decorator
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
raise NotImplementedError
# TODO actually use "user_id" for something
class SyncServer(Server):
class SyncServer(LockingServer):
"""Simple single-threaded / blocking server process"""
def __init__(
@@ -82,6 +148,14 @@ class SyncServer(Server):
# The default persistence manager that will get assigned to agents ON CREATION
self.default_persistence_manager_cls = default_persistence_manager_cls
def save_agents(self):
for agent_d in self.active_agents:
try:
agent_d["agent"].save()
print(f"Saved agent {agent_d['agent_id']}")
except Exception as e:
print(f"Error occured while trying to save agent {agent_d['agent_id']}:\n{e}")
def _get_agent(self, user_id: str, agent_id: str) -> Union[Agent, None]:
"""Get the agent object from the in-memory object store"""
for d in self.active_agents:
@@ -302,6 +376,7 @@ class SyncServer(Server):
input_message = system.get_token_limit_warning()
self._step(user_id=user_id, agent_id=agent_id, input_message=input_message)
@LockingServer.agent_lock_decorator
def user_message(self, user_id: str, agent_id: str, message: str) -> None:
"""Process an incoming user message and feed it through the MemGPT agent"""
from memgpt.utils import printd
@@ -321,6 +396,7 @@ class SyncServer(Server):
# Run the agent state forward
self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message)
@LockingServer.agent_lock_decorator
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
"""Run a command on the agent"""
# If the input begins with a command prefix, attempt to process it as a command
@@ -363,3 +439,83 @@ class SyncServer(Server):
print(f"Created new agent from config: {agent}")
return agent.config.name
def list_agents(self, user_id: str) -> dict:
"""List all available agents to a user"""
agents_list = utils.list_agent_config_files()
return {"num_agents": len(agents_list), "agent_names": agents_list}
def get_agent_memory(self, user_id: str, agent_id: str) -> dict:
"""Return the memory of an agent (core memory + non-core statistics)"""
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
core_memory = memgpt_agent.memory
recall_memory = memgpt_agent.persistence_manager.recall_memory
archival_memory = memgpt_agent.persistence_manager.archival_memory
memory_obj = {
"core_memory": {
"persona": core_memory.persona,
"human": core_memory.human,
},
"recall_memory": len(recall_memory) if recall_memory is not None else None,
"archival_memory": len(archival_memory) if archival_memory is not None else None,
}
return memory_obj
def get_agent_config(self, user_id: str, agent_id: str) -> dict:
"""Return the config of an agent"""
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
agent_config = vars(memgpt_agent.config)
return agent_config
def get_server_config(self, user_id: str) -> dict:
"""Return the base config"""
base_config = vars(MemGPTConfig.load())
def clean_keys(config):
config_copy = config.copy()
for k, v in config.items():
if k == "key" or "_key" in k:
config_copy[k] = server_utils.shorten_key_middle(v, chars_each_side=5)
return config_copy
clean_base_config = clean_keys(base_config)
return clean_base_config
def update_agent_core_memory(self, user_id: str, agent_id: str, new_memory_contents: dict) -> dict:
"""Update the agents core memory block, return the new state"""
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
old_core_memory = self.get_agent_memory(user_id=user_id, agent_id=agent_id)["core_memory"]
new_core_memory = old_core_memory.copy()
modified = False
if "persona" in new_memory_contents and new_memory_contents["persona"] is not None:
new_persona = new_memory_contents["persona"]
if old_core_memory["persona"] != new_persona:
new_core_memory["persona"] = new_persona
memgpt_agent.memory.edit_persona(new_persona)
modified = True
elif "human" in new_memory_contents and new_memory_contents["human"] is not None:
new_human = new_memory_contents["human"]
if old_core_memory["human"] != new_human:
new_core_memory["human"] = new_human
memgpt_agent.memory.edit_human(new_human)
modified = True
# If we modified the memory contents, we need to rebuild the memory block inside the system message
if modified:
memgpt_agent.rebuild_memory()
return {
"old_core_memory": old_core_memory,
"new_core_memory": new_core_memory,
"modified": modified,
}

View File

@@ -24,3 +24,21 @@ def print_server_response(response):
print(response)
else:
print(response)
def shorten_key_middle(key_string, chars_each_side=3):
"""
Shortens a key string by showing a specified number of characters on each side and adding an ellipsis in the middle.
Args:
key_string (str): The key string to be shortened.
chars_each_side (int): The number of characters to show on each side of the ellipsis.
Returns:
str: The shortened key string with an ellipsis in the middle.
"""
key_length = len(key_string)
if key_length <= 2 * chars_each_side:
return "..." # Return ellipsis if the key is too short
else:
return key_string[:chars_each_side] + "..." + key_string[-chars_each_side:]

View File

@@ -4,7 +4,7 @@ import json
import websockets
import memgpt.server.ws_api.protocol as protocol
from memgpt.server.constants import DEFAULT_PORT, CLIENT_TIMEOUT
from memgpt.server.constants import WS_DEFAULT_PORT, WS_CLIENT_TIMEOUT
from memgpt.server.utils import condition_to_stop_receiving, print_server_response
@@ -26,7 +26,7 @@ async def send_message_and_print_replies(websocket, user_message, agent_id):
# Wait for messages in a loop, since the server may send a few
while True:
response = await asyncio.wait_for(websocket.recv(), CLIENT_TIMEOUT)
response = await asyncio.wait_for(websocket.recv(), WS_CLIENT_TIMEOUT)
response = json.loads(response)
if CLEAN_RESPONSES:
@@ -44,7 +44,7 @@ async def basic_cli_client():
Meant to illustrate how to use the server.py process, so limited in features (only supports sending user messages)
"""
uri = f"ws://localhost:{DEFAULT_PORT}"
uri = f"ws://localhost:{WS_DEFAULT_PORT}"
closed_on_message = False
retry_attempts = 0

View File

@@ -1,26 +1,37 @@
import asyncio
import json
import signal
import sys
import traceback
import websockets
from memgpt.server.server import SyncServer
from memgpt.server.ws_api.interface import SyncWebSocketInterface
from memgpt.server.constants import DEFAULT_PORT
from memgpt.server.constants import WS_DEFAULT_PORT
import memgpt.server.ws_api.protocol as protocol
import memgpt.system as system
import memgpt.constants as memgpt_constants
class WebSocketServer:
def __init__(self, host="localhost", port=DEFAULT_PORT):
def __init__(self, host="localhost", port=WS_DEFAULT_PORT):
self.host = host
self.port = port
self.interface = SyncWebSocketInterface()
self.server = SyncServer(default_interface=self.interface)
def __del__(self):
self.interface.close()
def shutdown_server(self):
try:
self.server.save_agents()
print(f"Saved agents")
except Exception as e:
print(f"Saving agents failed with: {e}")
try:
self.interface.close()
print(f"Closed the WS interface")
except Exception as e:
print(f"Closing the WS interface failed with: {e}")
def initialize_server(self):
print("Server is initializing...")
@@ -102,6 +113,36 @@ class WebSocketServer:
self.interface.unregister_client(websocket)
def start_server():
# Check if a port argument is provided
port = WS_DEFAULT_PORT
if len(sys.argv) > 1:
try:
port = int(sys.argv[1])
except ValueError:
print(f"Invalid port number. Using default port {port}.")
server = WebSocketServer(port=port)
def handle_sigterm(*args):
# Perform necessary cleanup
print("SIGTERM received, shutting down...")
# Note: This should be quick and not involve asynchronous calls
print("Shutting down the server...")
server.shutdown_server()
print("Server has been shut down.")
sys.exit(0)
signal.signal(signal.SIGTERM, handle_sigterm)
try:
asyncio.run(server.run())
except KeyboardInterrupt:
print("Shutting down the server...")
finally:
server.shutdown_server()
print("Server has been shut down.")
if __name__ == "__main__":
server = WebSocketServer()
asyncio.run(server.run())
start_server()

View File

@@ -4,7 +4,7 @@ import json
import websockets
import pytest
from memgpt.server.constants import DEFAULT_PORT
from memgpt.server.constants import WS_DEFAULT_PORT
from memgpt.server.ws_api.server import WebSocketServer
from memgpt.config import AgentConfig
@@ -30,7 +30,7 @@ async def test_websocket_server():
# )
test_config = {}
uri = f"ws://{host}:{DEFAULT_PORT}"
uri = f"ws://{host}:{WS_DEFAULT_PORT}"
try:
async with websockets.connect(uri) as websocket:
# Initialize the server with a test config