fix: fixed the type hints in server to use uuid, patched tests that used strings as dummy users

This commit is contained in:
cpacker
2024-01-15 14:01:06 -08:00
parent e7e27fed17
commit 41f9640364
4 changed files with 67 additions and 40 deletions

View File

@@ -1,4 +1,5 @@
import os
import uuid
from typing import Dict, List, Union
from memgpt.data_types import AgentState
@@ -57,7 +58,16 @@ class Client(object):
if config is not None:
set_config_with_dict(config)
self.user_id = MemGPTConfig.load().anon_clientid if user_id is None else user_id
if user_id is None:
# the default user_id
config = MemGPTConfig.load()
self.user_id = uuid.UUID(config.anon_clientid)
elif isinstance(user_id, str):
self.user_id = uuid.UUID(user_id)
elif isinstance(user_id, uuid.UUID):
self.user_id = user_id
else:
raise TypeError(user_id)
self.interface = QueuingInterface(debug=debug)
self.server = SyncServer(default_interface=self.interface)

View File

@@ -33,39 +33,39 @@ class Server(object):
"""Abstract server class that supports multi-agent multi-user"""
@abstractmethod
def list_agents(self, user_id: str) -> dict:
def list_agents(self, user_id: uuid.UUID) -> dict:
"""List all available agents to a user"""
raise NotImplementedError
@abstractmethod
def get_agent_messages(self, user_id: str, agent_id: str, start: int, count: int) -> list:
def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
"""Paginated query of in-context messages in agent message queue"""
raise NotImplementedError
@abstractmethod
def get_agent_memory(self, user_id: str, agent_id: str) -> dict:
def get_agent_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict:
"""Return the memory of an agent (core memory + non-core statistics)"""
raise NotImplementedError
@abstractmethod
def get_agent_config(self, user_id: str, agent_id: str) -> dict:
def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict:
"""Return the config of an agent"""
raise NotImplementedError
@abstractmethod
def get_server_config(self, user_id: str) -> dict:
def get_server_config(self, user_id: uuid.UUID) -> dict:
"""Return the base config"""
raise NotImplementedError
@abstractmethod
def update_agent_core_memory(self, user_id: str, agent_id: str, new_memory_contents: dict) -> dict:
def update_agent_core_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_memory_contents: dict) -> dict:
"""Update the agents core memory block, return the new state"""
raise NotImplementedError
@abstractmethod
def create_agent(
self,
user_id: str,
user_id: uuid.UUID,
agent_config: Union[dict, AgentState],
interface: Union[AgentInterface, None],
# persistence_manager: Union[PersistenceManager, None],
@@ -74,17 +74,17 @@ class Server(object):
raise NotImplementedError
@abstractmethod
def user_message(self, user_id: str, agent_id: str, message: str) -> None:
def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None:
"""Process a message from the user, internally calls step"""
raise NotImplementedError
@abstractmethod
def system_message(self, user_id: str, agent_id: str, message: str) -> None:
def system_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None:
"""Process a message from the system, internally calls step"""
raise NotImplementedError
@abstractmethod
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
def run_command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]:
"""Run a command on the agent, e.g. /memory
May return a string with a message generated by the command
@@ -101,7 +101,7 @@ class LockingServer(Server):
@staticmethod
def agent_lock_decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(self, user_id: str, agent_id: str, *args, **kwargs):
def wrapper(self, user_id: uuid.UUID, agent_id: uuid.UUID, *args, **kwargs):
# logger.info("Locking check")
# Initialize the lock for the agent_id if it doesn't exist
@@ -126,11 +126,11 @@ class LockingServer(Server):
return wrapper
@agent_lock_decorator
def user_message(self, user_id: str, agent_id: str, message: str) -> None:
def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None:
raise NotImplementedError
@agent_lock_decorator
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
def run_command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]:
raise NotImplementedError
@@ -184,14 +184,14 @@ class SyncServer(LockingServer):
except Exception as e:
logger.exception(f"Error occurred while trying to save agent {agent_d['agent_id']}:\n{e}")
def _get_agent(self, user_id: str, agent_id: str) -> Union[Agent, None]:
def _get_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> Union[Agent, None]:
"""Get the agent object from the in-memory object store"""
for d in self.active_agents:
if d["user_id"] == user_id and d["agent_id"] == agent_id:
if d["user_id"] == str(user_id) and d["agent_id"] == str(agent_id):
return d["agent"]
return None
def _add_agent(self, user_id: str, agent_id: str, agent_obj: Agent) -> None:
def _add_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, agent_obj: Agent) -> None:
"""Put an agent object inside the in-memory object store"""
# Make sure the agent doesn't already exist
if self._get_agent(user_id=user_id, agent_id=agent_id) is not None:
@@ -199,42 +199,49 @@ class SyncServer(LockingServer):
# Add Agent instance to the in-memory list
self.active_agents.append(
{
"user_id": user_id,
"agent_id": agent_id,
"user_id": str(user_id),
"agent_id": str(agent_id),
"agent": agent_obj,
}
)
def _load_agent(self, user_id: str, agent_id: str, interface: Union[AgentInterface, None] = None) -> Agent:
def _load_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, interface: Union[AgentInterface, None] = None) -> Agent:
"""Loads a saved agent into memory (if it doesn't exist, throw an error)"""
assert isinstance(user_id, uuid.UUID), user_id
assert isinstance(agent_id, uuid.UUID), agent_id
# If an interface isn't specified, use the default
if interface is None:
interface = self.default_interface
try:
logger.info(f"Grabbing agent user_id={user_id} agent_id={agent_id} from database")
agent_state = self.ms.get_agent(agent_id=agent_id, user_id=user_id)
if not agent_state:
logger.exception(f"agent_id {agent_id} does not exist")
raise ValueError(f"agent_id {agent_id} does not exist")
# Instantiate an agent object using the state retrieved
logger.info(f"Creating an agent object")
memgpt_agent = Agent(agent_state=agent_state, interface=interface)
# Add the agent to the in-memory store and return its reference
self._add_agent(user_id=user_id, agent_id=agent_id, agent_obj=memgpt_agent)
logger.info(f"Creating an agent object")
return memgpt_agent
except Exception as e:
logger.exception(f"Error occurred while trying to get agent {agent_id}:\n{e}")
def _get_or_load_agent(self, user_id: str, agent_id: str) -> Agent:
def _get_or_load_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> Agent:
"""Check if the agent is in-memory, then load"""
memgpt_agent = self._get_agent(user_id=user_id, agent_id=agent_id)
if not memgpt_agent:
logger.info(f"Loading agent user_id={user_id} agent_id={agent_id}")
memgpt_agent = self._load_agent(user_id=user_id, agent_id=agent_id)
return memgpt_agent
def _step(self, user_id: str, agent_id: str, input_message: str) -> None:
def _step(self, user_id: uuid.UUID, agent_id: uuid.UUID, input_message: str) -> None:
"""Send the input message through the agent"""
logger.debug(f"Got input message: {input_message}")
@@ -278,7 +285,7 @@ class SyncServer(LockingServer):
memgpt_agent.interface.step_yield()
logger.debug(f"Finished agent step")
def _command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
def _command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]:
"""Process a CLI command"""
logger.debug(f"Got command: {command}")
@@ -407,7 +414,7 @@ class SyncServer(LockingServer):
self._step(user_id=user_id, agent_id=agent_id, input_message=input_message)
@LockingServer.agent_lock_decorator
def user_message(self, user_id: str, agent_id: str, message: str) -> None:
def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None:
"""Process an incoming user message and feed it through the MemGPT agent"""
# Basic input sanitization
@@ -426,7 +433,7 @@ class SyncServer(LockingServer):
self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message)
@LockingServer.agent_lock_decorator
def system_message(self, user_id: str, agent_id: str, message: str) -> None:
def system_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None:
"""Process an incoming system message and feed it through the MemGPT agent"""
from memgpt.utils import printd
@@ -446,7 +453,7 @@ class SyncServer(LockingServer):
self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_system_message)
@LockingServer.agent_lock_decorator
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
def run_command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]:
"""Run a command on the agent"""
# If the input begins with a command prefix, attempt to process it as a command
if command.startswith("/"):
@@ -456,7 +463,7 @@ class SyncServer(LockingServer):
def create_agent(
self,
user_id: str,
user_id: uuid.UUID,
agent_config: Union[dict, AgentState],
interface: Union[AgentInterface, None] = None,
# persistence_manager: Union[PersistenceManager, None] = None,
@@ -505,8 +512,8 @@ class SyncServer(LockingServer):
def delete_agent(
self,
user_id: str,
agent_id: str,
user_id: uuid.UUID,
agent_id: uuid.UUID,
):
# Make sure the user owns the agent
# TODO use real user_id
@@ -515,7 +522,7 @@ class SyncServer(LockingServer):
if agent is not None:
self.ms.delete_agent(agent_id=agent_id)
def list_agents(self, user_id: str) -> dict:
def list_agents(self, user_id: uuid.UUID) -> dict:
"""List all available agents to a user"""
# TODO actually use the user_id that was passed into the server
user_id = uuid.UUID(self.config.anon_clientid)
@@ -535,7 +542,7 @@ class SyncServer(LockingServer):
],
}
def get_agent_memory(self, user_id: str, agent_id: str) -> dict:
def get_agent_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict:
"""Return the memory of an agent (core memory + non-core statistics)"""
# Get the agent object (loaded in memory)
# TODO: use real user_id
@@ -556,7 +563,7 @@ class SyncServer(LockingServer):
return memory_obj
def get_agent_messages(self, user_id: str, agent_id: str, start: int, count: int) -> list:
def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
"""Paginated query of in-context messages in agent message queue"""
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
@@ -579,7 +586,7 @@ class SyncServer(LockingServer):
return paginated_messages
def get_agent_config(self, user_id: str, agent_id: str) -> dict:
def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict:
"""Return the config of an agent"""
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
@@ -602,7 +609,7 @@ class SyncServer(LockingServer):
clean_base_config = clean_keys(base_config)
return clean_base_config
def update_agent_core_memory(self, user_id: str, agent_id: str, new_memory_contents: dict) -> dict:
def update_agent_core_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_memory_contents: dict) -> dict:
"""Update the agents core memory block, return the new state"""
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)

