diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 15107b51..01967b26 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -42,7 +42,7 @@ jobs: PGVECTOR_TEST_DB_URL: ${{ secrets.PGVECTOR_TEST_DB_URL }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | - poetry install -E dev -E postgres -E local + poetry install -E dev -E postgres -E local -E server - name: Set Poetry config env: diff --git a/memgpt/server/rest_api/__init__.py b/memgpt/server/rest_api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/memgpt/server/rest_api/interface.py b/memgpt/server/rest_api/interface.py new file mode 100644 index 00000000..55ccb525 --- /dev/null +++ b/memgpt/server/rest_api/interface.py @@ -0,0 +1,75 @@ +import asyncio +import queue + +from memgpt.interface import AgentInterface + + +class QueuingInterface(AgentInterface): + """Messages are queued inside an internal buffer and manually flushed""" + + def __init__(self): + self.buffer = queue.Queue() + + def to_list(self): + """Convert queue to a list (empties it out at the same time)""" + items = [] + while not self.buffer.empty(): + try: + items.append(self.buffer.get_nowait()) + except queue.Empty: + break + return items + + def clear(self): + """Clear all messages from the queue.""" + with self.buffer.mutex: + # Empty the queue + self.buffer.queue.clear() + + async def message_generator(self): + while True: + if not self.buffer.empty(): + message = self.buffer.get() + if message == "STOP": + break + yield message + else: + await asyncio.sleep(0.1) # Small sleep to prevent a busy loop + + def step_yield(self): + """Enqueue a special stop message""" + self.buffer.put("STOP") + + def user_message(self, msg: str): + """Handle reception of a user message""" + pass + + def internal_monologue(self, msg: str) -> None: + """Handle the agent's internal monologue""" + print(msg) + self.buffer.put({"internal_monologue": msg}) + + def assistant_message(self, msg: str) -> None: + """Handle the agent sending a message""" + print(msg) + self.buffer.put({"assistant_message": msg}) + + def function_message(self, msg: str) -> None: + """Handle the agent calling a function""" + print(msg) + + if msg.startswith("Running "): + msg = msg.replace("Running ", "") + self.buffer.put({"function_call": msg}) + + elif msg.startswith("Success: "): + msg = msg.replace("Success: ", "") + self.buffer.put({"function_return": msg, "status": "success"}) + + elif msg.startswith("Error: "): + msg = msg.replace("Error: ", "") + self.buffer.put({"function_return": msg, "status": "error"}) + + else: + # NOTE: generic, should not happen + self.buffer.put({"function_message": msg}) diff --git a/memgpt/server/rest_api/server.py b/memgpt/server/rest_api/server.py new file mode 100644 index 00000000..ca23e099 --- /dev/null +++ b/memgpt/server/rest_api/server.py @@ -0,0 +1,112 @@ +import asyncio +import json +from typing import Union + +from fastapi import FastAPI, HTTPException +from fastapi.responses import StreamingResponse +from pydantic import BaseModel + +from memgpt.server.server import SyncServer +from memgpt.server.rest_api.interface import QueuingInterface +import memgpt.utils as utils + + +""" +Basic REST API sitting on top of the internal MemGPT python server (SyncServer) + +Start the server with: + cd memgpt/server/rest_api + poetry run uvicorn server:app --reload +""" + + +class CreateAgentConfig(BaseModel): + user_id: str + config: dict + + +class UserMessage(BaseModel): + user_id: str + agent_id: str + message: str + stream: bool = False + + +class Command(BaseModel): + user_id: str + agent_id: str + command: str + + +app = FastAPI() +interface = QueuingInterface() +server = SyncServer(default_interface=interface) + + +# server.list_agents +@app.get("/agents") +def list_agents(user_id: str): + interface.clear() + agents_list = utils.list_agent_config_files() + return {"num_agents": len(agents_list), "agent_names": agents_list} + + +# server.create_agent +@app.post("/agents") +def create_agents(body: CreateAgentConfig): + interface.clear() + try: + agent_id = server.create_agent(user_id=body.user_id, agent_config=body.config) + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") + return {"agent_id": agent_id} + + +# server.user_message +@app.post("/agents/message") +async def user_message(body: UserMessage): + if body.stream: + # For streaming response + try: + # Start the generation process (similar to the non-streaming case) + # This should be a non-blocking call or run in a background task + + # Check if server.user_message is an async function + if asyncio.iscoroutinefunction(server.user_message): + # Start the async task + asyncio.create_task(server.user_message(user_id=body.user_id, agent_id=body.agent_id, message=body.message)) + else: + # Run the synchronous function in a thread pool + loop = asyncio.get_event_loop() + loop.run_in_executor(None, server.user_message, body.user_id, body.agent_id, body.message) + + async def formatted_message_generator(): + async for message in interface.message_generator(): + formatted_message = f"data: {json.dumps(message)}\n\n" + yield formatted_message + await asyncio.sleep(1) + + # Return the streaming response using the generator + return StreamingResponse(formatted_message_generator(), media_type="text/event-stream") + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") + + else: + interface.clear() + try: + server.user_message(user_id=body.user_id, agent_id=body.agent_id, message=body.message) + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") + return {"messages": interface.to_list()} + + +# server.run_command +@app.post("/agents/command") +def run_command(body: Command): + interface.clear() + try: + response = server.run_command(user_id=body.user_id, agent_id=body.agent_id, command=body.command) + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") + response = server.run_command(user_id=body.user_id, agent_id=body.agent_id, command=body.command) + return {"response": response} diff --git a/memgpt/server/server.py b/memgpt/server/server.py new file mode 100644 index 00000000..0de45e86 --- /dev/null +++ b/memgpt/server/server.py @@ -0,0 +1,365 @@ +from abc import abstractmethod +from typing import Union +import json + +from memgpt.system import package_user_message +from memgpt.config import AgentConfig +from memgpt.agent import Agent +import memgpt.system as system +import memgpt.constants as constants +from memgpt.cli.cli import attach +from memgpt.connectors.storage import StorageConnector +import memgpt.presets.presets as presets +import memgpt.utils as utils +from memgpt.persistence_manager import PersistenceManager, LocalStateManager + +# TODO use custom interface +from memgpt.interface import CLIInterface # for printing to terminal +from memgpt.interface import AgentInterface # abstract + + +class Server(object): + """Abstract server class that supports multi-agent multi-user""" + + @abstractmethod + def list_agents(self, user_id: str, agent_id: str) -> str: + """List all available agents to a user""" + raise NotImplementedError + + @abstractmethod + def create_agent( + self, + user_id: str, + agent_config: Union[dict, AgentConfig], + interface: Union[AgentInterface, None], + persistence_manager: Union[PersistenceManager, None], + ) -> str: + """Create a new agent using a config""" + raise NotImplementedError + + @abstractmethod + def user_message(self, user_id: str, agent_id: str, message: str) -> None: + """Process a message from the user, internally calls step""" + raise NotImplementedError + + @abstractmethod + def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]: + """Run a command on the agent, e.g. /memory + + May return a string with a message generated by the command + """ + raise NotImplementedError + + +# TODO actually use "user_id" for something +class SyncServer(Server): + """Simple single-threaded / blocking server process""" + + def __init__( + self, + chaining: bool = True, + max_chaining_steps: bool = None, + # default_interface_cls: AgentInterface = CLIInterface, + default_interface: AgentInterface = CLIInterface(), + default_persistence_manager_cls: PersistenceManager = LocalStateManager, + ): + """Server process holds in-memory agents that are being run""" + + # List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts + self.active_agents = [] + + # chaining = whether or not to run again if request_heartbeat=true + self.chaining = chaining + + # if chaining == true, what's the max number of times we'll chain before yielding? + # none = no limit, can go on forever + self.max_chaining_steps = max_chaining_steps + + # The default interface that will get assigned to agents ON LOAD + # self.default_interface_cls = default_interface_cls + self.default_interface = default_interface + + # The default persistence manager that will get assigned to agents ON CREATION + self.default_persistence_manager_cls = default_persistence_manager_cls + + def _get_agent(self, user_id: str, agent_id: str) -> Union[Agent, None]: + """Get the agent object from the in-memory object store""" + for d in self.active_agents: + if d["user_id"] == user_id and d["agent_id"] == agent_id: + return d["agent"] + return None + + def _add_agent(self, user_id: str, agent_id: str, agent_obj: Agent) -> None: + """Put an agent object inside the in-memory object store""" + # Make sure the agent doesn't already exist + if self._get_agent(user_id=user_id, agent_id=agent_id) is not None: + raise KeyError(f"Agent (user={user_id}, agent={agent_id}) is already loaded") + # Add Agent instance to the in-memory list + self.active_agents.append( + { + "user_id": user_id, + "agent_id": agent_id, + "agent": agent_obj, + } + ) + + def _load_agent(self, user_id: str, agent_id: str, interface: Union[AgentInterface, None] = None) -> Agent: + """Loads a saved agent into memory (if it doesn't exist, throw an error)""" + from memgpt.utils import printd + + # If an interface isn't specified, use the default + if interface is None: + interface = self.default_interface + + # If the agent isn't load it, load it and put it into memory + if AgentConfig.exists(agent_id): + printd(f"(user={user_id}, agent={agent_id}) exists, loading into memory...") + agent_config = AgentConfig.load(agent_id) + memgpt_agent = Agent.load_agent(interface=interface, agent_config=agent_config) + self._add_agent(user_id=user_id, agent_id=agent_id, agent_obj=memgpt_agent) + return memgpt_agent + + # If the agent doesn't exist, throw an error + else: + raise ValueError(f"agent_id {agent_id} does not exist") + + def _get_or_load_agent(self, user_id: str, agent_id: str) -> Agent: + """Check if the agent is in-memory, then load""" + memgpt_agent = self._get_agent(user_id=user_id, agent_id=agent_id) + if not memgpt_agent: + memgpt_agent = self._load_agent(user_id=user_id, agent_id=agent_id) + return memgpt_agent + + def _step(self, user_id: str, agent_id: str, input_message: str) -> None: + """Send the input message through the agent""" + from memgpt.utils import printd + + printd(f"Got input message: {input_message}") + + # Get the agent object (loaded in memory) + memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) + if memgpt_agent is None: + raise KeyError(f"Agent (user={user_id}, agent={agent_id}) is not loaded") + + printd(f"Starting agent step") + no_verify = True + next_input_message = input_message + counter = 0 + while True: + new_messages, heartbeat_request, function_failed, token_warning = memgpt_agent.step( + next_input_message, first_message=False, skip_verify=no_verify + ) + counter += 1 + + # Chain stops + if not self.chaining: + printd("No chaining, stopping after one step") + break + elif self.max_chaining_steps is not None and counter > self.max_chaining_steps: + printd(f"Hit max chaining steps, stopping after {counter} steps") + break + # Chain handlers + elif token_warning: + next_input_message = system.get_token_limit_warning() + continue # always chain + elif function_failed: + next_input_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE) + continue # always chain + elif heartbeat_request: + next_input_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE) + continue # always chain + # MemGPT no-op / yield + else: + break + + memgpt_agent.interface.step_yield() + printd(f"Finished agent step") + + def _command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]: + """Process a CLI command""" + from memgpt.utils import printd + + printd(f"Got command: {command}") + + # Get the agent object (loaded in memory) + memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) + + if command.lower() == "exit": + # exit not supported on server.py + raise ValueError(command) + + elif command.lower() == "save" or command.lower() == "savechat": + memgpt_agent.save() + + elif command.lower() == "attach": + # Different from CLI, we extract the data source name from the command + command = command.strip().split() + try: + data_source = int(command[1]) + except: + raise ValueError(command) + + # TODO: check if agent already has it + data_source_options = StorageConnector.list_loaded_data() + if len(data_source_options) == 0: + raise ValueError('No sources available. You must load a souce with "memgpt load ..." before running /attach.') + elif data_source not in data_source_options: + raise ValueError(f"Invalid data source name: {data_source} (options={data_source_options})") + else: + # attach new data + attach(memgpt_agent.config.name, data_source) + + # update agent config + memgpt_agent.config.attach_data_source(data_source) + + # reload agent with new data source + # TODO: maybe make this less ugly... + memgpt_agent.persistence_manager.archival_memory.storage = StorageConnector.get_storage_connector( + agent_config=memgpt_agent.config + ) + + elif command.lower() == "dump" or command.lower().startswith("dump "): + # Check if there's an additional argument that's an integer + command = command.strip().split() + amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0 + if amount == 0: + memgpt_agent.interface.print_messages(memgpt_agent.messages, dump=True) + else: + memgpt_agent.interface.print_messages(memgpt_agent.messages[-min(amount, len(memgpt_agent.messages)) :], dump=True) + + elif command.lower() == "dumpraw": + memgpt_agent.interface.print_messages_raw(memgpt_agent.messages) + + elif command.lower() == "memory": + ret_str = ( + f"\nDumping memory contents:\n" + + f"\n{str(memgpt_agent.memory)}" + + f"\n{str(memgpt_agent.persistence_manager.archival_memory)}" + + f"\n{str(memgpt_agent.persistence_manager.recall_memory)}" + ) + return ret_str + + elif command.lower() == "pop" or command.lower().startswith("pop "): + # Check if there's an additional argument that's an integer + command = command.strip().split() + pop_amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 3 + n_messages = len(memgpt_agent.messages) + MIN_MESSAGES = 2 + if n_messages <= MIN_MESSAGES: + print(f"Agent only has {n_messages} messages in stack, none left to pop") + elif n_messages - pop_amount < MIN_MESSAGES: + print(f"Agent only has {n_messages} messages in stack, cannot pop more than {n_messages - MIN_MESSAGES}") + else: + print(f"Popping last {pop_amount} messages from stack") + for _ in range(min(pop_amount, len(memgpt_agent.messages))): + memgpt_agent.messages.pop() + + elif command.lower() == "retry": + # TODO this needs to also modify the persistence manager + print(f"Retrying for another answer") + while len(memgpt_agent.messages) > 0: + if memgpt_agent.messages[-1].get("role") == "user": + # we want to pop up to the last user message and send it again + user_message = memgpt_agent.messages[-1].get("content") + memgpt_agent.messages.pop() + break + memgpt_agent.messages.pop() + + elif command.lower() == "rethink" or command.lower().startswith("rethink "): + # TODO this needs to also modify the persistence manager + if len(command) < len("rethink "): + print("Missing text after the command") + else: + for x in range(len(memgpt_agent.messages) - 1, 0, -1): + if memgpt_agent.messages[x].get("role") == "assistant": + text = command[len("rethink ") :].strip() + memgpt_agent.messages[x].update({"content": text}) + break + + elif command.lower() == "rewrite" or command.lower().startswith("rewrite "): + # TODO this needs to also modify the persistence manager + if len(command) < len("rewrite "): + print("Missing text after the command") + else: + for x in range(len(memgpt_agent.messages) - 1, 0, -1): + if memgpt_agent.messages[x].get("role") == "assistant": + text = command[len("rewrite ") :].strip() + args = json.loads(memgpt_agent.messages[x].get("function_call").get("arguments")) + args["message"] = text + memgpt_agent.messages[x].get("function_call").update({"arguments": json.dumps(args)}) + break + + # No skip options + elif command.lower() == "wipe": + # exit not supported on server.py + raise ValueError(command) + + elif command.lower() == "heartbeat": + input_message = system.get_heartbeat() + self._step(user_id=user_id, agent_id=agent_id, input_message=input_message) + + elif command.lower() == "memorywarning": + input_message = system.get_token_limit_warning() + self._step(user_id=user_id, agent_id=agent_id, input_message=input_message) + + def user_message(self, user_id: str, agent_id: str, message: str) -> None: + """Process an incoming user message and feed it through the MemGPT agent""" + from memgpt.utils import printd + + # Basic input sanitization + if not isinstance(message, str) or len(message) == 0: + raise ValueError(f"Invalid input: '{message}'") + + # If the input begins with a command prefix, reject + elif message.startswith("/"): + raise ValueError(f"Invalid input: '{message}'") + + # Else, process it as a user message to be fed to the agent + else: + # Package the user message first + packaged_user_message = package_user_message(user_message=message) + # Run the agent state forward + self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message) + + def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]: + """Run a command on the agent""" + # If the input begins with a command prefix, attempt to process it as a command + if command.startswith("/"): + if len(command) > 1: + command = command[1:] # strip the prefix + return self._command(user_id=user_id, agent_id=agent_id, command=command) + + def create_agent( + self, + user_id: str, + agent_config: Union[dict, AgentConfig], + interface: Union[AgentInterface, None] = None, + persistence_manager: Union[PersistenceManager, None] = None, + ) -> str: + """Create a new agent using a config""" + + # Initialize the agent based on the provided configuration + if isinstance(agent_config, dict): + agent_config = AgentConfig(**agent_config) + + if interface is None: + # interface = self.default_interface_cls() + interface = self.default_interface + + if persistence_manager is None: + persistence_manager = self.default_persistence_manager_cls(agent_config=agent_config) + + # Create agent via preset from config + agent = presets.use_preset( + agent_config.preset, + agent_config, + agent_config.model, + utils.get_persona_text(agent_config.persona), + utils.get_human_text(agent_config.human), + interface, + persistence_manager, + ) + agent.save() + print(f"Created new agent from config: {agent}") + + return agent.config.name diff --git a/memgpt/server/websocket_client.py b/memgpt/server/websocket_client.py deleted file mode 100644 index 32020470..00000000 --- a/memgpt/server/websocket_client.py +++ /dev/null @@ -1,80 +0,0 @@ -import asyncio -import json - -import websockets - -import memgpt.server.websocket_protocol as protocol -from memgpt.server.websocket_server import WebSocketServer -from memgpt.server.constants import DEFAULT_PORT, CLIENT_TIMEOUT -from memgpt.server.utils import condition_to_stop_receiving, print_server_response - - -# CLEAN_RESPONSES = False # print the raw server responses (JSON) -CLEAN_RESPONSES = True # make the server responses cleaner - -# LOAD_AGENT = None # create a brand new agent -LOAD_AGENT = "agent_26" # load an existing agent - - -async def basic_cli_client(): - """Basic example of a MemGPT CLI client that connects to a MemGPT server.py process via WebSockets - - Meant to illustrate how to use the server.py process, so limited in features (only supports sending user messages) - """ - uri = f"ws://localhost:{DEFAULT_PORT}" - - async with websockets.connect(uri) as websocket: - if LOAD_AGENT is not None: - # Load existing agent - print("Sending load message to server...") - await websocket.send(protocol.client_command_load(LOAD_AGENT)) - - else: - # Initialize new agent - print("Sending config to server...") - example_config = { - "persona": "sam_pov", - "human": "cs_phd", - "model": "gpt-4-1106-preview", # gpt-4-turbo - } - await websocket.send(protocol.client_command_create(example_config)) - # Wait for the response - response = await websocket.recv() - response = json.loads(response) - print(f"Server response:\n{json.dumps(response, indent=2)}") - - await asyncio.sleep(1) - - while True: - user_input = input("\nEnter your message: ") - print("\n") - - # Send a message to the agent - await websocket.send(protocol.client_user_message(str(user_input))) - - # Wait for messages in a loop, since the server may send a few - while True: - try: - response = await asyncio.wait_for(websocket.recv(), CLIENT_TIMEOUT) - response = json.loads(response) - - if CLEAN_RESPONSES: - print_server_response(response) - else: - print(f"Server response:\n{json.dumps(response, indent=2)}") - - # Check for a specific condition to break the loop - if condition_to_stop_receiving(response): - break - except asyncio.TimeoutError: - print("Timeout waiting for the server response.") - break - except websockets.exceptions.ConnectionClosedError: - print("Connection to server was lost.") - break - except Exception as e: - print(f"An error occurred: {e}") - break - - -asyncio.run(basic_cli_client()) diff --git a/memgpt/server/websocket_server.py b/memgpt/server/websocket_server.py deleted file mode 100644 index 4c92dc74..00000000 --- a/memgpt/server/websocket_server.py +++ /dev/null @@ -1,205 +0,0 @@ -import asyncio -import json -import traceback - -import websockets - -from memgpt.server.websocket_interface import SyncWebSocketInterface -from memgpt.server.constants import DEFAULT_PORT -import memgpt.server.websocket_protocol as protocol -import memgpt.system as system -import memgpt.constants as memgpt_constants - - -class WebSocketServer: - def __init__(self, host="localhost", port=DEFAULT_PORT): - self.host = host - self.port = port - self.interface = SyncWebSocketInterface() - - self.agent = None - self.agent_name = None - - def run_step(self, user_message, first_message=False, no_verify=False): - while True: - new_messages, heartbeat_request, function_failed, token_warning = self.agent.step( - user_message, first_message=first_message, skip_verify=no_verify - ) - - if token_warning: - user_message = system.get_token_limit_warning() - elif function_failed: - user_message = system.get_heartbeat(memgpt_constants.FUNC_FAILED_HEARTBEAT_MESSAGE) - elif heartbeat_request: - user_message = system.get_heartbeat(memgpt_constants.REQ_HEARTBEAT_MESSAGE) - else: - # return control - break - - async def handle_client(self, websocket, path): - self.interface.register_client(websocket) - try: - # async for message in websocket: - while True: - message = await websocket.recv() - - # Assuming the message is a JSON string - try: - data = json.loads(message) - except: - print(f"[server] bad data from client:\n{data}") - await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}")) - continue - - if "type" not in data: - print(f"[server] bad data from client (JSON but no type):\n{data}") - await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}")) - - elif data["type"] == "command": - # Create a new agent - if data["command"] == "create_agent": - try: - self.agent = self.create_new_agent(data["config"]) - await websocket.send(protocol.server_command_response("OK: Agent initialized")) - except Exception as e: - self.agent = None - print(f"[server] self.create_new_agent failed with:\n{e}") - print(f"{traceback.format_exc()}") - await websocket.send(protocol.server_command_response(f"Error: Failed to init agent - {str(e)}")) - - # Load an existing agent - elif data["command"] == "load_agent": - agent_name = data.get("name") - if agent_name is not None: - try: - self.agent = self.load_agent(agent_name) - self.agent_name = agent_name - await websocket.send(protocol.server_command_response(f"OK: Agent '{agent_name}' loaded")) - except Exception as e: - print(f"[server] self.load_agent failed with:\n{e}") - print(f"{traceback.format_exc()}") - self.agent = None - await websocket.send( - protocol.server_command_response(f"Error: Failed to load agent '{agent_name}' - {str(e)}") - ) - else: - await websocket.send(protocol.server_command_response(f"Error: 'name' not provided")) - - else: - print(f"[server] unrecognized client command type: {data}") - await websocket.send(protocol.server_error(f"unrecognized client command type: {data}")) - - elif data["type"] == "user_message": - user_message = data["message"] - - if "agent_name" in data: - agent_name = data["agent_name"] - # If the agent requested the same one that's already loading? - if self.agent_name is None or self.agent_name != data["agent_name"]: - try: - print(f"[server] loading agent {agent_name}") - self.agent = self.load_agent(agent_name) - self.agent_name = agent_name - # await websocket.send(protocol.server_command_response(f"OK: Agent '{agent_name}' loaded")) - except Exception as e: - print(f"[server] self.load_agent failed with:\n{e}") - print(f"{traceback.format_exc()}") - self.agent = None - await websocket.send( - protocol.server_command_response(f"Error: Failed to load agent '{agent_name}' - {str(e)}") - ) - else: - await websocket.send(protocol.server_agent_response_error("agent_name was not specified in the request")) - continue - - if self.agent is None: - await websocket.send(protocol.server_agent_response_error("No agent has been initialized")) - else: - await websocket.send(protocol.server_agent_response_start()) - try: - self.run_step(user_message) - except Exception as e: - print(f"[server] self.run_step failed with:\n{e}") - print(f"{traceback.format_exc()}") - await websocket.send(protocol.server_agent_response_error(f"self.run_step failed with: {e}")) - - await asyncio.sleep(1) # pause before sending the terminating message, w/o this messages may be missed - await websocket.send(protocol.server_agent_response_end()) - - # ... handle other message types as needed ... - else: - print(f"[server] unrecognized client package data type: {data}") - await websocket.send(protocol.server_error(f"unrecognized client package data type: {data}")) - - except websockets.exceptions.ConnectionClosed: - print(f"[server] connection with client was closed") - finally: - # TODO autosave the agent - - self.interface.unregister_client(websocket) - - def create_new_agent(self, config): - """Config is json that arrived over websocket, so we need to turn it into a config object""" - from memgpt.config import AgentConfig - import memgpt.presets.presets as presets - import memgpt.utils as utils - from memgpt.persistence_manager import InMemoryStateManager - - print("Creating new agent...") - - # Initialize the agent based on the provided configuration - agent_config = AgentConfig(**config) - - # Use an in-state persistence manager - persistence_manager = InMemoryStateManager() - - # Create agent via preset from config - agent = presets.use_preset( - agent_config.preset, - agent_config, - agent_config.model, - utils.get_persona_text(agent_config.persona), - utils.get_human_text(agent_config.human), - self.interface, - persistence_manager, - ) - print("Created new agent from config") - - return agent - - def load_agent(self, agent_name): - """Load an agent from a directory""" - import memgpt.utils as utils - from memgpt.config import AgentConfig - from memgpt.agent import Agent - - print(f"Loading agent {agent_name}...") - - agent_files = utils.list_agent_config_files() - agent_names = [AgentConfig.load(f).name for f in agent_files] - - if agent_name not in agent_names: - raise ValueError(f"agent '{agent_name}' does not exist") - - agent_config = AgentConfig.load(agent_name) - agent = Agent.load_agent(self.interface, agent_config) - print("Created agent by loading existing config") - - return agent - - def initialize_server(self): - print("Server is initializing...") - print(f"Listening on {self.host}:{self.port}...") - - async def start_server(self): - self.initialize_server() - async with websockets.serve(self.handle_client, self.host, self.port): - await asyncio.Future() # Run forever - - def run(self): - return self.start_server() # Return the coroutine - - -if __name__ == "__main__": - server = WebSocketServer() - asyncio.run(server.run()) diff --git a/memgpt/server/ws_api/__init__.py b/memgpt/server/ws_api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/memgpt/server/ws_api/example_client.py b/memgpt/server/ws_api/example_client.py new file mode 100644 index 00000000..e8b4d393 --- /dev/null +++ b/memgpt/server/ws_api/example_client.py @@ -0,0 +1,106 @@ +import asyncio +import json + +import websockets + +import memgpt.server.ws_api.protocol as protocol +from memgpt.server.constants import DEFAULT_PORT, CLIENT_TIMEOUT +from memgpt.server.utils import condition_to_stop_receiving, print_server_response + + +# CLEAN_RESPONSES = False # print the raw server responses (JSON) +CLEAN_RESPONSES = True # make the server responses cleaner + +# LOAD_AGENT = None # create a brand new agent +AGENT_NAME = "agent_26" # load an existing agent +NEW_AGENT = False + +RECONNECT_DELAY = 1 +RECONNECT_MAX_TRIES = 5 + + +async def send_message_and_print_replies(websocket, user_message, agent_id): + """Send a message over websocket protocol and wait for the reply stream to end""" + # Send a message to the agent + await websocket.send(protocol.client_user_message(msg=str(user_message), agent_id=agent_id)) + + # Wait for messages in a loop, since the server may send a few + while True: + response = await asyncio.wait_for(websocket.recv(), CLIENT_TIMEOUT) + response = json.loads(response) + + if CLEAN_RESPONSES: + print_server_response(response) + else: + print(f"Server response:\n{json.dumps(response, indent=2)}") + + # Check for a specific condition to break the loop + if condition_to_stop_receiving(response): + break + + +async def basic_cli_client(): + """Basic example of a MemGPT CLI client that connects to a MemGPT server.py process via WebSockets + + Meant to illustrate how to use the server.py process, so limited in features (only supports sending user messages) + """ + uri = f"ws://localhost:{DEFAULT_PORT}" + + closed_on_message = False + retry_attempts = 0 + while True: # Outer loop for reconnection attempts + try: + async with websockets.connect(uri) as websocket: + if NEW_AGENT: + # Initialize new agent + print("Sending config to server...") + example_config = { + "persona": "sam_pov", + "human": "cs_phd", + "model": "gpt-4-1106-preview", # gpt-4-turbo + } + await websocket.send(protocol.client_command_create(example_config)) + # Wait for the response + response = await websocket.recv() + response = json.loads(response) + print(f"Server response:\n{json.dumps(response, indent=2)}") + + await asyncio.sleep(1) + + while True: + if closed_on_message: + # If we're on a retry after a disconnect, don't ask for input again + closed_on_message = False + else: + user_input = input("\nEnter your message: ") + print("\n") + + # Send a message to the agent + try: + await send_message_and_print_replies(websocket=websocket, user_message=user_input, agent_id=AGENT_NAME) + retry_attempts = 0 + except websockets.exceptions.ConnectionClosedError: + print("Connection to server was lost. Attempting to reconnect...") + closed_on_message = True + raise + + except websockets.exceptions.ConnectionClosedError: + # Decide whether or not to retry the connection + if retry_attempts < RECONNECT_MAX_TRIES: + retry_attempts += 1 + await asyncio.sleep(RECONNECT_DELAY) # Wait for N seconds before reconnecting + continue + else: + print(f"Max attempts exceeded ({retry_attempts} > {RECONNECT_MAX_TRIES})") + break + + except asyncio.TimeoutError: + print("Timeout waiting for the server response.") + continue + + except Exception as e: + print(f"An error occurred: {e}") + continue + + +asyncio.run(basic_cli_client()) diff --git a/memgpt/server/websocket_interface.py b/memgpt/server/ws_api/interface.py similarity index 97% rename from memgpt/server/websocket_interface.py rename to memgpt/server/ws_api/interface.py index 932b902c..5e675a17 100644 --- a/memgpt/server/websocket_interface.py +++ b/memgpt/server/ws_api/interface.py @@ -3,7 +3,7 @@ import threading from memgpt.interface import AgentInterface -import memgpt.server.websocket_protocol as protocol +import memgpt.server.ws_api.protocol as protocol class BaseWebSocketInterface(AgentInterface): @@ -20,6 +20,9 @@ class BaseWebSocketInterface(AgentInterface): """Unregister a client connection""" self.clients.remove(websocket) + def step_yield(self): + pass + class AsyncWebSocketInterface(BaseWebSocketInterface): """WebSocket calls are async""" diff --git a/memgpt/server/websocket_protocol.py b/memgpt/server/ws_api/protocol.py similarity index 86% rename from memgpt/server/websocket_protocol.py rename to memgpt/server/ws_api/protocol.py index 8c8d3ecb..93a0a100 100644 --- a/memgpt/server/websocket_protocol.py +++ b/memgpt/server/ws_api/protocol.py @@ -80,12 +80,12 @@ def server_agent_function_message(msg): # Client -> server -def client_user_message(msg, agent_name=None): +def client_user_message(msg, agent_id=None): return json.dumps( { "type": "user_message", "message": msg, - "agent_name": agent_name, + "agent_id": agent_id, } ) @@ -98,13 +98,3 @@ def client_command_create(config): "config": config, } ) - - -def client_command_load(agent_name): - return json.dumps( - { - "type": "command", - "command": "load_agent", - "name": agent_name, - } - ) diff --git a/memgpt/server/ws_api/server.py b/memgpt/server/ws_api/server.py new file mode 100644 index 00000000..3cecc1c8 --- /dev/null +++ b/memgpt/server/ws_api/server.py @@ -0,0 +1,107 @@ +import asyncio +import json +import traceback + +import websockets + +from memgpt.server.server import SyncServer +from memgpt.server.ws_api.interface import SyncWebSocketInterface +from memgpt.server.constants import DEFAULT_PORT +import memgpt.server.ws_api.protocol as protocol +import memgpt.system as system +import memgpt.constants as memgpt_constants + + +class WebSocketServer: + def __init__(self, host="localhost", port=DEFAULT_PORT): + self.host = host + self.port = port + self.interface = SyncWebSocketInterface() + self.server = SyncServer(default_interface=self.interface) + + def __del__(self): + self.interface.close() + + def initialize_server(self): + print("Server is initializing...") + print(f"Listening on {self.host}:{self.port}...") + + async def start_server(self): + self.initialize_server() + # Can play with ping_interval and ping_timeout + # See: https://websockets.readthedocs.io/en/stable/topics/timeouts.html + # and https://github.com/cpacker/MemGPT/issues/471 + async with websockets.serve(self.handle_client, self.host, self.port): + await asyncio.Future() # Run forever + + def run(self): + return self.start_server() # Return the coroutine + + async def handle_client(self, websocket, path): + self.interface.register_client(websocket) + try: + # async for message in websocket: + while True: + message = await websocket.recv() + + # Assuming the message is a JSON string + try: + data = json.loads(message) + except: + print(f"[server] bad data from client:\n{data}") + await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}")) + continue + + if "type" not in data: + print(f"[server] bad data from client (JSON but no type):\n{data}") + await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}")) + + elif data["type"] == "command": + # Create a new agent + if data["command"] == "create_agent": + try: + # self.agent = self.create_new_agent(data["config"]) + self.server.create_agent(user_id="NULL", agent_config=data["config"]) + await websocket.send(protocol.server_command_response("OK: Agent initialized")) + except Exception as e: + self.agent = None + print(f"[server] self.create_new_agent failed with:\n{e}") + print(f"{traceback.format_exc()}") + await websocket.send(protocol.server_command_response(f"Error: Failed to init agent - {str(e)}")) + + else: + print(f"[server] unrecognized client command type: {data}") + await websocket.send(protocol.server_error(f"unrecognized client command type: {data}")) + + elif data["type"] == "user_message": + user_message = data["message"] + + if "agent_id" not in data or data["agent_id"] is None: + await websocket.send(protocol.server_agent_response_error("agent_name was not specified in the request")) + continue + + await websocket.send(protocol.server_agent_response_start()) + try: + # self.run_step(user_message) + self.server.user_message(user_id="NULL", agent_id=data["agent_id"], message=user_message) + except Exception as e: + print(f"[server] self.server.user_message failed with:\n{e}") + print(f"{traceback.format_exc()}") + await websocket.send(protocol.server_agent_response_error(f"server.user_message failed with: {e}")) + await asyncio.sleep(1) # pause before sending the terminating message, w/o this messages may be missed + await websocket.send(protocol.server_agent_response_end()) + + # ... handle other message types as needed ... + else: + print(f"[server] unrecognized client package data type: {data}") + await websocket.send(protocol.server_error(f"unrecognized client package data type: {data}")) + + except websockets.exceptions.ConnectionClosed: + print(f"[server] connection with client was closed") + finally: + self.interface.unregister_client(websocket) + + +if __name__ == "__main__": + server = WebSocketServer() + asyncio.run(server.run()) diff --git a/poetry.lock b/poetry.lock index 2e54e414..6bd13a33 100644 --- a/poetry.lock +++ b/poetry.lock @@ -137,24 +137,24 @@ files = [ [[package]] name = "anyio" -version = "4.1.0" +version = "3.7.1" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false -python-versions = ">=3.8" +python-versions = ">=3.7" files = [ - {file = "anyio-4.1.0-py3-none-any.whl", hash = "sha256:56a415fbc462291813a94528a779597226619c8e78af7de0507333f700011e5f"}, - {file = "anyio-4.1.0.tar.gz", hash = "sha256:5a0bec7085176715be77df87fc66d6c9d70626bd752fcc85f57cdbee5b3760da"}, + {file = "anyio-3.7.1-py3-none-any.whl", hash = "sha256:91dee416e570e92c64041bd18b900d1d6fa78dff7048769ce5ac5ddad004fbb5"}, + {file = "anyio-3.7.1.tar.gz", hash = "sha256:44a3c9aba0f5defa43261a8b3efb97891f2bd7d804e0e1f56419befa1adfc780"}, ] [package.dependencies] -exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} +exceptiongroup = {version = "*", markers = "python_version < \"3.11\""} idna = ">=2.8" sniffio = ">=1.1" [package.extras] -doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] -trio = ["trio (>=0.23)"] +doc = ["Sphinx", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-jquery"] +test = ["anyio[trio]", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +trio = ["trio (<0.22)"] [[package]] name = "asgiref" @@ -737,19 +737,20 @@ test = ["pytest (>=6)"] [[package]] name = "fastapi" -version = "0.103.0" +version = "0.104.1" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "fastapi-0.103.0-py3-none-any.whl", hash = "sha256:61ab72c6c281205dd0cbaccf503e829a37e0be108d965ac223779a8479243665"}, - {file = "fastapi-0.103.0.tar.gz", hash = "sha256:4166732f5ddf61c33e9fa4664f73780872511e0598d4d5434b1816dc1e6d9421"}, + {file = "fastapi-0.104.1-py3-none-any.whl", hash = "sha256:752dc31160cdbd0436bb93bad51560b57e525cbb1d4bbf6f4904ceee75548241"}, + {file = "fastapi-0.104.1.tar.gz", hash = "sha256:e5e4540a7c5e1dcfbbcf5b903c234feddcdcd881f191977a1c5dfd917487e7ae"}, ] [package.dependencies] +anyio = ">=3.7.1,<4.0.0" pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" starlette = ">=0.27.0,<0.28.0" -typing-extensions = ">=4.5.0" +typing-extensions = ">=4.8.0" [package.extras] all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] @@ -3074,6 +3075,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.23.2" +description = "Pytest support for asyncio" +optional = true +python-versions = ">=3.8" +files = [ + {file = "pytest-asyncio-0.23.2.tar.gz", hash = "sha256:c16052382554c7b22d48782ab3438d5b10f8cf7a4bdcae7f0f67f097d95beecc"}, + {file = "pytest_asyncio-0.23.2-py3-none-any.whl", hash = "sha256:ea9021364e32d58f0be43b91c6233fb8d2224ccef2398d6837559e587682808f"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "python-box" version = "7.1.1" @@ -4774,11 +4793,12 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] [extras] -dev = ["black", "datasets", "pre-commit", "pytest"] +dev = ["black", "datasets", "pre-commit", "pytest", "pytest-asyncio"] local = ["huggingface-hub", "torch", "transformers"] postgres = ["pg8000", "pgvector", "psycopg", "psycopg-binary", "psycopg2-binary"] +server = ["fastapi", "uvicorn", "websockets"] [metadata] lock-version = "2.0" python-versions = "<3.12,>=3.9" -content-hash = "2d68f2515a73a9b2cafb445138c667f61153ac6feb23c124032fc2c2d56baf4a" \ No newline at end of file +content-hash = "4f675213d5a79f001bfb7441c9fba23ae114079ec61a30b0c88833c5427f152e" diff --git a/pyproject.toml b/pyproject.toml index 205068d5..00ac4d9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ pytz = "^2023.3.post1" tqdm = "^4.66.1" black = { version = "^23.10.1", optional = true } pytest = { version = "^7.4.3", optional = true } -llama-index = "0.9.13" +llama-index = "^0.9.13" setuptools = "^68.2.2" datasets = { version = "^2.14.6", optional = true} prettytable = "^3.9.0" @@ -39,7 +39,7 @@ transformers = { version = "4.34.1", optional = true } pre-commit = {version = "^3.5.0", optional = true } pg8000 = {version = "^1.30.3", optional = true} torch = {version = ">=2.0.0, !=2.0.1, !=2.1.0", optional = true} -websockets = "^12.0" +websockets = {version = "^12.0", optional = true} docstring-parser = "^0.15" lancedb = "^0.3.3" httpx = "^0.25.2" @@ -49,12 +49,17 @@ tiktoken = "^0.5.1" python-box = "^7.1.1" pypdf = "^3.17.1" pyyaml = "^6.0.1" +fastapi = {version = "^0.104.1", optional = true} +uvicorn = {version = "^0.24.0.post1", optional = true} chromadb = "^0.4.18" +pytest-asyncio = {version = "^0.23.2", optional = true} +pydantic = "^2.5.2" [tool.poetry.extras] local = ["torch", "huggingface-hub", "transformers"] postgres = ["pgvector", "psycopg", "psycopg-binary", "psycopg2-binary", "pg8000"] -dev = ["pytest", "black", "pre-commit", "datasets"] +dev = ["pytest", "pytest-asyncio", "black", "pre-commit", "datasets"] +server = ["websockets", "fastapi", "uvicorn"] [build-system] requires = ["poetry-core"] diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 00000000..66732779 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,43 @@ +import memgpt.utils as utils + +utils.DEBUG = True +from memgpt.server.server import SyncServer + + +def test_server(): + user_id = "NULL" + agent_id = "agent_26" + + server = SyncServer() + + try: + server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?") + except ValueError as e: + print(e) + except: + raise + + try: + server.user_message(user_id=user_id, agent_id=agent_id, message="/memory") + except ValueError as e: + print(e) + except: + raise + + try: + print(server.run_command(user_id=user_id, agent_id=agent_id, command="/memory")) + except ValueError as e: + print(e) + except: + raise + + try: + server.user_message(user_id=user_id, agent_id="agent no-exist", message="Hello?") + except ValueError as e: + print(e) + except: + raise + + +if __name__ == "__main__": + test_server() diff --git a/tests/test_websocket_interface.py b/tests/test_websocket_interface.py index e966cb05..e9971baa 100644 --- a/tests/test_websocket_interface.py +++ b/tests/test_websocket_interface.py @@ -1,9 +1,10 @@ +import os import pytest from unittest.mock import Mock, AsyncMock, MagicMock from memgpt.config import MemGPTConfig, AgentConfig -from memgpt.server.websocket_interface import SyncWebSocketInterface -import memgpt.presets as presets +from memgpt.server.ws_api.interface import SyncWebSocketInterface +import memgpt.presets.presets as presets import memgpt.utils as utils import memgpt.system as system from memgpt.persistence_manager import LocalStateManager @@ -54,19 +55,32 @@ async def test_websockets(): ws_interface.register_client(mock_websocket) # Create an agent and hook it up to the WebSocket interface - config = MemGPTConfig() + api_key = os.getenv("OPENAI_API_KEY") + if api_key is None: + ws_interface.close() + return + config = MemGPTConfig.load() + if config.openai_key is None: + config.openai_key = api_key + config.save() # Mock the persistence manager # create agents with defaults - agent_config = AgentConfig(persona="sam_pov", human="basic", model="gpt-4-1106-preview") + agent_config = AgentConfig( + persona="sam_pov", + human="basic", + model="gpt-4-1106-preview", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + ) persistence_manager = LocalStateManager(agent_config=agent_config) memgpt_agent = presets.use_preset( - presets.DEFAULT_PRESET, - config, # no agent config to provide - "gpt-4-1106-preview", - utils.get_persona_text("sam_pov"), - utils.get_human_text("basic"), + agent_config.preset, + agent_config, + agent_config.model, + agent_config.persona, # note: extracting the raw text, not pulling from a file + agent_config.human, # note: extracting raw text, not pulling from a file ws_interface, persistence_manager, ) diff --git a/tests/test_websocket_server.py b/tests/test_websocket_server.py index c8da0186..a2583f86 100644 --- a/tests/test_websocket_server.py +++ b/tests/test_websocket_server.py @@ -5,7 +5,7 @@ import websockets import pytest from memgpt.server.constants import DEFAULT_PORT -from memgpt.server.websocket_server import WebSocketServer +from memgpt.server.ws_api.server import WebSocketServer from memgpt.config import AgentConfig @@ -16,7 +16,9 @@ async def test_dummy(): @pytest.mark.asyncio async def test_websocket_server(): - server = WebSocketServer() + # host = "127.0.0.1" + host = "localhost" + server = WebSocketServer(host=host) server_task = asyncio.create_task(server.run()) # Create a task for the server # the agent config we want to ask the server to instantiate with @@ -28,23 +30,26 @@ async def test_websocket_server(): # ) test_config = {} - uri = f"ws://localhost:{DEFAULT_PORT}" - async with websockets.connect(uri) as websocket: - # Initialize the server with a test config - print("Sending config to server...") - await websocket.send(json.dumps({"type": "initialize", "config": test_config})) - # Wait for the response - response = await websocket.recv() - print(f"Response from the agent: {response}") + uri = f"ws://{host}:{DEFAULT_PORT}" + try: + async with websockets.connect(uri) as websocket: + # Initialize the server with a test config + print("Sending config to server...") + await websocket.send(json.dumps({"type": "initialize", "config": test_config})) + # Wait for the response + response = await websocket.recv() + print(f"Response from the agent: {response}") - await asyncio.sleep(1) # just in case + await asyncio.sleep(1) # just in case - # Send a message to the agent - print("Sending message to server...") - await websocket.send(json.dumps({"type": "message", "content": "Hello, Agent!"})) - # Wait for the response - # NOTE: we should be waiting for multiple responses - response = await websocket.recv() - print(f"Response from the agent: {response}") - - server_task.cancel() # Cancel the server task after the test + # Send a message to the agent + print("Sending message to server...") + await websocket.send(json.dumps({"type": "message", "content": "Hello, Agent!"})) + # Wait for the response + # NOTE: we should be waiting for multiple responses + response = await websocket.recv() + print(f"Response from the agent: {response}") + except (OSError, ConnectionRefusedError) as e: + print(f"Was unable to connect: {e}") + finally: + server_task.cancel() # Cancel the server task after the test