diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index 71fea458..c1270606 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -4,6 +4,10 @@ import sys import io import logging import questionary +from pathlib import Path +import os +import subprocess +from enum import Enum from llama_index import set_global_service_context from llama_index import ServiceContext @@ -18,6 +22,76 @@ from memgpt.config import MemGPTConfig, AgentConfig from memgpt.constants import MEMGPT_DIR, CLI_WARNING_PREFIX from memgpt.agent import Agent from memgpt.embeddings import embedding_model +from memgpt.server.constants import WS_DEFAULT_PORT, REST_DEFAULT_PORT + + +class ServerChoice(Enum): + rest_api = "rest" + ws_api = "websocket" + + +def server( + type: ServerChoice = typer.Option("rest", help="Server to run"), port: int = typer.Option(None, help="Port to run the server on") +): + """Launch a MemGPT server process""" + + if type == ServerChoice.rest_api: + if port is None: + port = REST_DEFAULT_PORT + + # Change to the desired directory + script_path = Path(__file__).resolve() + script_dir = script_path.parent + + server_directory = os.path.join(script_dir.parent, "server", "rest_api") + command = f"uvicorn server:app --reload --port {port}" + + # Run the command + print(f"Running REST server: {command} (inside {server_directory})") + + try: + # Start the subprocess in a new session + process = subprocess.Popen(command, shell=True, start_new_session=True, cwd=server_directory) + process.wait() + except KeyboardInterrupt: + # Handle CTRL-C + print("Terminating the server...") + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + print("Server terminated with kill()") + sys.exit(0) + + elif type == ServerChoice.ws_api: + if port is None: + port = WS_DEFAULT_PORT + + # Change to the desired directory + script_path = Path(__file__).resolve() + script_dir = script_path.parent + + server_directory = os.path.join(script_dir.parent, "server", "ws_api") + command = f"python server.py {port}" + + # Run the command + print(f"Running WS (websockets) server: {command} (inside {server_directory})") + + try: + # Start the subprocess in a new session + process = subprocess.Popen(command, shell=True, start_new_session=True, cwd=server_directory) + process.wait() + except KeyboardInterrupt: + # Handle CTRL-C + print("Terminating the server...") + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + print("Server terminated with kill()") + sys.exit(0) def run( diff --git a/memgpt/main.py b/memgpt/main.py index 9bc7fdbc..01dd947c 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -21,7 +21,7 @@ from memgpt.interface import CLIInterface as interface # for printing to termin import memgpt.agent as agent import memgpt.system as system import memgpt.constants as constants -from memgpt.cli.cli import run, attach, version +from memgpt.cli.cli import run, attach, version, server from memgpt.cli.cli_config import configure, list, add from memgpt.cli.cli_load import app as load_app from memgpt.connectors.storage import StorageConnector @@ -33,6 +33,7 @@ app.command(name="attach")(attach) app.command(name="configure")(configure) app.command(name="list")(list) app.command(name="add")(add) +app.command(name="server")(server) # load data commands app.add_typer(load_app, name="load") diff --git a/memgpt/server/constants.py b/memgpt/server/constants.py index 7447a554..d02f7dfd 100644 --- a/memgpt/server/constants.py +++ b/memgpt/server/constants.py @@ -1,3 +1,6 @@ -DEFAULT_PORT = 8282 +# WebSockets +WS_DEFAULT_PORT = 8282 +WS_CLIENT_TIMEOUT = 30 -CLIENT_TIMEOUT = 30 +# REST +REST_DEFAULT_PORT = 8283 diff --git a/memgpt/server/rest_api/interface.py b/memgpt/server/rest_api/interface.py index 55ccb525..c13e366e 100644 --- a/memgpt/server/rest_api/interface.py +++ b/memgpt/server/rest_api/interface.py @@ -18,6 +18,8 @@ class QueuingInterface(AgentInterface): items.append(self.buffer.get_nowait()) except queue.Empty: break + if len(items) > 1 and items[-1] == "STOP": + items.pop() return items def clear(self): diff --git a/memgpt/server/rest_api/server.py b/memgpt/server/rest_api/server.py index ca23e099..f51215a4 100644 --- a/memgpt/server/rest_api/server.py +++ b/memgpt/server/rest_api/server.py @@ -1,4 +1,5 @@ import asyncio +from contextlib import asynccontextmanager import json from typing import Union @@ -38,17 +39,64 @@ class Command(BaseModel): command: str -app = FastAPI() -interface = QueuingInterface() -server = SyncServer(default_interface=interface) +class CoreMemory(BaseModel): + user_id: str + agent_id: str + human: str | None = None + persona: str | None = None + + +server = None +interface = None + + +@asynccontextmanager +async def lifespan(application: FastAPI): + global server + global interface + interface = QueuingInterface() + server = SyncServer(default_interface=interface) + yield + server.save_agents() + server = None + + +app = FastAPI(lifespan=lifespan) + +# app = FastAPI() +# server = SyncServer(default_interface=interface) # server.list_agents @app.get("/agents") def list_agents(user_id: str): interface.clear() - agents_list = utils.list_agent_config_files() - return {"num_agents": len(agents_list), "agent_names": agents_list} + return server.list_agents(user_id=user_id) + + +@app.get("/agents/memory") +def get_agent_memory(user_id: str, agent_id: str): + interface.clear() + return server.get_agent_memory(user_id=user_id, agent_id=agent_id) + + +@app.put("/agents/memory") +def get_agent_memory(body: CoreMemory): + interface.clear() + new_memory_contents = {"persona": body.persona, "human": body.human} + return server.update_agent_core_memory(user_id=body.user_id, agent_id=body.agent_id, new_memory_contents=new_memory_contents) + + +@app.get("/agents/config") +def get_agent_config(user_id: str, agent_id: str): + interface.clear() + return server.get_agent_config(user_id=user_id, agent_id=agent_id) + + +@app.get("/config") +def get_server_config(user_id: str): + interface.clear() + return server.get_server_config(user_id=user_id) # server.create_agent @@ -88,6 +136,8 @@ async def user_message(body: UserMessage): # Return the streaming response using the generator return StreamingResponse(formatted_message_generator(), media_type="text/event-stream") + except HTTPException: + raise except Exception as e: raise HTTPException(status_code=500, detail=f"{e}") @@ -95,6 +145,8 @@ async def user_message(body: UserMessage): interface.clear() try: server.user_message(user_id=body.user_id, agent_id=body.agent_id, message=body.message) + except HTTPException: + raise except Exception as e: raise HTTPException(status_code=500, detail=f"{e}") return {"messages": interface.to_list()} @@ -106,6 +158,8 @@ def run_command(body: Command): interface.clear() try: response = server.run_command(user_id=body.user_id, agent_id=body.agent_id, command=body.command) + except HTTPException: + raise except Exception as e: raise HTTPException(status_code=500, detail=f"{e}") response = server.run_command(user_id=body.user_id, agent_id=body.agent_id, command=body.command) diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 0de45e86..360b4794 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -1,9 +1,12 @@ from abc import abstractmethod -from typing import Union +from typing import Union, Callable import json +from threading import Lock +from functools import wraps +from fastapi import HTTPException from memgpt.system import package_user_message -from memgpt.config import AgentConfig +from memgpt.config import AgentConfig, MemGPTConfig from memgpt.agent import Agent import memgpt.system as system import memgpt.constants as constants @@ -11,6 +14,7 @@ from memgpt.cli.cli import attach from memgpt.connectors.storage import StorageConnector import memgpt.presets.presets as presets import memgpt.utils as utils +import memgpt.server.utils as server_utils from memgpt.persistence_manager import PersistenceManager, LocalStateManager # TODO use custom interface @@ -22,10 +26,30 @@ class Server(object): """Abstract server class that supports multi-agent multi-user""" @abstractmethod - def list_agents(self, user_id: str, agent_id: str) -> str: + def list_agents(self, user_id: str) -> dict: """List all available agents to a user""" raise NotImplementedError + @abstractmethod + def get_agent_memory(self, user_id: str, agent_id: str) -> dict: + """Return the memory of an agent (core memory + non-core statistics)""" + raise NotImplementedError + + @abstractmethod + def get_agent_config(self, user_id: str, agent_id: str) -> dict: + """Return the config of an agent""" + raise NotImplementedError + + @abstractmethod + def get_server_config(self, user_id: str) -> dict: + """Return the base config""" + raise NotImplementedError + + @abstractmethod + def update_agent_core_memory(self, user_id: str, agent_id: str, new_memory_contents: dict) -> dict: + """Update the agents core memory block, return the new state""" + raise NotImplementedError + @abstractmethod def create_agent( self, @@ -51,8 +75,50 @@ class Server(object): raise NotImplementedError +class LockingServer(Server): + """Basic support for concurrency protections (all requests that modify an agent lock the agent until the operation is complete)""" + + # Locks for each agent + _agent_locks = {} + + @staticmethod + def agent_lock_decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(self, user_id: str, agent_id: str, *args, **kwargs): + # print("Locking check") + + # Initialize the lock for the agent_id if it doesn't exist + if agent_id not in self._agent_locks: + # print(f"Creating lock for agent_id = {agent_id}") + self._agent_locks[agent_id] = Lock() + + # Check if the agent is currently locked + if not self._agent_locks[agent_id].acquire(blocking=False): + # print(f"agent_id = {agent_id} is busy") + raise HTTPException(status_code=423, detail=f"Agent '{agent_id}' is currently busy.") + + try: + # Execute the function + # print(f"running function on agent_id = {agent_id}") + return func(self, user_id, agent_id, *args, **kwargs) + finally: + # Release the lock + # print(f"releasing lock on agent_id = {agent_id}") + self._agent_locks[agent_id].release() + + return wrapper + + @agent_lock_decorator + def user_message(self, user_id: str, agent_id: str, message: str) -> None: + raise NotImplementedError + + @agent_lock_decorator + def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]: + raise NotImplementedError + + # TODO actually use "user_id" for something -class SyncServer(Server): +class SyncServer(LockingServer): """Simple single-threaded / blocking server process""" def __init__( @@ -82,6 +148,14 @@ class SyncServer(Server): # The default persistence manager that will get assigned to agents ON CREATION self.default_persistence_manager_cls = default_persistence_manager_cls + def save_agents(self): + for agent_d in self.active_agents: + try: + agent_d["agent"].save() + print(f"Saved agent {agent_d['agent_id']}") + except Exception as e: + print(f"Error occured while trying to save agent {agent_d['agent_id']}:\n{e}") + def _get_agent(self, user_id: str, agent_id: str) -> Union[Agent, None]: """Get the agent object from the in-memory object store""" for d in self.active_agents: @@ -302,6 +376,7 @@ class SyncServer(Server): input_message = system.get_token_limit_warning() self._step(user_id=user_id, agent_id=agent_id, input_message=input_message) + @LockingServer.agent_lock_decorator def user_message(self, user_id: str, agent_id: str, message: str) -> None: """Process an incoming user message and feed it through the MemGPT agent""" from memgpt.utils import printd @@ -321,6 +396,7 @@ class SyncServer(Server): # Run the agent state forward self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message) + @LockingServer.agent_lock_decorator def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]: """Run a command on the agent""" # If the input begins with a command prefix, attempt to process it as a command @@ -363,3 +439,83 @@ class SyncServer(Server): print(f"Created new agent from config: {agent}") return agent.config.name + + def list_agents(self, user_id: str) -> dict: + """List all available agents to a user""" + agents_list = utils.list_agent_config_files() + return {"num_agents": len(agents_list), "agent_names": agents_list} + + def get_agent_memory(self, user_id: str, agent_id: str) -> dict: + """Return the memory of an agent (core memory + non-core statistics)""" + # Get the agent object (loaded in memory) + memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) + + core_memory = memgpt_agent.memory + recall_memory = memgpt_agent.persistence_manager.recall_memory + archival_memory = memgpt_agent.persistence_manager.archival_memory + + memory_obj = { + "core_memory": { + "persona": core_memory.persona, + "human": core_memory.human, + }, + "recall_memory": len(recall_memory) if recall_memory is not None else None, + "archival_memory": len(archival_memory) if archival_memory is not None else None, + } + + return memory_obj + + def get_agent_config(self, user_id: str, agent_id: str) -> dict: + """Return the config of an agent""" + # Get the agent object (loaded in memory) + memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) + agent_config = vars(memgpt_agent.config) + + return agent_config + + def get_server_config(self, user_id: str) -> dict: + """Return the base config""" + base_config = vars(MemGPTConfig.load()) + + def clean_keys(config): + config_copy = config.copy() + for k, v in config.items(): + if k == "key" or "_key" in k: + config_copy[k] = server_utils.shorten_key_middle(v, chars_each_side=5) + return config_copy + + clean_base_config = clean_keys(base_config) + return clean_base_config + + def update_agent_core_memory(self, user_id: str, agent_id: str, new_memory_contents: dict) -> dict: + """Update the agents core memory block, return the new state""" + # Get the agent object (loaded in memory) + memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) + + old_core_memory = self.get_agent_memory(user_id=user_id, agent_id=agent_id)["core_memory"] + new_core_memory = old_core_memory.copy() + + modified = False + if "persona" in new_memory_contents and new_memory_contents["persona"] is not None: + new_persona = new_memory_contents["persona"] + if old_core_memory["persona"] != new_persona: + new_core_memory["persona"] = new_persona + memgpt_agent.memory.edit_persona(new_persona) + modified = True + + elif "human" in new_memory_contents and new_memory_contents["human"] is not None: + new_human = new_memory_contents["human"] + if old_core_memory["human"] != new_human: + new_core_memory["human"] = new_human + memgpt_agent.memory.edit_human(new_human) + modified = True + + # If we modified the memory contents, we need to rebuild the memory block inside the system message + if modified: + memgpt_agent.rebuild_memory() + + return { + "old_core_memory": old_core_memory, + "new_core_memory": new_core_memory, + "modified": modified, + } diff --git a/memgpt/server/utils.py b/memgpt/server/utils.py index cc444166..14f684f0 100644 --- a/memgpt/server/utils.py +++ b/memgpt/server/utils.py @@ -24,3 +24,21 @@ def print_server_response(response): print(response) else: print(response) + + +def shorten_key_middle(key_string, chars_each_side=3): + """ + Shortens a key string by showing a specified number of characters on each side and adding an ellipsis in the middle. + + Args: + key_string (str): The key string to be shortened. + chars_each_side (int): The number of characters to show on each side of the ellipsis. + + Returns: + str: The shortened key string with an ellipsis in the middle. + """ + key_length = len(key_string) + if key_length <= 2 * chars_each_side: + return "..." # Return ellipsis if the key is too short + else: + return key_string[:chars_each_side] + "..." + key_string[-chars_each_side:] diff --git a/memgpt/server/ws_api/example_client.py b/memgpt/server/ws_api/example_client.py index e8b4d393..e49d226b 100644 --- a/memgpt/server/ws_api/example_client.py +++ b/memgpt/server/ws_api/example_client.py @@ -4,7 +4,7 @@ import json import websockets import memgpt.server.ws_api.protocol as protocol -from memgpt.server.constants import DEFAULT_PORT, CLIENT_TIMEOUT +from memgpt.server.constants import WS_DEFAULT_PORT, WS_CLIENT_TIMEOUT from memgpt.server.utils import condition_to_stop_receiving, print_server_response @@ -26,7 +26,7 @@ async def send_message_and_print_replies(websocket, user_message, agent_id): # Wait for messages in a loop, since the server may send a few while True: - response = await asyncio.wait_for(websocket.recv(), CLIENT_TIMEOUT) + response = await asyncio.wait_for(websocket.recv(), WS_CLIENT_TIMEOUT) response = json.loads(response) if CLEAN_RESPONSES: @@ -44,7 +44,7 @@ async def basic_cli_client(): Meant to illustrate how to use the server.py process, so limited in features (only supports sending user messages) """ - uri = f"ws://localhost:{DEFAULT_PORT}" + uri = f"ws://localhost:{WS_DEFAULT_PORT}" closed_on_message = False retry_attempts = 0 diff --git a/memgpt/server/ws_api/server.py b/memgpt/server/ws_api/server.py index 3cecc1c8..480e00f9 100644 --- a/memgpt/server/ws_api/server.py +++ b/memgpt/server/ws_api/server.py @@ -1,26 +1,37 @@ import asyncio import json +import signal +import sys import traceback import websockets from memgpt.server.server import SyncServer from memgpt.server.ws_api.interface import SyncWebSocketInterface -from memgpt.server.constants import DEFAULT_PORT +from memgpt.server.constants import WS_DEFAULT_PORT import memgpt.server.ws_api.protocol as protocol import memgpt.system as system import memgpt.constants as memgpt_constants class WebSocketServer: - def __init__(self, host="localhost", port=DEFAULT_PORT): + def __init__(self, host="localhost", port=WS_DEFAULT_PORT): self.host = host self.port = port self.interface = SyncWebSocketInterface() self.server = SyncServer(default_interface=self.interface) - def __del__(self): - self.interface.close() + def shutdown_server(self): + try: + self.server.save_agents() + print(f"Saved agents") + except Exception as e: + print(f"Saving agents failed with: {e}") + try: + self.interface.close() + print(f"Closed the WS interface") + except Exception as e: + print(f"Closing the WS interface failed with: {e}") def initialize_server(self): print("Server is initializing...") @@ -102,6 +113,36 @@ class WebSocketServer: self.interface.unregister_client(websocket) +def start_server(): + # Check if a port argument is provided + port = WS_DEFAULT_PORT + if len(sys.argv) > 1: + try: + port = int(sys.argv[1]) + except ValueError: + print(f"Invalid port number. Using default port {port}.") + + server = WebSocketServer(port=port) + + def handle_sigterm(*args): + # Perform necessary cleanup + print("SIGTERM received, shutting down...") + # Note: This should be quick and not involve asynchronous calls + print("Shutting down the server...") + server.shutdown_server() + print("Server has been shut down.") + sys.exit(0) + + signal.signal(signal.SIGTERM, handle_sigterm) + + try: + asyncio.run(server.run()) + except KeyboardInterrupt: + print("Shutting down the server...") + finally: + server.shutdown_server() + print("Server has been shut down.") + + if __name__ == "__main__": - server = WebSocketServer() - asyncio.run(server.run()) + start_server() diff --git a/tests/test_websocket_server.py b/tests/test_websocket_server.py index a2583f86..4a530e4d 100644 --- a/tests/test_websocket_server.py +++ b/tests/test_websocket_server.py @@ -4,7 +4,7 @@ import json import websockets import pytest -from memgpt.server.constants import DEFAULT_PORT +from memgpt.server.constants import WS_DEFAULT_PORT from memgpt.server.ws_api.server import WebSocketServer from memgpt.config import AgentConfig @@ -30,7 +30,7 @@ async def test_websocket_server(): # ) test_config = {} - uri = f"ws://{host}:{DEFAULT_PORT}" + uri = f"ws://{host}:{WS_DEFAULT_PORT}" try: async with websockets.connect(uri) as websocket: # Initialize the server with a test config