API server refactor + REST API (#593)

* init server refactor

* refactored websockets server/client code to use internal server API

* added intentional fail on test

* update workflow to try and get test to pass remotely

* refactor to put websocket code in a separate subdirectory

* added fastapi rest server

* add error handling

* modified interface return style

* disabled certain tests on remote

* added SSE response option for user_message

* fix ws interface test

* fallback for oai key

* add soft fail for test when localhost is borked

* add step_yield for all server related interfaces

* extra catch

* update toml + lock with server add-ons (add uvicorn+fastapi, move websockets to server extra)

* regen lock file

* added pytest-asyncio as an extra in dev

* add pydantic to deps

* renamed CreateConfig to CreateAgentConfig

* fixed POST request for creating agent + tested it
This commit is contained in:
Charles Packer
2023-12-11 15:08:42 -08:00
committed by GitHub
parent 033d9d61f4
commit b7427e2de7
17 changed files with 906 additions and 346 deletions

View File

@@ -42,7 +42,7 @@ jobs:
PGVECTOR_TEST_DB_URL: ${{ secrets.PGVECTOR_TEST_DB_URL }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry install -E dev -E postgres -E local
poetry install -E dev -E postgres -E local -E server
- name: Set Poetry config
env:

View File

View File

@@ -0,0 +1,75 @@
import asyncio
import queue
from memgpt.interface import AgentInterface
class QueuingInterface(AgentInterface):
"""Messages are queued inside an internal buffer and manually flushed"""
def __init__(self):
self.buffer = queue.Queue()
def to_list(self):
"""Convert queue to a list (empties it out at the same time)"""
items = []
while not self.buffer.empty():
try:
items.append(self.buffer.get_nowait())
except queue.Empty:
break
return items
def clear(self):
"""Clear all messages from the queue."""
with self.buffer.mutex:
# Empty the queue
self.buffer.queue.clear()
async def message_generator(self):
while True:
if not self.buffer.empty():
message = self.buffer.get()
if message == "STOP":
break
yield message
else:
await asyncio.sleep(0.1) # Small sleep to prevent a busy loop
def step_yield(self):
"""Enqueue a special stop message"""
self.buffer.put("STOP")
def user_message(self, msg: str):
"""Handle reception of a user message"""
pass
def internal_monologue(self, msg: str) -> None:
"""Handle the agent's internal monologue"""
print(msg)
self.buffer.put({"internal_monologue": msg})
def assistant_message(self, msg: str) -> None:
"""Handle the agent sending a message"""
print(msg)
self.buffer.put({"assistant_message": msg})
def function_message(self, msg: str) -> None:
"""Handle the agent calling a function"""
print(msg)
if msg.startswith("Running "):
msg = msg.replace("Running ", "")
self.buffer.put({"function_call": msg})
elif msg.startswith("Success: "):
msg = msg.replace("Success: ", "")
self.buffer.put({"function_return": msg, "status": "success"})
elif msg.startswith("Error: "):
msg = msg.replace("Error: ", "")
self.buffer.put({"function_return": msg, "status": "error"})
else:
# NOTE: generic, should not happen
self.buffer.put({"function_message": msg})

View File

@@ -0,0 +1,112 @@
import asyncio
import json
from typing import Union
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from memgpt.server.server import SyncServer
from memgpt.server.rest_api.interface import QueuingInterface
import memgpt.utils as utils
"""
Basic REST API sitting on top of the internal MemGPT python server (SyncServer)
Start the server with:
cd memgpt/server/rest_api
poetry run uvicorn server:app --reload
"""
class CreateAgentConfig(BaseModel):
user_id: str
config: dict
class UserMessage(BaseModel):
user_id: str
agent_id: str
message: str
stream: bool = False
class Command(BaseModel):
user_id: str
agent_id: str
command: str
app = FastAPI()
interface = QueuingInterface()
server = SyncServer(default_interface=interface)
# server.list_agents
@app.get("/agents")
def list_agents(user_id: str):
interface.clear()
agents_list = utils.list_agent_config_files()
return {"num_agents": len(agents_list), "agent_names": agents_list}
# server.create_agent
@app.post("/agents")
def create_agents(body: CreateAgentConfig):
interface.clear()
try:
agent_id = server.create_agent(user_id=body.user_id, agent_config=body.config)
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return {"agent_id": agent_id}
# server.user_message
@app.post("/agents/message")
async def user_message(body: UserMessage):
if body.stream:
# For streaming response
try:
# Start the generation process (similar to the non-streaming case)
# This should be a non-blocking call or run in a background task
# Check if server.user_message is an async function
if asyncio.iscoroutinefunction(server.user_message):
# Start the async task
asyncio.create_task(server.user_message(user_id=body.user_id, agent_id=body.agent_id, message=body.message))
else:
# Run the synchronous function in a thread pool
loop = asyncio.get_event_loop()
loop.run_in_executor(None, server.user_message, body.user_id, body.agent_id, body.message)
async def formatted_message_generator():
async for message in interface.message_generator():
formatted_message = f"data: {json.dumps(message)}\n\n"
yield formatted_message
await asyncio.sleep(1)
# Return the streaming response using the generator
return StreamingResponse(formatted_message_generator(), media_type="text/event-stream")
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
else:
interface.clear()
try:
server.user_message(user_id=body.user_id, agent_id=body.agent_id, message=body.message)
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return {"messages": interface.to_list()}
# server.run_command
@app.post("/agents/command")
def run_command(body: Command):
interface.clear()
try:
response = server.run_command(user_id=body.user_id, agent_id=body.agent_id, command=body.command)
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
response = server.run_command(user_id=body.user_id, agent_id=body.agent_id, command=body.command)
return {"response": response}

365
memgpt/server/server.py Normal file
View File

@@ -0,0 +1,365 @@
from abc import abstractmethod
from typing import Union
import json
from memgpt.system import package_user_message
from memgpt.config import AgentConfig
from memgpt.agent import Agent
import memgpt.system as system
import memgpt.constants as constants
from memgpt.cli.cli import attach
from memgpt.connectors.storage import StorageConnector
import memgpt.presets.presets as presets
import memgpt.utils as utils
from memgpt.persistence_manager import PersistenceManager, LocalStateManager
# TODO use custom interface
from memgpt.interface import CLIInterface # for printing to terminal
from memgpt.interface import AgentInterface # abstract
class Server(object):
"""Abstract server class that supports multi-agent multi-user"""
@abstractmethod
def list_agents(self, user_id: str, agent_id: str) -> str:
"""List all available agents to a user"""
raise NotImplementedError
@abstractmethod
def create_agent(
self,
user_id: str,
agent_config: Union[dict, AgentConfig],
interface: Union[AgentInterface, None],
persistence_manager: Union[PersistenceManager, None],
) -> str:
"""Create a new agent using a config"""
raise NotImplementedError
@abstractmethod
def user_message(self, user_id: str, agent_id: str, message: str) -> None:
"""Process a message from the user, internally calls step"""
raise NotImplementedError
@abstractmethod
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
"""Run a command on the agent, e.g. /memory
May return a string with a message generated by the command
"""
raise NotImplementedError
# TODO actually use "user_id" for something
class SyncServer(Server):
"""Simple single-threaded / blocking server process"""
def __init__(
self,
chaining: bool = True,
max_chaining_steps: bool = None,
# default_interface_cls: AgentInterface = CLIInterface,
default_interface: AgentInterface = CLIInterface(),
default_persistence_manager_cls: PersistenceManager = LocalStateManager,
):
"""Server process holds in-memory agents that are being run"""
# List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts
self.active_agents = []
# chaining = whether or not to run again if request_heartbeat=true
self.chaining = chaining
# if chaining == true, what's the max number of times we'll chain before yielding?
# none = no limit, can go on forever
self.max_chaining_steps = max_chaining_steps
# The default interface that will get assigned to agents ON LOAD
# self.default_interface_cls = default_interface_cls
self.default_interface = default_interface
# The default persistence manager that will get assigned to agents ON CREATION
self.default_persistence_manager_cls = default_persistence_manager_cls
def _get_agent(self, user_id: str, agent_id: str) -> Union[Agent, None]:
"""Get the agent object from the in-memory object store"""
for d in self.active_agents:
if d["user_id"] == user_id and d["agent_id"] == agent_id:
return d["agent"]
return None
def _add_agent(self, user_id: str, agent_id: str, agent_obj: Agent) -> None:
"""Put an agent object inside the in-memory object store"""
# Make sure the agent doesn't already exist
if self._get_agent(user_id=user_id, agent_id=agent_id) is not None:
raise KeyError(f"Agent (user={user_id}, agent={agent_id}) is already loaded")
# Add Agent instance to the in-memory list
self.active_agents.append(
{
"user_id": user_id,
"agent_id": agent_id,
"agent": agent_obj,
}
)
def _load_agent(self, user_id: str, agent_id: str, interface: Union[AgentInterface, None] = None) -> Agent:
"""Loads a saved agent into memory (if it doesn't exist, throw an error)"""
from memgpt.utils import printd
# If an interface isn't specified, use the default
if interface is None:
interface = self.default_interface
# If the agent isn't load it, load it and put it into memory
if AgentConfig.exists(agent_id):
printd(f"(user={user_id}, agent={agent_id}) exists, loading into memory...")
agent_config = AgentConfig.load(agent_id)
memgpt_agent = Agent.load_agent(interface=interface, agent_config=agent_config)
self._add_agent(user_id=user_id, agent_id=agent_id, agent_obj=memgpt_agent)
return memgpt_agent
# If the agent doesn't exist, throw an error
else:
raise ValueError(f"agent_id {agent_id} does not exist")
def _get_or_load_agent(self, user_id: str, agent_id: str) -> Agent:
"""Check if the agent is in-memory, then load"""
memgpt_agent = self._get_agent(user_id=user_id, agent_id=agent_id)
if not memgpt_agent:
memgpt_agent = self._load_agent(user_id=user_id, agent_id=agent_id)
return memgpt_agent
def _step(self, user_id: str, agent_id: str, input_message: str) -> None:
"""Send the input message through the agent"""
from memgpt.utils import printd
printd(f"Got input message: {input_message}")
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
if memgpt_agent is None:
raise KeyError(f"Agent (user={user_id}, agent={agent_id}) is not loaded")
printd(f"Starting agent step")
no_verify = True
next_input_message = input_message
counter = 0
while True:
new_messages, heartbeat_request, function_failed, token_warning = memgpt_agent.step(
next_input_message, first_message=False, skip_verify=no_verify
)
counter += 1
# Chain stops
if not self.chaining:
printd("No chaining, stopping after one step")
break
elif self.max_chaining_steps is not None and counter > self.max_chaining_steps:
printd(f"Hit max chaining steps, stopping after {counter} steps")
break
# Chain handlers
elif token_warning:
next_input_message = system.get_token_limit_warning()
continue # always chain
elif function_failed:
next_input_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE)
continue # always chain
elif heartbeat_request:
next_input_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE)
continue # always chain
# MemGPT no-op / yield
else:
break
memgpt_agent.interface.step_yield()
printd(f"Finished agent step")
def _command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
"""Process a CLI command"""
from memgpt.utils import printd
printd(f"Got command: {command}")
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
if command.lower() == "exit":
# exit not supported on server.py
raise ValueError(command)
elif command.lower() == "save" or command.lower() == "savechat":
memgpt_agent.save()
elif command.lower() == "attach":
# Different from CLI, we extract the data source name from the command
command = command.strip().split()
try:
data_source = int(command[1])
except:
raise ValueError(command)
# TODO: check if agent already has it
data_source_options = StorageConnector.list_loaded_data()
if len(data_source_options) == 0:
raise ValueError('No sources available. You must load a souce with "memgpt load ..." before running /attach.')
elif data_source not in data_source_options:
raise ValueError(f"Invalid data source name: {data_source} (options={data_source_options})")
else:
# attach new data
attach(memgpt_agent.config.name, data_source)
# update agent config
memgpt_agent.config.attach_data_source(data_source)
# reload agent with new data source
# TODO: maybe make this less ugly...
memgpt_agent.persistence_manager.archival_memory.storage = StorageConnector.get_storage_connector(
agent_config=memgpt_agent.config
)
elif command.lower() == "dump" or command.lower().startswith("dump "):
# Check if there's an additional argument that's an integer
command = command.strip().split()
amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0
if amount == 0:
memgpt_agent.interface.print_messages(memgpt_agent.messages, dump=True)
else:
memgpt_agent.interface.print_messages(memgpt_agent.messages[-min(amount, len(memgpt_agent.messages)) :], dump=True)
elif command.lower() == "dumpraw":
memgpt_agent.interface.print_messages_raw(memgpt_agent.messages)
elif command.lower() == "memory":
ret_str = (
f"\nDumping memory contents:\n"
+ f"\n{str(memgpt_agent.memory)}"
+ f"\n{str(memgpt_agent.persistence_manager.archival_memory)}"
+ f"\n{str(memgpt_agent.persistence_manager.recall_memory)}"
)
return ret_str
elif command.lower() == "pop" or command.lower().startswith("pop "):
# Check if there's an additional argument that's an integer
command = command.strip().split()
pop_amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 3
n_messages = len(memgpt_agent.messages)
MIN_MESSAGES = 2
if n_messages <= MIN_MESSAGES:
print(f"Agent only has {n_messages} messages in stack, none left to pop")
elif n_messages - pop_amount < MIN_MESSAGES:
print(f"Agent only has {n_messages} messages in stack, cannot pop more than {n_messages - MIN_MESSAGES}")
else:
print(f"Popping last {pop_amount} messages from stack")
for _ in range(min(pop_amount, len(memgpt_agent.messages))):
memgpt_agent.messages.pop()
elif command.lower() == "retry":
# TODO this needs to also modify the persistence manager
print(f"Retrying for another answer")
while len(memgpt_agent.messages) > 0:
if memgpt_agent.messages[-1].get("role") == "user":
# we want to pop up to the last user message and send it again
user_message = memgpt_agent.messages[-1].get("content")
memgpt_agent.messages.pop()
break
memgpt_agent.messages.pop()
elif command.lower() == "rethink" or command.lower().startswith("rethink "):
# TODO this needs to also modify the persistence manager
if len(command) < len("rethink "):
print("Missing text after the command")
else:
for x in range(len(memgpt_agent.messages) - 1, 0, -1):
if memgpt_agent.messages[x].get("role") == "assistant":
text = command[len("rethink ") :].strip()
memgpt_agent.messages[x].update({"content": text})
break
elif command.lower() == "rewrite" or command.lower().startswith("rewrite "):
# TODO this needs to also modify the persistence manager
if len(command) < len("rewrite "):
print("Missing text after the command")
else:
for x in range(len(memgpt_agent.messages) - 1, 0, -1):
if memgpt_agent.messages[x].get("role") == "assistant":
text = command[len("rewrite ") :].strip()
args = json.loads(memgpt_agent.messages[x].get("function_call").get("arguments"))
args["message"] = text
memgpt_agent.messages[x].get("function_call").update({"arguments": json.dumps(args)})
break
# No skip options
elif command.lower() == "wipe":
# exit not supported on server.py
raise ValueError(command)
elif command.lower() == "heartbeat":
input_message = system.get_heartbeat()
self._step(user_id=user_id, agent_id=agent_id, input_message=input_message)
elif command.lower() == "memorywarning":
input_message = system.get_token_limit_warning()
self._step(user_id=user_id, agent_id=agent_id, input_message=input_message)
def user_message(self, user_id: str, agent_id: str, message: str) -> None:
"""Process an incoming user message and feed it through the MemGPT agent"""
from memgpt.utils import printd
# Basic input sanitization
if not isinstance(message, str) or len(message) == 0:
raise ValueError(f"Invalid input: '{message}'")
# If the input begins with a command prefix, reject
elif message.startswith("/"):
raise ValueError(f"Invalid input: '{message}'")
# Else, process it as a user message to be fed to the agent
else:
# Package the user message first
packaged_user_message = package_user_message(user_message=message)
# Run the agent state forward
self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message)
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
"""Run a command on the agent"""
# If the input begins with a command prefix, attempt to process it as a command
if command.startswith("/"):
if len(command) > 1:
command = command[1:] # strip the prefix
return self._command(user_id=user_id, agent_id=agent_id, command=command)
def create_agent(
self,
user_id: str,
agent_config: Union[dict, AgentConfig],
interface: Union[AgentInterface, None] = None,
persistence_manager: Union[PersistenceManager, None] = None,
) -> str:
"""Create a new agent using a config"""
# Initialize the agent based on the provided configuration
if isinstance(agent_config, dict):
agent_config = AgentConfig(**agent_config)
if interface is None:
# interface = self.default_interface_cls()
interface = self.default_interface
if persistence_manager is None:
persistence_manager = self.default_persistence_manager_cls(agent_config=agent_config)
# Create agent via preset from config
agent = presets.use_preset(
agent_config.preset,
agent_config,
agent_config.model,
utils.get_persona_text(agent_config.persona),
utils.get_human_text(agent_config.human),
interface,
persistence_manager,
)
agent.save()
print(f"Created new agent from config: {agent}")
return agent.config.name

