added memgpt server command (#611)
* added memgpt server command * added the option to specify a port (rest default 8283, ws default 8282) * fixed import in test * added agent saving on shutdown * added basic locking mechanism (assumes only one server.py is running at the same time) * remove 'STOP' from buffer when converting to list for the non-streaming POST resposne * removed duplicate on_event (redundant to lifespan) * added GET agents/memory route * added GET agent config * added GET server config * added PUT route for modifying agent core memory * refactored to put server loop in separate function called via main
This commit is contained in:
@@ -4,6 +4,10 @@ import sys
|
||||
import io
|
||||
import logging
|
||||
import questionary
|
||||
from pathlib import Path
|
||||
import os
|
||||
import subprocess
|
||||
from enum import Enum
|
||||
|
||||
from llama_index import set_global_service_context
|
||||
from llama_index import ServiceContext
|
||||
@@ -18,6 +22,76 @@ from memgpt.config import MemGPTConfig, AgentConfig
|
||||
from memgpt.constants import MEMGPT_DIR, CLI_WARNING_PREFIX
|
||||
from memgpt.agent import Agent
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.server.constants import WS_DEFAULT_PORT, REST_DEFAULT_PORT
|
||||
|
||||
|
||||
class ServerChoice(Enum):
|
||||
rest_api = "rest"
|
||||
ws_api = "websocket"
|
||||
|
||||
|
||||
def server(
|
||||
type: ServerChoice = typer.Option("rest", help="Server to run"), port: int = typer.Option(None, help="Port to run the server on")
|
||||
):
|
||||
"""Launch a MemGPT server process"""
|
||||
|
||||
if type == ServerChoice.rest_api:
|
||||
if port is None:
|
||||
port = REST_DEFAULT_PORT
|
||||
|
||||
# Change to the desired directory
|
||||
script_path = Path(__file__).resolve()
|
||||
script_dir = script_path.parent
|
||||
|
||||
server_directory = os.path.join(script_dir.parent, "server", "rest_api")
|
||||
command = f"uvicorn server:app --reload --port {port}"
|
||||
|
||||
# Run the command
|
||||
print(f"Running REST server: {command} (inside {server_directory})")
|
||||
|
||||
try:
|
||||
# Start the subprocess in a new session
|
||||
process = subprocess.Popen(command, shell=True, start_new_session=True, cwd=server_directory)
|
||||
process.wait()
|
||||
except KeyboardInterrupt:
|
||||
# Handle CTRL-C
|
||||
print("Terminating the server...")
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
print("Server terminated with kill()")
|
||||
sys.exit(0)
|
||||
|
||||
elif type == ServerChoice.ws_api:
|
||||
if port is None:
|
||||
port = WS_DEFAULT_PORT
|
||||
|
||||
# Change to the desired directory
|
||||
script_path = Path(__file__).resolve()
|
||||
script_dir = script_path.parent
|
||||
|
||||
server_directory = os.path.join(script_dir.parent, "server", "ws_api")
|
||||
command = f"python server.py {port}"
|
||||
|
||||
# Run the command
|
||||
print(f"Running WS (websockets) server: {command} (inside {server_directory})")
|
||||
|
||||
try:
|
||||
# Start the subprocess in a new session
|
||||
process = subprocess.Popen(command, shell=True, start_new_session=True, cwd=server_directory)
|
||||
process.wait()
|
||||
except KeyboardInterrupt:
|
||||
# Handle CTRL-C
|
||||
print("Terminating the server...")
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
print("Server terminated with kill()")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def run(
|
||||
|
||||
@@ -21,7 +21,7 @@ from memgpt.interface import CLIInterface as interface # for printing to termin
|
||||
import memgpt.agent as agent
|
||||
import memgpt.system as system
|
||||
import memgpt.constants as constants
|
||||
from memgpt.cli.cli import run, attach, version
|
||||
from memgpt.cli.cli import run, attach, version, server
|
||||
from memgpt.cli.cli_config import configure, list, add
|
||||
from memgpt.cli.cli_load import app as load_app
|
||||
from memgpt.connectors.storage import StorageConnector
|
||||
@@ -33,6 +33,7 @@ app.command(name="attach")(attach)
|
||||
app.command(name="configure")(configure)
|
||||
app.command(name="list")(list)
|
||||
app.command(name="add")(add)
|
||||
app.command(name="server")(server)
|
||||
# load data commands
|
||||
app.add_typer(load_app, name="load")
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
DEFAULT_PORT = 8282
|
||||
# WebSockets
|
||||
WS_DEFAULT_PORT = 8282
|
||||
WS_CLIENT_TIMEOUT = 30
|
||||
|
||||
CLIENT_TIMEOUT = 30
|
||||
# REST
|
||||
REST_DEFAULT_PORT = 8283
|
||||
|
||||
@@ -18,6 +18,8 @@ class QueuingInterface(AgentInterface):
|
||||
items.append(self.buffer.get_nowait())
|
||||
except queue.Empty:
|
||||
break
|
||||
if len(items) > 1 and items[-1] == "STOP":
|
||||
items.pop()
|
||||
return items
|
||||
|
||||
def clear(self):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
import json
|
||||
from typing import Union
|
||||
|
||||
@@ -38,17 +39,64 @@ class Command(BaseModel):
|
||||
command: str
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
interface = QueuingInterface()
|
||||
server = SyncServer(default_interface=interface)
|
||||
class CoreMemory(BaseModel):
|
||||
user_id: str
|
||||
agent_id: str
|
||||
human: str | None = None
|
||||
persona: str | None = None
|
||||
|
||||
|
||||
server = None
|
||||
interface = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(application: FastAPI):
|
||||
global server
|
||||
global interface
|
||||
interface = QueuingInterface()
|
||||
server = SyncServer(default_interface=interface)
|
||||
yield
|
||||
server.save_agents()
|
||||
server = None
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# app = FastAPI()
|
||||
# server = SyncServer(default_interface=interface)
|
||||
|
||||
|
||||
# server.list_agents
|
||||
@app.get("/agents")
|
||||
def list_agents(user_id: str):
|
||||
interface.clear()
|
||||
agents_list = utils.list_agent_config_files()
|
||||
return {"num_agents": len(agents_list), "agent_names": agents_list}
|
||||
return server.list_agents(user_id=user_id)
|
||||
|
||||
|
||||
@app.get("/agents/memory")
|
||||
def get_agent_memory(user_id: str, agent_id: str):
|
||||
interface.clear()
|
||||
return server.get_agent_memory(user_id=user_id, agent_id=agent_id)
|
||||
|
||||
|
||||
@app.put("/agents/memory")
|
||||
def get_agent_memory(body: CoreMemory):
|
||||
interface.clear()
|
||||
new_memory_contents = {"persona": body.persona, "human": body.human}
|
||||
return server.update_agent_core_memory(user_id=body.user_id, agent_id=body.agent_id, new_memory_contents=new_memory_contents)
|
||||
|
||||
|
||||
@app.get("/agents/config")
|
||||
def get_agent_config(user_id: str, agent_id: str):
|
||||
interface.clear()
|
||||
return server.get_agent_config(user_id=user_id, agent_id=agent_id)
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
def get_server_config(user_id: str):
|
||||
interface.clear()
|
||||
return server.get_server_config(user_id=user_id)
|
||||
|
||||
|
||||
# server.create_agent
|
||||
@@ -88,6 +136,8 @@ async def user_message(body: UserMessage):
|
||||
|
||||
# Return the streaming response using the generator
|
||||
return StreamingResponse(formatted_message_generator(), media_type="text/event-stream")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
@@ -95,6 +145,8 @@ async def user_message(body: UserMessage):
|
||||
interface.clear()
|
||||
try:
|
||||
server.user_message(user_id=body.user_id, agent_id=body.agent_id, message=body.message)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return {"messages": interface.to_list()}
|
||||
@@ -106,6 +158,8 @@ def run_command(body: Command):
|
||||
interface.clear()
|
||||
try:
|
||||
response = server.run_command(user_id=body.user_id, agent_id=body.agent_id, command=body.command)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
response = server.run_command(user_id=body.user_id, agent_id=body.agent_id, command=body.command)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Union
|
||||
from typing import Union, Callable
|
||||
import json
|
||||
from threading import Lock
|
||||
from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
|
||||
from memgpt.system import package_user_message
|
||||
from memgpt.config import AgentConfig
|
||||
from memgpt.config import AgentConfig, MemGPTConfig
|
||||
from memgpt.agent import Agent
|
||||
import memgpt.system as system
|
||||
import memgpt.constants as constants
|
||||
@@ -11,6 +14,7 @@ from memgpt.cli.cli import attach
|
||||
from memgpt.connectors.storage import StorageConnector
|
||||
import memgpt.presets.presets as presets
|
||||
import memgpt.utils as utils
|
||||
import memgpt.server.utils as server_utils
|
||||
from memgpt.persistence_manager import PersistenceManager, LocalStateManager
|
||||
|
||||
# TODO use custom interface
|
||||
@@ -22,10 +26,30 @@ class Server(object):
|
||||
"""Abstract server class that supports multi-agent multi-user"""
|
||||
|
||||
@abstractmethod
|
||||
def list_agents(self, user_id: str, agent_id: str) -> str:
|
||||
def list_agents(self, user_id: str) -> dict:
|
||||
"""List all available agents to a user"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_agent_memory(self, user_id: str, agent_id: str) -> dict:
|
||||
"""Return the memory of an agent (core memory + non-core statistics)"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_agent_config(self, user_id: str, agent_id: str) -> dict:
|
||||
"""Return the config of an agent"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_server_config(self, user_id: str) -> dict:
|
||||
"""Return the base config"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_agent_core_memory(self, user_id: str, agent_id: str, new_memory_contents: dict) -> dict:
|
||||
"""Update the agents core memory block, return the new state"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_agent(
|
||||
self,
|
||||
@@ -51,8 +75,50 @@ class Server(object):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LockingServer(Server):
|
||||
"""Basic support for concurrency protections (all requests that modify an agent lock the agent until the operation is complete)"""
|
||||
|
||||
# Locks for each agent
|
||||
_agent_locks = {}
|
||||
|
||||
@staticmethod
|
||||
def agent_lock_decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(self, user_id: str, agent_id: str, *args, **kwargs):
|
||||
# print("Locking check")
|
||||
|
||||
# Initialize the lock for the agent_id if it doesn't exist
|
||||
if agent_id not in self._agent_locks:
|
||||
# print(f"Creating lock for agent_id = {agent_id}")
|
||||
self._agent_locks[agent_id] = Lock()
|
||||
|
||||
# Check if the agent is currently locked
|
||||
if not self._agent_locks[agent_id].acquire(blocking=False):
|
||||
# print(f"agent_id = {agent_id} is busy")
|
||||
raise HTTPException(status_code=423, detail=f"Agent '{agent_id}' is currently busy.")
|
||||
|
||||
try:
|
||||
# Execute the function
|
||||
# print(f"running function on agent_id = {agent_id}")
|
||||
return func(self, user_id, agent_id, *args, **kwargs)
|
||||
finally:
|
||||
# Release the lock
|
||||
# print(f"releasing lock on agent_id = {agent_id}")
|
||||
self._agent_locks[agent_id].release()
|
||||
|
||||
return wrapper
|
||||
|
||||
@agent_lock_decorator
|
||||
def user_message(self, user_id: str, agent_id: str, message: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@agent_lock_decorator
|
||||
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# TODO actually use "user_id" for something
|
||||
class SyncServer(Server):
|
||||
class SyncServer(LockingServer):
|
||||
"""Simple single-threaded / blocking server process"""
|
||||
|
||||
def __init__(
|
||||
@@ -82,6 +148,14 @@ class SyncServer(Server):
|
||||
# The default persistence manager that will get assigned to agents ON CREATION
|
||||
self.default_persistence_manager_cls = default_persistence_manager_cls
|
||||
|
||||
def save_agents(self):
|
||||
for agent_d in self.active_agents:
|
||||
try:
|
||||
agent_d["agent"].save()
|
||||
print(f"Saved agent {agent_d['agent_id']}")
|
||||
except Exception as e:
|
||||
print(f"Error occured while trying to save agent {agent_d['agent_id']}:\n{e}")
|
||||
|
||||
def _get_agent(self, user_id: str, agent_id: str) -> Union[Agent, None]:
|
||||
"""Get the agent object from the in-memory object store"""
|
||||
for d in self.active_agents:
|
||||
@@ -302,6 +376,7 @@ class SyncServer(Server):
|
||||
input_message = system.get_token_limit_warning()
|
||||
self._step(user_id=user_id, agent_id=agent_id, input_message=input_message)
|
||||
|
||||
@LockingServer.agent_lock_decorator
|
||||
def user_message(self, user_id: str, agent_id: str, message: str) -> None:
|
||||
"""Process an incoming user message and feed it through the MemGPT agent"""
|
||||
from memgpt.utils import printd
|
||||
@@ -321,6 +396,7 @@ class SyncServer(Server):
|
||||
# Run the agent state forward
|
||||
self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message)
|
||||
|
||||
@LockingServer.agent_lock_decorator
|
||||
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
|
||||
"""Run a command on the agent"""
|
||||
# If the input begins with a command prefix, attempt to process it as a command
|
||||
@@ -363,3 +439,83 @@ class SyncServer(Server):
|
||||
print(f"Created new agent from config: {agent}")
|
||||
|
||||
return agent.config.name
|
||||
|
||||
def list_agents(self, user_id: str) -> dict:
|
||||
"""List all available agents to a user"""
|
||||
agents_list = utils.list_agent_config_files()
|
||||
return {"num_agents": len(agents_list), "agent_names": agents_list}
|
||||
|
||||
def get_agent_memory(self, user_id: str, agent_id: str) -> dict:
|
||||
"""Return the memory of an agent (core memory + non-core statistics)"""
|
||||
# Get the agent object (loaded in memory)
|
||||
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
|
||||
core_memory = memgpt_agent.memory
|
||||
recall_memory = memgpt_agent.persistence_manager.recall_memory
|
||||
archival_memory = memgpt_agent.persistence_manager.archival_memory
|
||||
|
||||
memory_obj = {
|
||||
"core_memory": {
|
||||
"persona": core_memory.persona,
|
||||
"human": core_memory.human,
|
||||
},
|
||||
"recall_memory": len(recall_memory) if recall_memory is not None else None,
|
||||
"archival_memory": len(archival_memory) if archival_memory is not None else None,
|
||||
}
|
||||
|
||||
return memory_obj
|
||||
|
||||
def get_agent_config(self, user_id: str, agent_id: str) -> dict:
|
||||
"""Return the config of an agent"""
|
||||
# Get the agent object (loaded in memory)
|
||||
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
agent_config = vars(memgpt_agent.config)
|
||||
|
||||
return agent_config
|
||||
|
||||
def get_server_config(self, user_id: str) -> dict:
|
||||
"""Return the base config"""
|
||||
base_config = vars(MemGPTConfig.load())
|
||||
|
||||
def clean_keys(config):
|
||||
config_copy = config.copy()
|
||||
for k, v in config.items():
|
||||
if k == "key" or "_key" in k:
|
||||
config_copy[k] = server_utils.shorten_key_middle(v, chars_each_side=5)
|
||||
return config_copy
|
||||
|
||||
clean_base_config = clean_keys(base_config)
|
||||
return clean_base_config
|
||||
|
||||
def update_agent_core_memory(self, user_id: str, agent_id: str, new_memory_contents: dict) -> dict:
|
||||
"""Update the agents core memory block, return the new state"""
|
||||
# Get the agent object (loaded in memory)
|
||||
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
|
||||
old_core_memory = self.get_agent_memory(user_id=user_id, agent_id=agent_id)["core_memory"]
|
||||
new_core_memory = old_core_memory.copy()
|
||||
|
||||
modified = False
|
||||
if "persona" in new_memory_contents and new_memory_contents["persona"] is not None:
|
||||
new_persona = new_memory_contents["persona"]
|
||||
if old_core_memory["persona"] != new_persona:
|
||||
new_core_memory["persona"] = new_persona
|
||||
memgpt_agent.memory.edit_persona(new_persona)
|
||||
modified = True
|
||||
|
||||
elif "human" in new_memory_contents and new_memory_contents["human"] is not None:
|
||||
new_human = new_memory_contents["human"]
|
||||
if old_core_memory["human"] != new_human:
|
||||
new_core_memory["human"] = new_human
|
||||
memgpt_agent.memory.edit_human(new_human)
|
||||
modified = True
|
||||
|
||||
# If we modified the memory contents, we need to rebuild the memory block inside the system message
|
||||
if modified:
|
||||
memgpt_agent.rebuild_memory()
|
||||
|
||||
return {
|
||||
"old_core_memory": old_core_memory,
|
||||
"new_core_memory": new_core_memory,
|
||||
"modified": modified,
|
||||
}
|
||||
|
||||
@@ -24,3 +24,21 @@ def print_server_response(response):
|
||||
print(response)
|
||||
else:
|
||||
print(response)
|
||||
|
||||
|
||||
def shorten_key_middle(key_string, chars_each_side=3):
|
||||
"""
|
||||
Shortens a key string by showing a specified number of characters on each side and adding an ellipsis in the middle.
|
||||
|
||||
Args:
|
||||
key_string (str): The key string to be shortened.
|
||||
chars_each_side (int): The number of characters to show on each side of the ellipsis.
|
||||
|
||||
Returns:
|
||||
str: The shortened key string with an ellipsis in the middle.
|
||||
"""
|
||||
key_length = len(key_string)
|
||||
if key_length <= 2 * chars_each_side:
|
||||
return "..." # Return ellipsis if the key is too short
|
||||
else:
|
||||
return key_string[:chars_each_side] + "..." + key_string[-chars_each_side:]
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
import websockets
|
||||
|
||||
import memgpt.server.ws_api.protocol as protocol
|
||||
from memgpt.server.constants import DEFAULT_PORT, CLIENT_TIMEOUT
|
||||
from memgpt.server.constants import WS_DEFAULT_PORT, WS_CLIENT_TIMEOUT
|
||||
from memgpt.server.utils import condition_to_stop_receiving, print_server_response
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ async def send_message_and_print_replies(websocket, user_message, agent_id):
|
||||
|
||||
# Wait for messages in a loop, since the server may send a few
|
||||
while True:
|
||||
response = await asyncio.wait_for(websocket.recv(), CLIENT_TIMEOUT)
|
||||
response = await asyncio.wait_for(websocket.recv(), WS_CLIENT_TIMEOUT)
|
||||
response = json.loads(response)
|
||||
|
||||
if CLEAN_RESPONSES:
|
||||
@@ -44,7 +44,7 @@ async def basic_cli_client():
|
||||
|
||||
Meant to illustrate how to use the server.py process, so limited in features (only supports sending user messages)
|
||||
"""
|
||||
uri = f"ws://localhost:{DEFAULT_PORT}"
|
||||
uri = f"ws://localhost:{WS_DEFAULT_PORT}"
|
||||
|
||||
closed_on_message = False
|
||||
retry_attempts = 0
|
||||
|
||||
@@ -1,26 +1,37 @@
|
||||
import asyncio
|
||||
import json
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import websockets
|
||||
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.server.ws_api.interface import SyncWebSocketInterface
|
||||
from memgpt.server.constants import DEFAULT_PORT
|
||||
from memgpt.server.constants import WS_DEFAULT_PORT
|
||||
import memgpt.server.ws_api.protocol as protocol
|
||||
import memgpt.system as system
|
||||
import memgpt.constants as memgpt_constants
|
||||
|
||||
|
||||
class WebSocketServer:
|
||||
def __init__(self, host="localhost", port=DEFAULT_PORT):
|
||||
def __init__(self, host="localhost", port=WS_DEFAULT_PORT):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.interface = SyncWebSocketInterface()
|
||||
self.server = SyncServer(default_interface=self.interface)
|
||||
|
||||
def __del__(self):
|
||||
self.interface.close()
|
||||
def shutdown_server(self):
|
||||
try:
|
||||
self.server.save_agents()
|
||||
print(f"Saved agents")
|
||||
except Exception as e:
|
||||
print(f"Saving agents failed with: {e}")
|
||||
try:
|
||||
self.interface.close()
|
||||
print(f"Closed the WS interface")
|
||||
except Exception as e:
|
||||
print(f"Closing the WS interface failed with: {e}")
|
||||
|
||||
def initialize_server(self):
|
||||
print("Server is initializing...")
|
||||
@@ -102,6 +113,36 @@ class WebSocketServer:
|
||||
self.interface.unregister_client(websocket)
|
||||
|
||||
|
||||
def start_server():
|
||||
# Check if a port argument is provided
|
||||
port = WS_DEFAULT_PORT
|
||||
if len(sys.argv) > 1:
|
||||
try:
|
||||
port = int(sys.argv[1])
|
||||
except ValueError:
|
||||
print(f"Invalid port number. Using default port {port}.")
|
||||
|
||||
server = WebSocketServer(port=port)
|
||||
|
||||
def handle_sigterm(*args):
|
||||
# Perform necessary cleanup
|
||||
print("SIGTERM received, shutting down...")
|
||||
# Note: This should be quick and not involve asynchronous calls
|
||||
print("Shutting down the server...")
|
||||
server.shutdown_server()
|
||||
print("Server has been shut down.")
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, handle_sigterm)
|
||||
|
||||
try:
|
||||
asyncio.run(server.run())
|
||||
except KeyboardInterrupt:
|
||||
print("Shutting down the server...")
|
||||
finally:
|
||||
server.shutdown_server()
|
||||
print("Server has been shut down.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
server = WebSocketServer()
|
||||
asyncio.run(server.run())
|
||||
start_server()
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
import websockets
|
||||
import pytest
|
||||
|
||||
from memgpt.server.constants import DEFAULT_PORT
|
||||
from memgpt.server.constants import WS_DEFAULT_PORT
|
||||
from memgpt.server.ws_api.server import WebSocketServer
|
||||
from memgpt.config import AgentConfig
|
||||
|
||||
@@ -30,7 +30,7 @@ async def test_websocket_server():
|
||||
# )
|
||||
test_config = {}
|
||||
|
||||
uri = f"ws://{host}:{DEFAULT_PORT}"
|
||||
uri = f"ws://{host}:{WS_DEFAULT_PORT}"
|
||||
try:
|
||||
async with websockets.connect(uri) as websocket:
|
||||
# Initialize the server with a test config
|
||||
|
||||
Reference in New Issue
Block a user