diff --git a/memgpt/server/utils.py b/memgpt/server/utils.py index e6e4371c..cc444166 100644 --- a/memgpt/server/utils.py +++ b/memgpt/server/utils.py @@ -1,6 +1,9 @@ def condition_to_stop_receiving(response): """Determines when to stop listening to the server""" - return response.get("type") == "agent_response_end" + if response.get("type") in ["agent_response_end", "agent_response_error", "command_response", "server_error"]: + return True + else: + return False def print_server_response(response): diff --git a/memgpt/server/websocket_protocol.py b/memgpt/server/websocket_protocol.py index 7c39f810..8c8d3ecb 100644 --- a/memgpt/server/websocket_protocol.py +++ b/memgpt/server/websocket_protocol.py @@ -80,11 +80,12 @@ def server_agent_function_message(msg): # Client -> server -def client_user_message(msg): +def client_user_message(msg, agent_name=None): return json.dumps( { "type": "user_message", "message": msg, + "agent_name": agent_name, } ) diff --git a/memgpt/server/websocket_server.py b/memgpt/server/websocket_server.py index 07f6b8c6..4c92dc74 100644 --- a/memgpt/server/websocket_server.py +++ b/memgpt/server/websocket_server.py @@ -1,5 +1,6 @@ import asyncio import json +import traceback import websockets @@ -15,7 +16,9 @@ class WebSocketServer: self.host = host self.port = port self.interface = SyncWebSocketInterface() + self.agent = None + self.agent_name = None def run_step(self, user_message, first_message=False, no_verify=False): while True: @@ -41,9 +44,18 @@ class WebSocketServer: message = await websocket.recv() # Assuming the message is a JSON string - data = json.loads(message) + try: + data = json.loads(message) + except: + print(f"[server] bad data from client:\n{data}") + await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}")) + continue - if data["type"] == "command": + if "type" not in data: + print(f"[server] bad data from client (JSON but no type):\n{data}") + await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}")) + + elif data["type"] == "command": # Create a new agent if data["command"] == "create_agent": try: @@ -51,6 +63,8 @@ class WebSocketServer: await websocket.send(protocol.server_command_response("OK: Agent initialized")) except Exception as e: self.agent = None + print(f"[server] self.create_new_agent failed with:\n{e}") + print(f"{traceback.format_exc()}") await websocket.send(protocol.server_command_response(f"Error: Failed to init agent - {str(e)}")) # Load an existing agent @@ -59,9 +73,11 @@ class WebSocketServer: if agent_name is not None: try: self.agent = self.load_agent(agent_name) + self.agent_name = agent_name await websocket.send(protocol.server_command_response(f"OK: Agent '{agent_name}' loaded")) except Exception as e: print(f"[server] self.load_agent failed with:\n{e}") + print(f"{traceback.format_exc()}") self.agent = None await websocket.send( protocol.server_command_response(f"Error: Failed to load agent '{agent_name}' - {str(e)}") @@ -76,6 +92,26 @@ class WebSocketServer: elif data["type"] == "user_message": user_message = data["message"] + if "agent_name" in data: + agent_name = data["agent_name"] + # If the agent requested the same one that's already loading? + if self.agent_name is None or self.agent_name != data["agent_name"]: + try: + print(f"[server] loading agent {agent_name}") + self.agent = self.load_agent(agent_name) + self.agent_name = agent_name + # await websocket.send(protocol.server_command_response(f"OK: Agent '{agent_name}' loaded")) + except Exception as e: + print(f"[server] self.load_agent failed with:\n{e}") + print(f"{traceback.format_exc()}") + self.agent = None + await websocket.send( + protocol.server_command_response(f"Error: Failed to load agent '{agent_name}' - {str(e)}") + ) + else: + await websocket.send(protocol.server_agent_response_error("agent_name was not specified in the request")) + continue + if self.agent is None: await websocket.send(protocol.server_agent_response_error("No agent has been initialized")) else: @@ -84,6 +120,7 @@ class WebSocketServer: self.run_step(user_message) except Exception as e: print(f"[server] self.run_step failed with:\n{e}") + print(f"{traceback.format_exc()}") await websocket.send(protocol.server_agent_response_error(f"self.run_step failed with: {e}")) await asyncio.sleep(1) # pause before sending the terminating message, w/o this messages may be missed