View File

@@ -1,80 +0,0 @@
import asyncio
import json
import websockets
import memgpt.server.websocket_protocol as protocol
from memgpt.server.websocket_server import WebSocketServer
from memgpt.server.constants import DEFAULT_PORT, CLIENT_TIMEOUT
from memgpt.server.utils import condition_to_stop_receiving, print_server_response
# CLEAN_RESPONSES = False # print the raw server responses (JSON)
CLEAN_RESPONSES = True # make the server responses cleaner
# LOAD_AGENT = None # create a brand new agent
LOAD_AGENT = "agent_26" # load an existing agent
async def basic_cli_client():
"""Basic example of a MemGPT CLI client that connects to a MemGPT server.py process via WebSockets
Meant to illustrate how to use the server.py process, so limited in features (only supports sending user messages)
"""
uri = f"ws://localhost:{DEFAULT_PORT}"
async with websockets.connect(uri) as websocket:
if LOAD_AGENT is not None:
# Load existing agent
print("Sending load message to server...")
await websocket.send(protocol.client_command_load(LOAD_AGENT))
else:
# Initialize new agent
print("Sending config to server...")
example_config = {
"persona": "sam_pov",
"human": "cs_phd",
"model": "gpt-4-1106-preview", # gpt-4-turbo
}
await websocket.send(protocol.client_command_create(example_config))
# Wait for the response
response = await websocket.recv()
response = json.loads(response)
print(f"Server response:\n{json.dumps(response, indent=2)}")
await asyncio.sleep(1)
while True:
user_input = input("\nEnter your message: ")
print("\n")
# Send a message to the agent
await websocket.send(protocol.client_user_message(str(user_input)))
# Wait for messages in a loop, since the server may send a few
while True:
try:
response = await asyncio.wait_for(websocket.recv(), CLIENT_TIMEOUT)
response = json.loads(response)
if CLEAN_RESPONSES:
print_server_response(response)
else:
print(f"Server response:\n{json.dumps(response, indent=2)}")
# Check for a specific condition to break the loop
if condition_to_stop_receiving(response):
break
except asyncio.TimeoutError:
print("Timeout waiting for the server response.")
break
except websockets.exceptions.ConnectionClosedError:
print("Connection to server was lost.")
break
except Exception as e:
print(f"An error occurred: {e}")
break
asyncio.run(basic_cli_client())

