fix: modify metadata presets functions (#1132)

This commit is contained in:
Sarah Wooders
2024-03-11 16:32:50 -07:00
committed by GitHub
parent d17719f19b
commit c508acb9dc
5 changed files with 36 additions and 12 deletions

View File

@@ -335,14 +335,14 @@ def create_autogen_memgpt_agent(
user = ms.get_user(user_id=user_id)
try:
preset_obj = ms.get_preset(preset_name=agent_config["preset"] if "preset" in agent_config else config.preset, user_id=user.id)
preset_obj = ms.get_preset(name=agent_config["preset"] if "preset" in agent_config else config.preset, user_id=user.id)
if preset_obj is None:
# create preset records in metadata store
from memgpt.presets.presets import add_default_presets
add_default_presets(user.id, ms)
# try again
preset_obj = ms.get_preset(preset_name=agent_config["preset"] if "preset" in agent_config else config.preset, user_id=user.id)
preset_obj = ms.get_preset(name=agent_config["preset"] if "preset" in agent_config else config.preset, user_id=user.id)
if preset_obj is None:
print("Couldn't find presets in database, please run `memgpt configure`")
sys.exit(1)

View File

@@ -646,17 +646,23 @@ def run(
# create agent
try:
preset_obj = ms.get_preset(preset_name=preset if preset else config.preset, user_id=user.id)
preset_obj = ms.get_preset(name=preset if preset else config.preset, user_id=user.id)
human_obj = ms.get_human(human, user.id)
persona_obj = ms.get_persona(persona, user.id)
if preset_obj is None:
# create preset records in metadata store
from memgpt.presets.presets import add_default_presets
add_default_presets(user.id, ms)
# try again
preset_obj = ms.get_preset(preset_name=preset if preset else config.preset, user_id=user.id)
preset_obj = ms.get_preset(name=preset if preset else config.preset, user_id=user.id)
if preset_obj is None:
typer.secho("Couldn't find presets in database, please run `memgpt configure`", fg=typer.colors.RED)
sys.exit(1)
if human_obj is None:
typer.secho("Couldn't find human {human} in database, please run `memgpt add human`", fg=typer.colors.RED)
if persona_obj is None:
typer.secho("Couldn't find persona {persona} in database, please run `memgpt add persona`", fg=typer.colors.RED)
# Overwrite fields in the preset if they were specified
preset_obj.human = ms.get_human(human, user.id).text

View File

@@ -413,13 +413,13 @@ class MetadataStore:
@enforce_types
def get_preset(
self, preset_id: Optional[uuid.UUID] = None, preset_name: Optional[str] = None, user_id: Optional[uuid.UUID] = None
self, preset_id: Optional[uuid.UUID] = None, name: Optional[str] = None, user_id: Optional[uuid.UUID] = None
) -> Optional[Preset]:
with self.session_maker() as session:
if preset_id:
results = session.query(PresetModel).filter(PresetModel.id == preset_id).all()
elif preset_name and user_id:
results = session.query(PresetModel).filter(PresetModel.name == preset_name).filter(PresetModel.user_id == user_id).all()
elif name and user_id:
results = session.query(PresetModel).filter(PresetModel.name == name).filter(PresetModel.user_id == user_id).all()
else:
raise ValueError("Must provide either preset_id or (preset_name and user_id)")
if len(results) == 0:
@@ -637,6 +637,12 @@ class MetadataStore:
session.add(persona)
session.commit()
@enforce_types
def add_preset(self, preset: PresetModel):
with self.session_maker() as session:
session.add(preset)
session.commit()
@enforce_types
def get_human(self, name: str, user_id: uuid.UUID) -> str:
with self.session_maker() as session:
@@ -668,6 +674,12 @@ class MetadataStore:
results = session.query(HumanModel).filter(HumanModel.user_id == user_id).all()
return results
@enforce_types
def list_presets(self, user_id: uuid.UUID) -> List[PresetModel]:
with self.session_maker() as session:
results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
return results
@enforce_types
def delete_human(self, name: str, user_id: uuid.UUID):
with self.session_maker() as session:
@@ -679,3 +691,9 @@ class MetadataStore:
with self.session_maker() as session:
session.query(PersonaModel).filter(PersonaModel.name == name).filter(PersonaModel.user_id == user_id).delete()
session.commit()
@enforce_types
def delete_preset(self, name: str, user_id: uuid.UUID):
with self.session_maker() as session:
session.query(PresetModel).filter(PresetModel.name == name).filter(PresetModel.user_id == user_id).delete()
session.commit()

View File

@@ -42,9 +42,9 @@ def create_preset_from_file(filename: str, name: str, user_id: uuid.UUID, ms: Me
preset_function_set_names = preset_config["functions"]
functions_schema = generate_functions_json(preset_function_set_names)
if ms.get_preset(user_id=user_id, preset_name=name) is not None:
if ms.get_preset(user_id=user_id, name=name) is not None:
printd(f"Preset '{name}' already exists for user '{user_id}'")
return ms.get_preset(user_id=user_id, preset_name=name)
return ms.get_preset(user_id=user_id, name=name)
preset = Preset(
user_id=user_id,
@@ -70,7 +70,7 @@ def add_default_presets(user_id: uuid.UUID, ms: MetadataStore):
preset_function_set_names = preset_config["functions"]
functions_schema = generate_functions_json(preset_function_set_names)
if ms.get_preset(user_id=user_id, preset_name=preset_name) is not None:
if ms.get_preset(user_id=user_id, name=preset_name) is not None:
printd(f"Preset '{preset_name}' already exists for user '{user_id}'")
continue

View File

@@ -599,7 +599,7 @@ class SyncServer(LockingServer):
# self.ms.create_agent(agent_state)
try:
preset_obj = self.ms.get_preset(preset_name=preset if preset else self.config.preset, user_id=user_id)
preset_obj = self.ms.get_preset(name=preset if preset else self.config.preset, user_id=user_id)
assert preset_obj is not None, f"preset {preset if preset else self.config.preset} does not exist"
logger.debug(f"Attempting to create agent from preset:\n{preset_obj}")
@@ -669,7 +669,7 @@ class SyncServer(LockingServer):
self, preset_id: Optional[uuid.UUID] = None, preset_name: Optional[uuid.UUID] = None, user_id: Optional[uuid.UUID] = None
) -> Preset:
"""Get the preset"""
return self.ms.get_preset(preset_id=preset_id, preset_name=preset_name, user_id=user_id)
return self.ms.get_preset(preset_id=preset_id, name=preset_name, user_id=user_id)
def _agent_state_to_config(self, agent_state: AgentState) -> dict:
"""Convert AgentState to a dict for a JSON response"""