diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index 43bcbeea..21896994 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -10,7 +10,7 @@ import openai from llama_index import set_global_service_context from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext -import memgpt.interface # for printing to terminal +from memgpt.interface import CLIInterface as interface # for printing to terminal from memgpt.cli.cli_config import configure import memgpt.agent as agent import memgpt.system as system @@ -128,7 +128,7 @@ def run( agent_config.save() # load existing agent - memgpt_agent = Agent.load_agent(memgpt.interface, agent_config) + memgpt_agent = Agent.load_agent(interface, agent_config) else: # create new agent # create new agent config: override defaults with args if provided typer.secho("Creating new agent...", fg=typer.colors.GREEN) @@ -158,7 +158,7 @@ def run( agent_config.model, utils.get_persona_text(agent_config.persona), utils.get_human_text(agent_config.human), - memgpt.interface, + interface, persistence_manager, ) diff --git a/memgpt/config.py b/memgpt/config.py index c57ed4b6..4ca22a74 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -17,7 +17,7 @@ from colorama import Fore, Style from typing import List, Type import memgpt.utils as utils -import memgpt.interface as interface +from memgpt.interface import CLIInterface as interface from memgpt.personas.personas import get_persona_text from memgpt.humans.humans import get_human_text from memgpt.constants import MEMGPT_DIR, LLM_MAX_TOKENS @@ -109,7 +109,9 @@ class MemGPTConfig: # read config values model = config.get("defaults", "model") context_window = ( - config.get("defaults", "context_window") if config.has_option("defaults", "context_window") else LLM_MAX_TOKENS["DEFAULT"] + int(config.get("defaults", "context_window")) + if config.has_option("defaults", "context_window") + else LLM_MAX_TOKENS["DEFAULT"] ) preset = config.get("defaults", "preset") model_endpoint = config.get("defaults", "model_endpoint") diff --git a/memgpt/interface.py b/memgpt/interface.py index c78f7d73..1639455f 100644 --- a/memgpt/interface.py +++ b/memgpt/interface.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod import json import re @@ -13,210 +14,238 @@ DEBUG = False # only dumps important messages in the terminal STRIP_UI = False -def important_message(msg): - fstr = f"{Fore.MAGENTA}{Style.BRIGHT}{{msg}}{Style.RESET_ALL}" - if STRIP_UI: - fstr = "{msg}" - print(fstr.format(msg=msg)) +class AgentInterface(ABC): + """Interfaces handle MemGPT-related events (observer pattern)""" + + @abstractmethod + def user_message(self, msg): + """MemGPT receives a user message""" + raise NotImplementedError + + @abstractmethod + def internal_monologue(self, msg): + """MemGPT generates some internal monologue""" + raise NotImplementedError + + @abstractmethod + def assistant_message(self, msg): + """MemGPT uses send_message""" + raise NotImplementedError + + @abstractmethod + def function_message(self, msg): + """MemGPT calls a function""" + raise NotImplementedError -def warning_message(msg): - fstr = f"{Fore.RED}{Style.BRIGHT}{{msg}}{Style.RESET_ALL}" - if STRIP_UI: - fstr = "{msg}" - else: +class CLIInterface(AgentInterface): + """Basic interface for dumping agent events to the command-line""" + + @staticmethod + def important_message(msg): + fstr = f"{Fore.MAGENTA}{Style.BRIGHT}{{msg}}{Style.RESET_ALL}" + if STRIP_UI: + fstr = "{msg}" print(fstr.format(msg=msg)) - -def internal_monologue(msg): - # ANSI escape code for italic is '\x1B[3m' - fstr = f"\x1B[3m{Fore.LIGHTBLACK_EX}💭 {{msg}}{Style.RESET_ALL}" - if STRIP_UI: - fstr = "{msg}" - print(fstr.format(msg=msg)) - - -def assistant_message(msg): - fstr = f"{Fore.YELLOW}{Style.BRIGHT}🤖 {Fore.YELLOW}{{msg}}{Style.RESET_ALL}" - if STRIP_UI: - fstr = "{msg}" - print(fstr.format(msg=msg)) - - -def memory_message(msg): - fstr = f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{{msg}}{Style.RESET_ALL}" - if STRIP_UI: - fstr = "{msg}" - print(fstr.format(msg=msg)) - - -def system_message(msg): - fstr = f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}" - if STRIP_UI: - fstr = "{msg}" - print(fstr.format(msg=msg)) - - -def user_message(msg, raw=False, dump=False, debug=DEBUG): - def print_user_message(icon, msg, printf=print): + @staticmethod + def warning_message(msg): + fstr = f"{Fore.RED}{Style.BRIGHT}{{msg}}{Style.RESET_ALL}" if STRIP_UI: - printf(f"{icon} {msg}") + fstr = "{msg}" else: - printf(f"{Fore.GREEN}{Style.BRIGHT}{icon} {Fore.GREEN}{msg}{Style.RESET_ALL}") + print(fstr.format(msg=msg)) - def printd_user_message(icon, msg): - return print_user_message(icon, msg) + @staticmethod + def internal_monologue(msg): + # ANSI escape code for italic is '\x1B[3m' + fstr = f"\x1B[3m{Fore.LIGHTBLACK_EX}💭 {{msg}}{Style.RESET_ALL}" + if STRIP_UI: + fstr = "{msg}" + print(fstr.format(msg=msg)) - if not (raw or dump or debug): - # we do not want to repeat the message in normal use - return + @staticmethod + def assistant_message(msg): + fstr = f"{Fore.YELLOW}{Style.BRIGHT}🤖 {Fore.YELLOW}{{msg}}{Style.RESET_ALL}" + if STRIP_UI: + fstr = "{msg}" + print(fstr.format(msg=msg)) - if isinstance(msg, str): - if raw: - printd_user_message("🧑", msg) + @staticmethod + def memory_message(msg): + fstr = f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{{msg}}{Style.RESET_ALL}" + if STRIP_UI: + fstr = "{msg}" + print(fstr.format(msg=msg)) + + @staticmethod + def system_message(msg): + fstr = f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}" + if STRIP_UI: + fstr = "{msg}" + print(fstr.format(msg=msg)) + + @staticmethod + def user_message(msg, raw=False, dump=False, debug=DEBUG): + def print_user_message(icon, msg, printf=print): + if STRIP_UI: + printf(f"{icon} {msg}") + else: + printf(f"{Fore.GREEN}{Style.BRIGHT}{icon} {Fore.GREEN}{msg}{Style.RESET_ALL}") + + def printd_user_message(icon, msg): + return print_user_message(icon, msg) + + if not (raw or dump or debug): + # we do not want to repeat the message in normal use return - else: - try: - msg_json = json.loads(msg) - except: - printd(f"Warning: failed to parse user message into json") + + if isinstance(msg, str): + if raw: printd_user_message("🧑", msg) return - if msg_json["type"] == "user_message": - if dump: - print_user_message("🧑", msg_json["message"]) - return - msg_json.pop("type") - printd_user_message("🧑", msg_json) - elif msg_json["type"] == "heartbeat": - if debug: + else: + try: + msg_json = json.loads(msg) + except: + printd(f"Warning: failed to parse user message into json") + printd_user_message("🧑", msg) + return + if msg_json["type"] == "user_message": + if dump: + print_user_message("🧑", msg_json["message"]) + return msg_json.pop("type") - printd_user_message("💓", msg_json) - elif dump: - print_user_message("💓", msg_json) + printd_user_message("🧑", msg_json) + elif msg_json["type"] == "heartbeat": + if debug: + msg_json.pop("type") + printd_user_message("💓", msg_json) + elif dump: + print_user_message("💓", msg_json) + return + + elif msg_json["type"] == "system_message": + msg_json.pop("type") + printd_user_message("🖥️", msg_json) + else: + printd_user_message("🧑", msg_json) + + @staticmethod + def function_message(msg, debug=DEBUG): + def print_function_message(icon, msg, color=Fore.RED, printf=print): + if STRIP_UI: + printf(f"⚡{icon} [function] {msg}") + else: + printf(f"{color}{Style.BRIGHT}⚡{icon} [function] {color}{msg}{Style.RESET_ALL}") + + def printd_function_message(icon, msg, color=Fore.RED): + return print_function_message(icon, msg, color, printf=(print if debug else printd)) + + if isinstance(msg, dict): + printd_function_message("", msg) return - elif msg_json["type"] == "system_message": - msg_json.pop("type") - printd_user_message("🖥️", msg_json) - else: - printd_user_message("🧑", msg_json) - - -def function_message(msg, debug=DEBUG): - def print_function_message(icon, msg, color=Fore.RED, printf=print): - if STRIP_UI: - printf(f"⚡{icon} [function] {msg}") - else: - printf(f"{color}{Style.BRIGHT}⚡{icon} [function] {color}{msg}{Style.RESET_ALL}") - - def printd_function_message(icon, msg, color=Fore.RED): - return print_function_message(icon, msg, color, printf=(print if debug else printd)) - - if isinstance(msg, dict): - printd_function_message("", msg) - return - - if msg.startswith("Success"): - printd_function_message("🟢", msg) - elif msg.startswith("Error: "): - printd_function_message("🔴", msg) - elif msg.startswith("Running "): - if debug: - printd_function_message("", msg) - else: - match = re.search(r"Running (\w+)\((.*)\)", msg) - if match: - function_name = match.group(1) - function_args = match.group(2) - if "memory" in function_name: - print_function_message("🧠", f"updating memory with {function_name}") - try: - msg_dict = eval(function_args) - if function_name == "archival_memory_search": - output = f'\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}' - if STRIP_UI: - print(output) - else: - print(f"{Fore.RED}{output}{Style.RESET_ALL}") - elif function_name == "archival_memory_insert": - output = f'\t→ {msg_dict["content"]}' - if STRIP_UI: - print(output) - else: - print(f"{Style.BRIGHT}{Fore.RED}{output}{Style.RESET_ALL}") - else: - if STRIP_UI: - print(f'\t {msg_dict["old_content"]}\n\t→ {msg_dict["new_content"]}') - else: - print( - f'{Style.BRIGHT}\t{Fore.RED} {msg_dict["old_content"]}\n\t{Fore.GREEN}→ {msg_dict["new_content"]}{Style.RESET_ALL}' - ) - except Exception as e: - printd(str(e)) - printd(msg_dict) - pass - else: - printd(f"Warning: did not recognize function message") + if msg.startswith("Success"): + printd_function_message("🟢", msg) + elif msg.startswith("Error: "): + printd_function_message("🔴", msg) + elif msg.startswith("Running "): + if debug: printd_function_message("", msg) - else: - try: - msg_dict = json.loads(msg) - if "status" in msg_dict and msg_dict["status"] == "OK": - printd_function_message("", str(msg), color=Fore.GREEN) else: - printd_function_message("", str(msg), color=Fore.RED) - except Exception: - print(f"Warning: did not recognize function message {type(msg)} {msg}") - printd_function_message("", msg) + match = re.search(r"Running (\w+)\((.*)\)", msg) + if match: + function_name = match.group(1) + function_args = match.group(2) + if "memory" in function_name: + print_function_message("🧠", f"updating memory with {function_name}") + try: + msg_dict = eval(function_args) + if function_name == "archival_memory_search": + output = f'\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}' + if STRIP_UI: + print(output) + else: + print(f"{Fore.RED}{output}{Style.RESET_ALL}") + elif function_name == "archival_memory_insert": + output = f'\t→ {msg_dict["content"]}' + if STRIP_UI: + print(output) + else: + print(f"{Style.BRIGHT}{Fore.RED}{output}{Style.RESET_ALL}") + else: + if STRIP_UI: + print(f'\t {msg_dict["old_content"]}\n\t→ {msg_dict["new_content"]}') + else: + print( + f'{Style.BRIGHT}\t{Fore.RED} {msg_dict["old_content"]}\n\t{Fore.GREEN}→ {msg_dict["new_content"]}{Style.RESET_ALL}' + ) + except Exception as e: + printd(str(e)) + printd(msg_dict) + pass + else: + printd(f"Warning: did not recognize function message") + printd_function_message("", msg) + else: + try: + msg_dict = json.loads(msg) + if "status" in msg_dict and msg_dict["status"] == "OK": + printd_function_message("", str(msg), color=Fore.GREEN) + else: + printd_function_message("", str(msg), color=Fore.RED) + except Exception: + print(f"Warning: did not recognize function message {type(msg)} {msg}") + printd_function_message("", msg) + @staticmethod + def print_messages(message_sequence, dump=False): + idx = len(message_sequence) + for msg in message_sequence: + if dump: + print(f"[{idx}] ", end="") + idx -= 1 + role = msg["role"] + content = msg["content"] -def print_messages(message_sequence, dump=False): - idx = len(message_sequence) - for msg in message_sequence: - if dump: - print(f"[{idx}] ", end="") - idx -= 1 - role = msg["role"] - content = msg["content"] - - if role == "system": - system_message(content) - elif role == "assistant": - # Differentiate between internal monologue, function calls, and messages - if msg.get("function_call"): - if content is not None: - internal_monologue(content) - # I think the next one is not up to date - # function_message(msg["function_call"]) - args = json.loads(msg["function_call"].get("arguments")) - assistant_message(args.get("message")) - # assistant_message(content) + if role == "system": + CLIInterface.system_message(content) + elif role == "assistant": + # Differentiate between internal monologue, function calls, and messages + if msg.get("function_call"): + if content is not None: + CLIInterface.internal_monologue(content) + # I think the next one is not up to date + # function_message(msg["function_call"]) + args = json.loads(msg["function_call"].get("arguments")) + CLIInterface.assistant_message(args.get("message")) + # assistant_message(content) + else: + CLIInterface.internal_monologue(content) + elif role == "user": + CLIInterface.user_message(content, dump=dump) + elif role == "function": + CLIInterface.function_message(content, debug=dump) else: - internal_monologue(content) - elif role == "user": - user_message(content, dump=dump) - elif role == "function": - function_message(content, debug=dump) - else: - print(f"Unknown role: {content}") + print(f"Unknown role: {content}") + @staticmethod + def print_messages_simple(message_sequence): + for msg in message_sequence: + role = msg["role"] + content = msg["content"] -def print_messages_simple(message_sequence): - for msg in message_sequence: - role = msg["role"] - content = msg["content"] + if role == "system": + CLIInterface.system_message(content) + elif role == "assistant": + CLIInterface.assistant_message(content) + elif role == "user": + CLIInterface.user_message(content, raw=True) + else: + print(f"Unknown role: {content}") - if role == "system": - system_message(content) - elif role == "assistant": - assistant_message(content) - elif role == "user": - user_message(content, raw=True) - else: - print(f"Unknown role: {content}") - - -def print_messages_raw(message_sequence): - for msg in message_sequence: - print(msg) + @staticmethod + def print_messages_raw(message_sequence): + for msg in message_sequence: + print(msg) diff --git a/memgpt/main.py b/memgpt/main.py index d6b1b53a..eb1e701a 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -14,11 +14,10 @@ import typer from rich.console import Console from prettytable import PrettyTable -from .interface import print_messages console = Console() -import memgpt.interface # for printing to terminal +from memgpt.interface import CLIInterface as interface # for printing to terminal import memgpt.agent as agent import memgpt.system as system import memgpt.utils as utils @@ -208,7 +207,7 @@ def main( use_azure_openai, strip_ui, ): - memgpt.interface.STRIP_UI = strip_ui + interface.STRIP_UI = strip_ui utils.DEBUG = debug logging.getLogger().setLevel(logging.CRITICAL) if debug: @@ -235,7 +234,7 @@ def main( archival_storage_sqldb, ) ): - memgpt.interface.important_message("⚙️ Using legacy command line arguments.") + interface.important_message("⚙️ Using legacy command line arguments.") model = model if model is None: model = constants.DEFAULT_MEMGPT_MODEL @@ -314,9 +313,9 @@ def main( else: cfg = Config.config_init() - memgpt.interface.important_message("Running... [exit by typing '/exit', list available commands with '/help']") + interface.important_message("Running... [exit by typing '/exit', list available commands with '/help']") if cfg.model != constants.DEFAULT_MEMGPT_MODEL: - memgpt.interface.warning_message( + interface.warning_message( f"⛔️ Warning - you are running MemGPT with {cfg.model}, which is not officially supported (yet). Expect bugs!" ) @@ -329,7 +328,7 @@ def main( persistence_manager = InMemoryStateManager() if archival_storage_files_compute_embeddings: - memgpt.interface.important_message( + interface.important_message( f"(legacy) To avoid computing embeddings next time, replace --archival_storage_files_compute_embeddings={archival_storage_files_compute_embeddings} with\n\t --archival_storage_faiss_path={cfg.archival_storage_index} (if your files haven't changed)." ) @@ -343,10 +342,11 @@ def main( cfg.model, personas.get_persona_text(*chosen_persona), humans.get_human_text(*chosen_human), - memgpt.interface, + interface, persistence_manager, ) - print_messages = memgpt.interface.print_messages + + print_messages = interface.print_messages print_messages(memgpt_agent.messages) if cfg.load_type == "sql": # TODO: move this into config.py in a clean manner @@ -477,13 +477,13 @@ def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_ui=Fals command = user_input.strip().split() amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0 if amount == 0: - memgpt.interface.print_messages(memgpt_agent.messages, dump=True) + interface.print_messages(memgpt_agent.messages, dump=True) else: - memgpt.interface.print_messages(memgpt_agent.messages[-min(amount, len(memgpt_agent.messages)) :], dump=True) + interface.print_messages(memgpt_agent.messages[-min(amount, len(memgpt_agent.messages)) :], dump=True) continue elif user_input.lower() == "/dumpraw": - memgpt.interface.print_messages_raw(memgpt_agent.messages) + interface.print_messages_raw(memgpt_agent.messages) continue elif user_input.lower() == "/memory": @@ -549,7 +549,7 @@ def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_ui=Fals # No skip options elif user_input.lower() == "/wipe": - memgpt_agent = agent.Agent(memgpt.interface) + memgpt_agent = agent.Agent(interface) user_message = None elif user_input.lower() == "/heartbeat": diff --git a/memgpt/server/__init__.py b/memgpt/server/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/memgpt/server/constants.py b/memgpt/server/constants.py new file mode 100644 index 00000000..7447a554 --- /dev/null +++ b/memgpt/server/constants.py @@ -0,0 +1,3 @@ +DEFAULT_PORT = 8282 + +CLIENT_TIMEOUT = 30 diff --git a/memgpt/server/utils.py b/memgpt/server/utils.py new file mode 100644 index 00000000..e6e4371c --- /dev/null +++ b/memgpt/server/utils.py @@ -0,0 +1,23 @@ +def condition_to_stop_receiving(response): + """Determines when to stop listening to the server""" + return response.get("type") == "agent_response_end" + + +def print_server_response(response): + """Turn response json into a nice print""" + if response["type"] == "agent_response_start": + print("[agent.step start]") + elif response["type"] == "agent_response_end": + print("[agent.step end]") + elif response["type"] == "agent_response": + msg = response["message"] + if response["message_type"] == "internal_monologue": + print(f"[inner thoughts] {msg}") + elif response["message_type"] == "assistant_message": + print(f"{msg}") + elif response["message_type"] == "function_message": + pass + else: + print(response) + else: + print(response) diff --git a/memgpt/server/websocket_client.py b/memgpt/server/websocket_client.py new file mode 100644 index 00000000..32020470 --- /dev/null +++ b/memgpt/server/websocket_client.py @@ -0,0 +1,80 @@ +import asyncio +import json + +import websockets + +import memgpt.server.websocket_protocol as protocol +from memgpt.server.websocket_server import WebSocketServer +from memgpt.server.constants import DEFAULT_PORT, CLIENT_TIMEOUT +from memgpt.server.utils import condition_to_stop_receiving, print_server_response + + +# CLEAN_RESPONSES = False # print the raw server responses (JSON) +CLEAN_RESPONSES = True # make the server responses cleaner + +# LOAD_AGENT = None # create a brand new agent +LOAD_AGENT = "agent_26" # load an existing agent + + +async def basic_cli_client(): + """Basic example of a MemGPT CLI client that connects to a MemGPT server.py process via WebSockets + + Meant to illustrate how to use the server.py process, so limited in features (only supports sending user messages) + """ + uri = f"ws://localhost:{DEFAULT_PORT}" + + async with websockets.connect(uri) as websocket: + if LOAD_AGENT is not None: + # Load existing agent + print("Sending load message to server...") + await websocket.send(protocol.client_command_load(LOAD_AGENT)) + + else: + # Initialize new agent + print("Sending config to server...") + example_config = { + "persona": "sam_pov", + "human": "cs_phd", + "model": "gpt-4-1106-preview", # gpt-4-turbo + } + await websocket.send(protocol.client_command_create(example_config)) + # Wait for the response + response = await websocket.recv() + response = json.loads(response) + print(f"Server response:\n{json.dumps(response, indent=2)}") + + await asyncio.sleep(1) + + while True: + user_input = input("\nEnter your message: ") + print("\n") + + # Send a message to the agent + await websocket.send(protocol.client_user_message(str(user_input))) + + # Wait for messages in a loop, since the server may send a few + while True: + try: + response = await asyncio.wait_for(websocket.recv(), CLIENT_TIMEOUT) + response = json.loads(response) + + if CLEAN_RESPONSES: + print_server_response(response) + else: + print(f"Server response:\n{json.dumps(response, indent=2)}") + + # Check for a specific condition to break the loop + if condition_to_stop_receiving(response): + break + except asyncio.TimeoutError: + print("Timeout waiting for the server response.") + break + except websockets.exceptions.ConnectionClosedError: + print("Connection to server was lost.") + break + except Exception as e: + print(f"An error occurred: {e}") + break + + +asyncio.run(basic_cli_client()) diff --git a/memgpt/server/websocket_interface.py b/memgpt/server/websocket_interface.py new file mode 100644 index 00000000..932b902c --- /dev/null +++ b/memgpt/server/websocket_interface.py @@ -0,0 +1,108 @@ +import asyncio +import threading + + +from memgpt.interface import AgentInterface +import memgpt.server.websocket_protocol as protocol + + +class BaseWebSocketInterface(AgentInterface): + """Interface for interacting with a MemGPT agent over a WebSocket""" + + def __init__(self): + self.clients = set() + + def register_client(self, websocket): + """Register a new client connection""" + self.clients.add(websocket) + + def unregister_client(self, websocket): + """Unregister a client connection""" + self.clients.remove(websocket) + + +class AsyncWebSocketInterface(BaseWebSocketInterface): + """WebSocket calls are async""" + + async def user_message(self, msg): + """Handle reception of a user message""" + # Logic to process the user message and possibly trigger agent's response + pass + + async def internal_monologue(self, msg): + """Handle the agent's internal monologue""" + print(msg) + # Send the internal monologue to all clients + if self.clients: # Check if there are any clients connected + await asyncio.gather(*[client.send(protocol.server_agent_internal_monologue(msg)) for client in self.clients]) + + async def assistant_message(self, msg): + """Handle the agent sending a message""" + print(msg) + # Send the assistant's message to all clients + if self.clients: + await asyncio.gather(*[client.send(protocol.server_agent_assistant_message(msg)) for client in self.clients]) + + async def function_message(self, msg): + """Handle the agent calling a function""" + print(msg) + # Send the function call message to all clients + if self.clients: + await asyncio.gather(*[client.send(protocol.server_agent_function_message(msg)) for client in self.clients]) + + +class SyncWebSocketInterface(BaseWebSocketInterface): + def __init__(self): + super().__init__() + self.clients = set() + self.loop = asyncio.new_event_loop() # Create a new event loop + self.thread = threading.Thread(target=self._run_event_loop, daemon=True) + self.thread.start() + + def _run_event_loop(self): + """Run the dedicated event loop and handle its closure.""" + asyncio.set_event_loop(self.loop) + try: + self.loop.run_forever() + finally: + # Run the cleanup tasks in the event loop + self.loop.run_until_complete(self.loop.shutdown_asyncgens()) + self.loop.close() + + def _run_async(self, coroutine): + """Schedule coroutine to be run in the dedicated event loop.""" + if not self.loop.is_closed(): + asyncio.run_coroutine_threadsafe(coroutine, self.loop) + + async def _send_to_all_clients(self, clients, msg): + """Asynchronously sends a message to all clients.""" + if clients: + await asyncio.gather(*(client.send(msg) for client in clients)) + + def user_message(self, msg): + """Handle reception of a user message""" + # Logic to process the user message and possibly trigger agent's response + pass + + def internal_monologue(self, msg): + """Handle the agent's internal monologue""" + print(msg) + if self.clients: + self._run_async(self._send_to_all_clients(self.clients, protocol.server_agent_internal_monologue(msg))) + + def assistant_message(self, msg): + """Handle the agent sending a message""" + print(msg) + if self.clients: + self._run_async(self._send_to_all_clients(self.clients, protocol.server_agent_assistant_message(msg))) + + def function_message(self, msg): + """Handle the agent calling a function""" + print(msg) + if self.clients: + self._run_async(self._send_to_all_clients(self.clients, protocol.server_agent_function_message(msg))) + + def close(self): + """Shut down the WebSocket interface and its event loop.""" + self.loop.call_soon_threadsafe(self.loop.stop) # Signal the loop to stop + self.thread.join() # Wait for the thread to finish diff --git a/memgpt/server/websocket_protocol.py b/memgpt/server/websocket_protocol.py new file mode 100644 index 00000000..7c39f810 --- /dev/null +++ b/memgpt/server/websocket_protocol.py @@ -0,0 +1,109 @@ +import json + +# Server -> client + + +def server_error(msg): + """General server error""" + return json.dumps( + { + "type": "server_error", + "message": msg, + } + ) + + +def server_command_response(status): + return json.dumps( + { + "type": "command_response", + "status": status, + } + ) + + +def server_agent_response_error(msg): + return json.dumps( + { + "type": "agent_response_error", + "message": msg, + } + ) + + +def server_agent_response_start(): + return json.dumps( + { + "type": "agent_response_start", + } + ) + + +def server_agent_response_end(): + return json.dumps( + { + "type": "agent_response_end", + } + ) + + +def server_agent_internal_monologue(msg): + return json.dumps( + { + "type": "agent_response", + "message_type": "internal_monologue", + "message": msg, + } + ) + + +def server_agent_assistant_message(msg): + return json.dumps( + { + "type": "agent_response", + "message_type": "assistant_message", + "message": msg, + } + ) + + +def server_agent_function_message(msg): + return json.dumps( + { + "type": "agent_response", + "message_type": "function_message", + "message": msg, + } + ) + + +# Client -> server + + +def client_user_message(msg): + return json.dumps( + { + "type": "user_message", + "message": msg, + } + ) + + +def client_command_create(config): + return json.dumps( + { + "type": "command", + "command": "create_agent", + "config": config, + } + ) + + +def client_command_load(agent_name): + return json.dumps( + { + "type": "command", + "command": "load_agent", + "name": agent_name, + } + ) diff --git a/memgpt/server/websocket_server.py b/memgpt/server/websocket_server.py new file mode 100644 index 00000000..8df0d04e --- /dev/null +++ b/memgpt/server/websocket_server.py @@ -0,0 +1,162 @@ +import asyncio +import json + +import websockets + +from memgpt.server.websocket_interface import SyncWebSocketInterface +from memgpt.server.constants import DEFAULT_PORT +import memgpt.server.websocket_protocol as protocol +import memgpt.system as system +import memgpt.constants as memgpt_constants + + +class WebSocketServer: + def __init__(self, host="localhost", port=DEFAULT_PORT): + self.host = host + self.port = port + self.interface = SyncWebSocketInterface() + self.agent = None + + def run_step(self, user_message, first_message=False, no_verify=False): + while True: + new_messages, heartbeat_request, function_failed, token_warning = self.agent.step( + user_message, first_message=first_message, skip_verify=no_verify + ) + + if token_warning: + user_message = system.get_token_limit_warning() + elif function_failed: + user_message = system.get_heartbeat(memgpt_constants.FUNC_FAILED_HEARTBEAT_MESSAGE) + elif heartbeat_request: + user_message = system.get_heartbeat(memgpt_constants.REQ_HEARTBEAT_MESSAGE) + else: + # return control + break + + async def handle_client(self, websocket, path): + self.interface.register_client(websocket) + try: + # async for message in websocket: + while True: + message = await websocket.recv() + + # Assuming the message is a JSON string + data = json.loads(message) + + if data["type"] == "command": + # Create a new agent + if data["command"] == "create_agent": + try: + self.agent = self.create_new_agent(data["config"]) + await websocket.send(protocol.server_command_response("OK: Agent initialized")) + except Exception as e: + self.agent = None + await websocket.send(protocol.server_command_response(f"Error: Failed to init agent - {str(e)}")) + + # Load an existing agent + elif data["command"] == "load_agent": + agent_name = data.get("name") + if agent_name is not None: + try: + self.agent = self.load_agent(agent_name) + await websocket.send(protocol.server_command_response(f"OK: Agent '{agent_name}' loaded")) + except Exception as e: + self.agent = None + await websocket.send( + protocol.server_command_response(f"Error: Failed to load agent '{agent_name}' - {str(e)}") + ) + else: + await websocket.send(protocol.server_command_response(f"Error: 'name' not provided")) + + else: + print(f"[server] unrecognized client command type: {data}") + await websocket.send(protocol.server_error(f"unrecognized client command type: {data}")) + + elif data["type"] == "user_message": + user_message = data["message"] + + if self.agent is None: + await websocket.send(protocol.server_agent_response_error("No agent has been initialized")) + + await websocket.send(protocol.server_agent_response_start()) + self.run_step(user_message) + await asyncio.sleep(1) # pause before sending the terminating message, w/o this messages may be missed + await websocket.send(protocol.server_agent_response_end()) + + # ... handle other message types as needed ... + else: + print(f"[server] unrecognized client package data type: {data}") + await websocket.send(protocol.server_error(f"unrecognized client package data type: {data}")) + + except websockets.exceptions.ConnectionClosed: + print(f"[server] connection with client was closed") + finally: + # TODO autosave the agent + + self.interface.unregister_client(websocket) + + def create_new_agent(self, config): + """Config is json that arrived over websocket, so we need to turn it into a config object""" + from memgpt.config import AgentConfig + import memgpt.presets as presets + import memgpt.utils as utils + from memgpt.persistence_manager import InMemoryStateManager + + print("Creating new agent...") + + # Initialize the agent based on the provided configuration + agent_config = AgentConfig(**config) + + # Use an in-state persistence manager + persistence_manager = InMemoryStateManager() + + # Create agent via preset from config + agent = presets.use_preset( + agent_config.preset, + agent_config, + agent_config.model, + utils.get_persona_text(agent_config.persona), + utils.get_human_text(agent_config.human), + self.interface, + persistence_manager, + ) + print("Created new agent from config") + + return agent + + def load_agent(self, agent_name): + """Load an agent from a directory""" + import memgpt.utils as utils + from memgpt.config import AgentConfig + from memgpt.agent import Agent + + print(f"Loading agent {agent_name}...") + + agent_files = utils.list_agent_config_files() + agent_names = [AgentConfig.load(f).name for f in agent_files] + + if agent_name not in agent_names: + raise ValueError(f"agent '{agent_name}' does not exist") + + agent_config = AgentConfig.load(agent_name) + agent = Agent.load_agent(self.interface, agent_config) + print("Created agent by loading existing config") + + return agent + + def initialize_server(self): + print("Server is initializing...") + print(f"Listening on {self.host}:{self.port}...") + + async def start_server(self): + self.initialize_server() + async with websockets.serve(self.handle_client, self.host, self.port): + await asyncio.Future() # Run forever + + def run(self): + return self.start_server() # Return the coroutine + + +if __name__ == "__main__": + server = WebSocketServer() + asyncio.run(server.run()) diff --git a/poetry.lock b/poetry.lock index 6220ab9c..f48f7354 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3221,6 +3221,87 @@ files = [ {file = "wcwidth-0.2.9.tar.gz", hash = "sha256:a675d1a4a2d24ef67096a04b85b02deeecd8e226f57b5e3a72dbb9ed99d27da8"}, ] +[[package]] +name = "websockets" +version = "12.0" +description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "websockets-12.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d554236b2a2006e0ce16315c16eaa0d628dab009c33b63ea03f41c6107958374"}, + {file = "websockets-12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2d225bb6886591b1746b17c0573e29804619c8f755b5598d875bb4235ea639be"}, + {file = "websockets-12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:eb809e816916a3b210bed3c82fb88eaf16e8afcf9c115ebb2bacede1797d2547"}, + {file = "websockets-12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c588f6abc13f78a67044c6b1273a99e1cf31038ad51815b3b016ce699f0d75c2"}, + {file = "websockets-12.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5aa9348186d79a5f232115ed3fa9020eab66d6c3437d72f9d2c8ac0c6858c558"}, + {file = "websockets-12.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6350b14a40c95ddd53e775dbdbbbc59b124a5c8ecd6fbb09c2e52029f7a9f480"}, + {file = "websockets-12.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:70ec754cc2a769bcd218ed8d7209055667b30860ffecb8633a834dde27d6307c"}, + {file = "websockets-12.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6e96f5ed1b83a8ddb07909b45bd94833b0710f738115751cdaa9da1fb0cb66e8"}, + {file = "websockets-12.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4d87be612cbef86f994178d5186add3d94e9f31cc3cb499a0482b866ec477603"}, + {file = "websockets-12.0-cp310-cp310-win32.whl", hash = "sha256:befe90632d66caaf72e8b2ed4d7f02b348913813c8b0a32fae1cc5fe3730902f"}, + {file = "websockets-12.0-cp310-cp310-win_amd64.whl", hash = "sha256:363f57ca8bc8576195d0540c648aa58ac18cf85b76ad5202b9f976918f4219cf"}, + {file = "websockets-12.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5d873c7de42dea355d73f170be0f23788cf3fa9f7bed718fd2830eefedce01b4"}, + {file = "websockets-12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3f61726cae9f65b872502ff3c1496abc93ffbe31b278455c418492016e2afc8f"}, + {file = "websockets-12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ed2fcf7a07334c77fc8a230755c2209223a7cc44fc27597729b8ef5425aa61a3"}, + {file = "websockets-12.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e332c210b14b57904869ca9f9bf4ca32f5427a03eeb625da9b616c85a3a506c"}, + {file = "websockets-12.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5693ef74233122f8ebab026817b1b37fe25c411ecfca084b29bc7d6efc548f45"}, + {file = "websockets-12.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e9e7db18b4539a29cc5ad8c8b252738a30e2b13f033c2d6e9d0549b45841c04"}, + {file = "websockets-12.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6e2df67b8014767d0f785baa98393725739287684b9f8d8a1001eb2839031447"}, + {file = "websockets-12.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:bea88d71630c5900690fcb03161ab18f8f244805c59e2e0dc4ffadae0a7ee0ca"}, + {file = "websockets-12.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dff6cdf35e31d1315790149fee351f9e52978130cef6c87c4b6c9b3baf78bc53"}, + {file = "websockets-12.0-cp311-cp311-win32.whl", hash = "sha256:3e3aa8c468af01d70332a382350ee95f6986db479ce7af14d5e81ec52aa2b402"}, + {file = "websockets-12.0-cp311-cp311-win_amd64.whl", hash = "sha256:25eb766c8ad27da0f79420b2af4b85d29914ba0edf69f547cc4f06ca6f1d403b"}, + {file = "websockets-12.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0e6e2711d5a8e6e482cacb927a49a3d432345dfe7dea8ace7b5790df5932e4df"}, + {file = "websockets-12.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:dbcf72a37f0b3316e993e13ecf32f10c0e1259c28ffd0a85cee26e8549595fbc"}, + {file = "websockets-12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:12743ab88ab2af1d17dd4acb4645677cb7063ef4db93abffbf164218a5d54c6b"}, + {file = "websockets-12.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b645f491f3c48d3f8a00d1fce07445fab7347fec54a3e65f0725d730d5b99cb"}, + {file = "websockets-12.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9893d1aa45a7f8b3bc4510f6ccf8db8c3b62120917af15e3de247f0780294b92"}, + {file = "websockets-12.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f38a7b376117ef7aff996e737583172bdf535932c9ca021746573bce40165ed"}, + {file = "websockets-12.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:f764ba54e33daf20e167915edc443b6f88956f37fb606449b4a5b10ba42235a5"}, + {file = "websockets-12.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:1e4b3f8ea6a9cfa8be8484c9221ec0257508e3a1ec43c36acdefb2a9c3b00aa2"}, + {file = "websockets-12.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9fdf06fd06c32205a07e47328ab49c40fc1407cdec801d698a7c41167ea45113"}, + {file = "websockets-12.0-cp312-cp312-win32.whl", hash = "sha256:baa386875b70cbd81798fa9f71be689c1bf484f65fd6fb08d051a0ee4e79924d"}, + {file = "websockets-12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ae0a5da8f35a5be197f328d4727dbcfafa53d1824fac3d96cdd3a642fe09394f"}, + {file = "websockets-12.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5f6ffe2c6598f7f7207eef9a1228b6f5c818f9f4d53ee920aacd35cec8110438"}, + {file = "websockets-12.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9edf3fc590cc2ec20dc9d7a45108b5bbaf21c0d89f9fd3fd1685e223771dc0b2"}, + {file = "websockets-12.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8572132c7be52632201a35f5e08348137f658e5ffd21f51f94572ca6c05ea81d"}, + {file = "websockets-12.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:604428d1b87edbf02b233e2c207d7d528460fa978f9e391bd8aaf9c8311de137"}, + {file = "websockets-12.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a9d160fd080c6285e202327aba140fc9a0d910b09e423afff4ae5cbbf1c7205"}, + {file = "websockets-12.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87b4aafed34653e465eb77b7c93ef058516cb5acf3eb21e42f33928616172def"}, + {file = "websockets-12.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b2ee7288b85959797970114deae81ab41b731f19ebcd3bd499ae9ca0e3f1d2c8"}, + {file = "websockets-12.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:7fa3d25e81bfe6a89718e9791128398a50dec6d57faf23770787ff441d851967"}, + {file = "websockets-12.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a571f035a47212288e3b3519944f6bf4ac7bc7553243e41eac50dd48552b6df7"}, + {file = "websockets-12.0-cp38-cp38-win32.whl", hash = "sha256:3c6cc1360c10c17463aadd29dd3af332d4a1adaa8796f6b0e9f9df1fdb0bad62"}, + {file = "websockets-12.0-cp38-cp38-win_amd64.whl", hash = "sha256:1bf386089178ea69d720f8db6199a0504a406209a0fc23e603b27b300fdd6892"}, + {file = "websockets-12.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ab3d732ad50a4fbd04a4490ef08acd0517b6ae6b77eb967251f4c263011a990d"}, + {file = "websockets-12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a1d9697f3337a89691e3bd8dc56dea45a6f6d975f92e7d5f773bc715c15dde28"}, + {file = "websockets-12.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1df2fbd2c8a98d38a66f5238484405b8d1d16f929bb7a33ed73e4801222a6f53"}, + {file = "websockets-12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23509452b3bc38e3a057382c2e941d5ac2e01e251acce7adc74011d7d8de434c"}, + {file = "websockets-12.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e5fc14ec6ea568200ea4ef46545073da81900a2b67b3e666f04adf53ad452ec"}, + {file = "websockets-12.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46e71dbbd12850224243f5d2aeec90f0aaa0f2dde5aeeb8fc8df21e04d99eff9"}, + {file = "websockets-12.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b81f90dcc6c85a9b7f29873beb56c94c85d6f0dac2ea8b60d995bd18bf3e2aae"}, + {file = "websockets-12.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:a02413bc474feda2849c59ed2dfb2cddb4cd3d2f03a2fedec51d6e959d9b608b"}, + {file = "websockets-12.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:bbe6013f9f791944ed31ca08b077e26249309639313fff132bfbf3ba105673b9"}, + {file = "websockets-12.0-cp39-cp39-win32.whl", hash = "sha256:cbe83a6bbdf207ff0541de01e11904827540aa069293696dd528a6640bd6a5f6"}, + {file = "websockets-12.0-cp39-cp39-win_amd64.whl", hash = "sha256:fc4e7fa5414512b481a2483775a8e8be7803a35b30ca805afa4998a84f9fd9e8"}, + {file = "websockets-12.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:248d8e2446e13c1d4326e0a6a4e9629cb13a11195051a73acf414812700badbd"}, + {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f44069528d45a933997a6fef143030d8ca8042f0dfaad753e2906398290e2870"}, + {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c4e37d36f0d19f0a4413d3e18c0d03d0c268ada2061868c1e6f5ab1a6d575077"}, + {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d829f975fc2e527a3ef2f9c8f25e553eb7bc779c6665e8e1d52aa22800bb38b"}, + {file = "websockets-12.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2c71bd45a777433dd9113847af751aae36e448bc6b8c361a566cb043eda6ec30"}, + {file = "websockets-12.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0bee75f400895aef54157b36ed6d3b308fcab62e5260703add87f44cee9c82a6"}, + {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:423fc1ed29f7512fceb727e2d2aecb952c46aa34895e9ed96071821309951123"}, + {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:27a5e9964ef509016759f2ef3f2c1e13f403725a5e6a1775555994966a66e931"}, + {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3181df4583c4d3994d31fb235dc681d2aaad744fbdbf94c4802485ececdecf2"}, + {file = "websockets-12.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:b067cb952ce8bf40115f6c19f478dc71c5e719b7fbaa511359795dfd9d1a6468"}, + {file = "websockets-12.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:00700340c6c7ab788f176d118775202aadea7602c5cc6be6ae127761c16d6b0b"}, + {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e469d01137942849cff40517c97a30a93ae79917752b34029f0ec72df6b46399"}, + {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffefa1374cd508d633646d51a8e9277763a9b78ae71324183693959cf94635a7"}, + {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba0cab91b3956dfa9f512147860783a1829a8d905ee218a9837c18f683239611"}, + {file = "websockets-12.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2cb388a5bfb56df4d9a406783b7f9dbefb888c09b71629351cc6b036e9259370"}, + {file = "websockets-12.0-py3-none-any.whl", hash = "sha256:dc284bbc8d7c78a6c69e0c7325ab46ee5e40bb4d50e494d8131a07ef47500e9e"}, + {file = "websockets-12.0.tar.gz", hash = "sha256:81df9cbcbb6c260de1e007e58c011bfebe2dafc8435107b0537f393dd38c8b1b"}, +] + [[package]] name = "wheel" version = "0.41.3" @@ -3527,4 +3608,4 @@ postgres = ["pg8000", "pgvector", "psycopg", "psycopg-binary", "psycopg2-binary" [metadata] lock-version = "2.0" python-versions = "<3.12,>=3.9" -content-hash = "24e6c3cea1895441e07d362a5a2f9a07a045b92b5364531b8b6e3571904199fe" +content-hash = "0fa0b65ce00550c139abcf5b4134e9e5b19b277930782ffe8421afec9d2743e2" diff --git a/pyproject.toml b/pyproject.toml index 2b65902c..d831677d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ transformers = { version = "4.34.1", optional = true } pre-commit = {version = "^3.5.0", optional = true } pg8000 = {version = "^1.30.3", optional = true} torch = {version = ">=2.0.0, !=2.0.1, !=2.1.0", optional = true} +websockets = "^12.0" docstring-parser = "^0.15" [tool.poetry.extras] diff --git a/tests/test_websocket_interface.py b/tests/test_websocket_interface.py new file mode 100644 index 00000000..30b7d9f8 --- /dev/null +++ b/tests/test_websocket_interface.py @@ -0,0 +1,111 @@ +import argparse +import os +import subprocess +import sys + +import pytest +from unittest.mock import Mock, AsyncMock, MagicMock + +from memgpt.config import MemGPTConfig, AgentConfig +from memgpt.server.websocket_interface import SyncWebSocketInterface +import memgpt.presets as presets +import memgpt.personas.personas as personas +import memgpt.humans.humans as humans +import memgpt.system as system +from memgpt.persistence_manager import InMemoryStateManager + + +# def test_websockets(): +# # Create the websocket interface +# ws_interface = WebSocketInterface() + +# # Create a dummy persistence manager +# persistence_manager = InMemoryStateManager() + +# # Create an agent and hook it up to the WebSocket interface +# memgpt_agent = presets.use_preset( +# presets.DEFAULT_PRESET, +# None, # no agent config to provide +# "gpt-4-1106-preview", +# personas.get_persona_text("sam_pov"), +# humans.get_human_text("chad"), +# ws_interface, +# persistence_manager, +# ) + +# user_message = system.package_user_message("Hello, is anyone there?") + +# # This should trigger calls to interface user_message and others +# memgpt_agent.step(user_message=user_message) + +# # This should trigger the web socket to send over a +# ws_interface.print_messages(memgpt_agent.messages) + + +@pytest.mark.asyncio +async def test_dummy(): + assert True + + +@pytest.mark.asyncio +async def test_websockets(): + # Mock a WebSocket connection + mock_websocket = AsyncMock() + # mock_websocket = Mock() + + # Create the WebSocket interface with the mocked WebSocket + ws_interface = SyncWebSocketInterface() + + # Register the mock websocket as a client + ws_interface.register_client(mock_websocket) + + # Mock the persistence manager + persistence_manager = InMemoryStateManager() + + # Create an agent and hook it up to the WebSocket interface + config = MemGPTConfig() + memgpt_agent = presets.use_preset( + presets.DEFAULT_PRESET, + config, # no agent config to provide + "gpt-4-1106-preview", + personas.get_persona_text("sam_pov"), + humans.get_human_text("basic"), + ws_interface, + persistence_manager, + ) + + # Mock the user message packaging + user_message = system.package_user_message("Hello, is anyone there?") + + # Mock the agent's step method + # agent_step = AsyncMock() + # memgpt_agent.step = agent_step + + # Call the step method, which should trigger interface methods + ret = memgpt_agent.step(user_message=user_message, first_message=True, skip_verify=True) + print("ret\n") + print(ret) + + # Print what the WebSocket received + print("client\n") + for call in mock_websocket.send.mock_calls: + # print(call) + _, args, kwargs = call + # args will be a tuple of positional arguments sent to the send method + # kwargs will be a dictionary of keyword arguments sent to the send method + print(f"Sent data: {args[0] if args else None}") + # If you're using keyword arguments, you can print them out as well: + # print(f"Sent data with kwargs: {kwargs}") + + # This is required for the Sync wrapper version + ws_interface.close() + + # Assertions to ensure the step method was called + # agent_step.assert_called_once() + + # Assertions to ensure the WebSocket interface methods are called + # You would need to implement the logic to verify that methods like ws_interface.user_message are called + # This will require you to have some mechanism within your WebSocketInterface to track these calls + + +# await test_websockets() diff --git a/tests/test_websocket_server.py b/tests/test_websocket_server.py new file mode 100644 index 00000000..c8da0186 --- /dev/null +++ b/tests/test_websocket_server.py @@ -0,0 +1,50 @@ +import asyncio +import json + +import websockets +import pytest + +from memgpt.server.constants import DEFAULT_PORT +from memgpt.server.websocket_server import WebSocketServer +from memgpt.config import AgentConfig + + +@pytest.mark.asyncio +async def test_dummy(): + assert True + + +@pytest.mark.asyncio +async def test_websocket_server(): + server = WebSocketServer() + server_task = asyncio.create_task(server.run()) # Create a task for the server + + # the agent config we want to ask the server to instantiate with + # test_config = AgentConfig( + # persona="sam_pov", + # human="cs_phd", + # preset="memgpt_chat", + # model_endpoint= + # ) + test_config = {} + + uri = f"ws://localhost:{DEFAULT_PORT}" + async with websockets.connect(uri) as websocket: + # Initialize the server with a test config + print("Sending config to server...") + await websocket.send(json.dumps({"type": "initialize", "config": test_config})) + # Wait for the response + response = await websocket.recv() + print(f"Response from the agent: {response}") + + await asyncio.sleep(1) # just in case + + # Send a message to the agent + print("Sending message to server...") + await websocket.send(json.dumps({"type": "message", "content": "Hello, Agent!"})) + # Wait for the response + # NOTE: we should be waiting for multiple responses + response = await websocket.recv() + print(f"Response from the agent: {response}") + + server_task.cancel() # Cancel the server task after the test