diff --git a/.github/workflows/docker-integration-tests.yaml b/.github/workflows/docker-integration-tests.yaml index fe091fde..ec8855cd 100644 --- a/.github/workflows/docker-integration-tests.yaml +++ b/.github/workflows/docker-integration-tests.yaml @@ -56,7 +56,7 @@ jobs: PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }} run: | pipx install poetry==1.8.2 - poetry install -E dev + poetry install -E dev -E postgres poetry run pytest -s tests/test_client.py poetry run pytest -s tests/test_concurrent_connections.py diff --git a/memgpt/agent.py b/memgpt/agent.py index f31f3891..52f70d21 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -3,7 +3,6 @@ import inspect import json import traceback import uuid -from pathlib import Path from typing import List, Optional, Tuple, Union, cast from tqdm import tqdm @@ -11,8 +10,6 @@ from tqdm import tqdm from memgpt.agent_store.storage import StorageConnector from memgpt.constants import ( CLI_WARNING_PREFIX, - CORE_MEMORY_HUMAN_CHAR_LIMIT, - CORE_MEMORY_PERSONA_CHAR_LIMIT, FIRST_MESSAGE_ATTEMPTS, JSON_ENSURE_ASCII, JSON_LOADS_STRICT, @@ -24,9 +21,7 @@ from memgpt.constants import ( from memgpt.data_types import AgentState, EmbeddingConfig, Message, Passage from memgpt.interface import AgentInterface from memgpt.llm_api.llm_api_tools import create, is_context_overflow_error -from memgpt.memory import ArchivalMemory -from memgpt.memory import CoreMemory as InContextMemory -from memgpt.memory import RecallMemory, summarize_messages +from memgpt.memory import ArchivalMemory, BaseMemory, RecallMemory, summarize_messages from memgpt.metadata import MetadataStore from memgpt.models import chat_completion_response from memgpt.models.pydantic_models import ToolModel @@ -41,7 +36,6 @@ from memgpt.utils import ( count_tokens, create_uuid_from_string, get_local_time, - get_schema_diff, get_tool_call_id, get_utc_time, is_utc_datetime, @@ -53,81 +47,17 @@ from memgpt.utils import ( ) from .errors import LLMError -from .functions.functions import USER_FUNCTIONS_DIR, load_all_function_sets - - -def link_functions(function_schemas: list): - """Link function definitions to list of function schemas""" - - # need to dynamically link the functions - # the saved agent.functions will just have the schemas, but we need to - # go through the functions library and pull the respective python functions - - # Available functions is a mapping from: - # function_name -> { - # json_schema: schema - # python_function: function - # } - # agent.functions is a list of schemas (OpenAI kwarg functions style, see: https://platform.openai.com/docs/api-reference/chat/create) - # [{'name': ..., 'description': ...}, {...}] - available_functions = load_all_function_sets() - linked_function_set = {} - for f_schema in function_schemas: - # Attempt to find the function in the existing function library - f_name = f_schema.get("name") - if f_name is None: - raise ValueError(f"While loading agent.state.functions encountered a bad function schema object with no name:\n{f_schema}") - linked_function = available_functions.get(f_name) - if linked_function is None: - # raise ValueError( - # f"Function '{f_name}' was specified in agent.state.functions, but is not in function library:\n{available_functions.keys()}" - # ) - print( - f"Function '{f_name}' was specified in agent.state.functions, but is not in function library:\n{available_functions.keys()}" - ) - continue - - # Once we find a matching function, make sure the schema is identical - if json.dumps(f_schema, ensure_ascii=JSON_ENSURE_ASCII) != json.dumps( - linked_function["json_schema"], ensure_ascii=JSON_ENSURE_ASCII - ): - # error_message = ( - # f"Found matching function '{f_name}' from agent.state.functions inside function library, but schemas are different." - # + f"\n>>>agent.state.functions\n{json.dumps(f_schema, indent=2, ensure_ascii=JSON_ENSURE_ASCII)}" - # + f"\n>>>function library\n{json.dumps(linked_function['json_schema'], indent=2, ensure_ascii=JSON_ENSURE_ASCII)}" - # ) - schema_diff = get_schema_diff(f_schema, linked_function["json_schema"]) - error_message = ( - f"Found matching function '{f_name}' from agent.state.functions inside function library, but schemas are different.\n" - + "".join(schema_diff) - ) - - # NOTE to handle old configs, instead of erroring here let's just warn - # raise ValueError(error_message) - printd(error_message) - linked_function_set[f_name] = linked_function - return linked_function_set - - -def initialize_memory(ai_notes: Union[str, None], human_notes: Union[str, None]): - if ai_notes is None: - raise ValueError(ai_notes) - if human_notes is None: - raise ValueError(human_notes) - memory = InContextMemory(human_char_limit=CORE_MEMORY_HUMAN_CHAR_LIMIT, persona_char_limit=CORE_MEMORY_PERSONA_CHAR_LIMIT) - memory.edit_persona(ai_notes) - memory.edit_human(human_notes) - return memory def construct_system_with_memory( system: str, - memory: InContextMemory, + memory: BaseMemory, memory_edit_timestamp: str, archival_memory: Optional[ArchivalMemory] = None, recall_memory: Optional[RecallMemory] = None, include_char_count: bool = True, ): + # TODO: modify this to be generalized full_system_message = "\n".join( [ system, @@ -136,12 +66,13 @@ def construct_system_with_memory( f"{len(recall_memory) if recall_memory else 0} previous messages between you and the user are stored in recall memory (use functions to access them)", f"{len(archival_memory) if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)", "\nCore memory shown below (limited in size, additional information stored in archival / recall memory):", - f'' if include_char_count else "", - memory.persona, - "", - f'' if include_char_count else "", - memory.human, - "", + str(memory), + # f'' if include_char_count else "", + # memory.persona, + # "", + # f'' if include_char_count else "", + # memory.human, + # "", ] ) return full_system_message @@ -150,7 +81,7 @@ def construct_system_with_memory( def initialize_message_sequence( model: str, system: str, - memory: InContextMemory, + memory: BaseMemory, archival_memory: Optional[ArchivalMemory] = None, recall_memory: Optional[RecallMemory] = None, memory_edit_timestamp: Optional[str] = None, @@ -195,13 +126,14 @@ class Agent(object): # agents can be created from providing agent_state agent_state: AgentState, tools: List[ToolModel], + # memory: BaseMemory, # extras messages_total: Optional[int] = None, # TODO remove? first_message_verify_mono: bool = True, # TODO move to config? ): - # tools for tool in tools: + assert tool, f"Tool is None - must be error in querying tool from DB" assert tool.name in agent_state.tools, f"Tool {tool} not found in agent_state.tools" for tool_name in agent_state.tools: assert tool_name in [tool.name for tool in tools], f"Tool name {tool_name} not included in agent tool list" @@ -230,13 +162,8 @@ class Agent(object): self.system = self.agent_state.system # Initialize the memory object - # TODO: support more general memory types - if "persona" not in self.agent_state.state: # TODO: remove - raise ValueError(f"'persona' not found in provided AgentState") - if "human" not in self.agent_state.state: # TODO: remove - raise ValueError(f"'human' not found in provided AgentState") - self.memory = initialize_memory(ai_notes=self.agent_state.state["persona"], human_notes=self.agent_state.state["human"]) - printd("INITIALIZED MEMORY", self.memory.persona, self.memory.human) + self.memory = BaseMemory.load(self.agent_state.state["memory"]) + printd("Initialized memory object", self.memory) # Interface must implement: # - internal_monologue @@ -285,7 +212,7 @@ class Agent(object): m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc) else: - # print(f"Agent.__init__ :: creating, state={agent_state.state['messages']}") + printd(f"Agent.__init__ :: creating, state={agent_state.state['messages']}") init_messages = initialize_message_sequence( self.model, self.system, @@ -311,12 +238,10 @@ class Agent(object): # Keep track of the total number of messages throughout all time self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system) - # self.messages_total_init = self.messages_total self.messages_total_init = len(self._messages) - 1 printd(f"Agent initialized, self.messages_total={self.messages_total}") # Create the agent in the DB - # self.save() self.update_state() @property @@ -609,6 +534,10 @@ class Agent(object): heartbeat_request = False function_failed = False + # rebuild memory + # TODO: @charles please check this + self.rebuild_memory() + return messages, heartbeat_request, function_failed def step( @@ -770,6 +699,10 @@ class Agent(object): self._append_to_messages(all_new_messages) messages_to_return = [msg.to_openai_dict() for msg in all_new_messages] if return_dicts else all_new_messages + + # update state after each step + self.update_state() + return messages_to_return, heartbeat_request, function_failed, active_memory_warning, response.usage except Exception as e: @@ -913,6 +846,14 @@ class Agent(object): def rebuild_memory(self): """Rebuilds the system message with the latest memory object""" curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt + + # NOTE: This is a hacky way to check if the memory has changed + memory_repr = str(self.memory) + if memory_repr == curr_system_message["content"][-(len(memory_repr)) :]: + printd(f"Memory has not changed, not rebuilding system") + return + + # update memory (TODO: potentially update recall/archival stats seperately) new_system_message = initialize_message_sequence( self.model, self.system, @@ -922,136 +863,85 @@ class Agent(object): )[0] diff = united_diff(curr_system_message["content"], new_system_message["content"]) - printd(f"Rebuilding system with new memory...\nDiff:\n{diff}") + if len(diff) > 0: # there was a diff + printd(f"Rebuilding system with new memory...\nDiff:\n{diff}") - # Swap the system message out - self._swap_system_message( - Message.dict_to_message( - agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=new_system_message + # Swap the system message out (only if there is a diff) + self._swap_system_message( + Message.dict_to_message( + agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=new_system_message + ) + ) + assert self.messages[0]["content"] == new_system_message["content"], ( + self.messages[0]["content"], + new_system_message["content"], ) - ) - - # def to_agent_state(self) -> AgentState: - # # The state may have change since the last time we wrote it - # updated_state = { - # "persona": self.memory.persona, - # "human": self.memory.human, - # "system": self.system, - # "functions": self.functions, - # "messages": [str(msg.id) for msg in self._messages], - # } - - # agent_state = AgentState( - # name=self.agent_state.name, - # user_id=self.agent_state.user_id, - # persona=self.agent_state.persona, - # human=self.agent_state.human, - # llm_config=self.agent_state.llm_config, - # embedding_config=self.agent_state.embedding_config, - # preset=self.agent_state.preset, - # id=self.agent_state.id, - # created_at=self.agent_state.created_at, - # state=updated_state, - # ) - - # return agent_state def add_function(self, function_name: str) -> str: - if function_name in self.functions_python.keys(): - msg = f"Function {function_name} already loaded" - printd(msg) - return msg + # TODO: refactor + raise NotImplementedError + # if function_name in self.functions_python.keys(): + # msg = f"Function {function_name} already loaded" + # printd(msg) + # return msg - available_functions = load_all_function_sets() - if function_name not in available_functions.keys(): - raise ValueError(f"Function {function_name} not found in function library") + # available_functions = load_all_function_sets() + # if function_name not in available_functions.keys(): + # raise ValueError(f"Function {function_name} not found in function library") - self.functions.append(available_functions[function_name]["json_schema"]) - self.functions_python[function_name] = available_functions[function_name]["python_function"] + # self.functions.append(available_functions[function_name]["json_schema"]) + # self.functions_python[function_name] = available_functions[function_name]["python_function"] - msg = f"Added function {function_name}" - # self.save() - self.update_state() - printd(msg) - return msg + # msg = f"Added function {function_name}" + ## self.save() + # self.update_state() + # printd(msg) + # return msg def remove_function(self, function_name: str) -> str: - if function_name not in self.functions_python.keys(): - msg = f"Function {function_name} not loaded, ignoring" - printd(msg) - return msg + # TODO: refactor + raise NotImplementedError + # if function_name not in self.functions_python.keys(): + # msg = f"Function {function_name} not loaded, ignoring" + # printd(msg) + # return msg - # only allow removal of user defined functions - user_func_path = Path(USER_FUNCTIONS_DIR) - func_path = Path(inspect.getfile(self.functions_python[function_name])) - is_subpath = func_path.resolve().parts[: len(user_func_path.resolve().parts)] == user_func_path.resolve().parts + ## only allow removal of user defined functions + # user_func_path = Path(USER_FUNCTIONS_DIR) + # func_path = Path(inspect.getfile(self.functions_python[function_name])) + # is_subpath = func_path.resolve().parts[: len(user_func_path.resolve().parts)] == user_func_path.resolve().parts - if not is_subpath: - raise ValueError(f"Function {function_name} is not user defined and cannot be removed") + # if not is_subpath: + # raise ValueError(f"Function {function_name} is not user defined and cannot be removed") - self.functions = [f_schema for f_schema in self.functions if f_schema["name"] != function_name] - self.functions_python.pop(function_name) + # self.functions = [f_schema for f_schema in self.functions if f_schema["name"] != function_name] + # self.functions_python.pop(function_name) - msg = f"Removed function {function_name}" - # self.save() - self.update_state() - printd(msg) - return msg - - # def save(self): - # """Save agent state locally""" - - # new_agent_state = self.to_agent_state() - - # # without this, even after Agent.__init__, agent.config.state["messages"] will be None - # self.agent_state = new_agent_state - - # # Check if we need to create the agent - # if not self.ms.get_agent(agent_id=new_agent_state.id, user_id=new_agent_state.user_id, agent_name=new_agent_state.name): - # # print(f"Agent.save {new_agent_state.id} :: agent does not exist, creating...") - # self.ms.create_agent(agent=new_agent_state) - # # Otherwise, we should update the agent - # else: - # # print(f"Agent.save {new_agent_state.id} :: agent already exists, updating...") - # print(f"Agent.save {new_agent_state.id} :: preupdate:\n\tmessages={new_agent_state.state['messages']}") - # self.ms.update_agent(agent=new_agent_state) + # msg = f"Removed function {function_name}" + ## self.save() + # self.update_state() + # printd(msg) + # return msg def update_state(self) -> AgentState: - # updated_state = { - # "persona": self.memory.persona, - # "human": self.memory.human, - # "system": self.system, - # "functions": self.functions, - # "messages": [str(msg.id) for msg in self._messages], - # } memory = { "system": self.system, - "persona": self.memory.persona, - "human": self.memory.human, + "memory": self.memory.to_dict(), "messages": [str(msg.id) for msg in self._messages], # TODO: move out into AgentState.message_ids } - - # TODO: add this field - metadata = { # TODO - "human_name": self.agent_state.persona, - "persona_name": self.agent_state.human, - } - self.agent_state = AgentState( name=self.agent_state.name, user_id=self.agent_state.user_id, tools=self.agent_state.tools, system=self.system, - persona=self.agent_state.persona, # TODO: remove (stores persona_name) - human=self.agent_state.human, # TODO: remove (stores human_name) ## "model_state" llm_config=self.agent_state.llm_config, embedding_config=self.agent_state.embedding_config, - preset=self.agent_state.preset, id=self.agent_state.id, created_at=self.agent_state.created_at, ## "agent_state" state=memory, + _metadata=self.agent_state._metadata, ) return self.agent_state diff --git a/memgpt/agent_store/db.py b/memgpt/agent_store/db.py index d4220d45..c92de777 100644 --- a/memgpt/agent_store/db.py +++ b/memgpt/agent_store/db.py @@ -493,9 +493,7 @@ class PostgresStorageConnector(SQLStorageConnector): if isinstance(records[0], Passage): with self.engine.connect() as conn: db_records = [vars(record) for record in records] - # print("records", db_records) stmt = insert(self.db_model.__table__).values(db_records) - # print(stmt) if exists_ok: upsert_stmt = stmt.on_conflict_do_update( index_elements=["id"], set_={c.name: c for c in stmt.excluded} # Replace with your primary key column @@ -594,9 +592,7 @@ class SQLLiteStorageConnector(SQLStorageConnector): if isinstance(records[0], Passage): with self.engine.connect() as conn: db_records = [vars(record) for record in records] - # print("records", db_records) stmt = insert(self.db_model.__table__).values(db_records) - # print(stmt) if exists_ok: upsert_stmt = stmt.on_conflict_do_update( index_elements=["id"], set_={c.name: c for c in stmt.excluded} # Replace with your primary key column diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index d4c535fc..c7367306 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -13,16 +13,19 @@ import requests import typer import memgpt.utils as utils +from memgpt import create_client from memgpt.agent import Agent, save_agent from memgpt.cli.cli_config import configure from memgpt.config import MemGPTConfig from memgpt.constants import CLI_WARNING_PREFIX, MEMGPT_DIR from memgpt.credentials import MemGPTCredentials -from memgpt.data_types import AgentState, EmbeddingConfig, LLMConfig, User +from memgpt.data_types import EmbeddingConfig, LLMConfig, User from memgpt.log import get_logger +from memgpt.memory import ChatMemory from memgpt.metadata import MetadataStore from memgpt.migrate import migrate_all_agents, migrate_all_sources from memgpt.server.constants import WS_DEFAULT_PORT +from memgpt.server.server import logger as server_logger # from memgpt.interface import CLIInterface as interface # for printing to terminal from memgpt.streaming_interface import ( @@ -391,7 +394,6 @@ def run( persona: Annotated[Optional[str], typer.Option(help="Specify persona")] = None, agent: Annotated[Optional[str], typer.Option(help="Specify agent name")] = None, human: Annotated[Optional[str], typer.Option(help="Specify human")] = None, - preset: Annotated[Optional[str], typer.Option(help="Specify preset")] = None, # model flags model: Annotated[Optional[str], typer.Option(help="Specify the LLM model")] = None, model_wrapper: Annotated[Optional[str], typer.Option(help="Specify the LLM model wrapper")] = None, @@ -427,8 +429,10 @@ def run( if debug: logger.setLevel(logging.DEBUG) + server_logger.setLevel(logging.DEBUG) else: logger.setLevel(logging.CRITICAL) + server_logger.setLevel(logging.CRITICAL) from memgpt.migrate import ( VERSION_CUTOFF, @@ -511,8 +515,6 @@ def run( # read user id from config ms = MetadataStore(config) user = create_default_user_or_exit(config, ms) - human = human if human else config.human - persona = persona if persona else config.persona # determine agent to use, if not provided if not yes and not agent: @@ -529,6 +531,8 @@ def run( # create agent config agent_state = ms.get_agent(agent_name=agent, user_id=user.id) if agent else None + human = human if human else config.human + persona = persona if persona else config.persona if agent and agent_state: # use existing agent typer.secho(f"\nšŸ” Using existing agent {agent}", fg=typer.colors.GREEN) # agent_config = AgentConfig.load(agent) @@ -540,14 +544,6 @@ def run( # printd("Index path:", agent_config.save_agent_index_dir()) # persistence_manager = LocalStateManager(agent_config).load() # TODO: implement load # TODO: load prior agent state - if persona and persona != agent_state.persona: - typer.secho(f"{CLI_WARNING_PREFIX}Overriding existing persona {agent_state.persona} with {persona}", fg=typer.colors.YELLOW) - agent_state.persona = persona - # raise ValueError(f"Cannot override {agent_state.name} existing persona {agent_state.persona} with {persona}") - if human and human != agent_state.human: - typer.secho(f"{CLI_WARNING_PREFIX}Overriding existing human {agent_state.human} with {human}", fg=typer.colors.YELLOW) - agent_state.human = human - # raise ValueError(f"Cannot override {agent_config.name} existing human {agent_config.human} with {human}") # Allow overriding model specifics (model, model wrapper, model endpoint IP + type, context_window) if model and model != agent_state.llm_config.model: @@ -582,7 +578,12 @@ def run( # Update the agent with any overrides ms.update_agent(agent_state) - tools = [ms.get_tool(tool_name) for tool_name in agent_state.tools] + tools = [] + for tool_name in agent_state.tools: + tool = ms.get_tool(tool_name, agent_state.user_id) + if tool is None: + typer.secho(f"Couldn't find tool {tool_name} in database, please run `memgpt add tool`", fg=typer.colors.RED) + tools.append(tool) # create agent memgpt_agent = Agent(agent_state=agent_state, interface=interface(), tools=tools) @@ -626,45 +627,30 @@ def run( # create agent try: - preset_obj = ms.get_preset(name=preset if preset else config.preset, user_id=user.id) + client = create_client() human_obj = ms.get_human(human, user.id) persona_obj = ms.get_persona(persona, user.id) - if preset_obj is None: - # create preset records in metadata store - from memgpt.presets.presets import add_default_presets - - add_default_presets(user.id, ms) - # try again - preset_obj = ms.get_preset(name=preset if preset else config.preset, user_id=user.id) - if preset_obj is None: - typer.secho("Couldn't find presets in database, please run `memgpt configure`", fg=typer.colors.RED) - sys.exit(1) if human_obj is None: typer.secho("Couldn't find human {human} in database, please run `memgpt add human`", fg=typer.colors.RED) if persona_obj is None: typer.secho("Couldn't find persona {persona} in database, please run `memgpt add persona`", fg=typer.colors.RED) - # Overwrite fields in the preset if they were specified - preset_obj.human = ms.get_human(human, user.id).text - preset_obj.persona = ms.get_persona(persona, user.id).text + memory = ChatMemory(human=human_obj.text, persona=persona_obj.text) + metadata = {"human": human_obj.name, "persona": persona_obj.name} - typer.secho(f"-> šŸ¤– Using persona profile: '{preset_obj.persona_name}'", fg=typer.colors.WHITE) - typer.secho(f"-> šŸ§‘ Using human profile: '{preset_obj.human_name}'", fg=typer.colors.WHITE) + typer.secho(f"-> šŸ¤– Using persona profile: '{persona_obj.name}'", fg=typer.colors.WHITE) + typer.secho(f"-> šŸ§‘ Using human profile: '{human_obj.name}'", fg=typer.colors.WHITE) - agent_state = AgentState( + # add tools + agent_state = client.create_agent( name=agent_name, - user_id=user.id, - tools=list([schema["name"] for schema in preset_obj.functions_schema]), - system=preset_obj.system, - llm_config=llm_config, embedding_config=embedding_config, - human=preset_obj.human, - persona=preset_obj.persona, - preset=preset_obj.name, - state={"messages": None, "persona": preset_obj.persona, "human": preset_obj.human}, + llm_config=llm_config, + memory=memory, + metadata=metadata, ) typer.secho(f"-> šŸ› ļø {len(agent_state.tools)} tools: {', '.join([t for t in agent_state.tools])}", fg=typer.colors.WHITE) - tools = [ms.get_tool(tool_name) for tool_name in agent_state.tools] + tools = [ms.get_tool(tool_name, user_id=client.user_id) for tool_name in agent_state.tools] memgpt_agent = Agent( interface=interface(), diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index 188a7b30..3f73205b 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -38,8 +38,7 @@ from memgpt.local_llm.constants import ( ) from memgpt.local_llm.utils import get_available_wrappers from memgpt.metadata import MetadataStore -from memgpt.models.pydantic_models import HumanModel, PersonaModel -from memgpt.presets.presets import create_preset_from_file +from memgpt.models.pydantic_models import PersonaModel from memgpt.server.utils import shorten_key_middle app = typer.Typer() @@ -1085,18 +1084,12 @@ def configure(): else: ms.create_user(user) - # create preset records in metadata store - from memgpt.presets.presets import add_default_presets - - add_default_presets(user_id, ms) - class ListChoice(str, Enum): agents = "agents" humans = "humans" personas = "personas" sources = "sources" - presets = "presets" @app.command() @@ -1133,7 +1126,7 @@ def list(arg: Annotated[ListChoice, typer.Argument]): elif arg == ListChoice.humans: """List all humans""" table.field_names = ["Name", "Text"] - for human in client.list_humans(user_id=user_id): + for human in client.list_humans(): table.add_row([human.name, human.text.replace("\n", "")[:100]]) print(table) elif arg == ListChoice.personas: @@ -1170,21 +1163,6 @@ def list(arg: Annotated[ListChoice, typer.Argument]): ) print(table) - elif arg == ListChoice.presets: - """List all available presets""" - table.field_names = ["Name", "Description", "Sources", "Functions"] - for preset in ms.list_presets(user_id=user_id): - sources = ms.get_preset_sources(preset_id=preset.id) - table.add_row( - [ - preset.name, - preset.description, - ",".join([source.name for source in sources]), - # json.dumps(preset.functions_schema, indent=4) - ",\n".join([f["name"] for f in preset.functions_schema]), - ] - ) - print(table) else: raise ValueError(f"Unknown argument {arg}") @@ -1208,7 +1186,7 @@ def add( with open(filename, "r", encoding="utf-8") as f: text = f.read() if option == "persona": - persona = ms.get_persona(name=name, user_id=user_id) + persona = ms.get_persona(name=name) if persona: # config if user wants to overwrite if not questionary.confirm(f"Persona {name} already exists. Overwrite?").ask(): @@ -1220,7 +1198,7 @@ def add( ms.add_persona(persona) elif option == "human": - human = client.get_human(name=name, user_id=user_id) + human = client.get_human(name=name) if human: # config if user wants to overwrite if not questionary.confirm(f"Human {name} already exists. Overwrite?").ask(): @@ -1228,11 +1206,7 @@ def add( human.text = text client.update_human(human) else: - human = HumanModel(name=name, text=text, user_id=user_id) - client.add_human(HumanModel(name=name, text=text, user_id=user_id)) - elif option == "preset": - assert filename, "Must specify filename for preset" - create_preset_from_file(filename, name, user_id, ms) + human = client.create_human(name=name, human=text) else: raise ValueError(f"Unknown kind {option}") @@ -1281,18 +1255,14 @@ def delete(option: str, name: str): ms.delete_agent(agent_id=agent.id) elif option == "human": - human = client.get_human(name=name, user_id=user_id) + human = client.get_human(name=name) assert human is not None, f"Human {name} does not exist" - client.delete_human(name=name, user_id=user_id) + client.delete_human(name=name) elif option == "persona": - persona = ms.get_persona(name=name, user_id=user_id) + persona = ms.get_persona(name=name) assert persona is not None, f"Persona {name} does not exist" - ms.delete_persona(name=name, user_id=user_id) - assert ms.get_persona(name=name, user_id=user_id) is None, f"Persona {name} still exists" - elif option == "preset": - preset = ms.get_preset(name=name, user_id=user_id) - assert preset is not None, f"Preset {name} does not exist" - ms.delete_preset(name=name, user_id=user_id) + ms.delete_persona(name=name) + assert ms.get_persona(name=name) is None, f"Persona {name} still exists" else: raise ValueError(f"Option {option} not implemented") diff --git a/memgpt/client/admin.py b/memgpt/client/admin.py index b914c7f7..d5d01fff 100644 --- a/memgpt/client/admin.py +++ b/memgpt/client/admin.py @@ -43,7 +43,6 @@ class Admin: def create_key(self, user_id: uuid.UUID, key_name: str): payload = {"user_id": str(user_id), "key_name": key_name} response = requests.post(f"{self.base_url}/admin/users/keys", headers=self.headers, json=payload) - print(response.json()) if response.status_code != 200: raise HTTPError(response.json()) return CreateAPIKeyResponse(**response.json()) @@ -53,7 +52,6 @@ class Admin: response = requests.get(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers) if response.status_code != 200: raise HTTPError(response.json()) - print(response.text, response.status_code) return GetAPIKeysResponse(**response.json()).api_key_list def delete_key(self, api_key: str): @@ -114,6 +112,11 @@ class Admin: source_type = "python" json_schema["name"] + if "memory" in tags: + # special modifications to memory functions + # self.memory -> self.memory.memory, since Agent.memory.memory needs to be modified (not BaseMemory.memory) + source_code = source_code.replace("self.memory", "self.memory.memory") + # create data data = {"source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema} CreateToolRequest(**data) # validate diff --git a/memgpt/client/client.py b/memgpt/client/client.py index 047ab7da..d37aad03 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -6,23 +6,17 @@ from typing import Dict, List, Optional, Tuple, Union import requests from memgpt.config import MemGPTConfig -from memgpt.constants import BASE_TOOLS, DEFAULT_PRESET +from memgpt.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET from memgpt.data_sources.connectors import DataConnector -from memgpt.data_types import ( - AgentState, - EmbeddingConfig, - LLMConfig, - Preset, - Source, - User, -) +from memgpt.data_types import AgentState, EmbeddingConfig, LLMConfig, Preset, Source from memgpt.functions.functions import parse_source_code from memgpt.functions.schema_generator import generate_schema -from memgpt.metadata import MetadataStore +from memgpt.memory import BaseMemory, ChatMemory, get_memory_functions from memgpt.models.pydantic_models import ( HumanModel, JobModel, JobStatus, + LLMConfigModel, PersonaModel, PresetModel, SourceModel, @@ -32,6 +26,7 @@ from memgpt.server.rest_api.agents.command import CommandResponse from memgpt.server.rest_api.agents.config import GetAgentResponse from memgpt.server.rest_api.agents.index import CreateAgentResponse, ListAgentsResponse from memgpt.server.rest_api.agents.memory import ( + ArchivalMemoryObject, GetAgentArchivalMemoryResponse, GetAgentMemoryResponse, InsertAgentArchivalMemoryResponse, @@ -56,6 +51,7 @@ from memgpt.server.rest_api.sources.index import ListSourcesResponse # import pydantic response objects from memgpt.server.rest_api from memgpt.server.rest_api.tools.index import CreateToolRequest, ListToolsResponse from memgpt.server.server import SyncServer +from memgpt.utils import get_human_text def create_client(base_url: Optional[str] = None, token: Optional[str] = None): @@ -242,8 +238,6 @@ class RESTClient(AbstractClient): def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool: response = requests.get(f"{self.base_url}/api/agents/{str(agent_id)}/config", headers=self.headers) - print(response.text, response.status_code) - print(response) if response.status_code == 404: # not found error return False @@ -252,17 +246,24 @@ class RESTClient(AbstractClient): else: raise ValueError(f"Failed to check if agent exists: {response.text}") + def get_tool(self, tool_name: str): + response = requests.get(f"{self.base_url}/api/tools/{tool_name}", headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to get tool: {response.text}") + return ToolModel(**response.json()) + def create_agent( self, name: Optional[str] = None, preset: Optional[str] = None, # TODO: this should actually be re-named preset_name - persona: Optional[str] = None, - human: Optional[str] = None, embedding_config: Optional[EmbeddingConfig] = None, llm_config: Optional[LLMConfig] = None, + # memory + memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_human_text(DEFAULT_PERSONA)), # tools tools: Optional[List[str]] = None, include_base_tools: Optional[bool] = True, + metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA}, ) -> AgentState: """ Create an agent @@ -274,7 +275,6 @@ class RESTClient(AbstractClient): Returns: agent_state (AgentState): State of the the created agent. - """ if embedding_config or llm_config: raise ValueError("Cannot override embedding_config or llm_config when creating agent via REST API") @@ -286,14 +286,22 @@ class RESTClient(AbstractClient): if include_base_tools: tool_names += BASE_TOOLS + # add memory tools + memory_functions = get_memory_functions(memory) + for func_name, func in memory_functions.items(): + tool = self.create_tool(func, name=func_name, tags=["memory", "memgpt-base"], update=True) + tool_names.append(tool.name) + # TODO: distinguish between name and objects + # TODO: add metadata payload = { "config": { "name": name, "preset": preset, - "persona": persona, - "human": human, + "persona": memory.memory["persona"].value, + "human": memory.memory["human"].value, "function_names": tool_names, + "metadata": metadata, } } response = requests.post(f"{self.base_url}/api/agents", json=payload, headers=self.headers) @@ -322,14 +330,12 @@ class RESTClient(AbstractClient): id=response.agent_state.id, name=response.agent_state.name, user_id=response.agent_state.user_id, - preset=response.agent_state.preset, - persona=response.agent_state.persona, - human=response.agent_state.human, llm_config=llm_config, embedding_config=embedding_config, state=response.agent_state.state, system=response.agent_state.system, tools=response.agent_state.tools, + _metadata=response.agent_state.metadata, # load datetime from timestampe created_at=datetime.datetime.fromtimestamp(response.agent_state.created_at, tz=datetime.timezone.utc), ) @@ -352,25 +358,8 @@ class RESTClient(AbstractClient): response_obj = GetAgentResponse(**response.json()) return self.get_agent_response_to_state(response_obj) - ## presets - # def create_preset(self, preset: Preset) -> CreatePresetResponse: - # # TODO should the arg type here be PresetModel, not Preset? - # payload = CreatePresetsRequest( - # id=str(preset.id), - # name=preset.name, - # description=preset.description, - # system=preset.system, - # persona=preset.persona, - # human=preset.human, - # persona_name=preset.persona_name, - # human_name=preset.human_name, - # functions_schema=preset.functions_schema, - # ) - # response = requests.post(f"{self.base_url}/api/presets", json=payload.model_dump(), headers=self.headers) - # assert response.status_code == 200, f"Failed to create preset: {response.text}" - # return CreatePresetResponse(**response.json()) - def get_preset(self, name: str) -> PresetModel: + # TODO: remove response = requests.get(f"{self.base_url}/api/presets/{name}", headers=self.headers) assert response.status_code == 200, f"Failed to get preset: {response.text}" return PresetModel(**response.json()) @@ -385,6 +374,7 @@ class RESTClient(AbstractClient): tools: Optional[List[ToolModel]] = None, default_tools: bool = True, ) -> PresetModel: + # TODO: remove """Create an agent preset :param name: Name of the preset @@ -406,7 +396,6 @@ class RESTClient(AbstractClient): schema = [] if tools: for tool in tools: - print("CUSOTM TOOL", tool.json_schema) schema.append(tool.json_schema) # include default tools @@ -427,9 +416,6 @@ class RESTClient(AbstractClient): human_name=human_name, functions_schema=schema, ) - print(schema) - print(human_name, persona_name, system_name, name) - print(payload.model_dump()) response = requests.post(f"{self.base_url}/api/presets", json=payload.model_dump(), headers=self.headers) assert response.status_code == 200, f"Failed to create preset: {response.text}" return CreatePresetResponse(**response.json()).preset @@ -482,7 +468,6 @@ class RESTClient(AbstractClient): response = requests.post(f"{self.base_url}/api/agents/{agent_id}/archival", json={"content": memory}, headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to insert archival memory: {response.text}") - print(response.json()) return InsertAgentArchivalMemoryResponse(**response.json()) def delete_archival_memory(self, agent_id: uuid.UUID, memory_id: uuid.UUID): @@ -518,8 +503,6 @@ class RESTClient(AbstractClient): response = requests.post(f"{self.base_url}/api/humans", json=data, headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to create human: {response.text}") - - print(response.json()) return HumanModel(**response.json()) def list_personas(self) -> ListPersonasResponse: @@ -531,9 +514,24 @@ class RESTClient(AbstractClient): response = requests.post(f"{self.base_url}/api/personas", json=data, headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to create persona: {response.text}") - print(response.json()) return PersonaModel(**response.json()) + def get_persona(self, name: str) -> PersonaModel: + response = requests.get(f"{self.base_url}/api/personas/{name}", headers=self.headers) + if response.status_code == 404: + return None + elif response.status_code != 200: + raise ValueError(f"Failed to get persona: {response.text}") + return PersonaModel(**response.json()) + + def get_human(self, name: str) -> HumanModel: + response = requests.get(f"{self.base_url}/api/humans/{name}", headers=self.headers) + if response.status_code == 404: + return None + elif response.status_code != 200: + raise ValueError(f"Failed to get human: {response.text}") + return HumanModel(**response.json()) + # sources def list_sources(self): @@ -638,7 +636,7 @@ class RESTClient(AbstractClient): json_schema["name"] # create data - data = {"source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema} + data = {"source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema, "update": update} try: CreateToolRequest(**data) # validate data except Exception as e: @@ -694,23 +692,12 @@ class LocalClient(AbstractClient): else: self.user_id = uuid.UUID(config.anon_clientid) - # create user if does not exist - ms = MetadataStore(config) - self.user = User(id=self.user_id) - if ms.get_user(self.user_id): - # update user - ms.update_user(self.user) - else: - ms.create_user(self.user) - - # create preset records in metadata store - from memgpt.presets.presets import add_default_presets - - add_default_presets(self.user_id, ms) - self.interface = QueuingInterface(debug=debug) self.server = SyncServer(default_interface_factory=lambda: self.interface) + # create user if does not exist + self.server.create_user({"id": self.user_id}, exists_ok=True) + # messages def send_message(self, agent_id: uuid.UUID, message: str, role: str, stream: Optional[bool] = False) -> UserMessageResponse: self.interface.clear() @@ -740,14 +727,16 @@ class LocalClient(AbstractClient): def create_agent( self, name: Optional[str] = None, - preset: Optional[str] = None, # TODO: this should actually be re-named preset_name - persona: Optional[str] = None, - human: Optional[str] = None, + # model configs embedding_config: Optional[EmbeddingConfig] = None, llm_config: Optional[LLMConfig] = None, + # memory + memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_human_text(DEFAULT_PERSONA)), # tools tools: Optional[List[str]] = None, include_base_tools: Optional[bool] = True, + # metadata + metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA}, ) -> AgentState: if name and self.agent_exists(agent_name=name): raise ValueError(f"Agent with name {name} already exists (user_id={self.user_id})") @@ -759,23 +748,35 @@ class LocalClient(AbstractClient): if include_base_tools: tool_names += BASE_TOOLS + # add memory tools + memory_functions = get_memory_functions(memory) + for func_name, func in memory_functions.items(): + tool = self.create_tool(func, name=func_name, tags=["memory", "memgpt-base"]) + tool_names.append(tool.name) + self.interface.clear() + + # create agent agent_state = self.server.create_agent( user_id=self.user_id, name=name, - preset=preset, - persona=persona, - human=human, + memory=memory, llm_config=llm_config, embedding_config=embedding_config, tools=tool_names, + metadata=metadata, ) return agent_state + def rename_agent(self, agent_id: uuid.UUID, new_name: str): + # TODO: check valid name + agent_state = self.server.rename_agent(user_id=self.user_id, agent_id=agent_id, new_agent_name=new_name) + return agent_state + def delete_agent(self, agent_id: uuid.UUID): self.server.delete_agent(user_id=self.user_id, agent_id=agent_id) - def get_agent_config(self, agent_id: str) -> AgentState: + def get_agent(self, agent_id: uuid.UUID) -> AgentState: self.interface.clear() return self.server.get_agent_config(user_id=self.user_id, agent_id=agent_id) @@ -803,6 +804,19 @@ class LocalClient(AbstractClient): # agent interactions + def send_message(self, agent_id: uuid.UUID, message: str, role: str, stream: Optional[bool] = False) -> UserMessageResponse: + self.interface.clear() + if role == "system": + usage = self.server.system_message(user_id=self.user_id, agent_id=agent_id, message=message) + elif role == "user": + usage = self.server.user_message(user_id=self.user_id, agent_id=agent_id, message=message) + else: + raise ValueError(f"Role {role} not supported") + if self.auto_save: + self.save() + else: + return UserMessageResponse(messages=self.interface.to_list(), usage=usage) + def user_message(self, agent_id: str, message: str) -> Union[List[Dict], Tuple[List[Dict], int]]: self.interface.clear() usage = self.server.user_message(user_id=self.user_id, agent_id=agent_id, message=message) @@ -820,36 +834,37 @@ class LocalClient(AbstractClient): # archival memory - def get_agent_archival_memory( - self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000 - ): - _, archival_json_records = self.server.get_agent_archival_cursor( - user_id=self.user_id, - agent_id=agent_id, - after=after, - before=before, - limit=limit, - ) - return archival_json_records - - # messages - # humans / personas - def list_humans(self, user_id: uuid.UUID): - return self.server.list_humans(user_id=user_id if user_id else self.user_id) + def create_human(self, name: str, human: str): + return self.server.add_human(HumanModel(name=name, text=human, user_id=self.user_id)) - def get_human(self, name: str, user_id: uuid.UUID): - return self.server.get_human(name=name, user_id=user_id) + def create_persona(self, name: str, persona: str): + return self.server.add_persona(PersonaModel(name=name, text=persona, user_id=self.user_id)) - def add_human(self, human: HumanModel): - return self.server.add_human(human=human) + def list_humans(self): + return self.server.list_humans(user_id=self.user_id if self.user_id else self.user_id) + + def get_human(self, name: str): + return self.server.get_human(name=name, user_id=self.user_id) def update_human(self, human: HumanModel): return self.server.update_human(human=human) - def delete_human(self, name: str, user_id: uuid.UUID): - return self.server.delete_human(name, user_id) + def delete_human(self, name: str): + return self.server.delete_human(name, self.user_id) + + def list_personas(self): + return self.server.list_personas(user_id=self.user_id) + + def get_persona(self, name: str): + return self.server.get_persona(name=name, user_id=self.user_id) + + def update_persona(self, persona: PersonaModel): + return self.server.update_persona(persona=persona) + + def delete_persona(self, name: str): + return self.server.delete_persona(name, self.user_id) # tools def create_tool( @@ -879,6 +894,11 @@ class LocalClient(AbstractClient): source_type = "python" tool_name = json_schema["name"] + if "memory" in tags: + # special modifications to memory functions + # self.memory -> self.memory.memory, since Agent.memory.memory needs to be modified (not BaseMemory.memory) + source_code = source_code.replace("self.memory", "self.memory.memory") + # check if already exists: existing_tool = self.server.ms.get_tool(tool_name, self.user_id) if existing_tool: @@ -924,3 +944,48 @@ class LocalClient(AbstractClient): def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID): self.server.attach_source_to_agent(user_id=self.user_id, source_id=source_id, agent_id=agent_id) + + def get_agent_archival_memory( + self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000 + ): + self.interface.clear() + # TODO need to add support for non-postgres here + # chroma will throw: + # raise ValueError("Cannot run get_all_cursor with chroma") + _, archival_json_records = self.server.get_agent_archival_cursor( + user_id=self.user_id, + agent_id=agent_id, + after=after, + before=before, + limit=limit, + ) + archival_memory_objects = [ArchivalMemoryObject(id=passage["id"], contents=passage["text"]) for passage in archival_json_records] + return GetAgentArchivalMemoryResponse(archival_memory=archival_memory_objects) + + def insert_archival_memory(self, agent_id: uuid.UUID, memory: str) -> GetAgentArchivalMemoryResponse: + memory_ids = self.server.insert_archival_memory(user_id=self.user_id, agent_id=agent_id, memory_contents=memory) + return InsertAgentArchivalMemoryResponse(ids=memory_ids) + + def delete_archival_memory(self, agent_id: uuid.UUID, memory_id: uuid.UUID): + self.server.delete_archival_memory(user_id=self.user_id, agent_id=agent_id, memory_id=memory_id) + + def get_messages( + self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000 + ) -> GetAgentMessagesResponse: + self.interface.clear() + [_, messages] = self.server.get_agent_recall_cursor( + user_id=self.user_id, agent_id=agent_id, before=before, limit=limit, reverse=True + ) + return GetAgentMessagesResponse(messages=messages) + + def list_models(self) -> ListModelsResponse: + + llm_config = LLMConfigModel( + model=self.server.server_llm_config.model, + model_endpoint=self.server.server_llm_config.model_endpoint, + model_endpoint_type=self.server.server_llm_config.model_endpoint_type, + model_wrapper=self.server.server_llm_config.model_wrapper, + context_window=self.server.server_llm_config.context_window, + ) + + return ListModelsResponse(models=[llm_config]) diff --git a/memgpt/config.py b/memgpt/config.py index 7e64a823..cfda1a70 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -38,7 +38,7 @@ class MemGPTConfig: anon_clientid: str = str(uuid.UUID(int=0)) # preset - preset: str = DEFAULT_PRESET + preset: str = DEFAULT_PRESET # TODO: rename to system prompt # persona parameters persona: str = DEFAULT_PERSONA diff --git a/memgpt/constants.py b/memgpt/constants.py index 86f8f423..df10ed67 100644 --- a/memgpt/constants.py +++ b/memgpt/constants.py @@ -24,8 +24,6 @@ DEFAULT_PRESET = "memgpt_chat" # Tools BASE_TOOLS = [ "send_message", - "core_memory_replace", - "core_memory_append", "pause_heartbeats", "conversation_search", "conversation_search_date", diff --git a/memgpt/data_types.py b/memgpt/data_types.py index 1e3dcb24..80af9570 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -763,22 +763,14 @@ class AgentState: # system prompt system: str, # config - persona: str, # the filename where the persona was originally sourced from # TODO: remove - human: str, # the filename where the human was originally sourced from # TODO: remove llm_config: LLMConfig, embedding_config: EmbeddingConfig, - preset: str, # TODO: remove # (in-context) state contains: - # persona: str # the current persona text - # human: str # the current human text - # system: str, # system prompt (not required if initializing with a preset) - # functions: dict, # schema definitions ONLY (function code linked at runtime) - # messages: List[dict], # in-context messages id: Optional[uuid.UUID] = None, state: Optional[dict] = None, created_at: Optional[datetime] = None, # messages (TODO: implement this) - # _metadata: Optional[dict] = None, + _metadata: Optional[dict] = None, ): if id is None: self.id = uuid.uuid4() @@ -792,11 +784,8 @@ class AgentState: self.name = name assert self.name, f"AgentState name must be a non-empty string" self.user_id = user_id - self.preset = preset # The INITIAL values of the persona and human # The values inside self.state['persona'], self.state['human'] are the CURRENT values - self.persona = persona - self.human = human self.llm_config = llm_config self.embedding_config = embedding_config @@ -811,9 +800,10 @@ class AgentState: # system self.system = system + assert self.system is not None, f"Must provide system prompt, cannot be None" # metadata - # self._metadata = _metadata + self._metadata = _metadata class Source: @@ -866,6 +856,7 @@ class Token: class Preset(BaseModel): + # TODO: remove Preset name: str = Field(..., description="The name of the preset.") id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.") user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.") diff --git a/memgpt/functions/function_sets/base.py b/memgpt/functions/function_sets/base.py index dcae58bc..9c751d2c 100644 --- a/memgpt/functions/function_sets/base.py +++ b/memgpt/functions/function_sets/base.py @@ -56,39 +56,6 @@ def pause_heartbeats(self: Agent, minutes: int) -> Optional[str]: pause_heartbeats.__doc__ = pause_heartbeats_docstring -def core_memory_append(self: Agent, name: str, content: str) -> Optional[str]: - """ - Append to the contents of core memory. - - Args: - name (str): Section of the memory to be edited (persona or human). - content (str): Content to write to the memory. All unicode (including emojis) are supported. - - Returns: - Optional[str]: None is always returned as this function does not produce a response. - """ - self.memory.edit_append(name, content) - self.rebuild_memory() - return None - - -def core_memory_replace(self: Agent, name: str, old_content: str, new_content: str) -> Optional[str]: - """ - Replace the contents of core memory. To delete memories, use an empty string for new_content. - - Args: - name (str): Section of the memory to be edited (persona or human). - old_content (str): String to replace. Must be an exact match. - new_content (str): Content to write to the memory. All unicode (including emojis) are supported. - - Returns: - Optional[str]: None is always returned as this function does not produce a response. - """ - self.memory.edit_replace(name, old_content, new_content) - self.rebuild_memory() - return None - - def conversation_search(self: Agent, query: str, page: Optional[int] = 0) -> Optional[str]: """ Search prior conversation history using case-insensitive string matching. diff --git a/memgpt/functions/functions.py b/memgpt/functions/functions.py index c7a794da..d83fdc0f 100644 --- a/memgpt/functions/functions.py +++ b/memgpt/functions/functions.py @@ -2,7 +2,6 @@ import importlib import inspect import os import sys -import warnings from textwrap import dedent # remove indentation from types import ModuleType @@ -104,97 +103,3 @@ def load_function_file(filepath: str) -> dict: # load all functions in the module function_dict = load_function_set(module) return function_dict - - -def load_all_function_sets(merge: bool = True, ignore_duplicates: bool = True) -> dict: - from memgpt.utils import printd - - # functions/examples/*.py - scripts_dir = os.path.dirname(os.path.abspath(__file__)) # Get the directory of the current script - function_sets_dir = os.path.join(scripts_dir, "function_sets") # Path to the function_sets directory - # List all .py files in the directory (excluding __init__.py) - example_module_files = [f for f in os.listdir(function_sets_dir) if f.endswith(".py") and f != "__init__.py"] - - # ~/.memgpt/functions/*.py - # create if missing - if not os.path.exists(USER_FUNCTIONS_DIR): - os.makedirs(USER_FUNCTIONS_DIR) - user_module_files = [f for f in os.listdir(USER_FUNCTIONS_DIR) if f.endswith(".py") and f != "__init__.py"] - - # combine them both (pull from both examples and user-provided) - # all_module_files = example_module_files + user_module_files - - # Add user_scripts_dir to sys.path - if USER_FUNCTIONS_DIR not in sys.path: - sys.path.append(USER_FUNCTIONS_DIR) - - schemas_and_functions = {} - for dir_path, module_files in [(function_sets_dir, example_module_files), (USER_FUNCTIONS_DIR, user_module_files)]: - for file in module_files: - tags = [] - module_name = file[:-3] # Remove '.py' from filename - if dir_path == USER_FUNCTIONS_DIR: - # For user scripts, adjust the module name appropriately - module_full_path = os.path.join(dir_path, file) - printd(f"Loading user function set from '{module_full_path}'") - try: - spec = importlib.util.spec_from_file_location(module_name, module_full_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - except ModuleNotFoundError as e: - # Handle missing module imports - missing_package = str(e).split("'")[1] # Extract the name of the missing package - printd(f"{CLI_WARNING_PREFIX}skipped loading python file '{module_full_path}'!") - printd( - f"'{file}' imports '{missing_package}', but '{missing_package}' is not installed locally - install python package '{missing_package}' to link functions from '{file}' to MemGPT." - ) - continue - except SyntaxError as e: - # Handle syntax errors in the module - printd(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}' due to a syntax error: {e}") - continue - except Exception as e: - # Handle other general exceptions - printd(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}': {e}") - continue - else: - # For built-in scripts, use the existing method - full_module_name = f"memgpt.functions.function_sets.{module_name}" - tags.append(f"memgpt-{module_name}") - try: - module = importlib.import_module(full_module_name) - except Exception as e: - # Handle other general exceptions - printd(f"{CLI_WARNING_PREFIX}skipped loading python module '{full_module_name}': {e}") - continue - - try: - # Load the function set - function_set = load_function_set(module) - # Add the metadata tags - for k, v in function_set.items(): - # print(function_set) - v["tags"] = tags - schemas_and_functions[module_name] = function_set - except ValueError as e: - err = f"Error loading function set '{module_name}': {e}" - printd(err) - warnings.warn(err) - - if merge: - # Put all functions from all sets into the same level dict - merged_functions = {} - for set_name, function_set in schemas_and_functions.items(): - for function_name, function_info in function_set.items(): - if function_name in merged_functions: - err_msg = f"Duplicate function name '{function_name}' found in function set '{set_name}'" - if ignore_duplicates: - warnings.warn(err_msg, category=UserWarning, stacklevel=2) - else: - raise ValueError(err_msg) - else: - merged_functions[function_name] = function_info - return merged_functions - else: - # Nested dict where the top level is organized by the function set name - return schemas_and_functions diff --git a/memgpt/memory.py b/memgpt/memory.py index a3f6b3f9..f405ebd4 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -3,6 +3,8 @@ import uuid from abc import ABC, abstractmethod from typing import List, Optional, Tuple, Union +from pydantic import BaseModel, validator + from memgpt.constants import MESSAGE_SUMMARY_REQUEST_ACK, MESSAGE_SUMMARY_WARNING_FRAC from memgpt.data_types import AgentState, Message, Passage from memgpt.embeddings import embedding_model, parse_and_chunk_text, query_embedding @@ -16,96 +18,214 @@ from memgpt.utils import ( validate_date_format, ) -# from llama_index import Document -# from llama_index.node_parser import SimpleNodeParser + +class MemoryModule(BaseModel): + """Base class for memory modules""" + + description: Optional[str] = None + limit: int = 2000 + value: Optional[Union[List[str], str]] = None + + def __setattr__(self, name, value): + """Run validation if self.value is updated""" + super().__setattr__(name, value) + if name == "value": + # run validation + self.__class__.validate(self.dict(exclude_unset=True)) + + @validator("value", always=True) + def check_value_length(cls, v, values): + if v is not None: + # Fetching the limit from the values dictionary + limit = values.get("limit", 2000) # Default to 2000 if limit is not yet set + + # Check if the value exceeds the limit + if isinstance(v, str): + length = len(v) + elif isinstance(v, list): + length = sum(len(item) for item in v) + else: + raise ValueError("Value must be either a string or a list of strings.") + + if length > limit: + error_msg = f"Edit failed: Exceeds {limit} character limit (requested {length})." + # TODO: add archival memory error? + raise ValueError(error_msg) + return v + + def __len__(self): + return len(str(self)) + + def __str__(self) -> str: + if isinstance(self.value, list): + return ",".join(self.value) + elif isinstance(self.value, str): + return self.value + else: + return "" -class CoreMemory(object): - """Held in-context inside the system message +class BaseMemory: - Core Memory: Refers to the system block, which provides essential, foundational context to the AI. - This includes the persona information, essential user details, - and any other baseline data you deem necessary for the AI's basic functioning. - """ - - def __init__(self, persona=None, human=None, persona_char_limit=None, human_char_limit=None, archival_memory_exists=True): - self.persona = persona - self.human = human - self.persona_char_limit = persona_char_limit - self.human_char_limit = human_char_limit - - # affects the error message the AI will see on overflow inserts - self.archival_memory_exists = archival_memory_exists - - def __repr__(self) -> str: - return f"\n### CORE MEMORY ###" + f"\n=== Persona ===\n{self.persona}" + f"\n\n=== Human ===\n{self.human}" - - def to_dict(self): - return { - "persona": self.persona, - "human": self.human, - } + def __init__(self): + self.memory = {} @classmethod - def load(cls, state): - return cls(state["persona"], state["human"]) + def load(cls, state: dict): + """Load memory from dictionary object""" + obj = cls() + for key, value in state.items(): + obj.memory[key] = MemoryModule(**value) + return obj - def edit_persona(self, new_persona): - if self.persona_char_limit and len(new_persona) > self.persona_char_limit: - error_msg = f"Edit failed: Exceeds {self.persona_char_limit} character limit (requested {len(new_persona)})." - if self.archival_memory_exists: - error_msg = f"{error_msg} Consider summarizing existing core memories in 'persona' and/or moving lower priority content to archival memory to free up space in core memory, then trying again." - raise ValueError(error_msg) + def __str__(self) -> str: + """Representation of the memory in-context""" + section_strs = [] + for section, module in self.memory.items(): + section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n') + return "\n".join(section_strs) - self.persona = new_persona - return len(self.persona) + def to_dict(self): + """Convert to dictionary representation""" + return {key: value.dict() for key, value in self.memory.items()} - def edit_human(self, new_human): - if self.human_char_limit and len(new_human) > self.human_char_limit: - error_msg = f"Edit failed: Exceeds {self.human_char_limit} character limit (requested {len(new_human)})." - if self.archival_memory_exists: - error_msg = f"{error_msg} Consider summarizing existing core memories in 'human' and/or moving lower priority content to archival memory to free up space in core memory, then trying again." - raise ValueError(error_msg) - self.human = new_human - return len(self.human) +class ChatMemory(BaseMemory): - def edit(self, field, content): - if field == "persona": - return self.edit_persona(content) - elif field == "human": - return self.edit_human(content) - else: - raise KeyError(f'No memory section named {field} (must be either "persona" or "human")') + def __init__(self, persona: str, human: str, limit: int = 2000): + self.memory = { + "persona": MemoryModule(name="persona", value=persona, limit=limit), + "human": MemoryModule(name="human", value=human, limit=limit), + } - def edit_append(self, field, content, sep="\n"): - if field == "persona": - new_content = self.persona + sep + content - return self.edit_persona(new_content) - elif field == "human": - new_content = self.human + sep + content - return self.edit_human(new_content) - else: - raise KeyError(f'No memory section named {field} (must be either "persona" or "human")') + def core_memory_append(self, name: str, content: str) -> Optional[str]: + """ + Append to the contents of core memory. - def edit_replace(self, field, old_content, new_content): - if len(old_content) == 0: - raise ValueError("old_content cannot be an empty string (must specify old_content to replace)") + Args: + name (str): Section of the memory to be edited (persona or human). + content (str): Content to write to the memory. All unicode (including emojis) are supported. - if field == "persona": - if old_content in self.persona: - new_persona = self.persona.replace(old_content, new_content) - return self.edit_persona(new_persona) - else: - raise ValueError("Content not found in persona (make sure to use exact string)") - elif field == "human": - if old_content in self.human: - new_human = self.human.replace(old_content, new_content) - return self.edit_human(new_human) - else: - raise ValueError("Content not found in human (make sure to use exact string)") - else: - raise KeyError(f'No memory section named {field} (must be either "persona" or "human")') + Returns: + Optional[str]: None is always returned as this function does not produce a response. + """ + self.memory[name].value += "\n" + content + return None + + def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]: + """ + Replace the contents of core memory. To delete memories, use an empty string for new_content. + + Args: + name (str): Section of the memory to be edited (persona or human). + old_content (str): String to replace. Must be an exact match. + new_content (str): Content to write to the memory. All unicode (including emojis) are supported. + + Returns: + Optional[str]: None is always returned as this function does not produce a response. + """ + self.memory[name].value = self.memory[name].value.replace(old_content, new_content) + return None + + +def get_memory_functions(cls: BaseMemory) -> List[callable]: + """Get memory functions for a memory class""" + functions = {} + for func_name in dir(cls): + if func_name.startswith("_") or func_name in ["load", "to_dict"]: # skip base functions + continue + func = getattr(cls, func_name) + if callable(func): + functions[func_name] = func + return functions + + +# class CoreMemory(object): +# """Held in-context inside the system message +# +# Core Memory: Refers to the system block, which provides essential, foundational context to the AI. +# This includes the persona information, essential user details, +# and any other baseline data you deem necessary for the AI's basic functioning. +# """ +# +# def __init__(self, persona=None, human=None, persona_char_limit=None, human_char_limit=None, archival_memory_exists=True): +# self.persona = persona +# self.human = human +# self.persona_char_limit = persona_char_limit +# self.human_char_limit = human_char_limit +# +# # affects the error message the AI will see on overflow inserts +# self.archival_memory_exists = archival_memory_exists +# +# def __repr__(self) -> str: +# return f"\n### CORE MEMORY ###" + f"\n=== Persona ===\n{self.persona}" + f"\n\n=== Human ===\n{self.human}" +# +# def to_dict(self): +# return { +# "persona": self.persona, +# "human": self.human, +# } +# +# @classmethod +# def load(cls, state): +# return cls(state["persona"], state["human"]) +# +# def edit_persona(self, new_persona): +# if self.persona_char_limit and len(new_persona) > self.persona_char_limit: +# error_msg = f"Edit failed: Exceeds {self.persona_char_limit} character limit (requested {len(new_persona)})." +# if self.archival_memory_exists: +# error_msg = f"{error_msg} Consider summarizing existing core memories in 'persona' and/or moving lower priority content to archival memory to free up space in core memory, then trying again." +# raise ValueError(error_msg) +# +# self.persona = new_persona +# return len(self.persona) +# +# def edit_human(self, new_human): +# if self.human_char_limit and len(new_human) > self.human_char_limit: +# error_msg = f"Edit failed: Exceeds {self.human_char_limit} character limit (requested {len(new_human)})." +# if self.archival_memory_exists: +# error_msg = f"{error_msg} Consider summarizing existing core memories in 'human' and/or moving lower priority content to archival memory to free up space in core memory, then trying again." +# raise ValueError(error_msg) +# +# self.human = new_human +# return len(self.human) +# +# def edit(self, field, content): +# if field == "persona": +# return self.edit_persona(content) +# elif field == "human": +# return self.edit_human(content) +# else: +# raise KeyError(f'No memory section named {field} (must be either "persona" or "human")') +# +# def edit_append(self, field, content, sep="\n"): +# if field == "persona": +# new_content = self.persona + sep + content +# return self.edit_persona(new_content) +# elif field == "human": +# new_content = self.human + sep + content +# return self.edit_human(new_content) +# else: +# raise KeyError(f'No memory section named {field} (must be either "persona" or "human")') +# +# def edit_replace(self, field, old_content, new_content): +# if len(old_content) == 0: +# raise ValueError("old_content cannot be an empty string (must specify old_content to replace)") +# +# if field == "persona": +# if old_content in self.persona: +# new_persona = self.persona.replace(old_content, new_content) +# return self.edit_persona(new_persona) +# else: +# raise ValueError("Content not found in persona (make sure to use exact string)") +# elif field == "human": +# if old_content in self.human: +# new_human = self.human.replace(old_content, new_content) +# return self.edit_human(new_human) +# else: +# raise ValueError("Content not found in human (make sure to use exact string)") +# else: +# raise KeyError(f'No memory section named {field} (must be either "persona" or "human")') def _format_summary_history(message_history: List[Message]): diff --git a/memgpt/metadata.py b/memgpt/metadata.py index 29491df7..25137a32 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -175,10 +175,7 @@ class AgentModel(Base): id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) user_id = Column(CommonUUID, nullable=False) name = Column(String, nullable=False) - persona = Column(String) - human = Column(String) system = Column(String) - preset = Column(String) created_at = Column(DateTime(timezone=True), server_default=func.now()) # configs @@ -187,6 +184,7 @@ class AgentModel(Base): # state state = Column(JSON) + _metadata = Column(JSON) # tools tools = Column(JSON) @@ -199,15 +197,13 @@ class AgentModel(Base): id=self.id, user_id=self.user_id, name=self.name, - persona=self.persona, - human=self.human, - preset=self.preset, created_at=self.created_at, llm_config=self.llm_config, embedding_config=self.embedding_config, state=self.state, tools=self.tools, system=self.system, + _metadata=self._metadata, ) @@ -739,17 +735,21 @@ class MetadataStore: @enforce_types def add_human(self, human: HumanModel): with self.session_maker() as session: + if self.get_human(human.name, human.user_id): + raise ValueError(f"Human with name {human.name} already exists for user_id {human.user_id}") session.add(human) session.commit() @enforce_types def add_persona(self, persona: PersonaModel): with self.session_maker() as session: + if self.get_persona(persona.name, persona.user_id): + raise ValueError(f"Persona with name {persona.name} already exists for user_id {persona.user_id}") session.add(persona) session.commit() @enforce_types - def add_preset(self, preset: PresetModel): + def add_preset(self, preset: PresetModel): # TODO: remove with self.session_maker() as session: session.add(preset) session.commit() diff --git a/memgpt/models/pydantic_models.py b/memgpt/models/pydantic_models.py index 2e3ec8b6..24ea2af6 100644 --- a/memgpt/models/pydantic_models.py +++ b/memgpt/models/pydantic_models.py @@ -98,9 +98,6 @@ class AgentStateModel(BaseModel): created_at: int = Field(..., description="The unix timestamp of when the agent was created.") # preset information - preset: str = Field(..., description="The preset used by the agent.") - persona: str = Field(..., description="The persona used by the agent.") - human: str = Field(..., description="The human used by the agent.") tools: List[str] = Field(..., description="The tools used by the agent.") system: str = Field(..., description="The system prompt used by the agent.") # functions_schema: List[Dict] = Field(..., description="The functions schema used by the agent.") @@ -111,6 +108,7 @@ class AgentStateModel(BaseModel): # agent state state: Optional[Dict] = Field(None, description="The state of the agent.") + metadata: Optional[Dict] = Field(None, description="The metadata of the agent.") class CoreMemory(BaseModel): diff --git a/memgpt/presets/presets.py b/memgpt/presets/presets.py index 126b2dd6..372195f1 100644 --- a/memgpt/presets/presets.py +++ b/memgpt/presets/presets.py @@ -2,23 +2,14 @@ import importlib import inspect import os import uuid -from typing import List -from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA from memgpt.data_types import AgentState, Preset -from memgpt.functions.functions import load_all_function_sets, load_function_set +from memgpt.functions.functions import load_function_set from memgpt.interface import AgentInterface from memgpt.metadata import MetadataStore from memgpt.models.pydantic_models import HumanModel, PersonaModel, ToolModel -from memgpt.presets.utils import load_all_presets, load_yaml_file -from memgpt.prompts import gpt_system -from memgpt.utils import ( - get_human_text, - get_persona_text, - list_human_files, - list_persona_files, - printd, -) +from memgpt.presets.utils import load_all_presets +from memgpt.utils import list_human_files, list_persona_files, printd available_presets = load_all_presets() preset_options = list(available_presets.keys()) @@ -88,145 +79,13 @@ def add_default_humans_and_personas(user_id: uuid.UUID, ms: MetadataStore): printd(f"Human '{name}' already exists for user '{user_id}'") continue human = HumanModel(name=name, text=text, user_id=user_id) + print(human, user_id) ms.add_human(human) -def create_preset_from_file(filename: str, name: str, user_id: uuid.UUID, ms: MetadataStore) -> Preset: - preset_config = load_yaml_file(filename) - preset_system_prompt = preset_config["system_prompt"] - preset_function_set_names = preset_config["functions"] - functions_schema = generate_functions_json(preset_function_set_names) - - if ms.get_preset(user_id=user_id, name=name) is not None: - printd(f"Preset '{name}' already exists for user '{user_id}'") - return ms.get_preset(user_id=user_id, name=name) - - preset = Preset( - user_id=user_id, - name=name, - system=gpt_system.get_system_text(preset_system_prompt), - persona=get_persona_text(DEFAULT_PERSONA), - human=get_human_text(DEFAULT_HUMAN), - persona_name=DEFAULT_PERSONA, - human_name=DEFAULT_HUMAN, - functions_schema=functions_schema, - ) - ms.create_preset(preset) - return preset - - -def load_preset(preset_name: str, user_id: uuid.UUID): - preset_config = available_presets[preset_name] - preset_system_prompt = preset_config["system_prompt"] - preset_function_set_names = preset_config["functions"] - functions_schema = generate_functions_json(preset_function_set_names) - - preset = Preset( - user_id=user_id, - name=preset_name, - system=gpt_system.get_system_text(preset_system_prompt), - persona=get_persona_text(DEFAULT_PERSONA), - persona_name=DEFAULT_PERSONA, - human=get_human_text(DEFAULT_HUMAN), - human_name=DEFAULT_HUMAN, - functions_schema=functions_schema, - ) - return preset - - -def add_default_presets(user_id: uuid.UUID, ms: MetadataStore): - """Add the default presets to the metadata store""" - # make sure humans/personas added - add_default_humans_and_personas(user_id=user_id, ms=ms) - - # make sure base functions added - # TODO: pull from functions instead - add_default_tools(user_id=user_id, ms=ms) - - # add default presets - for preset_name in preset_options: - if ms.get_preset(user_id=user_id, name=preset_name) is not None: - printd(f"Preset '{preset_name}' already exists for user '{user_id}'") - continue - - preset = load_preset(preset_name, user_id) - ms.create_preset(preset) - - -def generate_functions_json(preset_functions: List[str]): - """ - Generate JSON schema for the functions based on what is locally available. - - TODO: store function definitions in the DB, instead of locally - """ - # Available functions is a mapping from: - # function_name -> { - # json_schema: schema - # python_function: function - # } - available_functions = load_all_function_sets() - # Filter down the function set based on what the preset requested - preset_function_set = {} - for f_name in preset_functions: - if f_name not in available_functions: - raise ValueError(f"Function '{f_name}' was specified in preset, but is not in function library:\n{available_functions.keys()}") - preset_function_set[f_name] = available_functions[f_name] - assert len(preset_functions) == len(preset_function_set) - preset_function_set_schemas = [f_dict["json_schema"] for f_name, f_dict in preset_function_set.items()] - printd(f"Available functions:\n", list(preset_function_set.keys())) - return preset_function_set_schemas - - # def create_agent_from_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager): def create_agent_from_preset( agent_state: AgentState, preset: Preset, interface: AgentInterface, persona_is_file: bool = True, human_is_file: bool = True ): """Initialize a new agent from a preset (combination of system + function)""" raise DeprecationWarning("Function no longer supported - pass a Preset object to Agent.__init__ instead") - - # Input validation - if agent_state.persona is None: - raise ValueError(f"'persona' not specified in AgentState (required)") - if agent_state.human is None: - raise ValueError(f"'human' not specified in AgentState (required)") - if agent_state.preset is None: - raise ValueError(f"'preset' not specified in AgentState (required)") - if not (agent_state.state == {} or agent_state.state is None): - raise ValueError(f"'state' must be uninitialized (empty)") - - assert preset is not None, "preset cannot be none" - preset_name = agent_state.preset - assert preset_name == preset.name, f"AgentState preset '{preset_name}' does not match preset name '{preset.name}'" - persona = agent_state.persona - human = agent_state.human - model = agent_state.llm_config.model - - from memgpt.agent import Agent - - # available_presets = load_all_presets() - # if preset_name not in available_presets: - # raise ValueError(f"Preset '{preset_name}.yaml' not found") - # preset = available_presets[preset_name] - # preset_system_prompt = preset["system_prompt"] - # preset_function_set_names = preset["functions"] - # preset_function_set_schemas = generate_functions_json(preset_function_set_names) - # Override the following in the AgentState: - # persona: str # the current persona text - # human: str # the current human text - # system: str, # system prompt (not required if initializing with a preset) - # functions: dict, # schema definitions ONLY (function code linked at runtime) - # messages: List[dict], # in-context messages - agent_state.state = { - "persona": get_persona_text(persona) if persona_is_file else persona, - "human": get_human_text(human) if human_is_file else human, - "system": preset.system, - "functions": preset.functions_schema, - "messages": None, - } - - return Agent( - agent_state=agent_state, - interface=interface, - # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now - first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False, - ) diff --git a/memgpt/server/rest_api/admin/users.py b/memgpt/server/rest_api/admin/users.py index c14d858a..5659a8d4 100644 --- a/memgpt/server/rest_api/admin/users.py +++ b/memgpt/server/rest_api/admin/users.py @@ -89,9 +89,6 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface): try: server.ms.create_user(new_user) - # initialize default presets automatically for user - server.initialize_default_presets(new_user.id) - # make sure we can retrieve the user from the DB too new_user_ret = server.ms.get_user(new_user.id) if new_user_ret is None: diff --git a/memgpt/server/rest_api/agents/config.py b/memgpt/server/rest_api/agents/config.py index 2050404b..12332b93 100644 --- a/memgpt/server/rest_api/agents/config.py +++ b/memgpt/server/rest_api/agents/config.py @@ -79,15 +79,13 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, id=agent_state.id, name=agent_state.name, user_id=agent_state.user_id, - preset=agent_state.preset, - persona=agent_state.persona, - human=agent_state.human, llm_config=llm_config, embedding_config=embedding_config, state=agent_state.state, created_at=int(agent_state.created_at.timestamp()), tools=agent_state.tools, system=agent_state.system, + metadata=agent_state._metadata, ), last_run_at=None, # TODO sources=attached_sources, @@ -125,9 +123,6 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, id=agent_state.id, name=agent_state.name, user_id=agent_state.user_id, - preset=agent_state.preset, - persona=agent_state.persona, - human=agent_state.human, llm_config=llm_config, embedding_config=embedding_config, state=agent_state.state, diff --git a/memgpt/server/rest_api/agents/index.py b/memgpt/server/rest_api/agents/index.py index 96306843..73b11e89 100644 --- a/memgpt/server/rest_api/agents/index.py +++ b/memgpt/server/rest_api/agents/index.py @@ -6,6 +6,7 @@ from fastapi import APIRouter, Body, Depends, HTTPException from pydantic import BaseModel, Field from memgpt.constants import BASE_TOOLS +from memgpt.memory import ChatMemory from memgpt.models.pydantic_models import ( AgentStateModel, EmbeddingConfigModel, @@ -69,8 +70,11 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p human = request.config["human"] if "human" in request.config else None persona_name = request.config["persona_name"] if "persona_name" in request.config else None persona = request.config["persona"] if "persona" in request.config else None - preset = request.config["preset"] if ("preset" in request.config and request.config["preset"]) else settings.default_preset + request.config["preset"] if ("preset" in request.config and request.config["preset"]) else settings.default_preset tool_names = request.config["function_names"] + metadata = request.config["metadata"] if "metadata" in request.config else {} + metadata["human"] = human_name + metadata["persona"] = persona_name # TODO: remove this -- should be added based on create agent fields if isinstance(tool_names, str): # TODO: fix this on clinet side? @@ -82,56 +86,54 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p tool_names.append(name) assert isinstance(tool_names, list), "Tool names must be a list of strings." + # TODO: eventually remove this - should support general memory at the REST endpoint + memory = ChatMemory(persona=persona, human=human) + try: agent_state = server.create_agent( user_id=user_id, # **request.config # TODO turn into a pydantic model name=request.config["name"], - preset=preset, - persona_name=persona_name, - human_name=human_name, - persona=persona, - human=human, + memory=memory, + # persona_name=persona_name, + # human_name=human_name, + # persona=persona, + # human=human, # llm_config=LLMConfigModel( # model=request.config['model'], # ) # tools tools=tool_names, + metadata=metadata, # function_names=request.config["function_names"].split(",") if "function_names" in request.config else None, ) llm_config = LLMConfigModel(**vars(agent_state.llm_config)) embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config)) - # TODO when get_preset returns a PresetModel instead of Preset, we can remove this packing/unpacking line - # TODO: remove - preset = server.ms.get_preset(name=agent_state.preset, user_id=user_id) - return CreateAgentResponse( agent_state=AgentStateModel( id=agent_state.id, name=agent_state.name, user_id=agent_state.user_id, - preset=agent_state.preset, - persona=agent_state.persona, - human=agent_state.human, llm_config=llm_config, embedding_config=embedding_config, state=agent_state.state, created_at=int(agent_state.created_at.timestamp()), tools=tool_names, system=agent_state.system, + metadata=agent_state._metadata, ), - preset=PresetModel( - name=preset.name, - id=preset.id, - user_id=preset.user_id, - description=preset.description, - created_at=preset.created_at, - system=preset.system, - persona=preset.persona, - human=preset.human, - functions_schema=preset.functions_schema, + preset=PresetModel( # TODO: remove (placeholder to avoid breaking frontend) + name="dummy_preset", + id=agent_state.id, + user_id=agent_state.user_id, + description="", + created_at=agent_state.created_at, + system=agent_state.system, + persona="", + human="", + functions_schema=[], ), ) except Exception as e: diff --git a/memgpt/server/rest_api/humans/index.py b/memgpt/server/rest_api/humans/index.py index a7f012a9..2b1f7c0c 100644 --- a/memgpt/server/rest_api/humans/index.py +++ b/memgpt/server/rest_api/humans/index.py @@ -2,7 +2,7 @@ import uuid from functools import partial from typing import List -from fastapi import APIRouter, Body, Depends +from fastapi import APIRouter, Body, Depends, HTTPException from pydantic import BaseModel, Field from memgpt.models.pydantic_models import HumanModel @@ -46,4 +46,24 @@ def setup_humans_index_router(server: SyncServer, interface: QueuingInterface, p server.ms.add_human(new_human) return HumanModel(id=human_id, text=request.text, name=request.name, user_id=user_id) + @router.delete("/humans/{human_name}", tags=["humans"], response_model=HumanModel) + async def delete_human( + human_name: str, + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + interface.clear() + human = server.ms.delete_human(human_name, user_id=user_id) + return human + + @router.get("/humans/{human_name}", tags=["humans"], response_model=HumanModel) + async def get_human( + human_name: str, + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + interface.clear() + human = server.ms.get_human(human_name, user_id=user_id) + if human is None: + raise HTTPException(status_code=404, detail="Human not found") + return human + return router diff --git a/memgpt/server/rest_api/personas/index.py b/memgpt/server/rest_api/personas/index.py index b8f2503c..14c82fdc 100644 --- a/memgpt/server/rest_api/personas/index.py +++ b/memgpt/server/rest_api/personas/index.py @@ -2,7 +2,7 @@ import uuid from functools import partial from typing import List -from fastapi import APIRouter, Body, Depends +from fastapi import APIRouter, Body, Depends, HTTPException from pydantic import BaseModel, Field from memgpt.models.pydantic_models import PersonaModel @@ -47,4 +47,24 @@ def setup_personas_index_router(server: SyncServer, interface: QueuingInterface, server.ms.add_persona(new_persona) return PersonaModel(id=persona_id, text=request.text, name=request.name, user_id=user_id) + @router.delete("/personas/{persona_name}", tags=["personas"], response_model=PersonaModel) + async def delete_persona( + persona_name: str, + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + interface.clear() + persona = server.ms.delete_persona(persona_name, user_id=user_id) + return persona + + @router.get("/personas/{persona_name}", tags=["personas"], response_model=PersonaModel) + async def get_persona( + persona_name: str, + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + interface.clear() + persona = server.ms.get_persona(persona_name, user_id=user_id) + if persona is None: + raise HTTPException(status_code=404, detail="Persona not found") + return persona + return router diff --git a/memgpt/server/rest_api/tools/index.py b/memgpt/server/rest_api/tools/index.py index b564dc72..05525eab 100644 --- a/memgpt/server/rest_api/tools/index.py +++ b/memgpt/server/rest_api/tools/index.py @@ -22,6 +22,7 @@ class CreateToolRequest(BaseModel): source_code: str = Field(..., description="The source code of the function.") source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.") tags: Optional[List[str]] = Field(None, description="Metadata tags.") + update: Optional[bool] = Field(False, description="Update the tool if it already exists.") class CreateToolResponse(BaseModel): @@ -87,8 +88,10 @@ def setup_user_tools_index_router(server: SyncServer, interface: QueuingInterfac source_type=request.source_type, tags=request.tags, user_id=user_id, + exists_ok=request.update, ) except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}") + print(e) + raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}, exists_ok={request.update}") return router diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 49a6c1b2..80be4230 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -15,8 +15,6 @@ import memgpt.server.utils as server_utils import memgpt.system as system from memgpt.agent import Agent, save_agent from memgpt.agent_store.storage import StorageConnector, TableType - -# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin from memgpt.cli.cli_config import get_model_options from memgpt.config import MemGPTConfig from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT @@ -37,6 +35,7 @@ from memgpt.data_types import ( from memgpt.interface import AgentInterface # abstract from memgpt.interface import CLIInterface # for printing to terminal from memgpt.log import get_logger +from memgpt.memory import BaseMemory from memgpt.metadata import MetadataStore from memgpt.models.chat_completion_response import UsageStatistics from memgpt.models.pydantic_models import ( @@ -44,10 +43,14 @@ from memgpt.models.pydantic_models import ( HumanModel, MemGPTUsageStatistics, PassageModel, + PersonaModel, PresetModel, SourceModel, ToolModel, ) + +# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin +from memgpt.prompts import gpt_system from memgpt.utils import create_random_username logger = get_logger(__name__) @@ -212,7 +215,7 @@ class SyncServer(LockingServer): # Initialize the connection to the DB self.config = MemGPTConfig.load() - print(f"server :: loading configuration from '{self.config.config_path}'") + logger.info(f"loading configuration from '{self.config.config_path}'") assert self.config.persona is not None, "Persona must be set in the config" assert self.config.human is not None, "Human must be set in the config" @@ -274,14 +277,9 @@ class SyncServer(LockingServer): self.ms.update_user(user) else: self.ms.create_user(user) - presets.add_default_presets(user_id, self.ms) - # NOTE: removed, since server should be multi-user - ## Create the default user - # base_user_id = uuid.UUID(self.config.anon_clientid) - # if not self.ms.get_user(user_id=base_user_id): - # base_user = User(id=base_user_id) - # self.ms.create_user(base_user) + # add global default tools + presets.add_default_tools(None, self.ms) def save_agents(self): """Saves all the agents that are in the in-memory object store""" @@ -336,7 +334,14 @@ class SyncServer(LockingServer): # Instantiate an agent object using the state retrieved logger.info(f"Creating an agent object") - tool_objs = [self.ms.get_tool(name, user_id) for name in agent_state.tools] # get tool objects + tool_objs = [] + for name in agent_state.tools: + tool_obj = self.ms.get_tool(name, user_id) + if not tool_obj: + logger.exception(f"Tool {name} does not exist for user {user_id}") + raise ValueError(f"Tool {name} does not exist for user {user_id}") + tool_objs.append(tool_obj) + memgpt_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs) # Add the agent to the in-memory store and return its reference @@ -350,11 +355,11 @@ class SyncServer(LockingServer): def _get_or_load_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> Agent: """Check if the agent is in-memory, then load""" - logger.info(f"Checking for agent user_id={user_id} agent_id={agent_id}") + logger.debug(f"Checking for agent user_id={user_id} agent_id={agent_id}") # TODO: consider disabling loading cached agents due to potential concurrency issues memgpt_agent = self._get_agent(user_id=user_id, agent_id=agent_id) if not memgpt_agent: - logger.info(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}") + logger.debug(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}") memgpt_agent = self._load_agent(user_id=user_id, agent_id=agent_id) return memgpt_agent @@ -426,7 +431,6 @@ class SyncServer(LockingServer): # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) - # print("AGENT", memgpt_agent.agent_state.id, memgpt_agent.agent_state.user_id) if command.lower() == "exit": # exit not supported on server.py @@ -657,49 +661,58 @@ class SyncServer(LockingServer): def create_user( self, user_config: Optional[Union[dict, User]] = {}, + exists_ok: bool = False, ): """Create a new user using a config""" if not isinstance(user_config, dict): raise ValueError(f"user_config must be provided as a dictionary") + if "id" in user_config: + existing_user = self.ms.get_user(user_id=user_config["id"]) + if existing_user: + if exists_ok: + presets.add_default_humans_and_personas(existing_user.id, self.ms) + return existing_user + else: + raise ValueError(f"User with ID {existing_user.id} already exists") + user = User( id=user_config["id"] if "id" in user_config else None, ) self.ms.create_user(user) logger.info(f"Created new user from config: {user}") + + # add default for the user + presets.add_default_humans_and_personas(user.id, self.ms) + return user def create_agent( self, user_id: uuid.UUID, tools: List[str], # list of tool names (handles) to include - # system: str, # system prompt + memory: BaseMemory, + system: Optional[str] = None, metadata: Optional[dict] = {}, # includes human/persona names name: Optional[str] = None, - preset: Optional[str] = None, # TODO: remove eventually # model config llm_config: Optional[LLMConfig] = None, embedding_config: Optional[EmbeddingConfig] = None, # interface interface: Union[AgentInterface, None] = None, - # TODO: refactor this to be a more general memory configuration - system: Optional[str] = None, # prompt value - persona: Optional[str] = None, # NOTE: this is not the name, it's the memory init value - human: Optional[str] = None, # NOTE: this is not the name, it's the memory init value - persona_name: Optional[str] = None, # TODO: remove - human_name: Optional[str] = None, # TODO: remove ) -> AgentState: """Create a new agent using a config""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if interface is None: - # interface = self.default_interface interface = self.default_interface_factory() - # if persistence_manager is None: - # persistence_manager = self.default_persistence_manager_cls(agent_config=agent_config) + # system prompt (get default if None) + if system is None: + system = gpt_system.get_system_text(self.config.preset) + # create agent name if name is None: name = create_random_username() @@ -708,90 +721,34 @@ class SyncServer(LockingServer): if not user: raise ValueError(f"cannot find user with associated client id: {user_id}") - # NOTE: you MUST add to the metadata store before creating the agent, otherwise the storage connectors will error on creation - # TODO: fix this db dependency and remove - # self.ms.#create_agent(agent_state) - - # TODO modify to do creation via preset try: - preset_obj = self.ms.get_preset(name=preset if preset else self.config.preset, user_id=user_id) - preset_override = False - assert preset_obj is not None, f"preset {preset if preset else self.config.preset} does not exist" - logger.debug(f"Attempting to create agent from preset:\n{preset_obj}") - - # system prompt - if system is None: - system = preset_obj.system - else: - preset_obj.system = system - preset_override = True - - # Overwrite fields in the preset if they were specified - if human is not None and human != preset_obj.human: - preset_override = True - preset_obj.human = human - # This is a check for a common bug where users were providing filenames instead of values - # try: - # get_human_text(human) - # raise ValueError(human) - # raise UserWarning( - # f"It looks like there is a human file named {human} - did you mean to pass the file contents to the `human` arg?" - # ) - # except: - # pass - if persona is not None: - preset_override = True - preset_obj.persona = persona - # try: - # get_persona_text(persona) - # raise ValueError(persona) - # raise UserWarning( - # f"It looks like there is a persona file named {persona} - did you mean to pass the file contents to the `persona` arg?" - # ) - # except: - # pass - if human_name is not None and human_name != preset_obj.human_name: - preset_override = True - preset_obj.human_name = human_name - if persona_name is not None and persona_name != preset_obj.persona_name: - preset_override = True - preset_obj.persona_name = persona_name - + # model configuration llm_config = llm_config if llm_config else self.server_llm_config embedding_config = embedding_config if embedding_config else self.server_embedding_config - # get tools + # get tools + make sure they exist tool_objs = [] for tool_name in tools: tool_obj = self.ms.get_tool(tool_name, user_id=user_id) - assert tool_obj is not None, f"Tool {tool_name} does not exist" + assert tool_obj, f"Tool {tool_name} does not exist" tool_objs.append(tool_obj) - # If the user overrode any parts of the preset, we need to create a new preset to refer back to - if preset_override: - # Change the name and uuid - preset_obj = Preset.clone(preset_obj=preset_obj) - # Then write out to the database for storage - self.ms.create_preset(preset=preset_obj) - + # TODO: add metadata agent_state = AgentState( name=name, user_id=user_id, - persona=preset_obj.persona_name, # TODO: remove - human=preset_obj.human_name, # TODO: remove tools=tools, # name=id for tools llm_config=llm_config, embedding_config=embedding_config, system=system, - preset=preset, # TODO: remove - state={"persona": preset_obj.persona, "human": preset_obj.human, "system": system, "messages": None}, + state={"system": system, "messages": None, "memory": memory.to_dict()}, + _metadata=metadata, ) agent = Agent( interface=interface, agent_state=agent_state, tools=tool_objs, - # embedding_config=embedding_config, # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False, ) @@ -805,10 +762,11 @@ class SyncServer(LockingServer): logger.exception(f"Failed to delete_agent:\n{delete_e}") raise e + # save agent save_agent(agent, self.ms) - logger.info(f"Created new agent from config: {agent}") + # return AgentState return agent.agent_state def delete_agent( @@ -872,8 +830,8 @@ class SyncServer(LockingServer): agent_config = { "id": agent_state.id, "name": agent_state.name, - "human": agent_state.human, - "persona": agent_state.persona, + "human": agent_state._metadata.get("human", None), + "persona": agent_state._metadata.get("persona", None), "created_at": agent_state.created_at.isoformat(), } return agent_config @@ -905,8 +863,8 @@ class SyncServer(LockingServer): # TODO hack for frontend, remove # (top level .persona is persona_name, and nested memory.persona is the state) # TODO: eventually modify this to be contained in the metadata - return_dict["persona"] = agent_state.human - return_dict["human"] = agent_state.persona + return_dict["persona"] = agent_state._metadata.get("persona", None) + return_dict["human"] = agent_state._metadata.get("human", None) # Add information about tools # TODO memgpt_agent should really have a field of List[ToolModel] @@ -918,10 +876,7 @@ class SyncServer(LockingServer): recall_memory = memgpt_agent.persistence_manager.recall_memory archival_memory = memgpt_agent.persistence_manager.archival_memory memory_obj = { - "core_memory": { - "persona": core_memory.persona, - "human": core_memory.human, - }, + "core_memory": {section: module.value for (section, module) in core_memory.memory.items()}, "recall_memory": len(recall_memory) if recall_memory is not None else None, "archival_memory": len(archival_memory) if archival_memory is not None else None, } @@ -941,12 +896,31 @@ class SyncServer(LockingServer): # Sort agents by "last_run" in descending order, most recent first agents_states_dicts.sort(key=lambda x: x["last_run"], reverse=True) - logger.info(f"Retrieved {len(agents_states)} agents for user {user_id}:\n{[vars(s) for s in agents_states]}") + logger.debug(f"Retrieved {len(agents_states)} agents for user {user_id}") return { "num_agents": len(agents_states), "agents": agents_states_dicts, } + def list_personas(self, user_id: uuid.UUID): + return self.ms.list_personas(user_id=user_id) + + def get_persona(self, name: str, user_id: uuid.UUID): + return self.ms.get_persona(name=name, user_id=user_id) + + def add_persona(self, persona: PersonaModel): + name = persona.name + user_id = persona.user_id + self.ms.add_persona(persona=persona) + persona = self.ms.get_persona(name=name, user_id=user_id) + return persona + + def update_persona(self, persona: PersonaModel): + return self.ms.update_persona(persona=persona) + + def delete_persona(self, name: str, user_id: uuid.UUID): + return self.ms.delete_persona(name=name, user_id=user_id) + def list_humans(self, user_id: uuid.UUID): return self.ms.list_humans(user_id=user_id) @@ -954,7 +928,11 @@ class SyncServer(LockingServer): return self.ms.get_human(name=name, user_id=user_id) def add_human(self, human: HumanModel): - return self.ms.add_human(human=human) + name = human.name + user_id = human.user_id + self.ms.add_human(human=human) + human = self.ms.get_human(name=name, user_id=user_id) + return human def update_human(self, human: HumanModel): return self.ms.update_human(human=human) @@ -983,11 +961,9 @@ class SyncServer(LockingServer): recall_memory = memgpt_agent.persistence_manager.recall_memory archival_memory = memgpt_agent.persistence_manager.archival_memory + # NOTE memory_obj = { - "core_memory": { - "persona": core_memory.persona, - "human": core_memory.human, - }, + "core_memory": {key: value.value for key, value in core_memory.memory.items()}, "recall_memory": len(recall_memory) if recall_memory is not None else None, "archival_memory": len(archival_memory) if archival_memory is not None else None, } @@ -1245,23 +1221,18 @@ class SyncServer(LockingServer): new_core_memory = old_core_memory.copy() modified = False - if "persona" in new_memory_contents and new_memory_contents["persona"] is not None: - new_persona = new_memory_contents["persona"] - if old_core_memory["persona"] != new_persona: - new_core_memory["persona"] = new_persona - memgpt_agent.memory.edit_persona(new_persona) - modified = True - - if "human" in new_memory_contents and new_memory_contents["human"] is not None: - new_human = new_memory_contents["human"] - if old_core_memory["human"] != new_human: - new_core_memory["human"] = new_human - memgpt_agent.memory.edit_human(new_human) + for key, value in new_memory_contents.items(): + if value is None: + continue + if key in old_core_memory and old_core_memory[key] != value: + memgpt_agent.memory.memory[key].value = value # update agent memory modified = True # If we modified the memory contents, we need to rebuild the memory block inside the system message if modified: memgpt_agent.rebuild_memory() + # save agent + save_agent(memgpt_agent, self.ms) return { "old_core_memory": old_core_memory, @@ -1290,6 +1261,7 @@ class SyncServer(LockingServer): logger.exception(f"Failed to update agent name with:\n{str(e)}") raise ValueError(f"Failed to update agent name in database") + assert isinstance(memgpt_agent.agent_state.id, uuid.UUID) return memgpt_agent.agent_state def delete_user(self, user_id: uuid.UUID): @@ -1504,7 +1476,7 @@ class SyncServer(LockingServer): tool (ToolModel): Tool object """ name = json_schema["name"] - tool = self.ms.get_tool(name) + tool = self.ms.get_tool(name, user_id=user_id) if tool: # check if function already exists if exists_ok: # update existing tool @@ -1514,7 +1486,7 @@ class SyncServer(LockingServer): tool.source_type = source_type self.ms.update_tool(tool) else: - raise ValueError(f"Tool with name {name} already exists.") + raise ValueError(f"[server] Tool with name {name} already exists.") else: # create new tool tool = ToolModel( diff --git a/tests/test_agent_function_update.py b/tests/test_agent_function_update.py deleted file mode 100644 index b9ab60eb..00000000 --- a/tests/test_agent_function_update.py +++ /dev/null @@ -1,121 +0,0 @@ -import inspect -import json -import os -import uuid - -import pytest - -from memgpt import constants, create_client -from memgpt.functions.functions import USER_FUNCTIONS_DIR -from memgpt.models import chat_completion_response -from memgpt.utils import assistant_function_to_tool -from tests import TEST_MEMGPT_CONFIG -from tests.utils import create_config, wipe_config - - -def hello_world(self) -> str: - """Test function for agent to gain access to - - Returns: - str: A message for the world - """ - return "hello, world!" - - -@pytest.fixture(scope="module") -def agent(): - """Create a test agent that we can call functions on""" - wipe_config() - global client - if os.getenv("OPENAI_API_KEY"): - create_config("openai") - else: - create_config("memgpt_hosted") - - # create memgpt client - client = create_client() - - # ensure user exists - user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid) - if not client.server.get_user(user_id=user_id): - client.server.create_user({"id": user_id}) - - agent_state = client.create_agent() - - return client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id) - - -@pytest.fixture(scope="module") -def hello_world_function(): - with open(os.path.join(USER_FUNCTIONS_DIR, "hello_world.py"), "w", encoding="utf-8") as f: - f.write(inspect.getsource(hello_world)) - - -@pytest.fixture(scope="module") -def ai_function_call(): - return chat_completion_response.Message( - **assistant_function_to_tool( - { - "role": "assistant", - "content": "I will now call hello world", - "function_call": { - "name": "hello_world", - "arguments": json.dumps({}), - }, - } - ) - ) - - return - - -def test_add_function_happy(agent, hello_world_function, ai_function_call): - agent.add_function("hello_world") - - assert "hello_world" in [f_schema["name"] for f_schema in agent.functions] - assert "hello_world" in agent.functions_python.keys() - - msgs, heartbeat_req, function_failed = agent._handle_ai_response(ai_function_call) - content = json.loads(msgs[-1].to_openai_dict()["content"], strict=constants.JSON_LOADS_STRICT) - assert content["message"] == "hello, world!" - assert content["status"] == "OK" - assert not function_failed - - -def test_add_function_already_loaded(agent, hello_world_function): - agent.add_function("hello_world") - # no exception for duplicate loading - agent.add_function("hello_world") - - -def test_add_function_not_exist(agent): - # pytest assert exception - with pytest.raises(ValueError): - agent.add_function("non_existent") - - -def test_remove_function_happy(agent, hello_world_function): - agent.add_function("hello_world") - - # ensure function is loaded - assert "hello_world" in [f_schema["name"] for f_schema in agent.functions] - assert "hello_world" in agent.functions_python.keys() - - agent.remove_function("hello_world") - - assert "hello_world" not in [f_schema["name"] for f_schema in agent.functions] - assert "hello_world" not in agent.functions_python.keys() - - -def test_remove_function_not_exist(agent): - # do not raise error - agent.remove_function("non_existent") - - -def test_remove_base_function_fails(agent): - with pytest.raises(ValueError): - agent.remove_function("send_message") - - -if __name__ == "__main__": - pytest.main(["-vv", os.path.abspath(__file__)]) diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index e5c49f6b..a1d8143b 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -1,11 +1,9 @@ import os -import uuid import pytest import memgpt.functions.function_sets.base as base_functions from memgpt import create_client -from tests import TEST_MEMGPT_CONFIG from .utils import create_config, wipe_config @@ -28,8 +26,7 @@ def agent_obj(): agent_state = client.create_agent() global agent_obj - user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid) - agent_obj = client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id) + agent_obj = client.server._get_or_load_agent(user_id=client.user_id, agent_id=agent_state.id) yield agent_obj client.delete_agent(agent_obj.agent_state.id) diff --git a/tests/test_client.py b/tests/test_client.py index 68b4c3b3..9e855e62 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -68,10 +68,7 @@ def run_server(): # Fixture to create clients with different configurations @pytest.fixture( - params=[ # whether to use REST API server - {"server": True}, - # {"server": False} # TODO: add when implemented - ], + params=[{"server": True}, {"server": False}], # whether to use REST API server scope="module", ) def client(request): @@ -90,8 +87,8 @@ def client(request): # create user via admin client admin = Admin(server_url, test_server_token) response = admin.create_user(test_user_id) # Adjust as per your client's method - response.user_id token = response.api_key + else: # use local client (no server) token = None @@ -109,7 +106,7 @@ def client(request): # Fixture for test agent @pytest.fixture(scope="module") def agent(client): - agent_state = client.create_agent(name=test_agent_name, preset=test_preset_name) + agent_state = client.create_agent(name=test_agent_name) print("AGENT ID", agent_state.id) yield agent_state @@ -123,11 +120,11 @@ def test_agent(client, agent): # test client.rename_agent new_name = "RenamedTestAgent" client.rename_agent(agent_id=agent.id, new_name=new_name) - renamed_agent = client.get_agent(agent_id=str(agent.id)) + renamed_agent = client.get_agent(agent_id=agent.id) assert renamed_agent.name == new_name, "Agent renaming failed" # test client.delete_agent and client.agent_exists - delete_agent = client.create_agent(name="DeleteTestAgent", preset=test_preset_name) + delete_agent = client.create_agent(name="DeleteTestAgent") assert client.agent_exists(agent_id=delete_agent.id), "Agent creation failed" client.delete_agent(agent_id=delete_agent.id) assert client.agent_exists(agent_id=delete_agent.id) == False, "Agent deletion failed" @@ -140,7 +137,7 @@ def test_memory(client, agent): print("MEMORY", memory_response) updated_memory = {"human": "Updated human memory", "persona": "Updated persona memory"} - client.update_agent_core_memory(agent_id=str(agent.id), new_memory_contents=updated_memory) + client.update_agent_core_memory(agent_id=agent.id, new_memory_contents=updated_memory) updated_memory_response = client.get_agent_memory(agent_id=agent.id) assert ( updated_memory_response.core_memory.human == updated_memory["human"] @@ -152,10 +149,10 @@ def test_agent_interactions(client, agent): _reset_config() message = "Hello, agent!" - message_response = client.user_message(agent_id=str(agent.id), message=message) + message_response = client.user_message(agent_id=agent.id, message=message) command = "/memory" - command_response = client.run_command(agent_id=str(agent.id), command=command) + command_response = client.run_command(agent_id=agent.id, command=command) print("command", command_response) @@ -197,11 +194,15 @@ def test_humans_personas(client, agent): print("PERSONAS", personas_response) persona_name = "TestPersona" + if client.get_persona(persona_name): + client.delete_persona(persona_name) persona = client.create_persona(name=persona_name, persona="Persona text") assert persona.name == persona_name assert persona.text == "Persona text", "Creating persona failed" human_name = "TestHuman" + if client.get_human(human_name): + client.delete_human(human_name) human = client.create_human(name=human_name, human="Human text") assert human.name == human_name assert human.text == "Human text", "Creating human failed" @@ -285,55 +286,55 @@ def test_sources(client, agent): client.delete_source(source.id) -def test_presets(client, agent): - _reset_config() - - # new_preset = Preset( - # # user_id=client.user_id, - # name="pytest_test_preset", - # description="DUMMY_DESCRIPTION", - # system="DUMMY_SYSTEM", - # persona="DUMMY_PERSONA", - # persona_name="DUMMY_PERSONA_NAME", - # human="DUMMY_HUMAN", - # human_name="DUMMY_HUMAN_NAME", - # functions_schema=[ - # { - # "name": "send_message", - # "json_schema": { - # "name": "send_message", - # "description": "Sends a message to the human user.", - # "parameters": { - # "type": "object", - # "properties": { - # "message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."} - # }, - # "required": ["message"], - # }, - # }, - # "tags": ["memgpt-base"], - # "source_type": "python", - # "source_code": 'def send_message(self, message: str) -> Optional[str]:\n """\n Sends a message to the human user.\n\n Args:\n message (str): Message contents. All unicode (including emojis) are supported.\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.\n """\n self.interface.assistant_message(message)\n return None\n', - # } - # ], - # ) - - ## List all presets and make sure the preset is NOT in the list - # all_presets = client.list_presets() - # assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets) - # Create a preset - new_preset = client.create_preset(name="pytest_test_preset") - - # List all presets and make sure the preset is in the list - all_presets = client.list_presets() - assert new_preset.id in [p.id for p in all_presets], (new_preset, all_presets) - - # Delete the preset - client.delete_preset(preset_id=new_preset.id) - - # List all presets and make sure the preset is NOT in the list - all_presets = client.list_presets() - assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets) +# def test_presets(client, agent): +# _reset_config() +# +# # new_preset = Preset( +# # # user_id=client.user_id, +# # name="pytest_test_preset", +# # description="DUMMY_DESCRIPTION", +# # system="DUMMY_SYSTEM", +# # persona="DUMMY_PERSONA", +# # persona_name="DUMMY_PERSONA_NAME", +# # human="DUMMY_HUMAN", +# # human_name="DUMMY_HUMAN_NAME", +# # functions_schema=[ +# # { +# # "name": "send_message", +# # "json_schema": { +# # "name": "send_message", +# # "description": "Sends a message to the human user.", +# # "parameters": { +# # "type": "object", +# # "properties": { +# # "message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."} +# # }, +# # "required": ["message"], +# # }, +# # }, +# # "tags": ["memgpt-base"], +# # "source_type": "python", +# # "source_code": 'def send_message(self, message: str) -> Optional[str]:\n """\n Sends a message to the human user.\n\n Args:\n message (str): Message contents. All unicode (including emojis) are supported.\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.\n """\n self.interface.assistant_message(message)\n return None\n', +# # } +# # ], +# # ) +# +# ## List all presets and make sure the preset is NOT in the list +# # all_presets = client.list_presets() +# # assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets) +# # Create a preset +# new_preset = client.create_preset(name="pytest_test_preset") +# +# # List all presets and make sure the preset is in the list +# all_presets = client.list_presets() +# assert new_preset.id in [p.id for p in all_presets], (new_preset, all_presets) +# +# # Delete the preset +# client.delete_preset(preset_id=new_preset.id) +# +# # List all presets and make sure the preset is NOT in the list +# all_presets = client.list_presets() +# assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets) # def test_tools(client, agent): diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index d60ed69a..f39e059c 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -29,14 +29,11 @@ def run_llm_endpoint(filename): agent_state = AgentState( name="test_agent", tools=[tool.name for tool in load_module_tools()], - system="", - persona="", - human="", - preset="memgpt_chat", embedding_config=embedding_config, llm_config=llm_config, user_id=uuid.UUID(int=1), - state={"persona": "", "human": "", "messages": None}, + state={"persona": "", "human": "", "messages": None, "memory": {}}, + system="", ) agent = Agent( interface=None, diff --git a/tests/test_memory.py b/tests/test_memory.py new file mode 100644 index 00000000..6a6065cd --- /dev/null +++ b/tests/test_memory.py @@ -0,0 +1,65 @@ +import pytest + +# Import the classes here, assuming the above definitions are in a module named memory_module +from memgpt.memory import BaseMemory, ChatMemory + + +@pytest.fixture +def sample_memory(): + return ChatMemory(persona="Chat Agent", human="User") + + +def test_create_chat_memory(): + """Test creating an instance of ChatMemory""" + chat_memory = ChatMemory(persona="Chat Agent", human="User") + assert chat_memory.memory["persona"].value == "Chat Agent" + assert chat_memory.memory["human"].value == "User" + + +def test_dump_memory_as_json(sample_memory): + """Test dumping ChatMemory as JSON compatible dictionary""" + memory_dict = sample_memory.to_dict() + assert isinstance(memory_dict, dict) + assert "persona" in memory_dict + assert memory_dict["persona"]["value"] == "Chat Agent" + + +def test_load_memory_from_json(sample_memory): + """Test loading ChatMemory from a JSON compatible dictionary""" + memory_dict = sample_memory.to_dict() + print(memory_dict) + new_memory = BaseMemory.load(memory_dict) + assert new_memory.memory["persona"].value == "Chat Agent" + assert new_memory.memory["human"].value == "User" + + +# def test_memory_functionality(sample_memory): +# """Test memory modification functions""" +# # Get memory functions +# functions = get_memory_functions(ChatMemory) +# # Test core_memory_append function +# append_func = functions['core_memory_append'] +# print("FUNCTIONS", functions) +# env = {} +# env.update(globals()) +# for tool in functions: +# # WARNING: name may not be consistent? +# exec(tool.source_code, env) +# +# print(exec) +# +# append_func(sample_memory, 'persona', " is a test.") +# assert sample_memory.memory['persona'].value == "Chat Agent\n is a test." +# # Test core_memory_replace function +# replace_func = functions['core_memory_replace'] +# replace_func(sample_memory, 'persona', " is a test.", " was a test.") +# assert sample_memory.memory['persona'].value == "Chat Agent\n was a test." + + +def test_memory_limit_validation(sample_memory): + """Test exceeding memory limit""" + with pytest.raises(ValueError): + ChatMemory(persona="x" * 3000, human="y" * 3000) + + with pytest.raises(ValueError): + sample_memory.memory["persona"].value = "x" * 3000 diff --git a/tests/test_metadata_store.py b/tests/test_metadata_store.py deleted file mode 100644 index 3ccf6c91..00000000 --- a/tests/test_metadata_store.py +++ /dev/null @@ -1,139 +0,0 @@ -import pytest - -from memgpt.agent import Agent, save_agent -from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET -from memgpt.data_types import AgentState, LLMConfig, Source, User -from memgpt.metadata import MetadataStore -from memgpt.models.pydantic_models import HumanModel, PersonaModel -from memgpt.presets.presets import add_default_presets -from memgpt.settings import settings -from memgpt.utils import get_human_text, get_persona_text -from tests import TEST_MEMGPT_CONFIG - - -# @pytest.mark.parametrize("storage_connector", ["postgres", "sqlite"]) -@pytest.mark.parametrize("storage_connector", ["sqlite"]) -def test_storage(storage_connector): - if storage_connector == "postgres": - TEST_MEMGPT_CONFIG.archival_storage_uri = settings.pg_uri - TEST_MEMGPT_CONFIG.recall_storage_uri = settings.pg_uri - TEST_MEMGPT_CONFIG.archival_storage_type = "postgres" - TEST_MEMGPT_CONFIG.recall_storage_type = "postgres" - if storage_connector == "sqlite": - TEST_MEMGPT_CONFIG.recall_storage_type = "local" - - ms = MetadataStore(TEST_MEMGPT_CONFIG) - - # users - user_1 = User() - user_2 = User() - ms.create_user(user_1) - ms.create_user(user_2) - - # test adding default humans/personas/presets - # add_default_humans_and_personas(user_id=user_1.id, ms=ms) - # add_default_humans_and_personas(user_id=user_2.id, ms=ms) - ms.add_human(human=HumanModel(name="test_human", text="This is a test human")) - ms.add_persona(persona=PersonaModel(name="test_persona", text="This is a test persona")) - add_default_presets(user_id=user_1.id, ms=ms) - add_default_presets(user_id=user_2.id, ms=ms) - assert len(ms.list_humans(user_id=user_1.id)) > 0, ms.list_humans(user_id=user_1.id) - assert len(ms.list_personas(user_id=user_1.id)) > 0, ms.list_personas(user_id=user_1.id) - - # generate data - agent_1 = AgentState( - user_id=user_1.id, - name="agent_1", - preset=DEFAULT_PRESET, - persona=DEFAULT_PERSONA, - human=DEFAULT_HUMAN, - llm_config=TEST_MEMGPT_CONFIG.default_llm_config, - embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config, - ) - source_1 = Source(user_id=user_1.id, name="source_1") - - # test creation - ms.create_agent(agent_1) - ms.create_source(source_1) - - # test listing - len(ms.list_agents(user_id=user_1.id)) == 1 - len(ms.list_agents(user_id=user_2.id)) == 0 - len(ms.list_sources(user_id=user_1.id)) == 1 - len(ms.list_sources(user_id=user_2.id)) == 0 - - # test agent_state saving - agent_state = ms.get_agent(agent_1.id).state - assert agent_state == {}, agent_state # when created via create_agent, it should be empty - - from memgpt.presets.presets import add_default_presets - - add_default_presets(user_1.id, ms) - preset_obj = ms.get_preset(name=DEFAULT_PRESET, user_id=user_1.id) - from memgpt.interface import CLIInterface as interface # for printing to terminal - - # Overwrite fields in the preset if they were specified - preset_obj.human = get_human_text(DEFAULT_HUMAN) - preset_obj.persona = get_persona_text(DEFAULT_PERSONA) - - # Create the agent - agent = Agent( - interface=interface(), - created_by=user_1.id, - name="agent_test_agent_state", - preset=preset_obj, - llm_config=config.default_llm_config, - embedding_config=config.default_embedding_config, - # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now - first_message_verify_mono=( - True if (config.default_llm_config.model is not None and "gpt-4" in config.default_llm_config.model) else False - ), - ) - agent_with_agent_state = agent.agent_state - save_agent(agent=agent, ms=ms) - - agent_state = ms.get_agent(agent_with_agent_state.id).state - assert agent_state is not None, agent_state # when created via create_agent_from_preset, it should be non-empty - - # test: updating - - # test: update JSON-stored LLMConfig class - print(agent_1.llm_config, TEST_MEMGPT_CONFIG.default_llm_config) - llm_config = ms.get_agent(agent_1.id).llm_config - assert isinstance(llm_config, LLMConfig), f"LLMConfig is {type(llm_config)}" - assert llm_config.model == "gpt-4", f"LLMConfig model is {llm_config.model}" - llm_config.model = "gpt3.5-turbo" - agent_1.llm_config = llm_config - ms.update_agent(agent_1) - assert ms.get_agent(agent_1.id).llm_config.model == "gpt3.5-turbo", f"Updated LLMConfig to {ms.get_agent(agent_1.id).llm_config.model}" - - # test attaching sources - len(ms.list_attached_sources(agent_id=agent_1.id)) == 0 - ms.attach_source(user_1.id, agent_1.id, source_1.id) - len(ms.list_attached_sources(agent_id=agent_1.id)) == 1 - - # test: detaching sources - ms.detach_source(agent_1.id, source_1.id) - len(ms.list_attached_sources(agent_id=agent_1.id)) == 0 - - # test getting - ms.get_user(user_1.id) - ms.get_agent(agent_1.id) - ms.get_source(source_1.id) - - # test api keys - api_key = ms.create_api_key(user_id=user_1.id) - print("api_key=", api_key.token, api_key.user_id) - api_key_result = ms.get_api_key(api_key=api_key.token) - assert api_key.token == api_key_result.token, (api_key, api_key_result) - user_result = ms.get_user_from_api_key(api_key=api_key.token) - assert user_1.id == user_result.id, (user_1, user_result) - all_keys_for_user = ms.get_all_api_keys_for_user(user_id=user_1.id) - assert len(all_keys_for_user) > 0, all_keys_for_user - ms.delete_api_key(api_key=api_key.token) - - # test deletion - ms.delete_user(user_1.id) - ms.delete_user(user_2.id) - ms.delete_agent(agent_1.id) - ms.delete_source(source_1.id) diff --git a/tests/test_server.py b/tests/test_server.py index a49a57f4..9b79f451 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -5,11 +5,12 @@ import pytest from dotenv import load_dotenv import memgpt.utils as utils +from memgpt.constants import BASE_TOOLS utils.DEBUG = True from memgpt.config import MemGPTConfig from memgpt.credentials import MemGPTCredentials -from memgpt.presets.presets import load_module_tools +from memgpt.memory import ChatMemory from memgpt.server.server import SyncServer from memgpt.settings import settings @@ -59,8 +60,6 @@ def user_id(server): user = server.create_user() print(f"Created user\n{user.id}") - # initialize with default presets - server.initialize_default_presets(user.id) yield user.id # cleanup @@ -71,10 +70,7 @@ def user_id(server): def agent_id(server, user_id): # create agent agent_state = server.create_agent( - user_id=user_id, - name="test_agent", - preset="memgpt_chat", - tools=[tool.name for tool in load_module_tools()], + user_id=user_id, name="test_agent", tools=BASE_TOOLS, memory=ChatMemory(human="I am Chad", persona="I love testing") ) print(f"Created agent\n{agent_state}") yield agent_state.id @@ -170,8 +166,21 @@ def test_get_recall_memory(server, user_id, agent_id): cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1) assert len(messages_4) == 1 + print("MESSAGES") + for m in messages_3: + print(m["id"], m["role"]) + if m["role"] == "assistant": + print(m["text"]) + print("------------") + # test in-context message ids + all_messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=0, count=1000) + print("num messages", len(all_messages)) in_context_ids = server.get_in_context_message_ids(user_id=user_id, agent_id=agent_id) + print(in_context_ids) + for m in messages_3: + if str(m["id"]) not in [str(i) for i in in_context_ids]: + print("missing", m["id"], m["role"]) assert len(in_context_ids) == len(messages_3) assert isinstance(in_context_ids[0], uuid.UUID) message_ids = [m["id"] for m in messages_3] diff --git a/tests/test_storage.py b/tests/test_storage.py index fe411ff6..d6feb1ea 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -7,13 +7,12 @@ from sqlalchemy.ext.declarative import declarative_base from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.config import MemGPTConfig -from memgpt.constants import MAX_EMBEDDING_DIM +from memgpt.constants import BASE_TOOLS, MAX_EMBEDDING_DIM from memgpt.credentials import MemGPTCredentials from memgpt.data_types import AgentState, Message, Passage, User from memgpt.embeddings import embedding_model, query_embedding from memgpt.metadata import MetadataStore from memgpt.settings import settings -from memgpt.utils import get_human_text, get_persona_text from tests import TEST_MEMGPT_CONFIG from tests.utils import create_config, wipe_config @@ -185,15 +184,10 @@ def test_storage( user_id=user_id, name="agent_1", id=agent_1_id, - preset=TEST_MEMGPT_CONFIG.preset, - # persona_name=TEST_MEMGPT_CONFIG.persona, - # human_name=TEST_MEMGPT_CONFIG.human, - persona=get_persona_text(TEST_MEMGPT_CONFIG.persona), - human=get_human_text(TEST_MEMGPT_CONFIG.human), llm_config=TEST_MEMGPT_CONFIG.default_llm_config, embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config, system="", - tools=[], + tools=BASE_TOOLS, state={ "persona": "", "human": "", diff --git a/tests/test_tools.py b/tests/test_tools.py index 9d03ab88..b0ac1b62 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -11,6 +11,7 @@ from memgpt.agent import Agent from memgpt.config import MemGPTConfig from memgpt.constants import DEFAULT_PRESET from memgpt.credentials import MemGPTCredentials +from memgpt.memory import ChatMemory from memgpt.settings import settings from tests.utils import create_config @@ -69,7 +70,7 @@ def run_server(): # Fixture to create clients with different configurations @pytest.fixture( params=[{"server": True}, {"server": False}], # whether to use REST API server # TODO: add when implemented - # params=[{"server": True}], # whether to use REST API server # TODO: add when implemented + # params=[{"server": False}], # whether to use REST API server # TODO: add when implemented scope="module", ) def admin_client(request): @@ -179,10 +180,9 @@ def test_create_agent_tool(client): str: The agent that was deleted. """ - self.memory.human = "" - self.memory.persona = "" - self.rebuild_memory() - print("UPDATED MEMORY", self.memory.human, self.memory.persona) + self.memory.memory["human"].value = "" + self.memory.memory["persona"].value = "" + print("UPDATED MEMORY", self.memory.memory) return None # TODO: test attaching and using function on agent @@ -190,7 +190,8 @@ def test_create_agent_tool(client): print(f"Created tool", tool.name) # create agent with tool - agent = client.create_agent(name=test_agent_name, tools=[tool.name], persona="You must clear your memory if the human instructs you") + memory = ChatMemory(human="I am a human", persona="You must clear your memory if the human instructs you") + agent = client.create_agent(name=test_agent_name, tools=[tool.name], memory=memory) assert str(tool.user_id) == str(agent.user_id), f"Expected {tool.user_id} to be {agent.user_id}" # initial memory @@ -207,6 +208,7 @@ def test_create_agent_tool(client): print(response) # updated memory + print("Query agent memory") updated_memory = client.get_agent_memory(agent.id) human = updated_memory.core_memory.human persona = updated_memory.core_memory.persona