View File

@@ -1,205 +0,0 @@
import asyncio
import json
import traceback
import websockets
from memgpt.server.websocket_interface import SyncWebSocketInterface
from memgpt.server.constants import DEFAULT_PORT
import memgpt.server.websocket_protocol as protocol
import memgpt.system as system
import memgpt.constants as memgpt_constants
class WebSocketServer:
def __init__(self, host="localhost", port=DEFAULT_PORT):
self.host = host
self.port = port
self.interface = SyncWebSocketInterface()
self.agent = None
self.agent_name = None
def run_step(self, user_message, first_message=False, no_verify=False):
while True:
new_messages, heartbeat_request, function_failed, token_warning = self.agent.step(
user_message, first_message=first_message, skip_verify=no_verify
)
if token_warning:
user_message = system.get_token_limit_warning()
elif function_failed:
user_message = system.get_heartbeat(memgpt_constants.FUNC_FAILED_HEARTBEAT_MESSAGE)
elif heartbeat_request:
user_message = system.get_heartbeat(memgpt_constants.REQ_HEARTBEAT_MESSAGE)
else:
# return control
break
async def handle_client(self, websocket, path):
self.interface.register_client(websocket)
try:
# async for message in websocket:
while True:
message = await websocket.recv()
# Assuming the message is a JSON string
try:
data = json.loads(message)
except:
print(f"[server] bad data from client:\n{data}")
await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}"))
continue
if "type" not in data:
print(f"[server] bad data from client (JSON but no type):\n{data}")
await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}"))
elif data["type"] == "command":
# Create a new agent
if data["command"] == "create_agent":
try:
self.agent = self.create_new_agent(data["config"])
await websocket.send(protocol.server_command_response("OK: Agent initialized"))
except Exception as e:
self.agent = None
print(f"[server] self.create_new_agent failed with:\n{e}")
print(f"{traceback.format_exc()}")
await websocket.send(protocol.server_command_response(f"Error: Failed to init agent - {str(e)}"))
# Load an existing agent
elif data["command"] == "load_agent":
agent_name = data.get("name")
if agent_name is not None:
try:
self.agent = self.load_agent(agent_name)
self.agent_name = agent_name
await websocket.send(protocol.server_command_response(f"OK: Agent '{agent_name}' loaded"))
except Exception as e:
print(f"[server] self.load_agent failed with:\n{e}")
print(f"{traceback.format_exc()}")
self.agent = None
await websocket.send(
protocol.server_command_response(f"Error: Failed to load agent '{agent_name}' - {str(e)}")
)
else:
await websocket.send(protocol.server_command_response(f"Error: 'name' not provided"))
else:
print(f"[server] unrecognized client command type: {data}")
await websocket.send(protocol.server_error(f"unrecognized client command type: {data}"))
elif data["type"] == "user_message":
user_message = data["message"]
if "agent_name" in data:
agent_name = data["agent_name"]
# If the agent requested the same one that's already loading?
if self.agent_name is None or self.agent_name != data["agent_name"]:
try:
print(f"[server] loading agent {agent_name}")
self.agent = self.load_agent(agent_name)
self.agent_name = agent_name
# await websocket.send(protocol.server_command_response(f"OK: Agent '{agent_name}' loaded"))
except Exception as e:
print(f"[server] self.load_agent failed with:\n{e}")
print(f"{traceback.format_exc()}")
self.agent = None
await websocket.send(
protocol.server_command_response(f"Error: Failed to load agent '{agent_name}' - {str(e)}")
)
else:
await websocket.send(protocol.server_agent_response_error("agent_name was not specified in the request"))
continue
if self.agent is None:
await websocket.send(protocol.server_agent_response_error("No agent has been initialized"))
else:
await websocket.send(protocol.server_agent_response_start())
try:
self.run_step(user_message)
except Exception as e:
print(f"[server] self.run_step failed with:\n{e}")
print(f"{traceback.format_exc()}")
await websocket.send(protocol.server_agent_response_error(f"self.run_step failed with: {e}"))
await asyncio.sleep(1) # pause before sending the terminating message, w/o this messages may be missed
await websocket.send(protocol.server_agent_response_end())
# ... handle other message types as needed ...
else:
print(f"[server] unrecognized client package data type: {data}")
await websocket.send(protocol.server_error(f"unrecognized client package data type: {data}"))
except websockets.exceptions.ConnectionClosed:
print(f"[server] connection with client was closed")
finally:
# TODO autosave the agent
self.interface.unregister_client(websocket)
def create_new_agent(self, config):
"""Config is json that arrived over websocket, so we need to turn it into a config object"""
from memgpt.config import AgentConfig
import memgpt.presets.presets as presets
import memgpt.utils as utils
from memgpt.persistence_manager import InMemoryStateManager
print("Creating new agent...")
# Initialize the agent based on the provided configuration
agent_config = AgentConfig(**config)
# Use an in-state persistence manager
persistence_manager = InMemoryStateManager()
# Create agent via preset from config
agent = presets.use_preset(
agent_config.preset,
agent_config,
agent_config.model,
utils.get_persona_text(agent_config.persona),
utils.get_human_text(agent_config.human),
self.interface,
persistence_manager,
)
print("Created new agent from config")
return agent
def load_agent(self, agent_name):
"""Load an agent from a directory"""
import memgpt.utils as utils
from memgpt.config import AgentConfig
from memgpt.agent import Agent
print(f"Loading agent {agent_name}...")
agent_files = utils.list_agent_config_files()
agent_names = [AgentConfig.load(f).name for f in agent_files]
if agent_name not in agent_names:
raise ValueError(f"agent '{agent_name}' does not exist")
agent_config = AgentConfig.load(agent_name)
agent = Agent.load_agent(self.interface, agent_config)
print("Created agent by loading existing config")
return agent
def initialize_server(self):
print("Server is initializing...")
print(f"Listening on {self.host}:{self.port}...")
async def start_server(self):
self.initialize_server()
async with websockets.serve(self.handle_client, self.host, self.port):
await asyncio.Future() # Run forever
def run(self):
return self.start_server() # Return the coroutine
if __name__ == "__main__":
server = WebSocketServer()
asyncio.run(server.run())

