diff --git a/main.py b/main.py index 636c6bf6..46e3618c 100644 --- a/main.py +++ b/main.py @@ -19,8 +19,8 @@ import memgpt.humans.humans as humans from memgpt.persistence_manager import InMemoryStateManager as persistence_manager FLAGS = flags.FLAGS -flags.DEFINE_string("persona", default=personas.DEFAULT, required=False, help="Specify persona") -flags.DEFINE_string("human", default=humans.DEFAULT, required=False, help="Specify human") +flags.DEFINE_string("persona", default=None, required=False, help="Specify persona") +flags.DEFINE_string("human", default=None, required=False, help="Specify human") flags.DEFINE_string("model", default=constants.DEFAULT_MEMGPT_MODEL, required=False, help="Specify the LLM model") flags.DEFINE_boolean("first", default=False, required=False, help="Use -first to send the first message in the sequence") flags.DEFINE_boolean("debug", default=False, required=False, help="Use -debug to enable debugging output") @@ -43,7 +43,14 @@ async def main(): logging.getLogger().setLevel(logging.DEBUG) print("Running... [exit by typing 'exit']") - memgpt_agent = presets.use_preset(presets.DEFAULT, FLAGS.model, personas.get_persona_text(FLAGS.persona), humans.get_human_text(), interface, persistence_manager()) + if FLAGS.model != constants.DEFAULT_MEMGPT_MODEL: + print(f"Warning - you are running MemGPT with {FLAGS.model}, which is not officially supported (yet). Expect bugs!") + + # Moved defaults out of FLAGS so that we can dynamically select the default persona based on model + chosen_human = FLAGS.human if FLAGS.human is not None else humans.DEFAULT + chosen_persona = FLAGS.persona if FLAGS.persona is not None else (personas.GPT35_DEFAULT if 'gpt-3.5' in flags.MODEL else personas.DEFAULT) + + memgpt_agent = presets.use_preset(presets.DEFAULT, FLAGS.model, personas.get_persona_text(chosen_persona), humans.get_human_text(chosen_human), interface, persistence_manager()) print_messages = interface.print_messages await print_messages(memgpt_agent.messages) diff --git a/memgpt/personas/personas.py b/memgpt/personas/personas.py index 50576493..a0d1ffb5 100644 --- a/memgpt/personas/personas.py +++ b/memgpt/personas/personas.py @@ -1,7 +1,7 @@ import os DEFAULT = 'sam' -GPT35_DEFAULT = 'sam' +GPT35_DEFAULT = 'sam_simple_pov_gpt35' def get_persona_text(key=DEFAULT):