From c508acb9dc6142febdc5d6d0f6ef6a2e4427f9d8 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Mon, 11 Mar 2024 16:32:50 -0700 Subject: [PATCH] fix: modify metadata presets functions (#1132) --- memgpt/autogen/memgpt_agent.py | 4 ++-- memgpt/cli/cli.py | 10 ++++++++-- memgpt/metadata.py | 24 +++++++++++++++++++++--- memgpt/presets/presets.py | 6 +++--- memgpt/server/server.py | 4 ++-- 5 files changed, 36 insertions(+), 12 deletions(-) diff --git a/memgpt/autogen/memgpt_agent.py b/memgpt/autogen/memgpt_agent.py index 91e3cd8a..e36483a4 100644 --- a/memgpt/autogen/memgpt_agent.py +++ b/memgpt/autogen/memgpt_agent.py @@ -335,14 +335,14 @@ def create_autogen_memgpt_agent( user = ms.get_user(user_id=user_id) try: - preset_obj = ms.get_preset(preset_name=agent_config["preset"] if "preset" in agent_config else config.preset, user_id=user.id) + preset_obj = ms.get_preset(name=agent_config["preset"] if "preset" in agent_config else config.preset, user_id=user.id) if preset_obj is None: # create preset records in metadata store from memgpt.presets.presets import add_default_presets add_default_presets(user.id, ms) # try again - preset_obj = ms.get_preset(preset_name=agent_config["preset"] if "preset" in agent_config else config.preset, user_id=user.id) + preset_obj = ms.get_preset(name=agent_config["preset"] if "preset" in agent_config else config.preset, user_id=user.id) if preset_obj is None: print("Couldn't find presets in database, please run `memgpt configure`") sys.exit(1) diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index 8d6c5794..57769572 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -646,17 +646,23 @@ def run( # create agent try: - preset_obj = ms.get_preset(preset_name=preset if preset else config.preset, user_id=user.id) + preset_obj = ms.get_preset(name=preset if preset else config.preset, user_id=user.id) + human_obj = ms.get_human(human, user.id) + persona_obj = ms.get_persona(persona, user.id) if preset_obj is None: # create preset records in metadata store from memgpt.presets.presets import add_default_presets add_default_presets(user.id, ms) # try again - preset_obj = ms.get_preset(preset_name=preset if preset else config.preset, user_id=user.id) + preset_obj = ms.get_preset(name=preset if preset else config.preset, user_id=user.id) if preset_obj is None: typer.secho("Couldn't find presets in database, please run `memgpt configure`", fg=typer.colors.RED) sys.exit(1) + if human_obj is None: + typer.secho("Couldn't find human {human} in database, please run `memgpt add human`", fg=typer.colors.RED) + if persona_obj is None: + typer.secho("Couldn't find persona {persona} in database, please run `memgpt add persona`", fg=typer.colors.RED) # Overwrite fields in the preset if they were specified preset_obj.human = ms.get_human(human, user.id).text diff --git a/memgpt/metadata.py b/memgpt/metadata.py index e6d2fce4..50a49870 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -413,13 +413,13 @@ class MetadataStore: @enforce_types def get_preset( - self, preset_id: Optional[uuid.UUID] = None, preset_name: Optional[str] = None, user_id: Optional[uuid.UUID] = None + self, preset_id: Optional[uuid.UUID] = None, name: Optional[str] = None, user_id: Optional[uuid.UUID] = None ) -> Optional[Preset]: with self.session_maker() as session: if preset_id: results = session.query(PresetModel).filter(PresetModel.id == preset_id).all() - elif preset_name and user_id: - results = session.query(PresetModel).filter(PresetModel.name == preset_name).filter(PresetModel.user_id == user_id).all() + elif name and user_id: + results = session.query(PresetModel).filter(PresetModel.name == name).filter(PresetModel.user_id == user_id).all() else: raise ValueError("Must provide either preset_id or (preset_name and user_id)") if len(results) == 0: @@ -637,6 +637,12 @@ class MetadataStore: session.add(persona) session.commit() + @enforce_types + def add_preset(self, preset: PresetModel): + with self.session_maker() as session: + session.add(preset) + session.commit() + @enforce_types def get_human(self, name: str, user_id: uuid.UUID) -> str: with self.session_maker() as session: @@ -668,6 +674,12 @@ class MetadataStore: results = session.query(HumanModel).filter(HumanModel.user_id == user_id).all() return results + @enforce_types + def list_presets(self, user_id: uuid.UUID) -> List[PresetModel]: + with self.session_maker() as session: + results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all() + return results + @enforce_types def delete_human(self, name: str, user_id: uuid.UUID): with self.session_maker() as session: @@ -679,3 +691,9 @@ class MetadataStore: with self.session_maker() as session: session.query(PersonaModel).filter(PersonaModel.name == name).filter(PersonaModel.user_id == user_id).delete() session.commit() + + @enforce_types + def delete_preset(self, name: str, user_id: uuid.UUID): + with self.session_maker() as session: + session.query(PresetModel).filter(PresetModel.name == name).filter(PresetModel.user_id == user_id).delete() + session.commit() diff --git a/memgpt/presets/presets.py b/memgpt/presets/presets.py index 810f415f..5505900a 100644 --- a/memgpt/presets/presets.py +++ b/memgpt/presets/presets.py @@ -42,9 +42,9 @@ def create_preset_from_file(filename: str, name: str, user_id: uuid.UUID, ms: Me preset_function_set_names = preset_config["functions"] functions_schema = generate_functions_json(preset_function_set_names) - if ms.get_preset(user_id=user_id, preset_name=name) is not None: + if ms.get_preset(user_id=user_id, name=name) is not None: printd(f"Preset '{name}' already exists for user '{user_id}'") - return ms.get_preset(user_id=user_id, preset_name=name) + return ms.get_preset(user_id=user_id, name=name) preset = Preset( user_id=user_id, @@ -70,7 +70,7 @@ def add_default_presets(user_id: uuid.UUID, ms: MetadataStore): preset_function_set_names = preset_config["functions"] functions_schema = generate_functions_json(preset_function_set_names) - if ms.get_preset(user_id=user_id, preset_name=preset_name) is not None: + if ms.get_preset(user_id=user_id, name=preset_name) is not None: printd(f"Preset '{preset_name}' already exists for user '{user_id}'") continue diff --git a/memgpt/server/server.py b/memgpt/server/server.py index c8b1f514..9911f405 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -599,7 +599,7 @@ class SyncServer(LockingServer): # self.ms.create_agent(agent_state) try: - preset_obj = self.ms.get_preset(preset_name=preset if preset else self.config.preset, user_id=user_id) + preset_obj = self.ms.get_preset(name=preset if preset else self.config.preset, user_id=user_id) assert preset_obj is not None, f"preset {preset if preset else self.config.preset} does not exist" logger.debug(f"Attempting to create agent from preset:\n{preset_obj}") @@ -669,7 +669,7 @@ class SyncServer(LockingServer): self, preset_id: Optional[uuid.UUID] = None, preset_name: Optional[uuid.UUID] = None, user_id: Optional[uuid.UUID] = None ) -> Preset: """Get the preset""" - return self.ms.get_preset(preset_id=preset_id, preset_name=preset_name, user_id=user_id) + return self.ms.get_preset(preset_id=preset_id, name=preset_name, user_id=user_id) def _agent_state_to_config(self, agent_state: AgentState) -> dict: """Convert AgentState to a dict for a JSON response"""