View File

View File

@@ -0,0 +1,106 @@
import asyncio
import json
import websockets
import memgpt.server.ws_api.protocol as protocol
from memgpt.server.constants import DEFAULT_PORT, CLIENT_TIMEOUT
from memgpt.server.utils import condition_to_stop_receiving, print_server_response
# CLEAN_RESPONSES = False # print the raw server responses (JSON)
CLEAN_RESPONSES = True # make the server responses cleaner
# LOAD_AGENT = None # create a brand new agent
AGENT_NAME = "agent_26" # load an existing agent
NEW_AGENT = False
RECONNECT_DELAY = 1
RECONNECT_MAX_TRIES = 5
async def send_message_and_print_replies(websocket, user_message, agent_id):
"""Send a message over websocket protocol and wait for the reply stream to end"""
# Send a message to the agent
await websocket.send(protocol.client_user_message(msg=str(user_message), agent_id=agent_id))
# Wait for messages in a loop, since the server may send a few
while True:
response = await asyncio.wait_for(websocket.recv(), CLIENT_TIMEOUT)
response = json.loads(response)
if CLEAN_RESPONSES:
print_server_response(response)
else:
print(f"Server response:\n{json.dumps(response, indent=2)}")
# Check for a specific condition to break the loop
if condition_to_stop_receiving(response):
break
async def basic_cli_client():
"""Basic example of a MemGPT CLI client that connects to a MemGPT server.py process via WebSockets
Meant to illustrate how to use the server.py process, so limited in features (only supports sending user messages)
"""
uri = f"ws://localhost:{DEFAULT_PORT}"
closed_on_message = False
retry_attempts = 0
while True: # Outer loop for reconnection attempts
try:
async with websockets.connect(uri) as websocket:
if NEW_AGENT:
# Initialize new agent
print("Sending config to server...")
example_config = {
"persona": "sam_pov",
"human": "cs_phd",
"model": "gpt-4-1106-preview", # gpt-4-turbo
}
await websocket.send(protocol.client_command_create(example_config))
# Wait for the response
response = await websocket.recv()
response = json.loads(response)
print(f"Server response:\n{json.dumps(response, indent=2)}")
await asyncio.sleep(1)
while True:
if closed_on_message:
# If we're on a retry after a disconnect, don't ask for input again
closed_on_message = False
else:
user_input = input("\nEnter your message: ")
print("\n")
# Send a message to the agent
try:
await send_message_and_print_replies(websocket=websocket, user_message=user_input, agent_id=AGENT_NAME)
retry_attempts = 0
except websockets.exceptions.ConnectionClosedError:
print("Connection to server was lost. Attempting to reconnect...")
closed_on_message = True
raise
except websockets.exceptions.ConnectionClosedError:
# Decide whether or not to retry the connection
if retry_attempts < RECONNECT_MAX_TRIES:
retry_attempts += 1
await asyncio.sleep(RECONNECT_DELAY) # Wait for N seconds before reconnecting
continue
else:
print(f"Max attempts exceeded ({retry_attempts} > {RECONNECT_MAX_TRIES})")
break
except asyncio.TimeoutError:
print("Timeout waiting for the server response.")
continue
except Exception as e:
print(f"An error occurred: {e}")
continue
asyncio.run(basic_cli_client())

