diff --git a/memgpt/client/client.py b/memgpt/client/client.py index 3c2217ad..2abde109 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -1,4 +1,5 @@ import os +import uuid from typing import Dict, List, Union from memgpt.data_types import AgentState @@ -57,7 +58,16 @@ class Client(object): if config is not None: set_config_with_dict(config) - self.user_id = MemGPTConfig.load().anon_clientid if user_id is None else user_id + if user_id is None: + # the default user_id + config = MemGPTConfig.load() + self.user_id = uuid.UUID(config.anon_clientid) + elif isinstance(user_id, str): + self.user_id = uuid.UUID(user_id) + elif isinstance(user_id, uuid.UUID): + self.user_id = user_id + else: + raise TypeError(user_id) self.interface = QueuingInterface(debug=debug) self.server = SyncServer(default_interface=self.interface) diff --git a/memgpt/server/server.py b/memgpt/server/server.py index c33e8447..99e49c31 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -33,39 +33,39 @@ class Server(object): """Abstract server class that supports multi-agent multi-user""" @abstractmethod - def list_agents(self, user_id: str) -> dict: + def list_agents(self, user_id: uuid.UUID) -> dict: """List all available agents to a user""" raise NotImplementedError @abstractmethod - def get_agent_messages(self, user_id: str, agent_id: str, start: int, count: int) -> list: + def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list: """Paginated query of in-context messages in agent message queue""" raise NotImplementedError @abstractmethod - def get_agent_memory(self, user_id: str, agent_id: str) -> dict: + def get_agent_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict: """Return the memory of an agent (core memory + non-core statistics)""" raise NotImplementedError @abstractmethod - def get_agent_config(self, user_id: str, agent_id: str) -> dict: + def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict: """Return the config of an agent""" raise NotImplementedError @abstractmethod - def get_server_config(self, user_id: str) -> dict: + def get_server_config(self, user_id: uuid.UUID) -> dict: """Return the base config""" raise NotImplementedError @abstractmethod - def update_agent_core_memory(self, user_id: str, agent_id: str, new_memory_contents: dict) -> dict: + def update_agent_core_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_memory_contents: dict) -> dict: """Update the agents core memory block, return the new state""" raise NotImplementedError @abstractmethod def create_agent( self, - user_id: str, + user_id: uuid.UUID, agent_config: Union[dict, AgentState], interface: Union[AgentInterface, None], # persistence_manager: Union[PersistenceManager, None], @@ -74,17 +74,17 @@ class Server(object): raise NotImplementedError @abstractmethod - def user_message(self, user_id: str, agent_id: str, message: str) -> None: + def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None: """Process a message from the user, internally calls step""" raise NotImplementedError @abstractmethod - def system_message(self, user_id: str, agent_id: str, message: str) -> None: + def system_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None: """Process a message from the system, internally calls step""" raise NotImplementedError @abstractmethod - def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]: + def run_command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]: """Run a command on the agent, e.g. /memory May return a string with a message generated by the command @@ -101,7 +101,7 @@ class LockingServer(Server): @staticmethod def agent_lock_decorator(func: Callable) -> Callable: @wraps(func) - def wrapper(self, user_id: str, agent_id: str, *args, **kwargs): + def wrapper(self, user_id: uuid.UUID, agent_id: uuid.UUID, *args, **kwargs): # logger.info("Locking check") # Initialize the lock for the agent_id if it doesn't exist @@ -126,11 +126,11 @@ class LockingServer(Server): return wrapper @agent_lock_decorator - def user_message(self, user_id: str, agent_id: str, message: str) -> None: + def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None: raise NotImplementedError @agent_lock_decorator - def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]: + def run_command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]: raise NotImplementedError @@ -184,14 +184,14 @@ class SyncServer(LockingServer): except Exception as e: logger.exception(f"Error occurred while trying to save agent {agent_d['agent_id']}:\n{e}") - def _get_agent(self, user_id: str, agent_id: str) -> Union[Agent, None]: + def _get_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> Union[Agent, None]: """Get the agent object from the in-memory object store""" for d in self.active_agents: - if d["user_id"] == user_id and d["agent_id"] == agent_id: + if d["user_id"] == str(user_id) and d["agent_id"] == str(agent_id): return d["agent"] return None - def _add_agent(self, user_id: str, agent_id: str, agent_obj: Agent) -> None: + def _add_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, agent_obj: Agent) -> None: """Put an agent object inside the in-memory object store""" # Make sure the agent doesn't already exist if self._get_agent(user_id=user_id, agent_id=agent_id) is not None: @@ -199,42 +199,49 @@ class SyncServer(LockingServer): # Add Agent instance to the in-memory list self.active_agents.append( { - "user_id": user_id, - "agent_id": agent_id, + "user_id": str(user_id), + "agent_id": str(agent_id), "agent": agent_obj, } ) - def _load_agent(self, user_id: str, agent_id: str, interface: Union[AgentInterface, None] = None) -> Agent: + def _load_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, interface: Union[AgentInterface, None] = None) -> Agent: """Loads a saved agent into memory (if it doesn't exist, throw an error)""" + assert isinstance(user_id, uuid.UUID), user_id + assert isinstance(agent_id, uuid.UUID), agent_id # If an interface isn't specified, use the default if interface is None: interface = self.default_interface try: + logger.info(f"Grabbing agent user_id={user_id} agent_id={agent_id} from database") agent_state = self.ms.get_agent(agent_id=agent_id, user_id=user_id) if not agent_state: + logger.exception(f"agent_id {agent_id} does not exist") raise ValueError(f"agent_id {agent_id} does not exist") # Instantiate an agent object using the state retrieved + logger.info(f"Creating an agent object") memgpt_agent = Agent(agent_state=agent_state, interface=interface) # Add the agent to the in-memory store and return its reference self._add_agent(user_id=user_id, agent_id=agent_id, agent_obj=memgpt_agent) + logger.info(f"Creating an agent object") return memgpt_agent except Exception as e: logger.exception(f"Error occurred while trying to get agent {agent_id}:\n{e}") - def _get_or_load_agent(self, user_id: str, agent_id: str) -> Agent: + def _get_or_load_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> Agent: """Check if the agent is in-memory, then load""" memgpt_agent = self._get_agent(user_id=user_id, agent_id=agent_id) if not memgpt_agent: + logger.info(f"Loading agent user_id={user_id} agent_id={agent_id}") memgpt_agent = self._load_agent(user_id=user_id, agent_id=agent_id) return memgpt_agent - def _step(self, user_id: str, agent_id: str, input_message: str) -> None: + def _step(self, user_id: uuid.UUID, agent_id: uuid.UUID, input_message: str) -> None: """Send the input message through the agent""" logger.debug(f"Got input message: {input_message}") @@ -278,7 +285,7 @@ class SyncServer(LockingServer): memgpt_agent.interface.step_yield() logger.debug(f"Finished agent step") - def _command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]: + def _command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]: """Process a CLI command""" logger.debug(f"Got command: {command}") @@ -407,7 +414,7 @@ class SyncServer(LockingServer): self._step(user_id=user_id, agent_id=agent_id, input_message=input_message) @LockingServer.agent_lock_decorator - def user_message(self, user_id: str, agent_id: str, message: str) -> None: + def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None: """Process an incoming user message and feed it through the MemGPT agent""" # Basic input sanitization @@ -426,7 +433,7 @@ class SyncServer(LockingServer): self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message) @LockingServer.agent_lock_decorator - def system_message(self, user_id: str, agent_id: str, message: str) -> None: + def system_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None: """Process an incoming system message and feed it through the MemGPT agent""" from memgpt.utils import printd @@ -446,7 +453,7 @@ class SyncServer(LockingServer): self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_system_message) @LockingServer.agent_lock_decorator - def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]: + def run_command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]: """Run a command on the agent""" # If the input begins with a command prefix, attempt to process it as a command if command.startswith("/"): @@ -456,7 +463,7 @@ class SyncServer(LockingServer): def create_agent( self, - user_id: str, + user_id: uuid.UUID, agent_config: Union[dict, AgentState], interface: Union[AgentInterface, None] = None, # persistence_manager: Union[PersistenceManager, None] = None, @@ -505,8 +512,8 @@ class SyncServer(LockingServer): def delete_agent( self, - user_id: str, - agent_id: str, + user_id: uuid.UUID, + agent_id: uuid.UUID, ): # Make sure the user owns the agent # TODO use real user_id @@ -515,7 +522,7 @@ class SyncServer(LockingServer): if agent is not None: self.ms.delete_agent(agent_id=agent_id) - def list_agents(self, user_id: str) -> dict: + def list_agents(self, user_id: uuid.UUID) -> dict: """List all available agents to a user""" # TODO actually use the user_id that was passed into the server user_id = uuid.UUID(self.config.anon_clientid) @@ -535,7 +542,7 @@ class SyncServer(LockingServer): ], } - def get_agent_memory(self, user_id: str, agent_id: str) -> dict: + def get_agent_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict: """Return the memory of an agent (core memory + non-core statistics)""" # Get the agent object (loaded in memory) # TODO: use real user_id @@ -556,7 +563,7 @@ class SyncServer(LockingServer): return memory_obj - def get_agent_messages(self, user_id: str, agent_id: str, start: int, count: int) -> list: + def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list: """Paginated query of in-context messages in agent message queue""" # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) @@ -579,7 +586,7 @@ class SyncServer(LockingServer): return paginated_messages - def get_agent_config(self, user_id: str, agent_id: str) -> dict: + def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict: """Return the config of an agent""" # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) @@ -602,7 +609,7 @@ class SyncServer(LockingServer): clean_base_config = clean_keys(base_config) return clean_base_config - def update_agent_core_memory(self, user_id: str, agent_id: str, new_memory_contents: dict) -> dict: + def update_agent_core_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_memory_contents: dict) -> dict: """Update the agents core memory block, return the new state""" # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) diff --git a/tests/test_client.py b/tests/test_client.py index 0422bfba..9109338b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,7 +1,10 @@ +import uuid +import os + from memgpt import MemGPT +from memgpt.config import MemGPTConfig from memgpt import constants from memgpt.data_types import LLMConfig, EmbeddingConfig -import os from .utils import wipe_config @@ -52,13 +55,16 @@ def test_save_load(): assert test_agent_state is not None, "Run create_agent test first" assert test_agent_state_post_message is not None, "Run test_user_message test first" + config = MemGPTConfig.load() + user_id = uuid.UUID(config.anon_clientid) + # Create a new client (not thread safe), and load the same agent # The agent state inside should correspond to the initial state pre-message if os.getenv("OPENAI_API_KEY"): client2 = MemGPT(quickstart="openai") else: client2 = MemGPT(quickstart="memgpt_hosted") - client2_agent_obj = client2.server._get_or_load_agent(user_id="", agent_id=test_agent_state.id) + client2_agent_obj = client2.server._get_or_load_agent(user_id=user_id, agent_id=test_agent_state.id) client2_agent_state = client2_agent_obj.to_agent_state() # assert test_agent_state == client2_agent_state, f"{vars(test_agent_state)}\n{vars(client2_agent_state)}" @@ -81,7 +87,7 @@ def test_save_load(): client3 = MemGPT(quickstart="openai") else: client3 = MemGPT(quickstart="memgpt_hosted") - client3_agent_obj = client3.server._get_or_load_agent(user_id="", agent_id=test_agent_state.id) + client3_agent_obj = client3.server._get_or_load_agent(user_id=user_id, agent_id=test_agent_state.id) client3_agent_state = client3_agent_obj.to_agent_state() check_state_equivalence(vars(test_agent_state_post_message), vars(client3_agent_state)) diff --git a/tests/test_server.py b/tests/test_server.py index e97f21b3..c7fcefd4 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,6 +1,9 @@ +import uuid + import memgpt.utils as utils utils.DEBUG = True +from memgpt.config import MemGPTConfig from memgpt.server.server import SyncServer from .utils import wipe_config, wipe_memgpt_home @@ -8,12 +11,13 @@ from .utils import wipe_config, wipe_memgpt_home def test_server(): wipe_memgpt_home() - user_id = "NULL" - + config = MemGPTConfig.load() + user_id = uuid.UUID(config.anon_clientid) server = SyncServer() try: - server.user_message(user_id=user_id, agent_id="agent no exist", message="Hello?") + fake_agent_id = uuid.uuid4() + server.user_message(user_id=user_id, agent_id=fake_agent_id, message="Hello?") raise Exception("user_message call should have failed") except (KeyError, ValueError) as e: # Error is expected