API server refactor + REST API (#593)
* init server refactor * refactored websockets server/client code to use internal server API * added intentional fail on test * update workflow to try and get test to pass remotely * refactor to put websocket code in a separate subdirectory * added fastapi rest server * add error handling * modified interface return style * disabled certain tests on remote * added SSE response option for user_message * fix ws interface test * fallback for oai key * add soft fail for test when localhost is borked * add step_yield for all server related interfaces * extra catch * update toml + lock with server add-ons (add uvicorn+fastapi, move websockets to server extra) * regen lock file * added pytest-asyncio as an extra in dev * add pydantic to deps * renamed CreateConfig to CreateAgentConfig * fixed POST request for creating agent + tested it
This commit is contained in:
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -42,7 +42,7 @@ jobs:
|
||||
PGVECTOR_TEST_DB_URL: ${{ secrets.PGVECTOR_TEST_DB_URL }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
run: |
|
||||
poetry install -E dev -E postgres -E local
|
||||
poetry install -E dev -E postgres -E local -E server
|
||||
|
||||
- name: Set Poetry config
|
||||
env:
|
||||
|
||||
0
memgpt/server/rest_api/__init__.py
Normal file
0
memgpt/server/rest_api/__init__.py
Normal file
75
memgpt/server/rest_api/interface.py
Normal file
75
memgpt/server/rest_api/interface.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import asyncio
|
||||
import queue
|
||||
|
||||
from memgpt.interface import AgentInterface
|
||||
|
||||
|
||||
class QueuingInterface(AgentInterface):
|
||||
"""Messages are queued inside an internal buffer and manually flushed"""
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = queue.Queue()
|
||||
|
||||
def to_list(self):
|
||||
"""Convert queue to a list (empties it out at the same time)"""
|
||||
items = []
|
||||
while not self.buffer.empty():
|
||||
try:
|
||||
items.append(self.buffer.get_nowait())
|
||||
except queue.Empty:
|
||||
break
|
||||
return items
|
||||
|
||||
def clear(self):
|
||||
"""Clear all messages from the queue."""
|
||||
with self.buffer.mutex:
|
||||
# Empty the queue
|
||||
self.buffer.queue.clear()
|
||||
|
||||
async def message_generator(self):
|
||||
while True:
|
||||
if not self.buffer.empty():
|
||||
message = self.buffer.get()
|
||||
if message == "STOP":
|
||||
break
|
||||
yield message
|
||||
else:
|
||||
await asyncio.sleep(0.1) # Small sleep to prevent a busy loop
|
||||
|
||||
def step_yield(self):
|
||||
"""Enqueue a special stop message"""
|
||||
self.buffer.put("STOP")
|
||||
|
||||
def user_message(self, msg: str):
|
||||
"""Handle reception of a user message"""
|
||||
pass
|
||||
|
||||
def internal_monologue(self, msg: str) -> None:
|
||||
"""Handle the agent's internal monologue"""
|
||||
print(msg)
|
||||
self.buffer.put({"internal_monologue": msg})
|
||||
|
||||
def assistant_message(self, msg: str) -> None:
|
||||
"""Handle the agent sending a message"""
|
||||
print(msg)
|
||||
self.buffer.put({"assistant_message": msg})
|
||||
|
||||
def function_message(self, msg: str) -> None:
|
||||
"""Handle the agent calling a function"""
|
||||
print(msg)
|
||||
|
||||
if msg.startswith("Running "):
|
||||
msg = msg.replace("Running ", "")
|
||||
self.buffer.put({"function_call": msg})
|
||||
|
||||
elif msg.startswith("Success: "):
|
||||
msg = msg.replace("Success: ", "")
|
||||
self.buffer.put({"function_return": msg, "status": "success"})
|
||||
|
||||
elif msg.startswith("Error: "):
|
||||
msg = msg.replace("Error: ", "")
|
||||
self.buffer.put({"function_return": msg, "status": "error"})
|
||||
|
||||
else:
|
||||
# NOTE: generic, should not happen
|
||||
self.buffer.put({"function_message": msg})
|
||||
112
memgpt/server/rest_api/server.py
Normal file
112
memgpt/server/rest_api/server.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Union
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
import memgpt.utils as utils
|
||||
|
||||
|
||||
"""
|
||||
Basic REST API sitting on top of the internal MemGPT python server (SyncServer)
|
||||
|
||||
Start the server with:
|
||||
cd memgpt/server/rest_api
|
||||
poetry run uvicorn server:app --reload
|
||||
"""
|
||||
|
||||
|
||||
class CreateAgentConfig(BaseModel):
|
||||
user_id: str
|
||||
config: dict
|
||||
|
||||
|
||||
class UserMessage(BaseModel):
|
||||
user_id: str
|
||||
agent_id: str
|
||||
message: str
|
||||
stream: bool = False
|
||||
|
||||
|
||||
class Command(BaseModel):
|
||||
user_id: str
|
||||
agent_id: str
|
||||
command: str
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
interface = QueuingInterface()
|
||||
server = SyncServer(default_interface=interface)
|
||||
|
||||
|
||||
# server.list_agents
|
||||
@app.get("/agents")
|
||||
def list_agents(user_id: str):
|
||||
interface.clear()
|
||||
agents_list = utils.list_agent_config_files()
|
||||
return {"num_agents": len(agents_list), "agent_names": agents_list}
|
||||
|
||||
|
||||
# server.create_agent
|
||||
@app.post("/agents")
|
||||
def create_agents(body: CreateAgentConfig):
|
||||
interface.clear()
|
||||
try:
|
||||
agent_id = server.create_agent(user_id=body.user_id, agent_config=body.config)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return {"agent_id": agent_id}
|
||||
|
||||
|
||||
# server.user_message
|
||||
@app.post("/agents/message")
|
||||
async def user_message(body: UserMessage):
|
||||
if body.stream:
|
||||
# For streaming response
|
||||
try:
|
||||
# Start the generation process (similar to the non-streaming case)
|
||||
# This should be a non-blocking call or run in a background task
|
||||
|
||||
# Check if server.user_message is an async function
|
||||
if asyncio.iscoroutinefunction(server.user_message):
|
||||
# Start the async task
|
||||
asyncio.create_task(server.user_message(user_id=body.user_id, agent_id=body.agent_id, message=body.message))
|
||||
else:
|
||||
# Run the synchronous function in a thread pool
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_in_executor(None, server.user_message, body.user_id, body.agent_id, body.message)
|
||||
|
||||
async def formatted_message_generator():
|
||||
async for message in interface.message_generator():
|
||||
formatted_message = f"data: {json.dumps(message)}\n\n"
|
||||
yield formatted_message
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Return the streaming response using the generator
|
||||
return StreamingResponse(formatted_message_generator(), media_type="text/event-stream")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
else:
|
||||
interface.clear()
|
||||
try:
|
||||
server.user_message(user_id=body.user_id, agent_id=body.agent_id, message=body.message)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return {"messages": interface.to_list()}
|
||||
|
||||
|
||||
# server.run_command
|
||||
@app.post("/agents/command")
|
||||
def run_command(body: Command):
|
||||
interface.clear()
|
||||
try:
|
||||
response = server.run_command(user_id=body.user_id, agent_id=body.agent_id, command=body.command)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
response = server.run_command(user_id=body.user_id, agent_id=body.agent_id, command=body.command)
|
||||
return {"response": response}
|
||||
365
memgpt/server/server.py
Normal file
365
memgpt/server/server.py
Normal file
@@ -0,0 +1,365 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Union
|
||||
import json
|
||||
|
||||
from memgpt.system import package_user_message
|
||||
from memgpt.config import AgentConfig
|
||||
from memgpt.agent import Agent
|
||||
import memgpt.system as system
|
||||
import memgpt.constants as constants
|
||||
from memgpt.cli.cli import attach
|
||||
from memgpt.connectors.storage import StorageConnector
|
||||
import memgpt.presets.presets as presets
|
||||
import memgpt.utils as utils
|
||||
from memgpt.persistence_manager import PersistenceManager, LocalStateManager
|
||||
|
||||
# TODO use custom interface
|
||||
from memgpt.interface import CLIInterface # for printing to terminal
|
||||
from memgpt.interface import AgentInterface # abstract
|
||||
|
||||
|
||||
class Server(object):
|
||||
"""Abstract server class that supports multi-agent multi-user"""
|
||||
|
||||
@abstractmethod
|
||||
def list_agents(self, user_id: str, agent_id: str) -> str:
|
||||
"""List all available agents to a user"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_agent(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_config: Union[dict, AgentConfig],
|
||||
interface: Union[AgentInterface, None],
|
||||
persistence_manager: Union[PersistenceManager, None],
|
||||
) -> str:
|
||||
"""Create a new agent using a config"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def user_message(self, user_id: str, agent_id: str, message: str) -> None:
|
||||
"""Process a message from the user, internally calls step"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
|
||||
"""Run a command on the agent, e.g. /memory
|
||||
|
||||
May return a string with a message generated by the command
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# TODO actually use "user_id" for something
|
||||
class SyncServer(Server):
|
||||
"""Simple single-threaded / blocking server process"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chaining: bool = True,
|
||||
max_chaining_steps: bool = None,
|
||||
# default_interface_cls: AgentInterface = CLIInterface,
|
||||
default_interface: AgentInterface = CLIInterface(),
|
||||
default_persistence_manager_cls: PersistenceManager = LocalStateManager,
|
||||
):
|
||||
"""Server process holds in-memory agents that are being run"""
|
||||
|
||||
# List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts
|
||||
self.active_agents = []
|
||||
|
||||
# chaining = whether or not to run again if request_heartbeat=true
|
||||
self.chaining = chaining
|
||||
|
||||
# if chaining == true, what's the max number of times we'll chain before yielding?
|
||||
# none = no limit, can go on forever
|
||||
self.max_chaining_steps = max_chaining_steps
|
||||
|
||||
# The default interface that will get assigned to agents ON LOAD
|
||||
# self.default_interface_cls = default_interface_cls
|
||||
self.default_interface = default_interface
|
||||
|
||||
# The default persistence manager that will get assigned to agents ON CREATION
|
||||
self.default_persistence_manager_cls = default_persistence_manager_cls
|
||||
|
||||
def _get_agent(self, user_id: str, agent_id: str) -> Union[Agent, None]:
|
||||
"""Get the agent object from the in-memory object store"""
|
||||
for d in self.active_agents:
|
||||
if d["user_id"] == user_id and d["agent_id"] == agent_id:
|
||||
return d["agent"]
|
||||
return None
|
||||
|
||||
def _add_agent(self, user_id: str, agent_id: str, agent_obj: Agent) -> None:
|
||||
"""Put an agent object inside the in-memory object store"""
|
||||
# Make sure the agent doesn't already exist
|
||||
if self._get_agent(user_id=user_id, agent_id=agent_id) is not None:
|
||||
raise KeyError(f"Agent (user={user_id}, agent={agent_id}) is already loaded")
|
||||
# Add Agent instance to the in-memory list
|
||||
self.active_agents.append(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"agent_id": agent_id,
|
||||
"agent": agent_obj,
|
||||
}
|
||||
)
|
||||
|
||||
def _load_agent(self, user_id: str, agent_id: str, interface: Union[AgentInterface, None] = None) -> Agent:
|
||||
"""Loads a saved agent into memory (if it doesn't exist, throw an error)"""
|
||||
from memgpt.utils import printd
|
||||
|
||||
# If an interface isn't specified, use the default
|
||||
if interface is None:
|
||||
interface = self.default_interface
|
||||
|
||||
# If the agent isn't load it, load it and put it into memory
|
||||
if AgentConfig.exists(agent_id):
|
||||
printd(f"(user={user_id}, agent={agent_id}) exists, loading into memory...")
|
||||
agent_config = AgentConfig.load(agent_id)
|
||||
memgpt_agent = Agent.load_agent(interface=interface, agent_config=agent_config)
|
||||
self._add_agent(user_id=user_id, agent_id=agent_id, agent_obj=memgpt_agent)
|
||||
return memgpt_agent
|
||||
|
||||
# If the agent doesn't exist, throw an error
|
||||
else:
|
||||
raise ValueError(f"agent_id {agent_id} does not exist")
|
||||
|
||||
def _get_or_load_agent(self, user_id: str, agent_id: str) -> Agent:
|
||||
"""Check if the agent is in-memory, then load"""
|
||||
memgpt_agent = self._get_agent(user_id=user_id, agent_id=agent_id)
|
||||
if not memgpt_agent:
|
||||
memgpt_agent = self._load_agent(user_id=user_id, agent_id=agent_id)
|
||||
return memgpt_agent
|
||||
|
||||
def _step(self, user_id: str, agent_id: str, input_message: str) -> None:
|
||||
"""Send the input message through the agent"""
|
||||
from memgpt.utils import printd
|
||||
|
||||
printd(f"Got input message: {input_message}")
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
if memgpt_agent is None:
|
||||
raise KeyError(f"Agent (user={user_id}, agent={agent_id}) is not loaded")
|
||||
|
||||
printd(f"Starting agent step")
|
||||
no_verify = True
|
||||
next_input_message = input_message
|
||||
counter = 0
|
||||
while True:
|
||||
new_messages, heartbeat_request, function_failed, token_warning = memgpt_agent.step(
|
||||
next_input_message, first_message=False, skip_verify=no_verify
|
||||
)
|
||||
counter += 1
|
||||
|
||||
# Chain stops
|
||||
if not self.chaining:
|
||||
printd("No chaining, stopping after one step")
|
||||
break
|
||||
elif self.max_chaining_steps is not None and counter > self.max_chaining_steps:
|
||||
printd(f"Hit max chaining steps, stopping after {counter} steps")
|
||||
break
|
||||
# Chain handlers
|
||||
elif token_warning:
|
||||
next_input_message = system.get_token_limit_warning()
|
||||
continue # always chain
|
||||
elif function_failed:
|
||||
next_input_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE)
|
||||
continue # always chain
|
||||
elif heartbeat_request:
|
||||
next_input_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE)
|
||||
continue # always chain
|
||||
# MemGPT no-op / yield
|
||||
else:
|
||||
break
|
||||
|
||||
memgpt_agent.interface.step_yield()
|
||||
printd(f"Finished agent step")
|
||||
|
||||
def _command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
|
||||
"""Process a CLI command"""
|
||||
from memgpt.utils import printd
|
||||
|
||||
printd(f"Got command: {command}")
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
|
||||
if command.lower() == "exit":
|
||||
# exit not supported on server.py
|
||||
raise ValueError(command)
|
||||
|
||||
elif command.lower() == "save" or command.lower() == "savechat":
|
||||
memgpt_agent.save()
|
||||
|
||||
elif command.lower() == "attach":
|
||||
# Different from CLI, we extract the data source name from the command
|
||||
command = command.strip().split()
|
||||
try:
|
||||
data_source = int(command[1])
|
||||
except:
|
||||
raise ValueError(command)
|
||||
|
||||
# TODO: check if agent already has it
|
||||
data_source_options = StorageConnector.list_loaded_data()
|
||||
if len(data_source_options) == 0:
|
||||
raise ValueError('No sources available. You must load a souce with "memgpt load ..." before running /attach.')
|
||||
elif data_source not in data_source_options:
|
||||
raise ValueError(f"Invalid data source name: {data_source} (options={data_source_options})")
|
||||
else:
|
||||
# attach new data
|
||||
attach(memgpt_agent.config.name, data_source)
|
||||
|
||||
# update agent config
|
||||
memgpt_agent.config.attach_data_source(data_source)
|
||||
|
||||
# reload agent with new data source
|
||||
# TODO: maybe make this less ugly...
|
||||
memgpt_agent.persistence_manager.archival_memory.storage = StorageConnector.get_storage_connector(
|
||||
agent_config=memgpt_agent.config
|
||||
)
|
||||
|
||||
elif command.lower() == "dump" or command.lower().startswith("dump "):
|
||||
# Check if there's an additional argument that's an integer
|
||||
command = command.strip().split()
|
||||
amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0
|
||||
if amount == 0:
|
||||
memgpt_agent.interface.print_messages(memgpt_agent.messages, dump=True)
|
||||
else:
|
||||
memgpt_agent.interface.print_messages(memgpt_agent.messages[-min(amount, len(memgpt_agent.messages)) :], dump=True)
|
||||
|
||||
elif command.lower() == "dumpraw":
|
||||
memgpt_agent.interface.print_messages_raw(memgpt_agent.messages)
|
||||
|
||||
elif command.lower() == "memory":
|
||||
ret_str = (
|
||||
f"\nDumping memory contents:\n"
|
||||
+ f"\n{str(memgpt_agent.memory)}"
|
||||
+ f"\n{str(memgpt_agent.persistence_manager.archival_memory)}"
|
||||
+ f"\n{str(memgpt_agent.persistence_manager.recall_memory)}"
|
||||
)
|
||||
return ret_str
|
||||
|
||||
elif command.lower() == "pop" or command.lower().startswith("pop "):
|
||||
# Check if there's an additional argument that's an integer
|
||||
command = command.strip().split()
|
||||
pop_amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 3
|
||||
n_messages = len(memgpt_agent.messages)
|
||||
MIN_MESSAGES = 2
|
||||
if n_messages <= MIN_MESSAGES:
|
||||
print(f"Agent only has {n_messages} messages in stack, none left to pop")
|
||||
elif n_messages - pop_amount < MIN_MESSAGES:
|
||||
print(f"Agent only has {n_messages} messages in stack, cannot pop more than {n_messages - MIN_MESSAGES}")
|
||||
else:
|
||||
print(f"Popping last {pop_amount} messages from stack")
|
||||
for _ in range(min(pop_amount, len(memgpt_agent.messages))):
|
||||
memgpt_agent.messages.pop()
|
||||
|
||||
elif command.lower() == "retry":
|
||||
# TODO this needs to also modify the persistence manager
|
||||
print(f"Retrying for another answer")
|
||||
while len(memgpt_agent.messages) > 0:
|
||||
if memgpt_agent.messages[-1].get("role") == "user":
|
||||
# we want to pop up to the last user message and send it again
|
||||
user_message = memgpt_agent.messages[-1].get("content")
|
||||
memgpt_agent.messages.pop()
|
||||
break
|
||||
memgpt_agent.messages.pop()
|
||||
|
||||
elif command.lower() == "rethink" or command.lower().startswith("rethink "):
|
||||
# TODO this needs to also modify the persistence manager
|
||||
if len(command) < len("rethink "):
|
||||
print("Missing text after the command")
|
||||
else:
|
||||
for x in range(len(memgpt_agent.messages) - 1, 0, -1):
|
||||
if memgpt_agent.messages[x].get("role") == "assistant":
|
||||
text = command[len("rethink ") :].strip()
|
||||
memgpt_agent.messages[x].update({"content": text})
|
||||
break
|
||||
|
||||
elif command.lower() == "rewrite" or command.lower().startswith("rewrite "):
|
||||
# TODO this needs to also modify the persistence manager
|
||||
if len(command) < len("rewrite "):
|
||||
print("Missing text after the command")
|
||||
else:
|
||||
for x in range(len(memgpt_agent.messages) - 1, 0, -1):
|
||||
if memgpt_agent.messages[x].get("role") == "assistant":
|
||||
text = command[len("rewrite ") :].strip()
|
||||
args = json.loads(memgpt_agent.messages[x].get("function_call").get("arguments"))
|
||||
args["message"] = text
|
||||
memgpt_agent.messages[x].get("function_call").update({"arguments": json.dumps(args)})
|
||||
break
|
||||
|
||||
# No skip options
|
||||
elif command.lower() == "wipe":
|
||||
# exit not supported on server.py
|
||||
raise ValueError(command)
|
||||
|
||||
elif command.lower() == "heartbeat":
|
||||
input_message = system.get_heartbeat()
|
||||
self._step(user_id=user_id, agent_id=agent_id, input_message=input_message)
|
||||
|
||||
elif command.lower() == "memorywarning":
|
||||
input_message = system.get_token_limit_warning()
|
||||
self._step(user_id=user_id, agent_id=agent_id, input_message=input_message)
|
||||
|
||||
def user_message(self, user_id: str, agent_id: str, message: str) -> None:
|
||||
"""Process an incoming user message and feed it through the MemGPT agent"""
|
||||
from memgpt.utils import printd
|
||||
|
||||
# Basic input sanitization
|
||||
if not isinstance(message, str) or len(message) == 0:
|
||||
raise ValueError(f"Invalid input: '{message}'")
|
||||
|
||||
# If the input begins with a command prefix, reject
|
||||
elif message.startswith("/"):
|
||||
raise ValueError(f"Invalid input: '{message}'")
|
||||
|
||||
# Else, process it as a user message to be fed to the agent
|
||||
else:
|
||||
# Package the user message first
|
||||
packaged_user_message = package_user_message(user_message=message)
|
||||
# Run the agent state forward
|
||||
self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message)
|
||||
|
||||
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
|
||||
"""Run a command on the agent"""
|
||||
# If the input begins with a command prefix, attempt to process it as a command
|
||||
if command.startswith("/"):
|
||||
if len(command) > 1:
|
||||
command = command[1:] # strip the prefix
|
||||
return self._command(user_id=user_id, agent_id=agent_id, command=command)
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_config: Union[dict, AgentConfig],
|
||||
interface: Union[AgentInterface, None] = None,
|
||||
persistence_manager: Union[PersistenceManager, None] = None,
|
||||
) -> str:
|
||||
"""Create a new agent using a config"""
|
||||
|
||||
# Initialize the agent based on the provided configuration
|
||||
if isinstance(agent_config, dict):
|
||||
agent_config = AgentConfig(**agent_config)
|
||||
|
||||
if interface is None:
|
||||
# interface = self.default_interface_cls()
|
||||
interface = self.default_interface
|
||||
|
||||
if persistence_manager is None:
|
||||
persistence_manager = self.default_persistence_manager_cls(agent_config=agent_config)
|
||||
|
||||
# Create agent via preset from config
|
||||
agent = presets.use_preset(
|
||||
agent_config.preset,
|
||||
agent_config,
|
||||
agent_config.model,
|
||||
utils.get_persona_text(agent_config.persona),
|
||||
utils.get_human_text(agent_config.human),
|
||||
interface,
|
||||
persistence_manager,
|
||||
)
|
||||
agent.save()
|
||||
print(f"Created new agent from config: {agent}")
|
||||
|
||||
return agent.config.name
|
||||
@@ -1,80 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import websockets
|
||||
|
||||
import memgpt.server.websocket_protocol as protocol
|
||||
from memgpt.server.websocket_server import WebSocketServer
|
||||
from memgpt.server.constants import DEFAULT_PORT, CLIENT_TIMEOUT
|
||||
from memgpt.server.utils import condition_to_stop_receiving, print_server_response
|
||||
|
||||
|
||||
# CLEAN_RESPONSES = False # print the raw server responses (JSON)
|
||||
CLEAN_RESPONSES = True # make the server responses cleaner
|
||||
|
||||
# LOAD_AGENT = None # create a brand new agent
|
||||
LOAD_AGENT = "agent_26" # load an existing agent
|
||||
|
||||
|
||||
async def basic_cli_client():
|
||||
"""Basic example of a MemGPT CLI client that connects to a MemGPT server.py process via WebSockets
|
||||
|
||||
Meant to illustrate how to use the server.py process, so limited in features (only supports sending user messages)
|
||||
"""
|
||||
uri = f"ws://localhost:{DEFAULT_PORT}"
|
||||
|
||||
async with websockets.connect(uri) as websocket:
|
||||
if LOAD_AGENT is not None:
|
||||
# Load existing agent
|
||||
print("Sending load message to server...")
|
||||
await websocket.send(protocol.client_command_load(LOAD_AGENT))
|
||||
|
||||
else:
|
||||
# Initialize new agent
|
||||
print("Sending config to server...")
|
||||
example_config = {
|
||||
"persona": "sam_pov",
|
||||
"human": "cs_phd",
|
||||
"model": "gpt-4-1106-preview", # gpt-4-turbo
|
||||
}
|
||||
await websocket.send(protocol.client_command_create(example_config))
|
||||
# Wait for the response
|
||||
response = await websocket.recv()
|
||||
response = json.loads(response)
|
||||
print(f"Server response:\n{json.dumps(response, indent=2)}")
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
while True:
|
||||
user_input = input("\nEnter your message: ")
|
||||
print("\n")
|
||||
|
||||
# Send a message to the agent
|
||||
await websocket.send(protocol.client_user_message(str(user_input)))
|
||||
|
||||
# Wait for messages in a loop, since the server may send a few
|
||||
while True:
|
||||
try:
|
||||
response = await asyncio.wait_for(websocket.recv(), CLIENT_TIMEOUT)
|
||||
response = json.loads(response)
|
||||
|
||||
if CLEAN_RESPONSES:
|
||||
print_server_response(response)
|
||||
else:
|
||||
print(f"Server response:\n{json.dumps(response, indent=2)}")
|
||||
|
||||
# Check for a specific condition to break the loop
|
||||
if condition_to_stop_receiving(response):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
print("Timeout waiting for the server response.")
|
||||
break
|
||||
except websockets.exceptions.ConnectionClosedError:
|
||||
print("Connection to server was lost.")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
break
|
||||
|
||||
|
||||
asyncio.run(basic_cli_client())
|
||||
@@ -1,205 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import traceback
|
||||
|
||||
import websockets
|
||||
|
||||
from memgpt.server.websocket_interface import SyncWebSocketInterface
|
||||
from memgpt.server.constants import DEFAULT_PORT
|
||||
import memgpt.server.websocket_protocol as protocol
|
||||
import memgpt.system as system
|
||||
import memgpt.constants as memgpt_constants
|
||||
|
||||
|
||||
class WebSocketServer:
|
||||
def __init__(self, host="localhost", port=DEFAULT_PORT):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.interface = SyncWebSocketInterface()
|
||||
|
||||
self.agent = None
|
||||
self.agent_name = None
|
||||
|
||||
def run_step(self, user_message, first_message=False, no_verify=False):
|
||||
while True:
|
||||
new_messages, heartbeat_request, function_failed, token_warning = self.agent.step(
|
||||
user_message, first_message=first_message, skip_verify=no_verify
|
||||
)
|
||||
|
||||
if token_warning:
|
||||
user_message = system.get_token_limit_warning()
|
||||
elif function_failed:
|
||||
user_message = system.get_heartbeat(memgpt_constants.FUNC_FAILED_HEARTBEAT_MESSAGE)
|
||||
elif heartbeat_request:
|
||||
user_message = system.get_heartbeat(memgpt_constants.REQ_HEARTBEAT_MESSAGE)
|
||||
else:
|
||||
# return control
|
||||
break
|
||||
|
||||
async def handle_client(self, websocket, path):
|
||||
self.interface.register_client(websocket)
|
||||
try:
|
||||
# async for message in websocket:
|
||||
while True:
|
||||
message = await websocket.recv()
|
||||
|
||||
# Assuming the message is a JSON string
|
||||
try:
|
||||
data = json.loads(message)
|
||||
except:
|
||||
print(f"[server] bad data from client:\n{data}")
|
||||
await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}"))
|
||||
continue
|
||||
|
||||
if "type" not in data:
|
||||
print(f"[server] bad data from client (JSON but no type):\n{data}")
|
||||
await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}"))
|
||||
|
||||
elif data["type"] == "command":
|
||||
# Create a new agent
|
||||
if data["command"] == "create_agent":
|
||||
try:
|
||||
self.agent = self.create_new_agent(data["config"])
|
||||
await websocket.send(protocol.server_command_response("OK: Agent initialized"))
|
||||
except Exception as e:
|
||||
self.agent = None
|
||||
print(f"[server] self.create_new_agent failed with:\n{e}")
|
||||
print(f"{traceback.format_exc()}")
|
||||
await websocket.send(protocol.server_command_response(f"Error: Failed to init agent - {str(e)}"))
|
||||
|
||||
# Load an existing agent
|
||||
elif data["command"] == "load_agent":
|
||||
agent_name = data.get("name")
|
||||
if agent_name is not None:
|
||||
try:
|
||||
self.agent = self.load_agent(agent_name)
|
||||
self.agent_name = agent_name
|
||||
await websocket.send(protocol.server_command_response(f"OK: Agent '{agent_name}' loaded"))
|
||||
except Exception as e:
|
||||
print(f"[server] self.load_agent failed with:\n{e}")
|
||||
print(f"{traceback.format_exc()}")
|
||||
self.agent = None
|
||||
await websocket.send(
|
||||
protocol.server_command_response(f"Error: Failed to load agent '{agent_name}' - {str(e)}")
|
||||
)
|
||||
else:
|
||||
await websocket.send(protocol.server_command_response(f"Error: 'name' not provided"))
|
||||
|
||||
else:
|
||||
print(f"[server] unrecognized client command type: {data}")
|
||||
await websocket.send(protocol.server_error(f"unrecognized client command type: {data}"))
|
||||
|
||||
elif data["type"] == "user_message":
|
||||
user_message = data["message"]
|
||||
|
||||
if "agent_name" in data:
|
||||
agent_name = data["agent_name"]
|
||||
# If the agent requested the same one that's already loading?
|
||||
if self.agent_name is None or self.agent_name != data["agent_name"]:
|
||||
try:
|
||||
print(f"[server] loading agent {agent_name}")
|
||||
self.agent = self.load_agent(agent_name)
|
||||
self.agent_name = agent_name
|
||||
# await websocket.send(protocol.server_command_response(f"OK: Agent '{agent_name}' loaded"))
|
||||
except Exception as e:
|
||||
print(f"[server] self.load_agent failed with:\n{e}")
|
||||
print(f"{traceback.format_exc()}")
|
||||
self.agent = None
|
||||
await websocket.send(
|
||||
protocol.server_command_response(f"Error: Failed to load agent '{agent_name}' - {str(e)}")
|
||||
)
|
||||
else:
|
||||
await websocket.send(protocol.server_agent_response_error("agent_name was not specified in the request"))
|
||||
continue
|
||||
|
||||
if self.agent is None:
|
||||
await websocket.send(protocol.server_agent_response_error("No agent has been initialized"))
|
||||
else:
|
||||
await websocket.send(protocol.server_agent_response_start())
|
||||
try:
|
||||
self.run_step(user_message)
|
||||
except Exception as e:
|
||||
print(f"[server] self.run_step failed with:\n{e}")
|
||||
print(f"{traceback.format_exc()}")
|
||||
await websocket.send(protocol.server_agent_response_error(f"self.run_step failed with: {e}"))
|
||||
|
||||
await asyncio.sleep(1) # pause before sending the terminating message, w/o this messages may be missed
|
||||
await websocket.send(protocol.server_agent_response_end())
|
||||
|
||||
# ... handle other message types as needed ...
|
||||
else:
|
||||
print(f"[server] unrecognized client package data type: {data}")
|
||||
await websocket.send(protocol.server_error(f"unrecognized client package data type: {data}"))
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
print(f"[server] connection with client was closed")
|
||||
finally:
|
||||
# TODO autosave the agent
|
||||
|
||||
self.interface.unregister_client(websocket)
|
||||
|
||||
def create_new_agent(self, config):
|
||||
"""Config is json that arrived over websocket, so we need to turn it into a config object"""
|
||||
from memgpt.config import AgentConfig
|
||||
import memgpt.presets.presets as presets
|
||||
import memgpt.utils as utils
|
||||
from memgpt.persistence_manager import InMemoryStateManager
|
||||
|
||||
print("Creating new agent...")
|
||||
|
||||
# Initialize the agent based on the provided configuration
|
||||
agent_config = AgentConfig(**config)
|
||||
|
||||
# Use an in-state persistence manager
|
||||
persistence_manager = InMemoryStateManager()
|
||||
|
||||
# Create agent via preset from config
|
||||
agent = presets.use_preset(
|
||||
agent_config.preset,
|
||||
agent_config,
|
||||
agent_config.model,
|
||||
utils.get_persona_text(agent_config.persona),
|
||||
utils.get_human_text(agent_config.human),
|
||||
self.interface,
|
||||
persistence_manager,
|
||||
)
|
||||
print("Created new agent from config")
|
||||
|
||||
return agent
|
||||
|
||||
def load_agent(self, agent_name):
|
||||
"""Load an agent from a directory"""
|
||||
import memgpt.utils as utils
|
||||
from memgpt.config import AgentConfig
|
||||
from memgpt.agent import Agent
|
||||
|
||||
print(f"Loading agent {agent_name}...")
|
||||
|
||||
agent_files = utils.list_agent_config_files()
|
||||
agent_names = [AgentConfig.load(f).name for f in agent_files]
|
||||
|
||||
if agent_name not in agent_names:
|
||||
raise ValueError(f"agent '{agent_name}' does not exist")
|
||||
|
||||
agent_config = AgentConfig.load(agent_name)
|
||||
agent = Agent.load_agent(self.interface, agent_config)
|
||||
print("Created agent by loading existing config")
|
||||
|
||||
return agent
|
||||
|
||||
def initialize_server(self):
|
||||
print("Server is initializing...")
|
||||
print(f"Listening on {self.host}:{self.port}...")
|
||||
|
||||
async def start_server(self):
|
||||
self.initialize_server()
|
||||
async with websockets.serve(self.handle_client, self.host, self.port):
|
||||
await asyncio.Future() # Run forever
|
||||
|
||||
def run(self):
|
||||
return self.start_server() # Return the coroutine
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
server = WebSocketServer()
|
||||
asyncio.run(server.run())
|
||||
0
memgpt/server/ws_api/__init__.py
Normal file
0
memgpt/server/ws_api/__init__.py
Normal file
106
memgpt/server/ws_api/example_client.py
Normal file
106
memgpt/server/ws_api/example_client.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import websockets
|
||||
|
||||
import memgpt.server.ws_api.protocol as protocol
|
||||
from memgpt.server.constants import DEFAULT_PORT, CLIENT_TIMEOUT
|
||||
from memgpt.server.utils import condition_to_stop_receiving, print_server_response
|
||||
|
||||
|
||||
# CLEAN_RESPONSES = False # print the raw server responses (JSON)
|
||||
CLEAN_RESPONSES = True # make the server responses cleaner
|
||||
|
||||
# LOAD_AGENT = None # create a brand new agent
|
||||
AGENT_NAME = "agent_26" # load an existing agent
|
||||
NEW_AGENT = False
|
||||
|
||||
RECONNECT_DELAY = 1
|
||||
RECONNECT_MAX_TRIES = 5
|
||||
|
||||
|
||||
async def send_message_and_print_replies(websocket, user_message, agent_id):
|
||||
"""Send a message over websocket protocol and wait for the reply stream to end"""
|
||||
# Send a message to the agent
|
||||
await websocket.send(protocol.client_user_message(msg=str(user_message), agent_id=agent_id))
|
||||
|
||||
# Wait for messages in a loop, since the server may send a few
|
||||
while True:
|
||||
response = await asyncio.wait_for(websocket.recv(), CLIENT_TIMEOUT)
|
||||
response = json.loads(response)
|
||||
|
||||
if CLEAN_RESPONSES:
|
||||
print_server_response(response)
|
||||
else:
|
||||
print(f"Server response:\n{json.dumps(response, indent=2)}")
|
||||
|
||||
# Check for a specific condition to break the loop
|
||||
if condition_to_stop_receiving(response):
|
||||
break
|
||||
|
||||
|
||||
async def basic_cli_client():
|
||||
"""Basic example of a MemGPT CLI client that connects to a MemGPT server.py process via WebSockets
|
||||
|
||||
Meant to illustrate how to use the server.py process, so limited in features (only supports sending user messages)
|
||||
"""
|
||||
uri = f"ws://localhost:{DEFAULT_PORT}"
|
||||
|
||||
closed_on_message = False
|
||||
retry_attempts = 0
|
||||
while True: # Outer loop for reconnection attempts
|
||||
try:
|
||||
async with websockets.connect(uri) as websocket:
|
||||
if NEW_AGENT:
|
||||
# Initialize new agent
|
||||
print("Sending config to server...")
|
||||
example_config = {
|
||||
"persona": "sam_pov",
|
||||
"human": "cs_phd",
|
||||
"model": "gpt-4-1106-preview", # gpt-4-turbo
|
||||
}
|
||||
await websocket.send(protocol.client_command_create(example_config))
|
||||
# Wait for the response
|
||||
response = await websocket.recv()
|
||||
response = json.loads(response)
|
||||
print(f"Server response:\n{json.dumps(response, indent=2)}")
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
while True:
|
||||
if closed_on_message:
|
||||
# If we're on a retry after a disconnect, don't ask for input again
|
||||
closed_on_message = False
|
||||
else:
|
||||
user_input = input("\nEnter your message: ")
|
||||
print("\n")
|
||||
|
||||
# Send a message to the agent
|
||||
try:
|
||||
await send_message_and_print_replies(websocket=websocket, user_message=user_input, agent_id=AGENT_NAME)
|
||||
retry_attempts = 0
|
||||
except websockets.exceptions.ConnectionClosedError:
|
||||
print("Connection to server was lost. Attempting to reconnect...")
|
||||
closed_on_message = True
|
||||
raise
|
||||
|
||||
except websockets.exceptions.ConnectionClosedError:
|
||||
# Decide whether or not to retry the connection
|
||||
if retry_attempts < RECONNECT_MAX_TRIES:
|
||||
retry_attempts += 1
|
||||
await asyncio.sleep(RECONNECT_DELAY) # Wait for N seconds before reconnecting
|
||||
continue
|
||||
else:
|
||||
print(f"Max attempts exceeded ({retry_attempts} > {RECONNECT_MAX_TRIES})")
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print("Timeout waiting for the server response.")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
continue
|
||||
|
||||
|
||||
asyncio.run(basic_cli_client())
|
||||
@@ -3,7 +3,7 @@ import threading
|
||||
|
||||
|
||||
from memgpt.interface import AgentInterface
|
||||
import memgpt.server.websocket_protocol as protocol
|
||||
import memgpt.server.ws_api.protocol as protocol
|
||||
|
||||
|
||||
class BaseWebSocketInterface(AgentInterface):
|
||||
@@ -20,6 +20,9 @@ class BaseWebSocketInterface(AgentInterface):
|
||||
"""Unregister a client connection"""
|
||||
self.clients.remove(websocket)
|
||||
|
||||
def step_yield(self):
|
||||
pass
|
||||
|
||||
|
||||
class AsyncWebSocketInterface(BaseWebSocketInterface):
|
||||
"""WebSocket calls are async"""
|
||||
@@ -80,12 +80,12 @@ def server_agent_function_message(msg):
|
||||
# Client -> server
|
||||
|
||||
|
||||
def client_user_message(msg, agent_name=None):
|
||||
def client_user_message(msg, agent_id=None):
|
||||
return json.dumps(
|
||||
{
|
||||
"type": "user_message",
|
||||
"message": msg,
|
||||
"agent_name": agent_name,
|
||||
"agent_id": agent_id,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -98,13 +98,3 @@ def client_command_create(config):
|
||||
"config": config,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def client_command_load(agent_name):
|
||||
return json.dumps(
|
||||
{
|
||||
"type": "command",
|
||||
"command": "load_agent",
|
||||
"name": agent_name,
|
||||
}
|
||||
)
|
||||
107
memgpt/server/ws_api/server.py
Normal file
107
memgpt/server/ws_api/server.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import asyncio
|
||||
import json
|
||||
import traceback
|
||||
|
||||
import websockets
|
||||
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.server.ws_api.interface import SyncWebSocketInterface
|
||||
from memgpt.server.constants import DEFAULT_PORT
|
||||
import memgpt.server.ws_api.protocol as protocol
|
||||
import memgpt.system as system
|
||||
import memgpt.constants as memgpt_constants
|
||||
|
||||
|
||||
class WebSocketServer:
|
||||
def __init__(self, host="localhost", port=DEFAULT_PORT):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.interface = SyncWebSocketInterface()
|
||||
self.server = SyncServer(default_interface=self.interface)
|
||||
|
||||
def __del__(self):
|
||||
self.interface.close()
|
||||
|
||||
def initialize_server(self):
|
||||
print("Server is initializing...")
|
||||
print(f"Listening on {self.host}:{self.port}...")
|
||||
|
||||
async def start_server(self):
|
||||
self.initialize_server()
|
||||
# Can play with ping_interval and ping_timeout
|
||||
# See: https://websockets.readthedocs.io/en/stable/topics/timeouts.html
|
||||
# and https://github.com/cpacker/MemGPT/issues/471
|
||||
async with websockets.serve(self.handle_client, self.host, self.port):
|
||||
await asyncio.Future() # Run forever
|
||||
|
||||
def run(self):
|
||||
return self.start_server() # Return the coroutine
|
||||
|
||||
async def handle_client(self, websocket, path):
|
||||
self.interface.register_client(websocket)
|
||||
try:
|
||||
# async for message in websocket:
|
||||
while True:
|
||||
message = await websocket.recv()
|
||||
|
||||
# Assuming the message is a JSON string
|
||||
try:
|
||||
data = json.loads(message)
|
||||
except:
|
||||
print(f"[server] bad data from client:\n{data}")
|
||||
await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}"))
|
||||
continue
|
||||
|
||||
if "type" not in data:
|
||||
print(f"[server] bad data from client (JSON but no type):\n{data}")
|
||||
await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}"))
|
||||
|
||||
elif data["type"] == "command":
|
||||
# Create a new agent
|
||||
if data["command"] == "create_agent":
|
||||
try:
|
||||
# self.agent = self.create_new_agent(data["config"])
|
||||
self.server.create_agent(user_id="NULL", agent_config=data["config"])
|
||||
await websocket.send(protocol.server_command_response("OK: Agent initialized"))
|
||||
except Exception as e:
|
||||
self.agent = None
|
||||
print(f"[server] self.create_new_agent failed with:\n{e}")
|
||||
print(f"{traceback.format_exc()}")
|
||||
await websocket.send(protocol.server_command_response(f"Error: Failed to init agent - {str(e)}"))
|
||||
|
||||
else:
|
||||
print(f"[server] unrecognized client command type: {data}")
|
||||
await websocket.send(protocol.server_error(f"unrecognized client command type: {data}"))
|
||||
|
||||
elif data["type"] == "user_message":
|
||||
user_message = data["message"]
|
||||
|
||||
if "agent_id" not in data or data["agent_id"] is None:
|
||||
await websocket.send(protocol.server_agent_response_error("agent_name was not specified in the request"))
|
||||
continue
|
||||
|
||||
await websocket.send(protocol.server_agent_response_start())
|
||||
try:
|
||||
# self.run_step(user_message)
|
||||
self.server.user_message(user_id="NULL", agent_id=data["agent_id"], message=user_message)
|
||||
except Exception as e:
|
||||
print(f"[server] self.server.user_message failed with:\n{e}")
|
||||
print(f"{traceback.format_exc()}")
|
||||
await websocket.send(protocol.server_agent_response_error(f"server.user_message failed with: {e}"))
|
||||
await asyncio.sleep(1) # pause before sending the terminating message, w/o this messages may be missed
|
||||
await websocket.send(protocol.server_agent_response_end())
|
||||
|
||||
# ... handle other message types as needed ...
|
||||
else:
|
||||
print(f"[server] unrecognized client package data type: {data}")
|
||||
await websocket.send(protocol.server_error(f"unrecognized client package data type: {data}"))
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
print(f"[server] connection with client was closed")
|
||||
finally:
|
||||
self.interface.unregister_client(websocket)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
server = WebSocketServer()
|
||||
asyncio.run(server.run())
|
||||
50
poetry.lock
generated
50
poetry.lock
generated
@@ -137,24 +137,24 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "anyio"
|
||||
version = "4.1.0"
|
||||
version = "3.7.1"
|
||||
description = "High level compatibility layer for multiple asynchronous event loop implementations"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "anyio-4.1.0-py3-none-any.whl", hash = "sha256:56a415fbc462291813a94528a779597226619c8e78af7de0507333f700011e5f"},
|
||||
{file = "anyio-4.1.0.tar.gz", hash = "sha256:5a0bec7085176715be77df87fc66d6c9d70626bd752fcc85f57cdbee5b3760da"},
|
||||
{file = "anyio-3.7.1-py3-none-any.whl", hash = "sha256:91dee416e570e92c64041bd18b900d1d6fa78dff7048769ce5ac5ddad004fbb5"},
|
||||
{file = "anyio-3.7.1.tar.gz", hash = "sha256:44a3c9aba0f5defa43261a8b3efb97891f2bd7d804e0e1f56419befa1adfc780"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""}
|
||||
exceptiongroup = {version = "*", markers = "python_version < \"3.11\""}
|
||||
idna = ">=2.8"
|
||||
sniffio = ">=1.1"
|
||||
|
||||
[package.extras]
|
||||
doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"]
|
||||
test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"]
|
||||
trio = ["trio (>=0.23)"]
|
||||
doc = ["Sphinx", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-jquery"]
|
||||
test = ["anyio[trio]", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"]
|
||||
trio = ["trio (<0.22)"]
|
||||
|
||||
[[package]]
|
||||
name = "asgiref"
|
||||
@@ -737,19 +737,20 @@ test = ["pytest (>=6)"]
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.103.0"
|
||||
version = "0.104.1"
|
||||
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "fastapi-0.103.0-py3-none-any.whl", hash = "sha256:61ab72c6c281205dd0cbaccf503e829a37e0be108d965ac223779a8479243665"},
|
||||
{file = "fastapi-0.103.0.tar.gz", hash = "sha256:4166732f5ddf61c33e9fa4664f73780872511e0598d4d5434b1816dc1e6d9421"},
|
||||
{file = "fastapi-0.104.1-py3-none-any.whl", hash = "sha256:752dc31160cdbd0436bb93bad51560b57e525cbb1d4bbf6f4904ceee75548241"},
|
||||
{file = "fastapi-0.104.1.tar.gz", hash = "sha256:e5e4540a7c5e1dcfbbcf5b903c234feddcdcd881f191977a1c5dfd917487e7ae"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=3.7.1,<4.0.0"
|
||||
pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0"
|
||||
starlette = ">=0.27.0,<0.28.0"
|
||||
typing-extensions = ">=4.5.0"
|
||||
typing-extensions = ">=4.8.0"
|
||||
|
||||
[package.extras]
|
||||
all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
|
||||
@@ -3074,6 +3075,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
|
||||
[package.extras]
|
||||
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-asyncio"
|
||||
version = "0.23.2"
|
||||
description = "Pytest support for asyncio"
|
||||
optional = true
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pytest-asyncio-0.23.2.tar.gz", hash = "sha256:c16052382554c7b22d48782ab3438d5b10f8cf7a4bdcae7f0f67f097d95beecc"},
|
||||
{file = "pytest_asyncio-0.23.2-py3-none-any.whl", hash = "sha256:ea9021364e32d58f0be43b91c6233fb8d2224ccef2398d6837559e587682808f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=7.0.0"
|
||||
|
||||
[package.extras]
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
|
||||
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "python-box"
|
||||
version = "7.1.1"
|
||||
@@ -4774,11 +4793,12 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link
|
||||
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
|
||||
|
||||
[extras]
|
||||
dev = ["black", "datasets", "pre-commit", "pytest"]
|
||||
dev = ["black", "datasets", "pre-commit", "pytest", "pytest-asyncio"]
|
||||
local = ["huggingface-hub", "torch", "transformers"]
|
||||
postgres = ["pg8000", "pgvector", "psycopg", "psycopg-binary", "psycopg2-binary"]
|
||||
server = ["fastapi", "uvicorn", "websockets"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.12,>=3.9"
|
||||
content-hash = "2d68f2515a73a9b2cafb445138c667f61153ac6feb23c124032fc2c2d56baf4a"
|
||||
content-hash = "4f675213d5a79f001bfb7441c9fba23ae114079ec61a30b0c88833c5427f152e"
|
||||
|
||||
@@ -26,7 +26,7 @@ pytz = "^2023.3.post1"
|
||||
tqdm = "^4.66.1"
|
||||
black = { version = "^23.10.1", optional = true }
|
||||
pytest = { version = "^7.4.3", optional = true }
|
||||
llama-index = "0.9.13"
|
||||
llama-index = "^0.9.13"
|
||||
setuptools = "^68.2.2"
|
||||
datasets = { version = "^2.14.6", optional = true}
|
||||
prettytable = "^3.9.0"
|
||||
@@ -39,7 +39,7 @@ transformers = { version = "4.34.1", optional = true }
|
||||
pre-commit = {version = "^3.5.0", optional = true }
|
||||
pg8000 = {version = "^1.30.3", optional = true}
|
||||
torch = {version = ">=2.0.0, !=2.0.1, !=2.1.0", optional = true}
|
||||
websockets = "^12.0"
|
||||
websockets = {version = "^12.0", optional = true}
|
||||
docstring-parser = "^0.15"
|
||||
lancedb = "^0.3.3"
|
||||
httpx = "^0.25.2"
|
||||
@@ -49,12 +49,17 @@ tiktoken = "^0.5.1"
|
||||
python-box = "^7.1.1"
|
||||
pypdf = "^3.17.1"
|
||||
pyyaml = "^6.0.1"
|
||||
fastapi = {version = "^0.104.1", optional = true}
|
||||
uvicorn = {version = "^0.24.0.post1", optional = true}
|
||||
chromadb = "^0.4.18"
|
||||
pytest-asyncio = {version = "^0.23.2", optional = true}
|
||||
pydantic = "^2.5.2"
|
||||
|
||||
[tool.poetry.extras]
|
||||
local = ["torch", "huggingface-hub", "transformers"]
|
||||
postgres = ["pgvector", "psycopg", "psycopg-binary", "psycopg2-binary", "pg8000"]
|
||||
dev = ["pytest", "black", "pre-commit", "datasets"]
|
||||
dev = ["pytest", "pytest-asyncio", "black", "pre-commit", "datasets"]
|
||||
server = ["websockets", "fastapi", "uvicorn"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
43
tests/test_server.py
Normal file
43
tests/test_server.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import memgpt.utils as utils
|
||||
|
||||
utils.DEBUG = True
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
|
||||
def test_server():
|
||||
user_id = "NULL"
|
||||
agent_id = "agent_26"
|
||||
|
||||
server = SyncServer()
|
||||
|
||||
try:
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
except:
|
||||
raise
|
||||
|
||||
try:
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="/memory")
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
except:
|
||||
raise
|
||||
|
||||
try:
|
||||
print(server.run_command(user_id=user_id, agent_id=agent_id, command="/memory"))
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
except:
|
||||
raise
|
||||
|
||||
try:
|
||||
server.user_message(user_id=user_id, agent_id="agent no-exist", message="Hello?")
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
except:
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_server()
|
||||
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, MagicMock
|
||||
|
||||
from memgpt.config import MemGPTConfig, AgentConfig
|
||||
from memgpt.server.websocket_interface import SyncWebSocketInterface
|
||||
import memgpt.presets as presets
|
||||
from memgpt.server.ws_api.interface import SyncWebSocketInterface
|
||||
import memgpt.presets.presets as presets
|
||||
import memgpt.utils as utils
|
||||
import memgpt.system as system
|
||||
from memgpt.persistence_manager import LocalStateManager
|
||||
@@ -54,19 +55,32 @@ async def test_websockets():
|
||||
ws_interface.register_client(mock_websocket)
|
||||
|
||||
# Create an agent and hook it up to the WebSocket interface
|
||||
config = MemGPTConfig()
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if api_key is None:
|
||||
ws_interface.close()
|
||||
return
|
||||
config = MemGPTConfig.load()
|
||||
if config.openai_key is None:
|
||||
config.openai_key = api_key
|
||||
config.save()
|
||||
|
||||
# Mock the persistence manager
|
||||
# create agents with defaults
|
||||
agent_config = AgentConfig(persona="sam_pov", human="basic", model="gpt-4-1106-preview")
|
||||
agent_config = AgentConfig(
|
||||
persona="sam_pov",
|
||||
human="basic",
|
||||
model="gpt-4-1106-preview",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
)
|
||||
persistence_manager = LocalStateManager(agent_config=agent_config)
|
||||
|
||||
memgpt_agent = presets.use_preset(
|
||||
presets.DEFAULT_PRESET,
|
||||
config, # no agent config to provide
|
||||
"gpt-4-1106-preview",
|
||||
utils.get_persona_text("sam_pov"),
|
||||
utils.get_human_text("basic"),
|
||||
agent_config.preset,
|
||||
agent_config,
|
||||
agent_config.model,
|
||||
agent_config.persona, # note: extracting the raw text, not pulling from a file
|
||||
agent_config.human, # note: extracting raw text, not pulling from a file
|
||||
ws_interface,
|
||||
persistence_manager,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ import websockets
|
||||
import pytest
|
||||
|
||||
from memgpt.server.constants import DEFAULT_PORT
|
||||
from memgpt.server.websocket_server import WebSocketServer
|
||||
from memgpt.server.ws_api.server import WebSocketServer
|
||||
from memgpt.config import AgentConfig
|
||||
|
||||
|
||||
@@ -16,7 +16,9 @@ async def test_dummy():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_server():
|
||||
server = WebSocketServer()
|
||||
# host = "127.0.0.1"
|
||||
host = "localhost"
|
||||
server = WebSocketServer(host=host)
|
||||
server_task = asyncio.create_task(server.run()) # Create a task for the server
|
||||
|
||||
# the agent config we want to ask the server to instantiate with
|
||||
@@ -28,23 +30,26 @@ async def test_websocket_server():
|
||||
# )
|
||||
test_config = {}
|
||||
|
||||
uri = f"ws://localhost:{DEFAULT_PORT}"
|
||||
async with websockets.connect(uri) as websocket:
|
||||
# Initialize the server with a test config
|
||||
print("Sending config to server...")
|
||||
await websocket.send(json.dumps({"type": "initialize", "config": test_config}))
|
||||
# Wait for the response
|
||||
response = await websocket.recv()
|
||||
print(f"Response from the agent: {response}")
|
||||
uri = f"ws://{host}:{DEFAULT_PORT}"
|
||||
try:
|
||||
async with websockets.connect(uri) as websocket:
|
||||
# Initialize the server with a test config
|
||||
print("Sending config to server...")
|
||||
await websocket.send(json.dumps({"type": "initialize", "config": test_config}))
|
||||
# Wait for the response
|
||||
response = await websocket.recv()
|
||||
print(f"Response from the agent: {response}")
|
||||
|
||||
await asyncio.sleep(1) # just in case
|
||||
await asyncio.sleep(1) # just in case
|
||||
|
||||
# Send a message to the agent
|
||||
print("Sending message to server...")
|
||||
await websocket.send(json.dumps({"type": "message", "content": "Hello, Agent!"}))
|
||||
# Wait for the response
|
||||
# NOTE: we should be waiting for multiple responses
|
||||
response = await websocket.recv()
|
||||
print(f"Response from the agent: {response}")
|
||||
|
||||
server_task.cancel() # Cancel the server task after the test
|
||||
# Send a message to the agent
|
||||
print("Sending message to server...")
|
||||
await websocket.send(json.dumps({"type": "message", "content": "Hello, Agent!"}))
|
||||
# Wait for the response
|
||||
# NOTE: we should be waiting for multiple responses
|
||||
response = await websocket.recv()
|
||||
print(f"Response from the agent: {response}")
|
||||
except (OSError, ConnectionRefusedError) as e:
|
||||
print(f"Was unable to connect: {e}")
|
||||
finally:
|
||||
server_task.cancel() # Cancel the server task after the test
|
||||
|
||||
Reference in New Issue
Block a user