* autogenerate openapi file on server startup * added endpoint for paginated retrieval of in-context agent messages * missing diff * added ability to pass system messages via message endpoint * patched bad depends into queries to fix the param info not showing up in get requests, fixed some bad copy paste
616 lines
26 KiB
Python
616 lines
26 KiB
Python
from abc import abstractmethod
|
|
from typing import Union, Callable
|
|
import json
|
|
import logging
|
|
from threading import Lock
|
|
from functools import wraps
|
|
from fastapi import HTTPException
|
|
|
|
from memgpt.system import package_user_message
|
|
from memgpt.config import MemGPTConfig
|
|
from memgpt.agent import Agent
|
|
import memgpt.system as system
|
|
import memgpt.constants as constants
|
|
from memgpt.cli.cli import attach
|
|
|
|
# from memgpt.agent_store.storage import StorageConnector
|
|
from memgpt.metadata import MetadataStore
|
|
import memgpt.presets.presets as presets
|
|
import memgpt.utils as utils
|
|
import memgpt.server.utils as server_utils
|
|
from memgpt.persistence_manager import PersistenceManager, LocalStateManager
|
|
from memgpt.data_types import Source, Passage, Document, User, AgentState
|
|
|
|
# TODO use custom interface
|
|
from memgpt.interface import CLIInterface # for printing to terminal
|
|
from memgpt.interface import AgentInterface # abstract
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Server(object):
|
|
"""Abstract server class that supports multi-agent multi-user"""
|
|
|
|
@abstractmethod
|
|
def list_agents(self, user_id: str) -> dict:
|
|
"""List all available agents to a user"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_agent_messages(self, user_id: str, agent_id: str, start: int, count: int) -> list:
|
|
"""Paginated query of in-context messages in agent message queue"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_agent_memory(self, user_id: str, agent_id: str) -> dict:
|
|
"""Return the memory of an agent (core memory + non-core statistics)"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_agent_config(self, user_id: str, agent_id: str) -> dict:
|
|
"""Return the config of an agent"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_server_config(self, user_id: str) -> dict:
|
|
"""Return the base config"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def update_agent_core_memory(self, user_id: str, agent_id: str, new_memory_contents: dict) -> dict:
|
|
"""Update the agents core memory block, return the new state"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def create_agent(
|
|
self,
|
|
user_id: str,
|
|
agent_config: Union[dict, AgentState],
|
|
interface: Union[AgentInterface, None],
|
|
persistence_manager: Union[PersistenceManager, None],
|
|
) -> str:
|
|
"""Create a new agent using a config"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def user_message(self, user_id: str, agent_id: str, message: str) -> None:
|
|
"""Process a message from the user, internally calls step"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def system_message(self, user_id: str, agent_id: str, message: str) -> None:
|
|
"""Process a message from the system, internally calls step"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
|
|
"""Run a command on the agent, e.g. /memory
|
|
|
|
May return a string with a message generated by the command
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class LockingServer(Server):
|
|
"""Basic support for concurrency protections (all requests that modify an agent lock the agent until the operation is complete)"""
|
|
|
|
# Locks for each agent
|
|
_agent_locks = {}
|
|
|
|
@staticmethod
|
|
def agent_lock_decorator(func: Callable) -> Callable:
|
|
@wraps(func)
|
|
def wrapper(self, user_id: str, agent_id: str, *args, **kwargs):
|
|
# logger.info("Locking check")
|
|
|
|
# Initialize the lock for the agent_id if it doesn't exist
|
|
if agent_id not in self._agent_locks:
|
|
# logger.info(f"Creating lock for agent_id = {agent_id}")
|
|
self._agent_locks[agent_id] = Lock()
|
|
|
|
# Check if the agent is currently locked
|
|
if not self._agent_locks[agent_id].acquire(blocking=False):
|
|
# logger.info(f"agent_id = {agent_id} is busy")
|
|
raise HTTPException(status_code=423, detail=f"Agent '{agent_id}' is currently busy.")
|
|
|
|
try:
|
|
# Execute the function
|
|
# logger.info(f"running function on agent_id = {agent_id}")
|
|
return func(self, user_id, agent_id, *args, **kwargs)
|
|
finally:
|
|
# Release the lock
|
|
# logger.info(f"releasing lock on agent_id = {agent_id}")
|
|
self._agent_locks[agent_id].release()
|
|
|
|
return wrapper
|
|
|
|
@agent_lock_decorator
|
|
def user_message(self, user_id: str, agent_id: str, message: str) -> None:
|
|
raise NotImplementedError
|
|
|
|
@agent_lock_decorator
|
|
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
|
|
raise NotImplementedError
|
|
|
|
|
|
# TODO actually use "user_id" for something
|
|
class SyncServer(LockingServer):
|
|
"""Simple single-threaded / blocking server process"""
|
|
|
|
def __init__(
|
|
self,
|
|
chaining: bool = True,
|
|
max_chaining_steps: bool = None,
|
|
# default_interface_cls: AgentInterface = CLIInterface,
|
|
default_interface: AgentInterface = CLIInterface(),
|
|
# default_persistence_manager_cls: PersistenceManager = LocalStateManager,
|
|
):
|
|
"""Server process holds in-memory agents that are being run"""
|
|
|
|
# List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts
|
|
self.active_agents = []
|
|
|
|
# chaining = whether or not to run again if request_heartbeat=true
|
|
self.chaining = chaining
|
|
|
|
# if chaining == true, what's the max number of times we'll chain before yielding?
|
|
# none = no limit, can go on forever
|
|
self.max_chaining_steps = max_chaining_steps
|
|
|
|
# The default interface that will get assigned to agents ON LOAD
|
|
# self.default_interface_cls = default_interface_cls
|
|
self.default_interface = default_interface
|
|
|
|
# The default persistence manager that will get assigned to agents ON CREATION
|
|
# self.default_persistence_manager_cls = default_persistence_manager_cls
|
|
|
|
# Initialize the connection to the DB
|
|
self.config = MemGPTConfig()
|
|
self.ms = MetadataStore(self.config)
|
|
|
|
def save_agents(self):
|
|
"""Saves all the agents that are in the in-memory object store"""
|
|
for agent_d in self.active_agents:
|
|
try:
|
|
agent_d["agent"].save()
|
|
logger.info(f"Saved agent {agent_d['agent_id']}")
|
|
except Exception as e:
|
|
logger.exception(f"Error occurred while trying to save agent {agent_d['agent_id']}:\n{e}")
|
|
|
|
def _get_agent(self, user_id: str, agent_id: str) -> Union[Agent, None]:
|
|
"""Get the agent object from the in-memory object store"""
|
|
for d in self.active_agents:
|
|
if d["user_id"] == user_id and d["agent_id"] == agent_id:
|
|
return d["agent"]
|
|
return None
|
|
|
|
def _add_agent(self, user_id: str, agent_id: str, agent_obj: Agent) -> None:
|
|
"""Put an agent object inside the in-memory object store"""
|
|
# Make sure the agent doesn't already exist
|
|
if self._get_agent(user_id=user_id, agent_id=agent_id) is not None:
|
|
raise KeyError(f"Agent (user={user_id}, agent={agent_id}) is already loaded")
|
|
# Add Agent instance to the in-memory list
|
|
self.active_agents.append(
|
|
{
|
|
"user_id": user_id,
|
|
"agent_id": agent_id,
|
|
"agent": agent_obj,
|
|
}
|
|
)
|
|
|
|
def _load_agent(self, user_id: str, agent_id: str, interface: Union[AgentInterface, None] = None) -> Agent:
|
|
"""Loads a saved agent into memory (if it doesn't exist, throw an error)"""
|
|
|
|
# If an interface isn't specified, use the default
|
|
if interface is None:
|
|
interface = self.default_interface
|
|
|
|
try:
|
|
agent_state = self.ms.get_agent(agent_id=agent_id, user_id=user_id)
|
|
if not agent_state:
|
|
raise ValueError(f"agent_id {agent_id} does not exist")
|
|
|
|
# Instantiate an agent object using the state retrieved
|
|
memgpt_agent = Agent(agent_state=agent_state, interface=interface)
|
|
|
|
# Add the agent to the in-memory store and return its reference
|
|
self._add_agent(user_id=user_id, agent_id=agent_id, agent_obj=memgpt_agent)
|
|
return memgpt_agent
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Error occurred while trying to get agent {agent_id}:\n{e}")
|
|
|
|
def _get_or_load_agent(self, user_id: str, agent_id: str) -> Agent:
|
|
"""Check if the agent is in-memory, then load"""
|
|
memgpt_agent = self._get_agent(user_id=user_id, agent_id=agent_id)
|
|
if not memgpt_agent:
|
|
memgpt_agent = self._load_agent(user_id=user_id, agent_id=agent_id)
|
|
return memgpt_agent
|
|
|
|
def _step(self, user_id: str, agent_id: str, input_message: str) -> None:
|
|
"""Send the input message through the agent"""
|
|
|
|
logger.debug(f"Got input message: {input_message}")
|
|
|
|
# Get the agent object (loaded in memory)
|
|
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
|
if memgpt_agent is None:
|
|
raise KeyError(f"Agent (user={user_id}, agent={agent_id}) is not loaded")
|
|
|
|
logger.debug(f"Starting agent step")
|
|
no_verify = True
|
|
next_input_message = input_message
|
|
counter = 0
|
|
while True:
|
|
new_messages, heartbeat_request, function_failed, token_warning = memgpt_agent.step(
|
|
next_input_message, first_message=False, skip_verify=no_verify
|
|
)
|
|
counter += 1
|
|
|
|
# Chain stops
|
|
if not self.chaining:
|
|
logger.debug("No chaining, stopping after one step")
|
|
break
|
|
elif self.max_chaining_steps is not None and counter > self.max_chaining_steps:
|
|
logger.debug(f"Hit max chaining steps, stopping after {counter} steps")
|
|
break
|
|
# Chain handlers
|
|
elif token_warning:
|
|
next_input_message = system.get_token_limit_warning()
|
|
continue # always chain
|
|
elif function_failed:
|
|
next_input_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE)
|
|
continue # always chain
|
|
elif heartbeat_request:
|
|
next_input_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE)
|
|
continue # always chain
|
|
# MemGPT no-op / yield
|
|
else:
|
|
break
|
|
|
|
memgpt_agent.interface.step_yield()
|
|
logger.debug(f"Finished agent step")
|
|
|
|
def _command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
|
|
"""Process a CLI command"""
|
|
|
|
logger.debug(f"Got command: {command}")
|
|
|
|
# Get the agent object (loaded in memory)
|
|
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
|
|
|
if command.lower() == "exit":
|
|
# exit not supported on server.py
|
|
raise ValueError(command)
|
|
|
|
elif command.lower() == "save" or command.lower() == "savechat":
|
|
memgpt_agent.save()
|
|
|
|
elif command.lower() == "attach":
|
|
# Different from CLI, we extract the data source name from the command
|
|
command = command.strip().split()
|
|
try:
|
|
data_source = int(command[1])
|
|
except:
|
|
raise ValueError(command)
|
|
|
|
# TODO: check if agent already has it
|
|
data_source_options = StorageConnector.list_loaded_data()
|
|
if len(data_source_options) == 0:
|
|
raise ValueError('No sources available. You must load a souce with "memgpt load ..." before running /attach.')
|
|
elif data_source not in data_source_options:
|
|
raise ValueError(f"Invalid data source name: {data_source} (options={data_source_options})")
|
|
else:
|
|
# attach new data
|
|
attach(memgpt_agent.config.name, data_source)
|
|
|
|
# update agent config
|
|
memgpt_agent.config.attach_data_source(data_source)
|
|
|
|
# reload agent with new data source
|
|
# TODO: maybe make this less ugly...
|
|
memgpt_agent.persistence_manager.archival_memory.storage = StorageConnector.get_storage_connector(
|
|
agent_config=memgpt_agent.config
|
|
)
|
|
|
|
elif command.lower() == "dump" or command.lower().startswith("dump "):
|
|
# Check if there's an additional argument that's an integer
|
|
command = command.strip().split()
|
|
amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0
|
|
if amount == 0:
|
|
memgpt_agent.interface.print_messages(memgpt_agent.messages, dump=True)
|
|
else:
|
|
memgpt_agent.interface.print_messages(memgpt_agent.messages[-min(amount, len(memgpt_agent.messages)) :], dump=True)
|
|
|
|
elif command.lower() == "dumpraw":
|
|
memgpt_agent.interface.print_messages_raw(memgpt_agent.messages)
|
|
|
|
elif command.lower() == "memory":
|
|
ret_str = (
|
|
f"\nDumping memory contents:\n"
|
|
+ f"\n{str(memgpt_agent.memory)}"
|
|
+ f"\n{str(memgpt_agent.persistence_manager.archival_memory)}"
|
|
+ f"\n{str(memgpt_agent.persistence_manager.recall_memory)}"
|
|
)
|
|
return ret_str
|
|
|
|
elif command.lower() == "pop" or command.lower().startswith("pop "):
|
|
# Check if there's an additional argument that's an integer
|
|
command = command.strip().split()
|
|
pop_amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 3
|
|
n_messages = len(memgpt_agent.messages)
|
|
MIN_MESSAGES = 2
|
|
if n_messages <= MIN_MESSAGES:
|
|
logger.info(f"Agent only has {n_messages} messages in stack, none left to pop")
|
|
elif n_messages - pop_amount < MIN_MESSAGES:
|
|
logger.info(f"Agent only has {n_messages} messages in stack, cannot pop more than {n_messages - MIN_MESSAGES}")
|
|
else:
|
|
logger.info(f"Popping last {pop_amount} messages from stack")
|
|
for _ in range(min(pop_amount, len(memgpt_agent.messages))):
|
|
memgpt_agent.messages.pop()
|
|
|
|
elif command.lower() == "retry":
|
|
# TODO this needs to also modify the persistence manager
|
|
logger.info(f"Retrying for another answer")
|
|
while len(memgpt_agent.messages) > 0:
|
|
if memgpt_agent.messages[-1].get("role") == "user":
|
|
# we want to pop up to the last user message and send it again
|
|
user_message = memgpt_agent.messages[-1].get("content")
|
|
memgpt_agent.messages.pop()
|
|
break
|
|
memgpt_agent.messages.pop()
|
|
|
|
elif command.lower() == "rethink" or command.lower().startswith("rethink "):
|
|
# TODO this needs to also modify the persistence manager
|
|
if len(command) < len("rethink "):
|
|
logger.warning("Missing text after the command")
|
|
else:
|
|
for x in range(len(memgpt_agent.messages) - 1, 0, -1):
|
|
if memgpt_agent.messages[x].get("role") == "assistant":
|
|
text = command[len("rethink ") :].strip()
|
|
memgpt_agent.messages[x].update({"content": text})
|
|
break
|
|
|
|
elif command.lower() == "rewrite" or command.lower().startswith("rewrite "):
|
|
# TODO this needs to also modify the persistence manager
|
|
if len(command) < len("rewrite "):
|
|
logger.warning("Missing text after the command")
|
|
else:
|
|
for x in range(len(memgpt_agent.messages) - 1, 0, -1):
|
|
if memgpt_agent.messages[x].get("role") == "assistant":
|
|
text = command[len("rewrite ") :].strip()
|
|
args = json.loads(memgpt_agent.messages[x].get("function_call").get("arguments"))
|
|
args["message"] = text
|
|
memgpt_agent.messages[x].get("function_call").update({"arguments": json.dumps(args)})
|
|
break
|
|
|
|
# No skip options
|
|
elif command.lower() == "wipe":
|
|
# exit not supported on server.py
|
|
raise ValueError(command)
|
|
|
|
elif command.lower() == "heartbeat":
|
|
input_message = system.get_heartbeat()
|
|
self._step(user_id=user_id, agent_id=agent_id, input_message=input_message)
|
|
|
|
elif command.lower() == "memorywarning":
|
|
input_message = system.get_token_limit_warning()
|
|
self._step(user_id=user_id, agent_id=agent_id, input_message=input_message)
|
|
|
|
@LockingServer.agent_lock_decorator
|
|
def user_message(self, user_id: str, agent_id: str, message: str) -> None:
|
|
"""Process an incoming user message and feed it through the MemGPT agent"""
|
|
|
|
# Basic input sanitization
|
|
if not isinstance(message, str) or len(message) == 0:
|
|
raise ValueError(f"Invalid input: '{message}'")
|
|
|
|
# If the input begins with a command prefix, reject
|
|
elif message.startswith("/"):
|
|
raise ValueError(f"Invalid input: '{message}'")
|
|
|
|
# Else, process it as a user message to be fed to the agent
|
|
else:
|
|
# Package the user message first
|
|
packaged_user_message = package_user_message(user_message=message)
|
|
# Run the agent state forward
|
|
self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message)
|
|
|
|
@LockingServer.agent_lock_decorator
|
|
def system_message(self, user_id: str, agent_id: str, message: str) -> None:
|
|
"""Process an incoming system message and feed it through the MemGPT agent"""
|
|
from memgpt.utils import printd
|
|
|
|
# Basic input sanitization
|
|
if not isinstance(message, str) or len(message) == 0:
|
|
raise ValueError(f"Invalid input: '{message}'")
|
|
|
|
# If the input begins with a command prefix, reject
|
|
elif message.startswith("/"):
|
|
raise ValueError(f"Invalid input: '{message}'")
|
|
|
|
# Else, process it as a user message to be fed to the agent
|
|
else:
|
|
# Package the user message first
|
|
packaged_system_message = package_system_message(system_message=message)
|
|
# Run the agent state forward
|
|
self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_system_message)
|
|
|
|
@LockingServer.agent_lock_decorator
|
|
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
|
|
"""Run a command on the agent"""
|
|
# If the input begins with a command prefix, attempt to process it as a command
|
|
if command.startswith("/"):
|
|
if len(command) > 1:
|
|
command = command[1:] # strip the prefix
|
|
return self._command(user_id=user_id, agent_id=agent_id, command=command)
|
|
|
|
def create_agent(
|
|
self,
|
|
user_id: str,
|
|
agent_config: dict,
|
|
interface: Union[AgentInterface, None] = None,
|
|
# persistence_manager: Union[PersistenceManager, None] = None,
|
|
) -> AgentState:
|
|
"""Create a new agent using a config"""
|
|
|
|
# Initialize the agent based on the provided configuration
|
|
if not isinstance(agent_config, dict):
|
|
raise ValueError(f"agent_config must be provided as a dictionary")
|
|
|
|
if interface is None:
|
|
# interface = self.default_interface_cls()
|
|
interface = self.default_interface
|
|
|
|
# if persistence_manager is None:
|
|
# persistence_manager = self.default_persistence_manager_cls(agent_config=agent_config)
|
|
|
|
# TODO actually use the user_id that was passed into the server
|
|
USER_ID = self.config.anon_clientid
|
|
# create user and agent
|
|
user = User(id=USER_ID)
|
|
user = self.ms.get_user(user_id=USER_ID)
|
|
if not user:
|
|
user = User(id=USER_ID)
|
|
self.ms.create_user(user)
|
|
|
|
agent_state = AgentState(
|
|
user_id=user.id,
|
|
name=agent_config["name"] if "name" in agent_config else utils.create_random_username(),
|
|
preset=agent_config["preset"] if "preset" in agent_config else user.default_preset,
|
|
# TODO we need to allow passing raw persona/human text via the server request
|
|
persona=agent_config["persona"] if "persona" in agent_config else user.default_persona,
|
|
human=agent_config["human"] if "human" in agent_config else user.default_human,
|
|
llm_config=agent_config["llm_config"] if "llm_config" in agent_config else user.default_llm_config,
|
|
embedding_config=agent_config["embedding_config"] if "embedding_config" in agent_config else user.default_embedding_config,
|
|
)
|
|
agent = presets.create_agent_from_preset(agent_state=agent_state, interface=interface)
|
|
# TODO where should we handle saving of the AgentState?
|
|
agent.save()
|
|
# try:
|
|
# self.ms.create_agent(agent)
|
|
# except ValueError:
|
|
# agent name under user.id already exists, not OK
|
|
# raise
|
|
logger.info(f"Created new agent from config: {agent}")
|
|
|
|
return agent.config
|
|
|
|
def delete_agent(
|
|
self,
|
|
user_id: str,
|
|
agent_id: str,
|
|
):
|
|
# Make sure the user owns the agent
|
|
# TODO use real user_id
|
|
USER_ID = self.config.anon_clientid
|
|
agent = self.ms.get_agent(agent_id=agent_id, user_id=USER_ID)
|
|
if agent is not None:
|
|
self.ms.delete_agent(agent_id=agent_id)
|
|
|
|
def list_agents(self, user_id: str) -> dict:
|
|
"""List all available agents to a user"""
|
|
# TODO actually use the user_id that was passed into the server
|
|
USER_ID = self.config.anon_clientid
|
|
agents_list = self.ms.list_agents(user_id=USER_ID)
|
|
return {"num_agents": len(agents_list), "agent_names": [state.name for state in agents_list]}
|
|
|
|
def get_agent_memory(self, user_id: str, agent_id: str) -> dict:
|
|
"""Return the memory of an agent (core memory + non-core statistics)"""
|
|
# Get the agent object (loaded in memory)
|
|
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
|
|
|
core_memory = memgpt_agent.memory
|
|
recall_memory = memgpt_agent.persistence_manager.recall_memory
|
|
archival_memory = memgpt_agent.persistence_manager.archival_memory
|
|
|
|
memory_obj = {
|
|
"core_memory": {
|
|
"persona": core_memory.persona,
|
|
"human": core_memory.human,
|
|
},
|
|
"recall_memory": len(recall_memory) if recall_memory is not None else None,
|
|
"archival_memory": len(archival_memory) if archival_memory is not None else None,
|
|
}
|
|
|
|
return memory_obj
|
|
|
|
def get_agent_messages(self, user_id: str, agent_id: str, start: int, count: int) -> list:
|
|
"""Paginated query of in-context messages in agent message queue"""
|
|
# Get the agent object (loaded in memory)
|
|
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
|
|
|
if start < 0 or count < 0:
|
|
raise ValueError("Start and count values should be non-negative")
|
|
|
|
# Reverse the list to make it in reverse chronological order
|
|
reversed_messages = memgpt_agent.messages[::-1]
|
|
|
|
# Check if start is within the range of the list
|
|
if start >= len(reversed_messages):
|
|
raise IndexError("Start index is out of range")
|
|
|
|
# Calculate the end index, ensuring it does not exceed the list length
|
|
end_index = min(start + count, len(reversed_messages))
|
|
|
|
# Slice the list for pagination
|
|
paginated_messages = reversed_messages[start:end_index]
|
|
|
|
return paginated_messages
|
|
|
|
def get_agent_config(self, user_id: str, agent_id: str) -> dict:
|
|
"""Return the config of an agent"""
|
|
# Get the agent object (loaded in memory)
|
|
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
|
agent_config = vars(memgpt_agent.config)
|
|
|
|
return agent_config
|
|
|
|
def get_server_config(self) -> dict:
|
|
"""Return the base config"""
|
|
# TODO: do we need a seperate server config?
|
|
base_config = vars(self.config)
|
|
|
|
def clean_keys(config):
|
|
config_copy = config.copy()
|
|
for k, v in config.items():
|
|
if k == "key" or "_key" in k:
|
|
config_copy[k] = server_utils.shorten_key_middle(v, chars_each_side=5)
|
|
return config_copy
|
|
|
|
clean_base_config = clean_keys(base_config)
|
|
return clean_base_config
|
|
|
|
def update_agent_core_memory(self, user_id: str, agent_id: str, new_memory_contents: dict) -> dict:
|
|
"""Update the agents core memory block, return the new state"""
|
|
# Get the agent object (loaded in memory)
|
|
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
|
|
|
old_core_memory = self.get_agent_memory(user_id=user_id, agent_id=agent_id)["core_memory"]
|
|
new_core_memory = old_core_memory.copy()
|
|
|
|
modified = False
|
|
if "persona" in new_memory_contents and new_memory_contents["persona"] is not None:
|
|
new_persona = new_memory_contents["persona"]
|
|
if old_core_memory["persona"] != new_persona:
|
|
new_core_memory["persona"] = new_persona
|
|
memgpt_agent.memory.edit_persona(new_persona)
|
|
modified = True
|
|
|
|
if "human" in new_memory_contents and new_memory_contents["human"] is not None:
|
|
new_human = new_memory_contents["human"]
|
|
if old_core_memory["human"] != new_human:
|
|
new_core_memory["human"] = new_human
|
|
memgpt_agent.memory.edit_human(new_human)
|
|
modified = True
|
|
|
|
# If we modified the memory contents, we need to rebuild the memory block inside the system message
|
|
if modified:
|
|
memgpt_agent.rebuild_memory()
|
|
|
|
return {
|
|
"old_core_memory": old_core_memory,
|
|
"new_core_memory": new_core_memory,
|
|
"modified": modified,
|
|
}
|