import json import uuid import warnings from abc import abstractmethod from datetime import datetime from functools import wraps from threading import Lock from typing import Callable, List, Optional, Tuple, Union from fastapi import HTTPException import memgpt.constants as constants import memgpt.presets.presets as presets import memgpt.server.utils as server_utils import memgpt.system as system from memgpt.agent import Agent, save_agent from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.cli.cli_config import get_model_options from memgpt.config import MemGPTConfig from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT from memgpt.credentials import MemGPTCredentials from memgpt.data_sources.connectors import DataConnector, load_data from memgpt.data_types import ( AgentState, EmbeddingConfig, LLMConfig, Message, Preset, Source, Token, User, ) from memgpt.functions.functions import parse_source_code from memgpt.functions.schema_generator import generate_schema # TODO use custom interface from memgpt.interface import AgentInterface # abstract from memgpt.interface import CLIInterface # for printing to terminal from memgpt.log import get_logger from memgpt.memory import BaseMemory, get_memory_functions from memgpt.metadata import MetadataStore from memgpt.models.chat_completion_response import UsageStatistics from memgpt.models.pydantic_models import ( DocumentModel, HumanModel, MemGPTUsageStatistics, PassageModel, PersonaModel, PresetModel, SourceModel, ToolModel, ) # from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin from memgpt.prompts import gpt_system from memgpt.utils import create_random_username logger = get_logger(__name__) class Server(object): """Abstract server class that supports multi-agent multi-user""" @abstractmethod def list_agents(self, user_id: uuid.UUID) -> dict: """List all available agents to a user""" raise NotImplementedError @abstractmethod def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list: """Paginated query of in-context messages in agent message queue""" raise NotImplementedError @abstractmethod def get_agent_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict: """Return the memory of an agent (core memory + non-core statistics)""" raise NotImplementedError @abstractmethod def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict: """Return the config of an agent""" raise NotImplementedError @abstractmethod def get_server_config(self, user_id: uuid.UUID) -> dict: """Return the base config""" raise NotImplementedError @abstractmethod def update_agent_core_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_memory_contents: dict) -> dict: """Update the agents core memory block, return the new state""" raise NotImplementedError @abstractmethod def create_agent( self, user_id: uuid.UUID, agent_config: Union[dict, AgentState], interface: Union[AgentInterface, None], # persistence_manager: Union[PersistenceManager, None], ) -> str: """Create a new agent using a config""" raise NotImplementedError @abstractmethod def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None: """Process a message from the user, internally calls step""" raise NotImplementedError @abstractmethod def system_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None: """Process a message from the system, internally calls step""" raise NotImplementedError @abstractmethod def run_command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]: """Run a command on the agent, e.g. /memory May return a string with a message generated by the command """ raise NotImplementedError class LockingServer(Server): """Basic support for concurrency protections (all requests that modify an agent lock the agent until the operation is complete)""" # Locks for each agent _agent_locks = {} @staticmethod def agent_lock_decorator(func: Callable) -> Callable: @wraps(func) def wrapper(self, user_id: uuid.UUID, agent_id: uuid.UUID, *args, **kwargs): # logger.info("Locking check") # Initialize the lock for the agent_id if it doesn't exist if agent_id not in self._agent_locks: # logger.info(f"Creating lock for agent_id = {agent_id}") self._agent_locks[agent_id] = Lock() # Check if the agent is currently locked if not self._agent_locks[agent_id].acquire(blocking=False): # logger.info(f"agent_id = {agent_id} is busy") raise HTTPException(status_code=423, detail=f"Agent '{agent_id}' is currently busy.") try: # Execute the function # logger.info(f"running function on agent_id = {agent_id}") return func(self, user_id, agent_id, *args, **kwargs) finally: # Release the lock # logger.info(f"releasing lock on agent_id = {agent_id}") self._agent_locks[agent_id].release() return wrapper # @agent_lock_decorator def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None: raise NotImplementedError # @agent_lock_decorator def run_command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]: raise NotImplementedError class SyncServer(LockingServer): """Simple single-threaded / blocking server process""" def __init__( self, chaining: bool = True, max_chaining_steps: bool = None, default_interface_factory: Callable[[], AgentInterface] = lambda: CLIInterface(), # default_interface: AgentInterface = CLIInterface(), # default_persistence_manager_cls: PersistenceManager = LocalStateManager, # auth_mode: str = "none", # "none, "jwt", "external" ): """Server process holds in-memory agents that are being run""" # Server supports several auth modes: # "none": # no authentication, trust the incoming requests to have access to the user_id being modified # "jwt_local": # clients send bearer JWT tokens, which decode to user_ids # JWT tokens are generated by the server process (using pyJWT) and stored in a database table # "jwt_external": # clients still send bearer JWT tokens, but token generation and validation is handled by an external service # ie the server process will call 'external.decode(token)' to get the user_id # if auth_mode == "none": # self.auth_mode = auth_mode # raise NotImplementedError # TODO # elif auth_mode == "jwt_local": # self.auth_mode = auth_mode # elif auth_mode == "jwt_external": # self.auth_mode = auth_mode # raise NotImplementedError # TODO # else: # raise ValueError(auth_mode) # List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts self.active_agents = [] # chaining = whether or not to run again if request_heartbeat=true self.chaining = chaining # if chaining == true, what's the max number of times we'll chain before yielding? # none = no limit, can go on forever self.max_chaining_steps = max_chaining_steps # The default interface that will get assigned to agents ON LOAD self.default_interface_factory = default_interface_factory # self.default_interface = default_interface # self.default_interface = default_interface_cls() # The default persistence manager that will get assigned to agents ON CREATION # self.default_persistence_manager_cls = default_persistence_manager_cls # Initialize the connection to the DB self.config = MemGPTConfig.load() logger.debug(f"loading configuration from '{self.config.config_path}'") assert self.config.persona is not None, "Persona must be set in the config" assert self.config.human is not None, "Human must be set in the config" # Update storage URI to match passed in settings # (NOTE: no longer needed since envs being used, I think) # for memory_type in ("archival", "recall", "metadata"): # if settings.memgpt_pg_uri: # # override with env # setattr(self.config, f"{memory_type}_storage_uri", settings.memgpt_pg_uri) # self.config.save() # TODO figure out how to handle credentials for the server self.credentials = MemGPTCredentials.load() # Ensure valid database configuration # TODO: add back once tests are matched # assert ( # self.config.metadata_storage_type == "postgres" # ), f"Invalid metadata_storage_type for server: {self.config.metadata_storage_type}" # assert ( # self.config.archival_storage_type == "postgres" # ), f"Invalid archival_storage_type for server: {self.config.archival_storage_type}" # assert self.config.recall_storage_type == "postgres", f"Invalid recall_storage_type for server: {self.config.recall_storage_type}" # Generate default LLM/Embedding configs for the server # TODO: we may also want to do the same thing with default persona/human/etc. self.server_llm_config = LLMConfig( model=self.config.default_llm_config.model, model_endpoint_type=self.config.default_llm_config.model_endpoint_type, model_endpoint=self.config.default_llm_config.model_endpoint, model_wrapper=self.config.default_llm_config.model_wrapper, context_window=self.config.default_llm_config.context_window, # openai_key=self.credentials.openai_key, # azure_key=self.credentials.azure_key, # azure_endpoint=self.credentials.azure_endpoint, # azure_version=self.credentials.azure_version, # azure_deployment=self.credentials.azure_deployment, ) self.server_embedding_config = EmbeddingConfig( embedding_endpoint_type=self.config.default_embedding_config.embedding_endpoint_type, embedding_endpoint=self.config.default_embedding_config.embedding_endpoint, embedding_dim=self.config.default_embedding_config.embedding_dim, embedding_model=self.config.default_embedding_config.embedding_model, embedding_chunk_size=self.config.default_embedding_config.embedding_chunk_size, ) assert self.server_embedding_config.embedding_model is not None, vars(self.server_embedding_config) # Initialize the metadata store self.ms = MetadataStore(self.config) # pre-fill database (users, presets, humans, personas) # TODO: figure out how to handle default users (server is technically multi-user) user_id = uuid.UUID(self.config.anon_clientid) user = User( id=uuid.UUID(self.config.anon_clientid), ) if self.ms.get_user(user_id): # update user self.ms.update_user(user) else: self.ms.create_user(user) # add global default tools (for admin) presets.add_default_tools(user_id, self.ms) presets.add_default_humans_and_personas(user_id, self.ms) def save_agents(self): """Saves all the agents that are in the in-memory object store""" for agent_d in self.active_agents: try: # agent_d["agent"].save() save_agent(agent_d["agent"], self.ms) logger.info(f"Saved agent {agent_d['agent_id']}") except Exception as e: logger.exception(f"Error occurred while trying to save agent {agent_d['agent_id']}:\n{e}") def _get_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> Union[Agent, None]: """Get the agent object from the in-memory object store""" for d in self.active_agents: if d["user_id"] == str(user_id) and d["agent_id"] == str(agent_id): return d["agent"] return None def _add_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, agent_obj: Agent) -> None: """Put an agent object inside the in-memory object store""" # Make sure the agent doesn't already exist if self._get_agent(user_id=user_id, agent_id=agent_id) is not None: # Can be triggered on concucrent request, so don't throw a full error # raise KeyError(f"Agent (user={user_id}, agent={agent_id}) is already loaded") logger.exception(f"Agent (user={user_id}, agent={agent_id}) is already loaded") return # Add Agent instance to the in-memory list self.active_agents.append( { "user_id": str(user_id), "agent_id": str(agent_id), "agent": agent_obj, } ) def _load_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, interface: Union[AgentInterface, None] = None) -> Agent: """Loads a saved agent into memory (if it doesn't exist, throw an error)""" assert isinstance(user_id, uuid.UUID), user_id assert isinstance(agent_id, uuid.UUID), agent_id # If an interface isn't specified, use the default if interface is None: interface = self.default_interface_factory() try: logger.info(f"Grabbing agent user_id={user_id} agent_id={agent_id} from database") agent_state = self.ms.get_agent(agent_id=agent_id, user_id=user_id) if not agent_state: logger.exception(f"agent_id {agent_id} does not exist") raise ValueError(f"agent_id {agent_id} does not exist") # print(f"server._load_agent :: load got agent state {agent_id}, messages = {agent_state.state['messages']}") # Instantiate an agent object using the state retrieved logger.info(f"Creating an agent object") tool_objs = [] for name in agent_state.tools: tool_obj = self.ms.get_tool(name, user_id) if not tool_obj: logger.exception(f"Tool {name} does not exist for user {user_id}") raise ValueError(f"Tool {name} does not exist for user {user_id}") tool_objs.append(tool_obj) memgpt_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs) # Add the agent to the in-memory store and return its reference logger.info(f"Adding agent to the agent cache: user_id={user_id}, agent_id={agent_id}") self._add_agent(user_id=user_id, agent_id=agent_id, agent_obj=memgpt_agent) return memgpt_agent except Exception as e: logger.exception(f"Error occurred while trying to get agent {agent_id}:\n{e}") raise def _get_or_load_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> Agent: """Check if the agent is in-memory, then load""" logger.debug(f"Checking for agent user_id={user_id} agent_id={agent_id}") # TODO: consider disabling loading cached agents due to potential concurrency issues memgpt_agent = self._get_agent(user_id=user_id, agent_id=agent_id) if not memgpt_agent: logger.debug(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}") memgpt_agent = self._load_agent(user_id=user_id, agent_id=agent_id) return memgpt_agent def _step( self, user_id: uuid.UUID, agent_id: uuid.UUID, input_message: Union[str, Message], timestamp: Optional[datetime] ) -> MemGPTUsageStatistics: """Send the input message through the agent""" logger.debug(f"Got input message: {input_message}") # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) if memgpt_agent is None: raise KeyError(f"Agent (user={user_id}, agent={agent_id}) is not loaded") # Determine whether or not to token stream based on the capability of the interface token_streaming = memgpt_agent.interface.streaming_mode if hasattr(memgpt_agent.interface, "streaming_mode") else False logger.debug(f"Starting agent step") no_verify = True next_input_message = input_message counter = 0 total_usage = UsageStatistics() step_count = 0 while True: new_messages, heartbeat_request, function_failed, token_warning, usage = memgpt_agent.step( next_input_message, first_message=False, skip_verify=no_verify, return_dicts=False, stream=token_streaming, timestamp=timestamp, ) step_count += 1 total_usage += usage counter += 1 memgpt_agent.interface.step_complete() # Chain stops if not self.chaining: logger.debug("No chaining, stopping after one step") break elif self.max_chaining_steps is not None and counter > self.max_chaining_steps: logger.debug(f"Hit max chaining steps, stopping after {counter} steps") break # Chain handlers elif token_warning: next_input_message = system.get_token_limit_warning() continue # always chain elif function_failed: next_input_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE) continue # always chain elif heartbeat_request: next_input_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE) continue # always chain # MemGPT no-op / yield else: break memgpt_agent.interface.step_yield() logger.debug(f"Finished agent step") # save updated state save_agent(memgpt_agent, self.ms) return MemGPTUsageStatistics(**total_usage.dict(), step_count=step_count) def _command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]: """Process a CLI command""" logger.debug(f"Got command: {command}") # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) if command.lower() == "exit": # exit not supported on server.py raise ValueError(command) elif command.lower() == "save" or command.lower() == "savechat": save_agent(memgpt_agent, self.ms) elif command.lower() == "attach": # Different from CLI, we extract the data source name from the command command = command.strip().split() try: data_source = int(command[1]) except: raise ValueError(command) # attach data to agent from source source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) memgpt_agent.attach_source(data_source, source_connector, self.ms) elif command.lower() == "dump" or command.lower().startswith("dump "): # Check if there's an additional argument that's an integer command = command.strip().split() amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0 if amount == 0: memgpt_agent.interface.print_messages(memgpt_agent.messages, dump=True) else: memgpt_agent.interface.print_messages(memgpt_agent.messages[-min(amount, len(memgpt_agent.messages)) :], dump=True) elif command.lower() == "dumpraw": memgpt_agent.interface.print_messages_raw(memgpt_agent.messages) elif command.lower() == "memory": ret_str = ( f"\nDumping memory contents:\n" + f"\n{str(memgpt_agent.memory)}" + f"\n{str(memgpt_agent.persistence_manager.archival_memory)}" + f"\n{str(memgpt_agent.persistence_manager.recall_memory)}" ) return ret_str elif command.lower() == "pop" or command.lower().startswith("pop "): # Check if there's an additional argument that's an integer command = command.strip().split() pop_amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 3 n_messages = len(memgpt_agent.messages) MIN_MESSAGES = 2 if n_messages <= MIN_MESSAGES: logger.info(f"Agent only has {n_messages} messages in stack, none left to pop") elif n_messages - pop_amount < MIN_MESSAGES: logger.info(f"Agent only has {n_messages} messages in stack, cannot pop more than {n_messages - MIN_MESSAGES}") else: logger.info(f"Popping last {pop_amount} messages from stack") for _ in range(min(pop_amount, len(memgpt_agent.messages))): memgpt_agent.messages.pop() elif command.lower() == "retry": # TODO this needs to also modify the persistence manager logger.info(f"Retrying for another answer") while len(memgpt_agent.messages) > 0: if memgpt_agent.messages[-1].get("role") == "user": # we want to pop up to the last user message and send it again memgpt_agent.messages[-1].get("content") memgpt_agent.messages.pop() break memgpt_agent.messages.pop() elif command.lower() == "rethink" or command.lower().startswith("rethink "): # TODO this needs to also modify the persistence manager if len(command) < len("rethink "): logger.warning("Missing text after the command") else: for x in range(len(memgpt_agent.messages) - 1, 0, -1): if memgpt_agent.messages[x].get("role") == "assistant": text = command[len("rethink ") :].strip() memgpt_agent.messages[x].update({"content": text}) break elif command.lower() == "rewrite" or command.lower().startswith("rewrite "): # TODO this needs to also modify the persistence manager if len(command) < len("rewrite "): logger.warning("Missing text after the command") else: for x in range(len(memgpt_agent.messages) - 1, 0, -1): if memgpt_agent.messages[x].get("role") == "assistant": text = command[len("rewrite ") :].strip() args = json.loads(memgpt_agent.messages[x].get("function_call").get("arguments"), strict=JSON_LOADS_STRICT) args["message"] = text memgpt_agent.messages[x].get("function_call").update( {"arguments": json.dumps(args, ensure_ascii=JSON_ENSURE_ASCII)} ) break # No skip options elif command.lower() == "wipe": # exit not supported on server.py raise ValueError(command) elif command.lower() == "heartbeat": input_message = system.get_heartbeat() self._step(user_id=user_id, agent_id=agent_id, input_message=input_message) elif command.lower() == "memorywarning": input_message = system.get_token_limit_warning() self._step(user_id=user_id, agent_id=agent_id, input_message=input_message) # @LockingServer.agent_lock_decorator def user_message( self, user_id: uuid.UUID, agent_id: uuid.UUID, message: Union[str, Message], timestamp: Optional[datetime] = None, ) -> MemGPTUsageStatistics: """Process an incoming user message and feed it through the MemGPT agent""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") # Basic input sanitization if isinstance(message, str): if len(message) == 0: raise ValueError(f"Invalid input: '{message}'") # If the input begins with a command prefix, reject elif message.startswith("/"): raise ValueError(f"Invalid input: '{message}'") packaged_user_message = system.package_user_message( user_message=message, time=timestamp.isoformat() if timestamp else None, ) # NOTE: eventually deprecate and only allow passing Message types # Convert to a Message object message = Message( user_id=user_id, agent_id=agent_id, role="user", text=packaged_user_message, created_at=timestamp, # name=None, # TODO handle name via API ) # TODO: I don't think this does anything because all we care about is packaged_user_message which only exists if message is str if isinstance(message, Message): # Can't have a null text field if len(message.text) == 0 or message.text is None: raise ValueError(f"Invalid input: '{message.text}'") # If the input begins with a command prefix, reject elif message.text.startswith("/"): raise ValueError(f"Invalid input: '{message.text}'") if timestamp: # Override the timestamp with what the caller provided message.created_at = timestamp else: raise TypeError(f"Invalid input: '{message}' - type {type(message)}") # Run the agent state forward usage = self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message, timestamp=timestamp) return usage # @LockingServer.agent_lock_decorator def system_message( self, user_id: uuid.UUID, agent_id: uuid.UUID, message: Union[str, Message], timestamp: Optional[datetime] = None, ) -> MemGPTUsageStatistics: """Process an incoming system message and feed it through the MemGPT agent""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") # Basic input sanitization if isinstance(message, str): if len(message) == 0: raise ValueError(f"Invalid input: '{message}'") # If the input begins with a command prefix, reject elif message.startswith("/"): raise ValueError(f"Invalid input: '{message}'") packaged_system_message = system.package_system_message(system_message=message) # NOTE: eventually deprecate and only allow passing Message types # Convert to a Message object message = Message( user_id=user_id, agent_id=agent_id, role="user", text=packaged_system_message, # name=None, # TODO handle name via API ) if isinstance(message, Message): # Can't have a null text field if len(message.text) == 0 or message.text is None: raise ValueError(f"Invalid input: '{message.text}'") # If the input begins with a command prefix, reject elif message.text.startswith("/"): raise ValueError(f"Invalid input: '{message.text}'") else: raise TypeError(f"Invalid input: '{message}' - type {type(message)}") if timestamp: # Override the timestamp with what the caller provided message.created_at = timestamp # Run the agent state forward return self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_system_message, timestamp=None) # @LockingServer.agent_lock_decorator def run_command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[MemGPTUsageStatistics, None]: """Run a command on the agent""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") # If the input begins with a command prefix, attempt to process it as a command if command.startswith("/"): if len(command) > 1: command = command[1:] # strip the prefix return self._command(user_id=user_id, agent_id=agent_id, command=command) def create_user( self, user_config: Optional[Union[dict, User]] = {}, exists_ok: bool = False, ): """Create a new user using a config""" if not isinstance(user_config, dict): raise ValueError(f"user_config must be provided as a dictionary") if "id" in user_config: existing_user = self.ms.get_user(user_id=user_config["id"]) if existing_user: if exists_ok: presets.add_default_humans_and_personas(existing_user.id, self.ms) return existing_user else: raise ValueError(f"User with ID {existing_user.id} already exists") user = User( id=user_config["id"] if "id" in user_config else None, ) self.ms.create_user(user) logger.info(f"Created new user from config: {user}") # add default for the user presets.add_default_humans_and_personas(user.id, self.ms) presets.add_default_tools(None, self.ms) return user def create_agent( self, user_id: uuid.UUID, tools: List[str], # list of tool names (handles) to include memory: BaseMemory, system: Optional[str] = None, metadata: Optional[dict] = {}, # includes human/persona names name: Optional[str] = None, # model config llm_config: Optional[LLMConfig] = None, embedding_config: Optional[EmbeddingConfig] = None, # interface interface: Union[AgentInterface, None] = None, ) -> AgentState: """Create a new agent using a config""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if interface is None: interface = self.default_interface_factory() # system prompt (get default if None) if system is None: system = gpt_system.get_system_text(self.config.preset) # create agent name if name is None: name = create_random_username() logger.debug(f"Attempting to find user: {user_id}") user = self.ms.get_user(user_id=user_id) if not user: raise ValueError(f"cannot find user with associated client id: {user_id}") try: # model configuration llm_config = llm_config if llm_config else self.server_llm_config embedding_config = embedding_config if embedding_config else self.server_embedding_config # get tools + make sure they exist tool_objs = [] for tool_name in tools: tool_obj = self.ms.get_tool(tool_name, user_id=user_id) assert tool_obj, f"Tool {tool_name} does not exist" tool_objs.append(tool_obj) # make sure memory tools are added # TODO: remove this - eventually memory tools need to be added when the memory is created # this is duplicated with logic on the client-side memory_functions = get_memory_functions(memory) for func_name, func in memory_functions.items(): if func_name in tools: # tool already added continue source_code = parse_source_code(func) json_schema = generate_schema(func, func_name) source_type = "python" tags = ["memory", "memgpt-base"] tool = self.create_tool( user_id=user_id, json_schema=json_schema, source_code=source_code, source_type=source_type, tags=tags, exists_ok=True ) tool_objs.append(tool) tools.append(tool.name) # TODO: add metadata agent_state = AgentState( name=name, user_id=user_id, tools=tools, # name=id for tools llm_config=llm_config, embedding_config=embedding_config, system=system, state={"system": system, "messages": None, "memory": memory.to_dict()}, _metadata=metadata, ) agent = Agent( interface=interface, agent_state=agent_state, tools=tool_objs, # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False, ) # FIXME: this is a hacky way to get the system prompts injected into agent into the DB # self.ms.update_agent(agent.agent_state) except Exception as e: logger.exception(e) try: self.ms.delete_agent(agent_id=agent.agent_state.id) except Exception as delete_e: logger.exception(f"Failed to delete_agent:\n{delete_e}") raise e # save agent save_agent(agent, self.ms) logger.info(f"Created new agent from config: {agent}") # return AgentState return agent.agent_state def delete_agent( self, user_id: uuid.UUID, agent_id: uuid.UUID, ): # TODO: delete agent data if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") # TODO: Make sure the user owns the agent agent = self.ms.get_agent(agent_id=agent_id, user_id=user_id) if agent is not None: self.ms.delete_agent(agent_id=agent_id) def delete_preset(self, user_id: uuid.UUID, preset_id: uuid.UUID) -> Preset: if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") # first get the preset by name preset = self.get_preset(preset_id=preset_id, user_id=user_id) if preset is None: raise ValueError(f"Could not find preset_id {preset_id}") # then delete via name # TODO allow delete-by-id, eg via server.delete_preset function self.ms.delete_preset(name=preset.name, user_id=user_id) return preset def initialize_default_presets(self, user_id: uuid.UUID): """Add default preset options into the metadata store""" presets.add_default_presets(user_id, self.ms) def create_preset(self, preset: Preset): """Create a new preset using a config""" if preset.user_id is not None and self.ms.get_user(user_id=preset.user_id) is None: raise ValueError(f"User user_id={preset.user_id} does not exist") self.ms.create_preset(preset) return preset def get_preset( self, preset_id: Optional[uuid.UUID] = None, preset_name: Optional[uuid.UUID] = None, user_id: Optional[uuid.UUID] = None ) -> Preset: """Get the preset""" return self.ms.get_preset(preset_id=preset_id, name=preset_name, user_id=user_id) def list_presets(self, user_id: uuid.UUID) -> List[PresetModel]: # TODO update once we strip Preset in favor of PresetModel presets = self.ms.list_presets(user_id=user_id) presets = [PresetModel(**vars(p)) for p in presets] return presets def _agent_state_to_config(self, agent_state: AgentState) -> dict: """Convert AgentState to a dict for a JSON response""" assert agent_state is not None agent_config = { "id": agent_state.id, "name": agent_state.name, "human": agent_state._metadata.get("human", None), "persona": agent_state._metadata.get("persona", None), "created_at": agent_state.created_at.isoformat(), } return agent_config def list_agents( self, user_id: uuid.UUID, ) -> List[AgentState]: """List all available agents to a user""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") agents_states = self.ms.list_agents(user_id=user_id) return agents_states # TODO make return type pydantic def list_agents_legacy( self, user_id: uuid.UUID, ) -> dict: """List all available agents to a user""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") agents_states = self.ms.list_agents(user_id=user_id) agents_states_dicts = [self._agent_state_to_config(state) for state in agents_states] # TODO add a get_message_obj_from_message_id(...) function # this would allow grabbing Message.created_by without having to load the agent object # all_available_tools = self.ms.list_tools(user_id=user_id) # TODO: add back when user-specific self.ms.list_tools() for agent_state, return_dict in zip(agents_states, agents_states_dicts): # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_state.id) # TODO remove this eventually when return type get pydanticfied # this is to add persona_name and human_name so that the columns in UI can populate # TODO hack for frontend, remove # (top level .persona is persona_name, and nested memory.persona is the state) # TODO: eventually modify this to be contained in the metadata return_dict["persona"] = agent_state._metadata.get("persona", None) return_dict["human"] = agent_state._metadata.get("human", None) # Add information about tools # TODO memgpt_agent should really have a field of List[ToolModel] # then we could just pull that field and return it here # return_dict["tools"] = [tool for tool in all_available_tools if tool.json_schema in memgpt_agent.functions] # get tool info from agent state tools = [] for tool_name in agent_state.tools: tool = self.ms.get_tool(tool_name, user_id) tools.append(tool) return_dict["tools"] = tools # Add information about memory (raw core, size of recall, size of archival) core_memory = memgpt_agent.memory recall_memory = memgpt_agent.persistence_manager.recall_memory archival_memory = memgpt_agent.persistence_manager.archival_memory memory_obj = { "core_memory": {section: module.value for (section, module) in core_memory.memory.items()}, "recall_memory": len(recall_memory) if recall_memory is not None else None, "archival_memory": len(archival_memory) if archival_memory is not None else None, } return_dict["memory"] = memory_obj # Add information about last run # NOTE: 'last_run' is just the timestamp on the latest message in the buffer # Retrieve the Message object via the recall storage or by directly access _messages last_msg_obj = memgpt_agent._messages[-1] return_dict["last_run"] = last_msg_obj.created_at # Add information about attached sources sources_ids = self.ms.list_attached_sources(agent_id=agent_state.id) sources = [self.ms.get_source(source_id=s_id) for s_id in sources_ids] return_dict["sources"] = [vars(s) for s in sources] # Sort agents by "last_run" in descending order, most recent first agents_states_dicts.sort(key=lambda x: x["last_run"], reverse=True) logger.debug(f"Retrieved {len(agents_states)} agents for user {user_id}") return { "num_agents": len(agents_states), "agents": agents_states_dicts, } def list_personas(self, user_id: uuid.UUID): return self.ms.list_personas(user_id=user_id) def get_persona(self, name: str, user_id: uuid.UUID): return self.ms.get_persona(name=name, user_id=user_id) def add_persona(self, persona: PersonaModel): name = persona.name user_id = persona.user_id self.ms.add_persona(persona=persona) persona = self.ms.get_persona(name=name, user_id=user_id) return persona def update_persona(self, persona: PersonaModel): return self.ms.update_persona(persona=persona) def delete_persona(self, name: str, user_id: uuid.UUID): return self.ms.delete_persona(name=name, user_id=user_id) def list_humans(self, user_id: uuid.UUID): return self.ms.list_humans(user_id=user_id) def get_human(self, name: str, user_id: uuid.UUID): return self.ms.get_human(name=name, user_id=user_id) def add_human(self, human: HumanModel): name = human.name user_id = human.user_id self.ms.add_human(human=human) human = self.ms.get_human(name=name, user_id=user_id) return human def update_human(self, human: HumanModel): return self.ms.update_human(human=human) def delete_human(self, name: str, user_id: uuid.UUID): return self.ms.delete_human(name, user_id) def get_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID): """Get the agent state""" return self.ms.get_agent(agent_id=agent_id, user_id=user_id) def get_user(self, user_id: uuid.UUID) -> User: """Get the user""" return self.ms.get_user(user_id=user_id) def get_agent_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict: """Return the memory of an agent (core memory + non-core statistics)""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) core_memory = memgpt_agent.memory recall_memory = memgpt_agent.persistence_manager.recall_memory archival_memory = memgpt_agent.persistence_manager.archival_memory # NOTE memory_obj = { "core_memory": {key: value.value for key, value in core_memory.memory.items()}, "recall_memory": len(recall_memory) if recall_memory is not None else None, "archival_memory": len(archival_memory) if archival_memory is not None else None, } return memory_obj def get_in_context_message_ids(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> List[uuid.UUID]: """Get the message ids of the in-context messages in the agent's memory""" # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) return [m.id for m in memgpt_agent._messages] def get_agent_message(self, agent_id: uuid.UUID, message_id: uuid.UUID) -> Message: """Get message based on agent and message ID""" agent_state = self.ms.get_agent(agent_id=agent_id) if agent_state is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") user_id = agent_state.user_id # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) message = memgpt_agent.persistence_manager.recall_memory.storage.get(message_id=message_id) return message def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list: """Paginated query of all messages in agent message queue""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) if start < 0 or count < 0: raise ValueError("Start and count values should be non-negative") if start + count < len(memgpt_agent._messages): # messages can be returned from whats in memory # Reverse the list to make it in reverse chronological order reversed_messages = memgpt_agent._messages[::-1] # Check if start is within the range of the list if start >= len(reversed_messages): raise IndexError("Start index is out of range") # Calculate the end index, ensuring it does not exceed the list length end_index = min(start + count, len(reversed_messages)) # Slice the list for pagination messages = reversed_messages[start:end_index] # Convert to json # Add a tag indicating in-context or not json_messages = [{**record.to_json(), "in_context": True} for record in messages] else: # need to access persistence manager for additional messages db_iterator = memgpt_agent.persistence_manager.recall_memory.storage.get_all_paginated(page_size=count, offset=start) # get a single page of messages # TODO: handle stop iteration page = next(db_iterator, []) # return messages in reverse chronological order messages = sorted(page, key=lambda x: x.created_at, reverse=True) # Convert to json # Add a tag indicating in-context or not json_messages = [record.to_json() for record in messages] in_context_message_ids = [str(m.id) for m in memgpt_agent._messages] for d in json_messages: d["in_context"] = True if str(d["id"]) in in_context_message_ids else False return json_messages def get_agent_archival(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list: """Paginated query of all messages in agent archival memory""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) # iterate over records db_iterator = memgpt_agent.persistence_manager.archival_memory.storage.get_all_paginated(page_size=count, offset=start) # get a single page of messages page = next(db_iterator, []) json_passages = [vars(record) for record in page] return json_passages def get_agent_archival_cursor( self, user_id: uuid.UUID, agent_id: uuid.UUID, after: Optional[uuid.UUID] = None, before: Optional[uuid.UUID] = None, limit: Optional[int] = 100, order_by: Optional[str] = "created_at", reverse: Optional[bool] = False, ): if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) # iterate over recorde cursor, records = memgpt_agent.persistence_manager.archival_memory.storage.get_all_cursor( after=after, before=before, limit=limit, order_by=order_by, reverse=reverse ) json_records = [vars(record) for record in records] return cursor, json_records def get_all_archival_memories(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> list: # TODO deprecate (not safe to be returning an unbounded list) if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) # Assume passages records = memgpt_agent.persistence_manager.archival_memory.storage.get_all() return [dict(id=str(r.id), contents=r.text) for r in records] def insert_archival_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID, memory_contents: str) -> uuid.UUID: if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) # Insert into archival memory # memory_id = uuid.uuid4() passage_ids = memgpt_agent.persistence_manager.archival_memory.insert(memory_string=memory_contents, return_ids=True) return [str(p_id) for p_id in passage_ids] def delete_archival_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID, memory_id: uuid.UUID): if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) # Delete by ID # TODO check if it exists first, and throw error if not memgpt_agent.persistence_manager.archival_memory.storage.delete({"id": memory_id}) def get_agent_recall_cursor( self, user_id: uuid.UUID, agent_id: uuid.UUID, after: Optional[uuid.UUID] = None, before: Optional[uuid.UUID] = None, limit: Optional[int] = 100, order_by: Optional[str] = "created_at", order: Optional[str] = "asc", reverse: Optional[bool] = False, ) -> Tuple[uuid.UUID, List[dict]]: if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) # iterate over records cursor, records = memgpt_agent.persistence_manager.recall_memory.storage.get_all_cursor( after=after, before=before, limit=limit, order_by=order_by, reverse=reverse ) json_records = [record.to_json() for record in records] # TODO: mark what is in-context versus not return cursor, json_records def get_agent_config(self, user_id: uuid.UUID, agent_id: Optional[uuid.UUID], agent_name: Optional[str] = None) -> AgentState: """Return the config of an agent""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if agent_id: if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") else: agent_state = self.ms.get_agent(agent_name=agent_name, user_id=user_id) if agent_state is None: raise ValueError(f"Agent agent_name={agent_name} does not exist") agent_id = agent_state.id # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) return memgpt_agent.agent_state def get_server_config(self, include_defaults: bool = False) -> dict: """Return the base config""" def clean_keys(config): config_copy = config.copy() for k, v in config.items(): if k == "key" or "_key" in k: config_copy[k] = server_utils.shorten_key_middle(v, chars_each_side=5) return config_copy # TODO: do we need a seperate server config? base_config = vars(self.config) clean_base_config = clean_keys(base_config) clean_base_config_default_llm_config_dict = vars(clean_base_config["default_llm_config"]) clean_base_config_default_embedding_config_dict = vars(clean_base_config["default_embedding_config"]) clean_base_config["default_llm_config"] = clean_base_config_default_llm_config_dict clean_base_config["default_embedding_config"] = clean_base_config_default_embedding_config_dict response = {"config": clean_base_config} if include_defaults: default_config = vars(MemGPTConfig()) clean_default_config = clean_keys(default_config) clean_default_config["default_llm_config"] = clean_base_config_default_llm_config_dict clean_default_config["default_embedding_config"] = clean_base_config_default_embedding_config_dict response["defaults"] = clean_default_config return response def get_available_models(self) -> list: """Poll the LLM endpoint for a list of available models""" credentials = MemGPTCredentials().load() try: model_options = get_model_options( credentials=credentials, model_endpoint_type=self.config.default_llm_config.model_endpoint_type, model_endpoint=self.config.default_llm_config.model_endpoint, ) return model_options except Exception as e: logger.exception(f"Failed to get list of available models from LLM endpoint:\n{str(e)}") raise def update_agent_core_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_memory_contents: dict) -> dict: """Update the agents core memory block, return the new state""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) old_core_memory = self.get_agent_memory(user_id=user_id, agent_id=agent_id)["core_memory"] new_core_memory = old_core_memory.copy() modified = False for key, value in new_memory_contents.items(): if value is None: continue if key in old_core_memory and old_core_memory[key] != value: memgpt_agent.memory.memory[key].value = value # update agent memory modified = True # If we modified the memory contents, we need to rebuild the memory block inside the system message if modified: memgpt_agent.rebuild_memory() # save agent save_agent(memgpt_agent, self.ms) return { "old_core_memory": old_core_memory, "new_core_memory": new_core_memory, "modified": modified, } def rename_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_agent_name: str) -> AgentState: """Update the name of the agent in the database""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) current_name = memgpt_agent.agent_state.name if current_name == new_agent_name: raise ValueError(f"New name ({new_agent_name}) is the same as the current name") try: memgpt_agent.agent_state.name = new_agent_name self.ms.update_agent(agent=memgpt_agent.agent_state) except Exception as e: logger.exception(f"Failed to update agent name with:\n{str(e)}") raise ValueError(f"Failed to update agent name in database") assert isinstance(memgpt_agent.agent_state.id, uuid.UUID) return memgpt_agent.agent_state def delete_user(self, user_id: uuid.UUID): # TODO: delete user pass def delete_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID): """Delete an agent in the database""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") # Verify that the agent exists and is owned by the user agent_state = self.ms.get_agent(agent_id=agent_id, user_id=user_id) if not agent_state: raise ValueError(f"Could not find agent_id={agent_id} under user_id={user_id}") if agent_state.user_id != user_id: raise ValueError(f"Could not authorize agent_id={agent_id} with user_id={user_id}") # First, if the agent is in the in-memory cache we should remove it # List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts try: self.active_agents = [d for d in self.active_agents if str(d["agent_id"]) != str(agent_id)] except Exception as e: logger.exception(f"Failed to delete agent {agent_id} from cache via ID with:\n{str(e)}") raise ValueError(f"Failed to delete agent {agent_id} from cache") # Next, attempt to delete it from the actual database try: self.ms.delete_agent(agent_id=agent_id) except Exception as e: logger.exception(f"Failed to delete agent {agent_id} via ID with:\n{str(e)}") raise ValueError(f"Failed to delete agent {agent_id} in database") def authenticate_user(self) -> uuid.UUID: # TODO: Implement actual authentication to enable multi user setup return uuid.UUID(MemGPTConfig.load().anon_clientid) def api_key_to_user(self, api_key: str) -> uuid.UUID: """Decode an API key to a user""" user = self.ms.get_user_from_api_key(api_key=api_key) if user is None: raise HTTPException(status_code=403, detail="Invalid credentials") else: return user.id def create_api_key_for_user(self, user_id: uuid.UUID) -> Token: """Create a new API key for a user""" token = self.ms.create_api_key(user_id=user_id) return token def create_source(self, name: str, user_id: uuid.UUID) -> Source: # TODO: add other fields """Create a new data source""" source = Source( name=name, user_id=user_id, embedding_model=self.config.default_embedding_config.embedding_model, embedding_dim=self.config.default_embedding_config.embedding_dim, ) self.ms.create_source(source) assert self.ms.get_source(source_name=name, user_id=user_id) is not None, f"Failed to create source {name}" return source def delete_source(self, source_id: uuid.UUID, user_id: uuid.UUID): """Delete a data source""" source = self.ms.get_source(source_id=source_id, user_id=user_id) self.ms.delete_source(source_id) # delete data from passage store passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) passage_store.delete({"data_source": source.name}) # TODO: delete data from agent passage stores (?) def load_data( self, user_id: uuid.UUID, connector: DataConnector, source_name: str, ) -> Tuple[int, int]: """Load data from a DataConnector into a source for a specified user_id""" # TODO: this should be implemented as a batch job or at least async, since it may take a long time # load data from a data source into the document store source = self.ms.get_source(source_name=source_name, user_id=user_id) if source is None: raise ValueError(f"Data source {source_name} does not exist for user {user_id}") # get the data connectors passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) # TODO: add document store support document_store = None # StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id) # load data into the document store passage_count, document_count = load_data(connector, source, self.config.default_embedding_config, passage_store, document_store) return passage_count, document_count def attach_source_to_agent( self, user_id: uuid.UUID, agent_id: uuid.UUID, source_id: Optional[uuid.UUID] = None, source_name: Optional[str] = None, ): # attach a data source to an agent data_source = self.ms.get_source(source_id=source_id, user_id=user_id, source_name=source_name) if data_source is None: raise ValueError(f"Data source id={source_id} name={source_name} does not exist for user_id {user_id}") # get connection to data source storage source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) # load agent agent = self._get_or_load_agent(user_id, agent_id) # attach source to agent agent.attach_source(data_source.name, source_connector, self.ms) return data_source def detach_source_from_agent( self, user_id: uuid.UUID, agent_id: uuid.UUID, source_id: Optional[uuid.UUID] = None, source_name: Optional[str] = None, ): # TODO: remove all passages coresponding to source from agent's archival memory raise NotImplementedError def list_attached_sources(self, agent_id: uuid.UUID): # list all attached sources to an agent return self.ms.list_attached_sources(agent_id) def list_data_source_passages(self, user_id: uuid.UUID, source_id: uuid.UUID) -> List[PassageModel]: warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning) return [] def list_data_source_documents(self, user_id: uuid.UUID, source_id: uuid.UUID) -> List[DocumentModel]: warnings.warn("list_data_source_documents is not yet implemented, returning empty list.", category=UserWarning) return [] def list_all_sources(self, user_id: uuid.UUID) -> List[SourceModel]: """List all sources (w/ extra metadata) belonging to a user""" sources = self.ms.list_sources(user_id=user_id) # TODO don't unpack here, instead list_sources should return a SourceModel sources = [ SourceModel( name=source.name, description=None, # TODO: actually store descriptions user_id=source.user_id, id=source.id, embedding_config=self.server_embedding_config, created_at=source.created_at, ) for source in sources ] # Add extra metadata to the sources sources_with_metadata = [] for source in sources: # count number of passages passage_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) num_passages = passage_conn.size({"data_source": source.name}) # TODO: add when documents table implemented ## count number of documents # document_conn = StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id) # num_documents = document_conn.size({"data_source": source.name}) num_documents = 0 agent_ids = self.ms.list_attached_agents(source_id=source.id) # add the agent name information attached_agents = [ { "id": str(a_id), "name": self.ms.get_agent(user_id=user_id, agent_id=a_id).name, } for a_id in agent_ids ] # Overwrite metadata field, should be empty anyways source.metadata_ = dict( num_documents=num_documents, num_passages=num_passages, attached_agents=attached_agents, ) sources_with_metadata.append(source) return sources_with_metadata def create_tool( self, json_schema: dict, source_code: str, source_type: str, tags: Optional[List[str]] = None, exists_ok: Optional[bool] = True, user_id: Optional[uuid.UUID] = None, ) -> ToolModel: # TODO: add other fields """Create a new tool Args: TODO Returns: tool (ToolModel): Tool object """ if tags and "memory" in tags: # special modifications to memory functions # self.memory -> self.memory.memory, since Agent.memory.memory needs to be modified (not BaseMemory.memory) source_code = source_code.replace("self.memory", "self.memory.memory") # check if already exists: tool_name = json_schema["name"] existing_tool = self.ms.get_tool(tool_name, user_id) if existing_tool: if exists_ok: # update existing tool existing_tool.source_code = source_code existing_tool.source_type = source_type existing_tool.tags = tags existing_tool.json_schema = json_schema self.ms.update_tool(existing_tool) return self.ms.get_tool(tool_name, user_id) else: raise ValueError(f"Tool {tool_name} already exists and update=False") tool = ToolModel( name=tool_name, source_code=source_code, source_type=source_type, tags=tags, json_schema=json_schema, user_id=user_id ) self.ms.add_tool(tool) return self.ms.get_tool(tool_name, user_id) def delete_tool(self, name: str): """Delete a tool""" self.ms.delete_tool(name)