View File

@@ -1,7 +1,10 @@
import uuid
import os
from memgpt import MemGPT
from memgpt.config import MemGPTConfig
from memgpt import constants
from memgpt.data_types import LLMConfig, EmbeddingConfig
import os
from .utils import wipe_config
@@ -52,13 +55,16 @@ def test_save_load():
assert test_agent_state is not None, "Run create_agent test first"
assert test_agent_state_post_message is not None, "Run test_user_message test first"
config = MemGPTConfig.load()
user_id = uuid.UUID(config.anon_clientid)
# Create a new client (not thread safe), and load the same agent
# The agent state inside should correspond to the initial state pre-message
if os.getenv("OPENAI_API_KEY"):
client2 = MemGPT(quickstart="openai")
else:
client2 = MemGPT(quickstart="memgpt_hosted")
client2_agent_obj = client2.server._get_or_load_agent(user_id="", agent_id=test_agent_state.id)
client2_agent_obj = client2.server._get_or_load_agent(user_id=user_id, agent_id=test_agent_state.id)
client2_agent_state = client2_agent_obj.to_agent_state()
# assert test_agent_state == client2_agent_state, f"{vars(test_agent_state)}\n{vars(client2_agent_state)}"
@@ -81,7 +87,7 @@ def test_save_load():
client3 = MemGPT(quickstart="openai")
else:
client3 = MemGPT(quickstart="memgpt_hosted")
client3_agent_obj = client3.server._get_or_load_agent(user_id="", agent_id=test_agent_state.id)
client3_agent_obj = client3.server._get_or_load_agent(user_id=user_id, agent_id=test_agent_state.id)
client3_agent_state = client3_agent_obj.to_agent_state()
check_state_equivalence(vars(test_agent_state_post_message), vars(client3_agent_state))

View File

@@ -1,6 +1,9 @@
import uuid
import memgpt.utils as utils
utils.DEBUG = True
from memgpt.config import MemGPTConfig
from memgpt.server.server import SyncServer
from .utils import wipe_config, wipe_memgpt_home
@@ -8,12 +11,13 @@ from .utils import wipe_config, wipe_memgpt_home
def test_server():
wipe_memgpt_home()
user_id = "NULL"
config = MemGPTConfig.load()
user_id = uuid.UUID(config.anon_clientid)
server = SyncServer()
try:
server.user_message(user_id=user_id, agent_id="agent no exist", message="Hello?")
fake_agent_id = uuid.uuid4()
server.user_message(user_id=user_id, agent_id=fake_agent_id, message="Hello?")
raise Exception("user_message call should have failed")
except (KeyError, ValueError) as e:
# Error is expected