fix: modify metadata presets functions (#1132)
This commit is contained in:
@@ -335,14 +335,14 @@ def create_autogen_memgpt_agent(
|
||||
user = ms.get_user(user_id=user_id)
|
||||
|
||||
try:
|
||||
preset_obj = ms.get_preset(preset_name=agent_config["preset"] if "preset" in agent_config else config.preset, user_id=user.id)
|
||||
preset_obj = ms.get_preset(name=agent_config["preset"] if "preset" in agent_config else config.preset, user_id=user.id)
|
||||
if preset_obj is None:
|
||||
# create preset records in metadata store
|
||||
from memgpt.presets.presets import add_default_presets
|
||||
|
||||
add_default_presets(user.id, ms)
|
||||
# try again
|
||||
preset_obj = ms.get_preset(preset_name=agent_config["preset"] if "preset" in agent_config else config.preset, user_id=user.id)
|
||||
preset_obj = ms.get_preset(name=agent_config["preset"] if "preset" in agent_config else config.preset, user_id=user.id)
|
||||
if preset_obj is None:
|
||||
print("Couldn't find presets in database, please run `memgpt configure`")
|
||||
sys.exit(1)
|
||||
|
||||
@@ -646,17 +646,23 @@ def run(
|
||||
|
||||
# create agent
|
||||
try:
|
||||
preset_obj = ms.get_preset(preset_name=preset if preset else config.preset, user_id=user.id)
|
||||
preset_obj = ms.get_preset(name=preset if preset else config.preset, user_id=user.id)
|
||||
human_obj = ms.get_human(human, user.id)
|
||||
persona_obj = ms.get_persona(persona, user.id)
|
||||
if preset_obj is None:
|
||||
# create preset records in metadata store
|
||||
from memgpt.presets.presets import add_default_presets
|
||||
|
||||
add_default_presets(user.id, ms)
|
||||
# try again
|
||||
preset_obj = ms.get_preset(preset_name=preset if preset else config.preset, user_id=user.id)
|
||||
preset_obj = ms.get_preset(name=preset if preset else config.preset, user_id=user.id)
|
||||
if preset_obj is None:
|
||||
typer.secho("Couldn't find presets in database, please run `memgpt configure`", fg=typer.colors.RED)
|
||||
sys.exit(1)
|
||||
if human_obj is None:
|
||||
typer.secho("Couldn't find human {human} in database, please run `memgpt add human`", fg=typer.colors.RED)
|
||||
if persona_obj is None:
|
||||
typer.secho("Couldn't find persona {persona} in database, please run `memgpt add persona`", fg=typer.colors.RED)
|
||||
|
||||
# Overwrite fields in the preset if they were specified
|
||||
preset_obj.human = ms.get_human(human, user.id).text
|
||||
|
||||
@@ -413,13 +413,13 @@ class MetadataStore:
|
||||
|
||||
@enforce_types
|
||||
def get_preset(
|
||||
self, preset_id: Optional[uuid.UUID] = None, preset_name: Optional[str] = None, user_id: Optional[uuid.UUID] = None
|
||||
self, preset_id: Optional[uuid.UUID] = None, name: Optional[str] = None, user_id: Optional[uuid.UUID] = None
|
||||
) -> Optional[Preset]:
|
||||
with self.session_maker() as session:
|
||||
if preset_id:
|
||||
results = session.query(PresetModel).filter(PresetModel.id == preset_id).all()
|
||||
elif preset_name and user_id:
|
||||
results = session.query(PresetModel).filter(PresetModel.name == preset_name).filter(PresetModel.user_id == user_id).all()
|
||||
elif name and user_id:
|
||||
results = session.query(PresetModel).filter(PresetModel.name == name).filter(PresetModel.user_id == user_id).all()
|
||||
else:
|
||||
raise ValueError("Must provide either preset_id or (preset_name and user_id)")
|
||||
if len(results) == 0:
|
||||
@@ -637,6 +637,12 @@ class MetadataStore:
|
||||
session.add(persona)
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def add_preset(self, preset: PresetModel):
|
||||
with self.session_maker() as session:
|
||||
session.add(preset)
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def get_human(self, name: str, user_id: uuid.UUID) -> str:
|
||||
with self.session_maker() as session:
|
||||
@@ -668,6 +674,12 @@ class MetadataStore:
|
||||
results = session.query(HumanModel).filter(HumanModel.user_id == user_id).all()
|
||||
return results
|
||||
|
||||
@enforce_types
|
||||
def list_presets(self, user_id: uuid.UUID) -> List[PresetModel]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
|
||||
return results
|
||||
|
||||
@enforce_types
|
||||
def delete_human(self, name: str, user_id: uuid.UUID):
|
||||
with self.session_maker() as session:
|
||||
@@ -679,3 +691,9 @@ class MetadataStore:
|
||||
with self.session_maker() as session:
|
||||
session.query(PersonaModel).filter(PersonaModel.name == name).filter(PersonaModel.user_id == user_id).delete()
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def delete_preset(self, name: str, user_id: uuid.UUID):
|
||||
with self.session_maker() as session:
|
||||
session.query(PresetModel).filter(PresetModel.name == name).filter(PresetModel.user_id == user_id).delete()
|
||||
session.commit()
|
||||
|
||||
@@ -42,9 +42,9 @@ def create_preset_from_file(filename: str, name: str, user_id: uuid.UUID, ms: Me
|
||||
preset_function_set_names = preset_config["functions"]
|
||||
functions_schema = generate_functions_json(preset_function_set_names)
|
||||
|
||||
if ms.get_preset(user_id=user_id, preset_name=name) is not None:
|
||||
if ms.get_preset(user_id=user_id, name=name) is not None:
|
||||
printd(f"Preset '{name}' already exists for user '{user_id}'")
|
||||
return ms.get_preset(user_id=user_id, preset_name=name)
|
||||
return ms.get_preset(user_id=user_id, name=name)
|
||||
|
||||
preset = Preset(
|
||||
user_id=user_id,
|
||||
@@ -70,7 +70,7 @@ def add_default_presets(user_id: uuid.UUID, ms: MetadataStore):
|
||||
preset_function_set_names = preset_config["functions"]
|
||||
functions_schema = generate_functions_json(preset_function_set_names)
|
||||
|
||||
if ms.get_preset(user_id=user_id, preset_name=preset_name) is not None:
|
||||
if ms.get_preset(user_id=user_id, name=preset_name) is not None:
|
||||
printd(f"Preset '{preset_name}' already exists for user '{user_id}'")
|
||||
continue
|
||||
|
||||
|
||||
@@ -599,7 +599,7 @@ class SyncServer(LockingServer):
|
||||
# self.ms.create_agent(agent_state)
|
||||
|
||||
try:
|
||||
preset_obj = self.ms.get_preset(preset_name=preset if preset else self.config.preset, user_id=user_id)
|
||||
preset_obj = self.ms.get_preset(name=preset if preset else self.config.preset, user_id=user_id)
|
||||
assert preset_obj is not None, f"preset {preset if preset else self.config.preset} does not exist"
|
||||
logger.debug(f"Attempting to create agent from preset:\n{preset_obj}")
|
||||
|
||||
@@ -669,7 +669,7 @@ class SyncServer(LockingServer):
|
||||
self, preset_id: Optional[uuid.UUID] = None, preset_name: Optional[uuid.UUID] = None, user_id: Optional[uuid.UUID] = None
|
||||
) -> Preset:
|
||||
"""Get the preset"""
|
||||
return self.ms.get_preset(preset_id=preset_id, preset_name=preset_name, user_id=user_id)
|
||||
return self.ms.get_preset(preset_id=preset_id, name=preset_name, user_id=user_id)
|
||||
|
||||
def _agent_state_to_config(self, agent_state: AgentState) -> dict:
|
||||
"""Convert AgentState to a dict for a JSON response"""
|
||||
|
||||
Reference in New Issue
Block a user