View File

@@ -3,7 +3,7 @@ import threading
from memgpt.interface import AgentInterface
import memgpt.server.websocket_protocol as protocol
import memgpt.server.ws_api.protocol as protocol
class BaseWebSocketInterface(AgentInterface):
@@ -20,6 +20,9 @@ class BaseWebSocketInterface(AgentInterface):
"""Unregister a client connection"""
self.clients.remove(websocket)
def step_yield(self):
pass
class AsyncWebSocketInterface(BaseWebSocketInterface):
"""WebSocket calls are async"""

View File

@@ -80,12 +80,12 @@ def server_agent_function_message(msg):
# Client -> server
def client_user_message(msg, agent_name=None):
def client_user_message(msg, agent_id=None):
return json.dumps(
{
"type": "user_message",
"message": msg,
"agent_name": agent_name,
"agent_id": agent_id,
}
)
@@ -98,13 +98,3 @@ def client_command_create(config):
"config": config,
}
)
def client_command_load(agent_name):
return json.dumps(
{
"type": "command",
"command": "load_agent",
"name": agent_name,
}
)

View File

@@ -0,0 +1,107 @@
import asyncio
import json
import traceback
import websockets
from memgpt.server.server import SyncServer
from memgpt.server.ws_api.interface import SyncWebSocketInterface
from memgpt.server.constants import DEFAULT_PORT
import memgpt.server.ws_api.protocol as protocol
import memgpt.system as system
import memgpt.constants as memgpt_constants
class WebSocketServer:
def __init__(self, host="localhost", port=DEFAULT_PORT):
self.host = host
self.port = port
self.interface = SyncWebSocketInterface()
self.server = SyncServer(default_interface=self.interface)
def __del__(self):
self.interface.close()
def initialize_server(self):
print("Server is initializing...")
print(f"Listening on {self.host}:{self.port}...")
async def start_server(self):
self.initialize_server()
# Can play with ping_interval and ping_timeout
# See: https://websockets.readthedocs.io/en/stable/topics/timeouts.html
# and https://github.com/cpacker/MemGPT/issues/471
async with websockets.serve(self.handle_client, self.host, self.port):
await asyncio.Future() # Run forever
def run(self):
return self.start_server() # Return the coroutine
async def handle_client(self, websocket, path):
self.interface.register_client(websocket)
try:
# async for message in websocket:
while True:
message = await websocket.recv()
# Assuming the message is a JSON string
try:
data = json.loads(message)
except:
print(f"[server] bad data from client:\n{data}")
await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}"))
continue
if "type" not in data:
print(f"[server] bad data from client (JSON but no type):\n{data}")
await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}"))
elif data["type"] == "command":
# Create a new agent
if data["command"] == "create_agent":
try:
# self.agent = self.create_new_agent(data["config"])
self.server.create_agent(user_id="NULL", agent_config=data["config"])
await websocket.send(protocol.server_command_response("OK: Agent initialized"))
except Exception as e:
self.agent = None
print(f"[server] self.create_new_agent failed with:\n{e}")
print(f"{traceback.format_exc()}")
await websocket.send(protocol.server_command_response(f"Error: Failed to init agent - {str(e)}"))
else:
print(f"[server] unrecognized client command type: {data}")
await websocket.send(protocol.server_error(f"unrecognized client command type: {data}"))
elif data["type"] == "user_message":
user_message = data["message"]
if "agent_id" not in data or data["agent_id"] is None:
await websocket.send(protocol.server_agent_response_error("agent_name was not specified in the request"))
continue
await websocket.send(protocol.server_agent_response_start())
try:
# self.run_step(user_message)
self.server.user_message(user_id="NULL", agent_id=data["agent_id"], message=user_message)
except Exception as e:
print(f"[server] self.server.user_message failed with:\n{e}")
print(f"{traceback.format_exc()}")
await websocket.send(protocol.server_agent_response_error(f"server.user_message failed with: {e}"))
await asyncio.sleep(1) # pause before sending the terminating message, w/o this messages may be missed
await websocket.send(protocol.server_agent_response_end())
# ... handle other message types as needed ...
else:
print(f"[server] unrecognized client package data type: {data}")
await websocket.send(protocol.server_error(f"unrecognized client package data type: {data}"))
except websockets.exceptions.ConnectionClosed:
print(f"[server] connection with client was closed")
finally:
self.interface.unregister_client(websocket)
if __name__ == "__main__":
server = WebSocketServer()
asyncio.run(server.run())

