diff --git a/memgpt/humans/humans.py b/memgpt/humans/humans.py index 2d325438..534ecdbb 100644 --- a/memgpt/humans/humans.py +++ b/memgpt/humans/humans.py @@ -3,7 +3,7 @@ import os DEFAULT = 'cs_phd' def get_human_text(key=DEFAULT): - filename = f'{key}.txt' + filename = key if key.endswith('.txt') else f'{key}.txt' file_path = os.path.join(os.path.dirname(__file__), 'examples', filename) if os.path.exists(file_path): diff --git a/memgpt/personas/personas.py b/memgpt/personas/personas.py index 1eb74315..d86dd270 100644 --- a/memgpt/personas/personas.py +++ b/memgpt/personas/personas.py @@ -4,9 +4,9 @@ DEFAULT = 'sam' def get_persona_text(key=DEFAULT): - filename = f'{key}.txt' + filename = key if key.endswith('.txt') else f'{key}.txt' file_path = os.path.join(os.path.dirname(__file__), 'examples', filename) - + if os.path.exists(file_path): with open(file_path, 'r') as file: return file.read().strip()