50
poetry.lock generated
View File

@@ -137,24 +137,24 @@ files = [
[[package]]
name = "anyio"
version = "4.1.0"
version = "3.7.1"
description = "High level compatibility layer for multiple asynchronous event loop implementations"
optional = false
python-versions = ">=3.8"
python-versions = ">=3.7"
files = [
{file = "anyio-4.1.0-py3-none-any.whl", hash = "sha256:56a415fbc462291813a94528a779597226619c8e78af7de0507333f700011e5f"},
{file = "anyio-4.1.0.tar.gz", hash = "sha256:5a0bec7085176715be77df87fc66d6c9d70626bd752fcc85f57cdbee5b3760da"},
{file = "anyio-3.7.1-py3-none-any.whl", hash = "sha256:91dee416e570e92c64041bd18b900d1d6fa78dff7048769ce5ac5ddad004fbb5"},
{file = "anyio-3.7.1.tar.gz", hash = "sha256:44a3c9aba0f5defa43261a8b3efb97891f2bd7d804e0e1f56419befa1adfc780"},
]
[package.dependencies]
exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""}
exceptiongroup = {version = "*", markers = "python_version < \"3.11\""}
idna = ">=2.8"
sniffio = ">=1.1"
[package.extras]
doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"]
test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"]
trio = ["trio (>=0.23)"]
doc = ["Sphinx", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-jquery"]
test = ["anyio[trio]", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"]
trio = ["trio (<0.22)"]
[[package]]
name = "asgiref"
@@ -737,19 +737,20 @@ test = ["pytest (>=6)"]
[[package]]
name = "fastapi"
version = "0.103.0"
version = "0.104.1"
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
optional = false
python-versions = ">=3.7"
python-versions = ">=3.8"
files = [
{file = "fastapi-0.103.0-py3-none-any.whl", hash = "sha256:61ab72c6c281205dd0cbaccf503e829a37e0be108d965ac223779a8479243665"},
{file = "fastapi-0.103.0.tar.gz", hash = "sha256:4166732f5ddf61c33e9fa4664f73780872511e0598d4d5434b1816dc1e6d9421"},
{file = "fastapi-0.104.1-py3-none-any.whl", hash = "sha256:752dc31160cdbd0436bb93bad51560b57e525cbb1d4bbf6f4904ceee75548241"},
{file = "fastapi-0.104.1.tar.gz", hash = "sha256:e5e4540a7c5e1dcfbbcf5b903c234feddcdcd881f191977a1c5dfd917487e7ae"},
]
[package.dependencies]
anyio = ">=3.7.1,<4.0.0"
pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0"
starlette = ">=0.27.0,<0.28.0"
typing-extensions = ">=4.5.0"
typing-extensions = ">=4.8.0"
[package.extras]
all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
@@ -3074,6 +3075,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
[package.extras]
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-asyncio"
version = "0.23.2"
description = "Pytest support for asyncio"
optional = true
python-versions = ">=3.8"
files = [
{file = "pytest-asyncio-0.23.2.tar.gz", hash = "sha256:c16052382554c7b22d48782ab3438d5b10f8cf7a4bdcae7f0f67f097d95beecc"},
{file = "pytest_asyncio-0.23.2-py3-none-any.whl", hash = "sha256:ea9021364e32d58f0be43b91c6233fb8d2224ccef2398d6837559e587682808f"},
]
[package.dependencies]
pytest = ">=7.0.0"
[package.extras]
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
[[package]]
name = "python-box"
version = "7.1.1"
@@ -4774,11 +4793,12 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
[extras]
dev = ["black", "datasets", "pre-commit", "pytest"]
dev = ["black", "datasets", "pre-commit", "pytest", "pytest-asyncio"]
local = ["huggingface-hub", "torch", "transformers"]
postgres = ["pg8000", "pgvector", "psycopg", "psycopg-binary", "psycopg2-binary"]
server = ["fastapi", "uvicorn", "websockets"]
[metadata]
lock-version = "2.0"
python-versions = "<3.12,>=3.9"
content-hash = "2d68f2515a73a9b2cafb445138c667f61153ac6feb23c124032fc2c2d56baf4a"
content-hash = "4f675213d5a79f001bfb7441c9fba23ae114079ec61a30b0c88833c5427f152e"

View File

@@ -26,7 +26,7 @@ pytz = "^2023.3.post1"
tqdm = "^4.66.1"
black = { version = "^23.10.1", optional = true }
pytest = { version = "^7.4.3", optional = true }
llama-index = "0.9.13"
llama-index = "^0.9.13"
setuptools = "^68.2.2"
datasets = { version = "^2.14.6", optional = true}
prettytable = "^3.9.0"
@@ -39,7 +39,7 @@ transformers = { version = "4.34.1", optional = true }
pre-commit = {version = "^3.5.0", optional = true }
pg8000 = {version = "^1.30.3", optional = true}
torch = {version = ">=2.0.0, !=2.0.1, !=2.1.0", optional = true}
websockets = "^12.0"
websockets = {version = "^12.0", optional = true}
docstring-parser = "^0.15"
lancedb = "^0.3.3"
httpx = "^0.25.2"
@@ -49,12 +49,17 @@ tiktoken = "^0.5.1"
python-box = "^7.1.1"
pypdf = "^3.17.1"
pyyaml = "^6.0.1"
fastapi = {version = "^0.104.1", optional = true}
uvicorn = {version = "^0.24.0.post1", optional = true}
chromadb = "^0.4.18"
pytest-asyncio = {version = "^0.23.2", optional = true}
pydantic = "^2.5.2"
[tool.poetry.extras]
local = ["torch", "huggingface-hub", "transformers"]
postgres = ["pgvector", "psycopg", "psycopg-binary", "psycopg2-binary", "pg8000"]
dev = ["pytest", "black", "pre-commit", "datasets"]
dev = ["pytest", "pytest-asyncio", "black", "pre-commit", "datasets"]
server = ["websockets", "fastapi", "uvicorn"]
[build-system]
requires = ["poetry-core"]

43
tests/test_server.py Normal file
View File

@@ -0,0 +1,43 @@
import memgpt.utils as utils
utils.DEBUG = True
from memgpt.server.server import SyncServer
def test_server():
user_id = "NULL"
agent_id = "agent_26"
server = SyncServer()
try:
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
except ValueError as e:
print(e)
except:
raise
try:
server.user_message(user_id=user_id, agent_id=agent_id, message="/memory")
except ValueError as e:
print(e)
except:
raise
try:
print(server.run_command(user_id=user_id, agent_id=agent_id, command="/memory"))
except ValueError as e:
print(e)
except:
raise
try:
server.user_message(user_id=user_id, agent_id="agent no-exist", message="Hello?")
except ValueError as e:
print(e)
except:
raise
if __name__ == "__main__":
test_server()

View File

@@ -1,9 +1,10 @@
import os
import pytest
from unittest.mock import Mock, AsyncMock, MagicMock
from memgpt.config import MemGPTConfig, AgentConfig
from memgpt.server.websocket_interface import SyncWebSocketInterface
import memgpt.presets as presets
from memgpt.server.ws_api.interface import SyncWebSocketInterface
import memgpt.presets.presets as presets
import memgpt.utils as utils
import memgpt.system as system
from memgpt.persistence_manager import LocalStateManager
@@ -54,19 +55,32 @@ async def test_websockets():
ws_interface.register_client(mock_websocket)
# Create an agent and hook it up to the WebSocket interface
config = MemGPTConfig()
api_key = os.getenv("OPENAI_API_KEY")
if api_key is None:
ws_interface.close()
return
config = MemGPTConfig.load()
if config.openai_key is None:
config.openai_key = api_key
config.save()
# Mock the persistence manager
# create agents with defaults
agent_config = AgentConfig(persona="sam_pov", human="basic", model="gpt-4-1106-preview")
agent_config = AgentConfig(
persona="sam_pov",
human="basic",
model="gpt-4-1106-preview",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
)
persistence_manager = LocalStateManager(agent_config=agent_config)
memgpt_agent = presets.use_preset(
presets.DEFAULT_PRESET,
config, # no agent config to provide
"gpt-4-1106-preview",
utils.get_persona_text("sam_pov"),
utils.get_human_text("basic"),
agent_config.preset,
agent_config,
agent_config.model,
agent_config.persona, # note: extracting the raw text, not pulling from a file
agent_config.human, # note: extracting raw text, not pulling from a file
ws_interface,
persistence_manager,
)

View File

@@ -5,7 +5,7 @@ import websockets
import pytest
from memgpt.server.constants import DEFAULT_PORT
from memgpt.server.websocket_server import WebSocketServer
from memgpt.server.ws_api.server import WebSocketServer
from memgpt.config import AgentConfig
@@ -16,7 +16,9 @@ async def test_dummy():
@pytest.mark.asyncio
async def test_websocket_server():
server = WebSocketServer()
# host = "127.0.0.1"
host = "localhost"
server = WebSocketServer(host=host)
server_task = asyncio.create_task(server.run()) # Create a task for the server
# the agent config we want to ask the server to instantiate with
@@ -28,23 +30,26 @@ async def test_websocket_server():
# )
test_config = {}
uri = f"ws://localhost:{DEFAULT_PORT}"
async with websockets.connect(uri) as websocket:
# Initialize the server with a test config
print("Sending config to server...")
await websocket.send(json.dumps({"type": "initialize", "config": test_config}))
# Wait for the response
response = await websocket.recv()
print(f"Response from the agent: {response}")
uri = f"ws://{host}:{DEFAULT_PORT}"
try:
async with websockets.connect(uri) as websocket:
# Initialize the server with a test config
print("Sending config to server...")
await websocket.send(json.dumps({"type": "initialize", "config": test_config}))
# Wait for the response
response = await websocket.recv()
print(f"Response from the agent: {response}")
await asyncio.sleep(1) # just in case
await asyncio.sleep(1) # just in case
# Send a message to the agent
print("Sending message to server...")
await websocket.send(json.dumps({"type": "message", "content": "Hello, Agent!"}))
# Wait for the response
# NOTE: we should be waiting for multiple responses
response = await websocket.recv()
print(f"Response from the agent: {response}")
server_task.cancel() # Cancel the server task after the test
# Send a message to the agent
print("Sending message to server...")
await websocket.send(json.dumps({"type": "message", "content": "Hello, Agent!"}))
# Wait for the response
# NOTE: we should be waiting for multiple responses
response = await websocket.recv()
print(f"Response from the agent: {response}")
except (OSError, ConnectionRefusedError) as e:
print(f"Was unable to connect: {e}")
finally:
server_task.cancel() # Cancel the server task after the test