diff --git a/memgpt/__main__.py b/memgpt/__main__.py index 2310408d..89f11424 100644 --- a/memgpt/__main__.py +++ b/memgpt/__main__.py @@ -1,2 +1,3 @@ from .main import app + app() diff --git a/memgpt/agent.py b/memgpt/agent.py index 6e5de9f7..85ce0df3 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -11,10 +11,15 @@ from .system import get_heartbeat, get_login_event, package_function_response, p from .memory import CoreMemory as Memory, summarize_messages from .openai_tools import acompletions_with_backoff as acreate from .utils import get_local_time, parse_json, united_diff, printd, count_tokens -from .constants import \ - FIRST_MESSAGE_ATTEMPTS, MAX_PAUSE_HEARTBEATS, \ - MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MESSAGE_SUMMARY_WARNING_TOKENS, \ - CORE_MEMORY_HUMAN_CHAR_LIMIT, CORE_MEMORY_PERSONA_CHAR_LIMIT +from .constants import ( + FIRST_MESSAGE_ATTEMPTS, + MAX_PAUSE_HEARTBEATS, + MESSAGE_CHATGPT_FUNCTION_MODEL, + MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, + MESSAGE_SUMMARY_WARNING_TOKENS, + CORE_MEMORY_HUMAN_CHAR_LIMIT, + CORE_MEMORY_PERSONA_CHAR_LIMIT, +) def initialize_memory(ai_notes, human_notes): @@ -28,52 +33,57 @@ def initialize_memory(ai_notes, human_notes): return memory -def construct_system_with_memory( - system, memory, memory_edit_timestamp, - archival_memory=None, recall_memory=None - ): - full_system_message = "\n".join([ - system, - "\n", - f"### Memory [last modified: {memory_edit_timestamp}", - f"{len(recall_memory) if recall_memory else 0} previous messages between you and the user are stored in recall memory (use functions to access them)", - f"{len(archival_memory) if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)", - "\nCore memory shown below (limited in size, additional information stored in archival / recall memory):", - "", - memory.persona, - "", - "", - memory.human, - "", - ]) +def construct_system_with_memory(system, memory, memory_edit_timestamp, archival_memory=None, recall_memory=None): + full_system_message = "\n".join( + [ + system, + "\n", + f"### Memory [last modified: {memory_edit_timestamp}", + f"{len(recall_memory) if recall_memory else 0} previous messages between you and the user are stored in recall memory (use functions to access them)", + f"{len(archival_memory) if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)", + "\nCore memory shown below (limited in size, additional information stored in archival / recall memory):", + "", + memory.persona, + "", + "", + memory.human, + "", + ] + ) return full_system_message def initialize_message_sequence( - model, - system, - memory, - archival_memory=None, - recall_memory=None, - memory_edit_timestamp=None, - include_initial_boot_message=True, - ): + model, + system, + memory, + archival_memory=None, + recall_memory=None, + memory_edit_timestamp=None, + include_initial_boot_message=True, +): if memory_edit_timestamp is None: memory_edit_timestamp = get_local_time() - full_system_message = construct_system_with_memory(system, memory, memory_edit_timestamp, archival_memory=archival_memory, recall_memory=recall_memory) + full_system_message = construct_system_with_memory( + system, memory, memory_edit_timestamp, archival_memory=archival_memory, recall_memory=recall_memory + ) first_user_message = get_login_event() # event letting MemGPT know the user just logged in if include_initial_boot_message: - if 'gpt-3.5' in model: - initial_boot_messages = get_initial_boot_messages('startup_with_send_message_gpt35') + if "gpt-3.5" in model: + initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35") else: - initial_boot_messages = get_initial_boot_messages('startup_with_send_message') - messages = [ - {"role": "system", "content": full_system_message}, - ] + initial_boot_messages + [ - {"role": "user", "content": first_user_message}, - ] + initial_boot_messages = get_initial_boot_messages("startup_with_send_message") + messages = ( + [ + {"role": "system", "content": full_system_message}, + ] + + initial_boot_messages + + [ + {"role": "user", "content": first_user_message}, + ] + ) else: messages = [ @@ -85,11 +95,11 @@ def initialize_message_sequence( async def get_ai_reply_async( - model, - message_sequence, - functions, - function_call="auto", - ): + model, + message_sequence, + functions, + function_call="auto", +): """Base call to GPT API w/ functions""" try: @@ -101,11 +111,11 @@ async def get_ai_reply_async( ) # special case for 'length' - if response.choices[0].finish_reason == 'length': - raise Exception('Finish reason was length (maximum context length)') + if response.choices[0].finish_reason == "length": + raise Exception("Finish reason was length (maximum context length)") # catches for soft errors - if response.choices[0].finish_reason not in ['stop', 'function_call']: + if response.choices[0].finish_reason not in ["stop", "function_call"]: raise Exception(f"API call finish with bad finish reason: {response}") # unpack with response.choices[0].message.content @@ -118,7 +128,19 @@ async def get_ai_reply_async( class AgentAsync(object): """Core logic for a MemGPT agent""" - def __init__(self, model, system, functions, interface, persistence_manager, persona_notes, human_notes, messages_total=None, persistence_manager_init=True, first_message_verify_mono=True): + def __init__( + self, + model, + system, + functions, + interface, + persistence_manager, + persona_notes, + human_notes, + messages_total=None, + persistence_manager_init=True, + first_message_verify_mono=True, + ): # gpt-4, gpt-3.5-turbo self.model = model # Store the system instructions (used to rebuild memory) @@ -173,7 +195,7 @@ class AgentAsync(object): @messages.setter def messages(self, value): - raise Exception('Modifying message list directly not allowed') + raise Exception("Modifying message list directly not allowed") def trim_messages(self, num): """Trim messages from the front, not including the system message""" @@ -196,16 +218,16 @@ class AgentAsync(object): # strip extra metadata if it exists for msg in added_messages: - msg.pop('api_response', None) - msg.pop('api_args', None) + msg.pop("api_response", None) + msg.pop("api_args", None) new_messages = self.messages + added_messages # append self._messages = new_messages self.messages_total += len(added_messages) def swap_system_message(self, new_system_message): - assert new_system_message['role'] == 'system', new_system_message - assert self.messages[0]['role'] == 'system', self.messages + assert new_system_message["role"] == "system", new_system_message + assert self.messages[0]["role"] == "system", self.messages self.persistence_manager.swap_system_message(new_system_message) @@ -223,7 +245,7 @@ class AgentAsync(object): recall_memory=self.persistence_manager.recall_memory, )[0] - diff = united_diff(curr_system_message['content'], new_system_message['content']) + diff = united_diff(curr_system_message["content"], new_system_message["content"]) printd(f"Rebuilding system with new memory...\nDiff:\n{diff}") # Store the memory change (if stateful) @@ -235,32 +257,32 @@ class AgentAsync(object): ### Local state management def to_dict(self): return { - 'model': self.model, - 'system': self.system, - 'functions': self.functions, - 'messages': self.messages, - 'messages_total': self.messages_total, - 'memory': self.memory.to_dict(), + "model": self.model, + "system": self.system, + "functions": self.functions, + "messages": self.messages, + "messages_total": self.messages_total, + "memory": self.memory.to_dict(), } def save_to_json_file(self, filename): - with open(filename, 'w') as file: + with open(filename, "w") as file: json.dump(self.to_dict(), file) @classmethod def load(cls, state, interface, persistence_manager): - model = state['model'] - system = state['system'] - functions = state['functions'] - messages = state['messages'] + model = state["model"] + system = state["system"] + functions = state["functions"] + messages = state["messages"] try: - messages_total = state['messages_total'] + messages_total = state["messages_total"] except KeyError: messages_total = len(messages) - 1 # memory requires a nested load - memory_dict = state['memory'] - persona_notes = memory_dict['persona'] - human_notes = memory_dict['human'] + memory_dict = state["memory"] + persona_notes = memory_dict["persona"] + human_notes = memory_dict["human"] # Two-part load new_agent = cls( @@ -278,18 +300,18 @@ class AgentAsync(object): return new_agent def load_inplace(self, state): - self.model = state['model'] - self.system = state['system'] - self.functions = state['functions'] + self.model = state["model"] + self.system = state["system"] + self.functions = state["functions"] # memory requires a nested load - memory_dict = state['memory'] - persona_notes = memory_dict['persona'] - human_notes = memory_dict['human'] + memory_dict = state["memory"] + persona_notes = memory_dict["persona"] + human_notes = memory_dict["human"] self.memory = initialize_memory(persona_notes, human_notes) # messages also - self._messages = state['messages'] + self._messages = state["messages"] try: - self.messages_total = state['messages_total'] + self.messages_total = state["messages_total"] except KeyError: self.messages_total = len(self.messages) - 1 # -system @@ -300,14 +322,14 @@ class AgentAsync(object): @classmethod def load_from_json_file(cls, json_file, interface, persistence_manager): - with open(json_file, 'r') as file: + with open(json_file, "r") as file: state = json.load(file) return cls.load(state, interface, persistence_manager) def load_from_json_file_inplace(self, json_file): # Load in-place # No interface arg needed, we can use the current one - with open(json_file, 'r') as file: + with open(json_file, "r") as file: state = json.load(file) self.load_inplace(state) @@ -317,7 +339,6 @@ class AgentAsync(object): # Step 2: check if LLM wanted to call a function if response_message.get("function_call"): - # The content if then internal monologue, not chat await self.interface.internal_monologue(response_message.content) messages.append(response_message) # extend conversation with assistant's reply @@ -348,7 +369,7 @@ class AgentAsync(object): try: function_to_call = available_functions[function_name] except KeyError as e: - error_msg = f'No function named {function_name}' + error_msg = f"No function named {function_name}" function_response = package_function_response(False, error_msg) messages.append( { @@ -357,7 +378,7 @@ class AgentAsync(object): "content": function_response, } ) # extend conversation with function response - await self.interface.function_message(f'Error: {error_msg}') + await self.interface.function_message(f"Error: {error_msg}") return messages, None, True # force a heartbeat to allow agent to handle error # Failure case 2: function name is OK, but function args are bad JSON @@ -374,18 +395,20 @@ class AgentAsync(object): "content": function_response, } ) # extend conversation with function response - await self.interface.function_message(f'Error: {error_msg}') + await self.interface.function_message(f"Error: {error_msg}") return messages, None, True # force a heartbeat to allow agent to handle error # (Still parsing function args) # Handle requests for immediate heartbeat - heartbeat_request = function_args.pop('request_heartbeat', None) + heartbeat_request = function_args.pop("request_heartbeat", None) if not (isinstance(heartbeat_request, bool) or heartbeat_request is None): - printd(f"Warning: 'request_heartbeat' arg parsed was not a bool or None, type={type(heartbeat_request)}, value={heartbeat_request}") + printd( + f"Warning: 'request_heartbeat' arg parsed was not a bool or None, type={type(heartbeat_request)}, value={heartbeat_request}" + ) heartbeat_request = None # Failure case 3: function failed during execution - await self.interface.function_message(f'Running {function_name}({function_args})') + await self.interface.function_message(f"Running {function_name}({function_args})") try: function_response_string = await function_to_call(**function_args) function_response = package_function_response(True, function_response_string) @@ -401,12 +424,12 @@ class AgentAsync(object): "content": function_response, } ) # extend conversation with function response - await self.interface.function_message(f'Error: {error_msg}') + await self.interface.function_message(f"Error: {error_msg}") return messages, None, True # force a heartbeat to allow agent to handle error # If no failures happened along the way: ... # Step 4: send the info on the function call and function response to GPT - await self.interface.function_message(f'Success: {function_response_string}') + await self.interface.function_message(f"Success: {function_response_string}") messages.append( { "role": "function", @@ -434,25 +457,29 @@ class AgentAsync(object): return False function_name = response_message["function_call"]["name"] - if require_send_message and function_name != 'send_message': + if require_send_message and function_name != "send_message": printd(f"First message function call wasn't send_message: {response_message}") return False - if require_monologue and (not response_message.get("content") or response_message["content"] is None or response_message["content"] == ""): + if require_monologue and ( + not response_message.get("content") or response_message["content"] is None or response_message["content"] == "" + ): printd(f"First message missing internal monologue: {response_message}") return False if response_message.get("content"): ### Extras monologue = response_message.get("content") + def contains_special_characters(s): special_characters = '(){}[]"' return any(char in s for char in special_characters) + if contains_special_characters(monologue): printd(f"First message internal monologue contained special characters: {response_message}") return False # if 'functions' in monologue or 'send_message' in monologue or 'inner thought' in monologue.lower(): - if 'functions' in monologue or 'send_message' in monologue: + if "functions" in monologue or "send_message" in monologue: # Sometimes the syntax won't be correct and internal syntax will leak into message.context printd(f"First message internal monologue contained reserved words: {response_message}") return False @@ -466,12 +493,12 @@ class AgentAsync(object): # Step 0: add user message if user_message is not None: await self.interface.user_message(user_message) - packed_user_message = {'role': 'user', 'content': user_message} + packed_user_message = {"role": "user", "content": user_message} input_message_sequence = self.messages + [packed_user_message] else: input_message_sequence = self.messages - if len(input_message_sequence) > 1 and input_message_sequence[-1]['role'] != 'user': + if len(input_message_sequence) > 1 and input_message_sequence[-1]["role"] != "user": printd(f"WARNING: attempting to run ChatCompletion without user as the last message in the queue") # Step 1: send the conversation and available functions to GPT @@ -479,14 +506,13 @@ class AgentAsync(object): printd(f"This is the first message. Running extra verifier on AI response.") counter = 0 while True: - response = await get_ai_reply_async(model=self.model, message_sequence=input_message_sequence, functions=self.functions) if self.verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono): break counter += 1 if counter > first_message_retry_limit: - raise Exception(f'Hit first message retry limit ({first_message_retry_limit})') + raise Exception(f"Hit first message retry limit ({first_message_retry_limit})") else: response = await get_ai_reply_async(model=self.model, message_sequence=input_message_sequence, functions=self.functions) @@ -500,13 +526,13 @@ class AgentAsync(object): # Add the extra metadata to the assistant response # (e.g. enough metadata to enable recreating the API call) - assert 'api_response' not in all_response_messages[0] - all_response_messages[0]['api_response'] = response_message_copy - assert 'api_args' not in all_response_messages[0] - all_response_messages[0]['api_args'] = { - 'model': self.model, - 'messages': input_message_sequence, - 'functions': self.functions, + assert "api_response" not in all_response_messages[0] + all_response_messages[0]["api_response"] = response_message_copy + assert "api_args" not in all_response_messages[0] + all_response_messages[0]["api_args"] = { + "model": self.model, + "messages": input_message_sequence, + "functions": self.functions, } # Step 4: extend the message history @@ -516,7 +542,7 @@ class AgentAsync(object): all_new_messages = all_response_messages # Check the memory pressure and potentially issue a memory pressure warning - current_total_tokens = response['usage']['total_tokens'] + current_total_tokens = response["usage"]["total_tokens"] active_memory_warning = False if current_total_tokens > MESSAGE_SUMMARY_WARNING_TOKENS: printd(f"WARNING: last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_TOKENS}") @@ -534,7 +560,7 @@ class AgentAsync(object): printd(f"step() failed\nuser_message = {user_message}\nerror = {e}") # If we got a context alert, try trimming the messages length, then try again - if 'maximum context length' in str(e): + if "maximum context length" in str(e): # A separate API call to run a summarizer await self.summarize_messages_inplace() @@ -546,21 +572,21 @@ class AgentAsync(object): async def summarize_messages_inplace(self, cutoff=None): if cutoff is None: - tokens_so_far = 0 # Smart cutoff -- just below the max. + tokens_so_far = 0 # Smart cutoff -- just below the max. cutoff = len(self.messages) - 1 for m in reversed(self.messages): tokens_so_far += count_tokens(str(m), self.model) - if tokens_so_far >= MESSAGE_SUMMARY_WARNING_TOKENS*0.2: + if tokens_so_far >= MESSAGE_SUMMARY_WARNING_TOKENS * 0.2: break cutoff -= 1 - cutoff = min(len(self.messages) - 3, cutoff) # Always keep the last two messages too + cutoff = min(len(self.messages) - 3, cutoff) # Always keep the last two messages too # Try to make an assistant message come after the cutoff try: printd(f"Selected cutoff {cutoff} was a 'user', shifting one...") - if self.messages[cutoff]['role'] == 'user': + if self.messages[cutoff]["role"] == "user": new_cutoff = cutoff + 1 - if self.messages[new_cutoff]['role'] == 'user': + if self.messages[new_cutoff]["role"] == "user": printd(f"Shifted cutoff {new_cutoff} is still a 'user', ignoring...") cutoff = new_cutoff except IndexError: @@ -600,11 +626,11 @@ class AgentAsync(object): while limit is None or step_count < limit: if function_failed: - user_message = get_heartbeat('Function call failed') + user_message = get_heartbeat("Function call failed") new_messages, heartbeat_request, function_failed = await self.step(user_message) step_count += 1 elif heartbeat_request: - user_message = get_heartbeat('AI requested') + user_message = get_heartbeat("AI requested") new_messages, heartbeat_request, function_failed = await self.step(user_message) step_count += 1 else: @@ -638,7 +664,7 @@ class AgentAsync(object): return None async def recall_memory_search(self, query, count=5, page=0): - results, total = await self.persistence_manager.recall_memory.text_search(query, count=count, start=page*count) + results, total = await self.persistence_manager.recall_memory.text_search(query, count=count, start=page * count) num_pages = math.ceil(total / count) - 1 # 0 index if len(results) == 0: results_str = f"No results found." @@ -649,7 +675,7 @@ class AgentAsync(object): return results_str async def recall_memory_search_date(self, start_date, end_date, count=5, page=0): - results, total = await self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page*count) + results, total = await self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page * count) num_pages = math.ceil(total / count) - 1 # 0 index if len(results) == 0: results_str = f"No results found." @@ -664,7 +690,7 @@ class AgentAsync(object): return None async def archival_memory_search(self, query, count=5, page=0): - results, total = await self.persistence_manager.archival_memory.search(query, count=count, start=page*count) + results, total = await self.persistence_manager.archival_memory.search(query, count=count, start=page * count) num_pages = math.ceil(total / count) - 1 # 0 index if len(results) == 0: results_str = f"No results found." @@ -683,7 +709,7 @@ class AgentAsync(object): # And record how long the pause should go for self.pause_heartbeats_minutes = int(minutes) - return f'Pausing timed heartbeats for {minutes} min' + return f"Pausing timed heartbeats for {minutes} min" def heartbeat_is_paused(self): """Check if there's a requested pause on timed heartbeats""" @@ -700,8 +726,8 @@ class AgentAsync(object): """Base call to GPT API w/ functions""" message_sequence = [ - {'role': 'system', 'content': MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE}, - {'role': 'user', 'content': str(message)}, + {"role": "system", "content": MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE}, + {"role": "user", "content": str(message)}, ] response = await acreate( model=MESSAGE_CHATGPT_FUNCTION_MODEL, diff --git a/memgpt/agent_base.py b/memgpt/agent_base.py index 06442c92..7f132e49 100644 --- a/memgpt/agent_base.py +++ b/memgpt/agent_base.py @@ -2,7 +2,6 @@ from abc import ABC, abstractmethod class AgentAsyncBase(ABC): - @abstractmethod async def step(self, user_message): - pass \ No newline at end of file + pass diff --git a/memgpt/autogen/interface.py b/memgpt/autogen/interface.py index 4f01fd7a..f3776790 100644 --- a/memgpt/autogen/interface.py +++ b/memgpt/autogen/interface.py @@ -68,41 +68,25 @@ class AutoGenInterface(object): print(f"inner thoughts :: {msg}") if not self.show_inner_thoughts: return - message = ( - f"\x1B[3m{Fore.LIGHTBLACK_EX}💭 {msg}{Style.RESET_ALL}" - if self.fancy - else f"[inner thoughts] {msg}" - ) + message = f"\x1B[3m{Fore.LIGHTBLACK_EX}💭 {msg}{Style.RESET_ALL}" if self.fancy else f"[inner thoughts] {msg}" self.message_list.append(message) async def assistant_message(self, msg): if self.debug: print(f"assistant :: {msg}") - message = ( - f"{Fore.YELLOW}{Style.BRIGHT}🤖 {Fore.YELLOW}{msg}{Style.RESET_ALL}" - if self.fancy - else msg - ) + message = f"{Fore.YELLOW}{Style.BRIGHT}🤖 {Fore.YELLOW}{msg}{Style.RESET_ALL}" if self.fancy else msg self.message_list.append(message) async def memory_message(self, msg): if self.debug: print(f"memory :: {msg}") - message = ( - f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{msg}{Style.RESET_ALL}" - if self.fancy - else f"[memory] {msg}" - ) + message = f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{msg}{Style.RESET_ALL}" if self.fancy else f"[memory] {msg}" self.message_list.append(message) async def system_message(self, msg): if self.debug: print(f"system :: {msg}") - message = ( - f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}" - if self.fancy - else f"[system] {msg}" - ) + message = f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}" if self.fancy else f"[system] {msg}" self.message_list.append(message) async def user_message(self, msg, raw=False): @@ -113,11 +97,7 @@ class AutoGenInterface(object): if isinstance(msg, str): if raw: - message = ( - f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}" - if self.fancy - else f"[user] {msg}" - ) + message = f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}" if self.fancy else f"[user] {msg}" self.message_list.append(message) return else: @@ -125,42 +105,24 @@ class AutoGenInterface(object): msg_json = json.loads(msg) except: print(f"Warning: failed to parse user message into json") - message = ( - f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}" - if self.fancy - else f"[user] {msg}" - ) + message = f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}" if self.fancy else f"[user] {msg}" self.message_list.append(message) return if msg_json["type"] == "user_message": msg_json.pop("type") - message = ( - f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}" - if self.fancy - else f"[user] {msg}" - ) + message = f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}" if self.fancy else f"[user] {msg}" elif msg_json["type"] == "heartbeat": if True or DEBUG: msg_json.pop("type") message = ( - f"{Fore.GREEN}{Style.BRIGHT}💓 {Fore.GREEN}{msg_json}{Style.RESET_ALL}" - if self.fancy - else f"[system heartbeat] {msg}" + f"{Fore.GREEN}{Style.BRIGHT}💓 {Fore.GREEN}{msg_json}{Style.RESET_ALL}" if self.fancy else f"[system heartbeat] {msg}" ) elif msg_json["type"] == "system_message": msg_json.pop("type") - message = ( - f"{Fore.GREEN}{Style.BRIGHT}🖥️ {Fore.GREEN}{msg_json}{Style.RESET_ALL}" - if self.fancy - else f"[system] {msg}" - ) + message = f"{Fore.GREEN}{Style.BRIGHT}🖥️ {Fore.GREEN}{msg_json}{Style.RESET_ALL}" if self.fancy else f"[system] {msg}" else: - message = ( - f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}" - if self.fancy - else f"[user] {msg}" - ) + message = f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}" if self.fancy else f"[user] {msg}" self.message_list.append(message) @@ -171,31 +133,19 @@ class AutoGenInterface(object): return if isinstance(msg, dict): - message = ( - f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" - ) + message = f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" self.message_list.append(message) return if msg.startswith("Success: "): - message = ( - f"{Fore.RED}{Style.BRIGHT}⚡🟢 [function] {Fore.RED}{msg}{Style.RESET_ALL}" - if self.fancy - else f"[function - OK] {msg}" - ) + message = f"{Fore.RED}{Style.BRIGHT}⚡🟢 [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function - OK] {msg}" elif msg.startswith("Error: "): message = ( - f"{Fore.RED}{Style.BRIGHT}⚡🔴 [function] {Fore.RED}{msg}{Style.RESET_ALL}" - if self.fancy - else f"[function - error] {msg}" + f"{Fore.RED}{Style.BRIGHT}⚡🔴 [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function - error] {msg}" ) elif msg.startswith("Running "): if DEBUG: - message = ( - f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" - if self.fancy - else f"[function] {msg}" - ) + message = f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}" else: if "memory" in msg: match = re.search(r"Running (\w+)\((.*)\)", msg) @@ -227,35 +177,25 @@ class AutoGenInterface(object): else: print(f"Warning: did not recognize function message") message = ( - f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" - if self.fancy - else f"[function] {msg}" + f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}" ) elif "send_message" in msg: # ignore in debug mode message = None else: message = ( - f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" - if self.fancy - else f"[function] {msg}" + f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}" ) else: try: msg_dict = json.loads(msg) if "status" in msg_dict and msg_dict["status"] == "OK": message = ( - f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}" - if self.fancy - else f"[function] {msg}" + f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}" ) except Exception: print(f"Warning: did not recognize function message {type(msg)} {msg}") - message = ( - f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" - if self.fancy - else f"[function] {msg}" - ) + message = f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}" if message: self.message_list.append(message) diff --git a/memgpt/autogen/memgpt_agent.py b/memgpt/autogen/memgpt_agent.py index 91adf5d8..6e9db5a2 100644 --- a/memgpt/autogen/memgpt_agent.py +++ b/memgpt/autogen/memgpt_agent.py @@ -55,11 +55,7 @@ def create_autogen_memgpt_agent( ``` """ interface = AutoGenInterface(**interface_kwargs) if interface is None else interface - persistence_manager = ( - InMemoryStateManager(**persistence_manager_kwargs) - if persistence_manager is None - else persistence_manager - ) + persistence_manager = InMemoryStateManager(**persistence_manager_kwargs) if persistence_manager is None else persistence_manager memgpt_agent = presets.use_preset( preset, @@ -89,9 +85,7 @@ class MemGPTAgent(ConversableAgent): self.agent = agent self.skip_verify = skip_verify self.concat_other_agent_messages = concat_other_agent_messages - self.register_reply( - [Agent, None], MemGPTAgent._a_generate_reply_for_user_message - ) + self.register_reply([Agent, None], MemGPTAgent._a_generate_reply_for_user_message) self.register_reply([Agent, None], MemGPTAgent._generate_reply_for_user_message) self.messages_processed_up_to_idx = 0 @@ -119,11 +113,7 @@ class MemGPTAgent(ConversableAgent): sender: Optional[Agent] = None, config: Optional[Any] = None, ) -> Tuple[bool, Union[str, Dict, None]]: - return asyncio.run( - self._a_generate_reply_for_user_message( - messages=messages, sender=sender, config=config - ) - ) + return asyncio.run(self._a_generate_reply_for_user_message(messages=messages, sender=sender, config=config)) async def _a_generate_reply_for_user_message( self, @@ -137,9 +127,7 @@ class MemGPTAgent(ConversableAgent): if len(new_messages) > 1: if self.concat_other_agent_messages: # Combine all the other messages into one message - user_message = "\n".join( - [self.format_other_agent_message(m) for m in new_messages] - ) + user_message = "\n".join([self.format_other_agent_message(m) for m in new_messages]) else: # Extend the MemGPT message list with multiple 'user' messages, then push the last one with agent.step() self.agent.messages.extend(new_messages[:-1]) @@ -157,16 +145,12 @@ class MemGPTAgent(ConversableAgent): heartbeat_request, function_failed, token_warning, - ) = await self.agent.step( - user_message, first_message=False, skip_verify=self.skip_verify - ) + ) = await self.agent.step(user_message, first_message=False, skip_verify=self.skip_verify) # Skip user inputs if there's a memory warning, function execution failed, or the agent asked for control if token_warning: user_message = system.get_token_limit_warning() elif function_failed: - user_message = system.get_heartbeat( - constants.FUNC_FAILED_HEARTBEAT_MESSAGE - ) + user_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE) elif heartbeat_request: user_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE) else: diff --git a/memgpt/config.py b/memgpt/config.py index d22ee281..d9a8aa93 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -24,6 +24,7 @@ model_choices = [ ), ] + class Config: personas_dir = os.path.join("memgpt", "personas", "examples") custom_personas_dir = os.path.join(MEMGPT_DIR, "personas") @@ -78,12 +79,8 @@ class Config: cfg = Config.get_most_recent_config() use_cfg = False if cfg: - print( - f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Found saved config file.{Style.RESET_ALL}" - ) - use_cfg = await questionary.confirm( - f"Use most recent config file '{cfg}'?" - ).ask_async() + print(f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Found saved config file.{Style.RESET_ALL}") + use_cfg = await questionary.confirm(f"Use most recent config file '{cfg}'?").ask_async() if use_cfg: self.config_file = cfg @@ -104,9 +101,7 @@ class Config: return self # print("No settings file found, configuring MemGPT...") - print( - f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ No settings file found, configuring MemGPT...{Style.RESET_ALL}" - ) + print(f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ No settings file found, configuring MemGPT...{Style.RESET_ALL}") self.model = await questionary.select( "Which model would you like to use?", @@ -126,9 +121,7 @@ class Config: ).ask_async() self.archival_storage_index = None - self.preload_archival = await questionary.confirm( - "Would you like to preload anything into MemGPT's archival memory?" - ).ask_async() + self.preload_archival = await questionary.confirm("Would you like to preload anything into MemGPT's archival memory?").ask_async() if self.preload_archival: self.load_type = await questionary.select( "What would you like to load?", @@ -139,19 +132,13 @@ class Config: ], ).ask_async() if self.load_type == "folder" or self.load_type == "sql": - archival_storage_path = await questionary.path( - "Please enter the folder or file (tab for autocomplete):" - ).ask_async() + archival_storage_path = await questionary.path("Please enter the folder or file (tab for autocomplete):").ask_async() if os.path.isdir(archival_storage_path): - self.archival_storage_files = os.path.join( - archival_storage_path, "*" - ) + self.archival_storage_files = os.path.join(archival_storage_path, "*") else: self.archival_storage_files = archival_storage_path else: - self.archival_storage_files = await questionary.path( - "Please enter the glob pattern (tab for autocomplete):" - ).ask_async() + self.archival_storage_files = await questionary.path("Please enter the glob pattern (tab for autocomplete):").ask_async() self.compute_embeddings = await questionary.confirm( "Would you like to compute embeddings over these files to enable embeddings search?" ).ask_async() @@ -167,19 +154,11 @@ class Config: "⛔️ Embeddings on a non-OpenAI endpoint are not yet supported, falling back to substring matching search." ) else: - self.archival_storage_index = ( - await utils.prepare_archival_index_from_files_compute_embeddings( - self.archival_storage_files - ) - ) + self.archival_storage_index = await utils.prepare_archival_index_from_files_compute_embeddings(self.archival_storage_files) if self.compute_embeddings and self.archival_storage_index: - self.index, self.archival_database = utils.prepare_archival_index( - self.archival_storage_index - ) + self.index, self.archival_database = utils.prepare_archival_index(self.archival_storage_index) else: - self.archival_database = utils.prepare_archival_index_from_files( - self.archival_storage_files - ) + self.archival_database = utils.prepare_archival_index_from_files(self.archival_storage_files) def to_dict(self): return { @@ -216,15 +195,11 @@ class Config: configs_dir = Config.configs_dir os.makedirs(configs_dir, exist_ok=True) if self.config_file is None: - filename = os.path.join( - configs_dir, utils.get_local_time().replace(" ", "_").replace(":", "_") - ) + filename = os.path.join(configs_dir, utils.get_local_time().replace(" ", "_").replace(":", "_")) self.config_file = f"{filename}.json" with open(self.config_file, "wt") as f: json.dump(self.to_dict(), f, indent=4) - print( - f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Saved config file to {self.config_file}.{Style.RESET_ALL}" - ) + print(f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Saved config file to {self.config_file}.{Style.RESET_ALL}") @staticmethod def is_valid_config_file(file: str): @@ -233,9 +208,7 @@ class Config: cfg.load_config(file) except Exception: return False - return ( - cfg.memgpt_persona is not None and cfg.human_persona is not None - ) # TODO: more validation for configs + return cfg.memgpt_persona is not None and cfg.human_persona is not None # TODO: more validation for configs @staticmethod def get_memgpt_personas(): @@ -330,8 +303,7 @@ class Config: files = [ os.path.join(configs_dir, f) for f in os.listdir(configs_dir) - if os.path.isfile(os.path.join(configs_dir, f)) - and Config.is_valid_config_file(os.path.join(configs_dir, f)) + if os.path.isfile(os.path.join(configs_dir, f)) and Config.is_valid_config_file(os.path.join(configs_dir, f)) ] # Return the file with the most recent modification time if len(files) == 0: diff --git a/memgpt/connectors/connector.py b/memgpt/connectors/connector.py index 549c2d7d..4b4c399a 100644 --- a/memgpt/connectors/connector.py +++ b/memgpt/connectors/connector.py @@ -1,7 +1,7 @@ -""" +""" This file contains functions for loading data into MemGPT's archival storage. -Data can be loaded with the following command, once a load function is defined: +Data can be loaded with the following command, once a load function is defined: ``` memgpt load --name [ADDITIONAL ARGS] ``` @@ -18,14 +18,13 @@ from memgpt.utils import estimate_openai_cost, get_index, save_index app = typer.Typer() - @app.command("directory") def load_directory( name: str = typer.Option(help="Name of dataset to load."), input_dir: str = typer.Option(None, help="Path to directory containing dataset."), input_files: List[str] = typer.Option(None, help="List of paths to files containing dataset."), recursive: bool = typer.Option(False, help="Recursively search for files in directory."), -): +): from llama_index import SimpleDirectoryReader if recursive: @@ -35,34 +34,35 @@ def load_directory( recursive=True, ) else: - reader = SimpleDirectoryReader( - input_files=input_files - ) + reader = SimpleDirectoryReader(input_files=input_files) # load docs print("Loading data...") docs = reader.load_data() - # embed docs + # embed docs print("Indexing documents...") index = get_index(name, docs) # save connector information into .memgpt metadata file save_index(index, name) + @app.command("webpage") def load_webpage( name: str = typer.Option(help="Name of dataset to load."), urls: List[str] = typer.Option(None, help="List of urls to load."), -): +): from llama_index import SimpleWebPageReader + docs = SimpleWebPageReader(html_to_text=True).load_data(urls) - # embed docs + # embed docs print("Indexing documents...") index = get_index(docs) # save connector information into .memgpt metadata file save_index(index, name) + @app.command("database") def load_database( name: str = typer.Option(help="Name of dataset to load."), @@ -76,12 +76,14 @@ def load_database( dbname: str = typer.Option(None, help="Database name."), ): from llama_index.readers.database import DatabaseReader + print(dump_path, scheme) - if dump_path is not None: + if dump_path is not None: # read from database dump file from sqlalchemy import create_engine, MetaData - engine = create_engine(f'sqlite:///{dump_path}') + + engine = create_engine(f"sqlite:///{dump_path}") db = DatabaseReader(engine=engine) else: @@ -104,8 +106,6 @@ def load_database( # load data docs = db.load_data(query=query) - + index = get_index(name, docs) save_index(index, name) - - diff --git a/memgpt/constants.py b/memgpt/constants.py index bd83f7fc..aae904c4 100644 --- a/memgpt/constants.py +++ b/memgpt/constants.py @@ -7,9 +7,7 @@ DEFAULT_MEMGPT_MODEL = "gpt-4" FIRST_MESSAGE_ATTEMPTS = 10 INITIAL_BOOT_MESSAGE = "Boot sequence complete. Persona activated." -INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT = ( - "Bootup sequence complete. Persona activated. Testing messaging functionality." -) +INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT = "Bootup sequence complete. Persona activated. Testing messaging functionality." STARTUP_QUOTES = [ "I think, therefore I am.", "All those moments will be lost in time, like tears in rain.", @@ -28,9 +26,7 @@ CORE_MEMORY_HUMAN_CHAR_LIMIT = 2000 MAX_PAUSE_HEARTBEATS = 360 # in min MESSAGE_CHATGPT_FUNCTION_MODEL = "gpt-3.5-turbo" -MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE = ( - "You are a helpful assistant. Keep your responses short and concise." -) +MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE = "You are a helpful assistant. Keep your responses short and concise." #### Functions related diff --git a/memgpt/interface.py b/memgpt/interface.py index 0e66af08..b9b95be6 100644 --- a/memgpt/interface.py +++ b/memgpt/interface.py @@ -29,15 +29,11 @@ async def assistant_message(msg): async def memory_message(msg): - print( - f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{msg}{Style.RESET_ALL}" - ) + print(f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{msg}{Style.RESET_ALL}") async def system_message(msg): - printd( - f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}" - ) + printd(f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}") async def user_message(msg, raw=False): @@ -50,9 +46,7 @@ async def user_message(msg, raw=False): msg_json = json.loads(msg) except: printd(f"Warning: failed to parse user message into json") - printd( - f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}" - ) + printd(f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}") return if msg_json["type"] == "user_message": @@ -61,9 +55,7 @@ async def user_message(msg, raw=False): elif msg_json["type"] == "heartbeat": if DEBUG: msg_json.pop("type") - printd( - f"{Fore.GREEN}{Style.BRIGHT}💓 {Fore.GREEN}{msg_json}{Style.RESET_ALL}" - ) + printd(f"{Fore.GREEN}{Style.BRIGHT}💓 {Fore.GREEN}{msg_json}{Style.RESET_ALL}") elif msg_json["type"] == "system_message": msg_json.pop("type") printd(f"{Fore.GREEN}{Style.BRIGHT}🖥️ {Fore.GREEN}{msg_json}{Style.RESET_ALL}") @@ -77,33 +69,23 @@ async def function_message(msg): return if msg.startswith("Success: "): - printd( - f"{Fore.RED}{Style.BRIGHT}⚡🟢 [function] {Fore.RED}{msg}{Style.RESET_ALL}" - ) + printd(f"{Fore.RED}{Style.BRIGHT}⚡🟢 [function] {Fore.RED}{msg}{Style.RESET_ALL}") elif msg.startswith("Error: "): - printd( - f"{Fore.RED}{Style.BRIGHT}⚡🔴 [function] {Fore.RED}{msg}{Style.RESET_ALL}" - ) + printd(f"{Fore.RED}{Style.BRIGHT}⚡🔴 [function] {Fore.RED}{msg}{Style.RESET_ALL}") elif msg.startswith("Running "): if DEBUG: - printd( - f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" - ) + printd(f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}") else: if "memory" in msg: match = re.search(r"Running (\w+)\((.*)\)", msg) if match: function_name = match.group(1) function_args = match.group(2) - print( - f"{Fore.RED}{Style.BRIGHT}⚡🧠 [function] {Fore.RED}updating memory with {function_name}{Style.RESET_ALL}:" - ) + print(f"{Fore.RED}{Style.BRIGHT}⚡🧠 [function] {Fore.RED}updating memory with {function_name}{Style.RESET_ALL}:") try: msg_dict = eval(function_args) if function_name == "archival_memory_search": - print( - f'{Fore.RED}\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}' - ) + print(f'{Fore.RED}\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}') else: print( f'{Fore.RED}{Style.BRIGHT}\t{Fore.RED} {msg_dict["old_content"]}\n\t{Fore.GREEN}→ {msg_dict["new_content"]}' @@ -114,28 +96,20 @@ async def function_message(msg): pass else: printd(f"Warning: did not recognize function message") - printd( - f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" - ) + printd(f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}") elif "send_message" in msg: # ignore in debug mode pass else: - printd( - f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" - ) + printd(f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}") else: try: msg_dict = json.loads(msg) if "status" in msg_dict and msg_dict["status"] == "OK": - printd( - f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}" - ) + printd(f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}") except Exception: printd(f"Warning: did not recognize function message {type(msg)} {msg}") - printd( - f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" - ) + printd(f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}") async def print_messages(message_sequence): diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py index 60f8ee6b..0b2100fd 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py @@ -190,9 +190,7 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper): function_parameters = function_json_output["params"] if self.clean_func_args: - function_name, function_parameters = self.clean_function_args( - function_name, function_parameters - ) + function_name, function_parameters = self.clean_function_args(function_name, function_parameters) message = { "role": "assistant", @@ -275,9 +273,7 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper): func_str += f"\n description: {schema['description']}" func_str += f"\n params:" if add_inner_thoughts: - func_str += ( - f"\n inner_thoughts: Deep inner monologue private to you only." - ) + func_str += f"\n inner_thoughts: Deep inner monologue private to you only." for param_k, param_v in schema["parameters"]["properties"].items(): # TODO we're ignoring type func_str += f"\n {param_k}: {param_v['description']}" diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/dolphin.py b/memgpt/local_llm/llm_chat_completion_wrappers/dolphin.py index 0ce5d4b1..40d98579 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/dolphin.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/dolphin.py @@ -152,9 +152,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper): try: content_json = json.loads(message["content"]) content_simple = content_json["message"] - prompt += ( - f"\n{IM_START_TOKEN}user\n{content_simple}{IM_END_TOKEN}" - ) + prompt += f"\n{IM_START_TOKEN}user\n{content_simple}{IM_END_TOKEN}" # prompt += f"\nUSER: {content_simple}" except: prompt += f"\n{IM_START_TOKEN}user\n{message['content']}{IM_END_TOKEN}" @@ -227,9 +225,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper): function_parameters = function_json_output["params"] if self.clean_func_args: - function_name, function_parameters = self.clean_function_args( - function_name, function_parameters - ) + function_name, function_parameters = self.clean_function_args(function_name, function_parameters) message = { "role": "assistant", diff --git a/memgpt/main.py b/memgpt/main.py index 93df2329..3df4ee0e 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -84,12 +84,8 @@ def load(memgpt_agent, filename): print(f"Loading {filename} failed with: {e}") else: # Load the latest file - print( - f"/load warning: no checkpoint specified, loading most recent checkpoint instead" - ) - json_files = glob.glob( - "saved_state/*.json" - ) # This will list all .json files in the current directory. + print(f"/load warning: no checkpoint specified, loading most recent checkpoint instead") + json_files = glob.glob("saved_state/*.json") # This will list all .json files in the current directory. # Check if there are any json files. if not json_files: @@ -111,27 +107,17 @@ def load(memgpt_agent, filename): ) # TODO(fixme):for different types of persistence managers that require different load/save methods print(f"Loaded persistence manager from {filename}") except Exception as e: - print( - f"/load warning: loading persistence manager from {filename} failed with: {e}" - ) + print(f"/load warning: loading persistence manager from {filename} failed with: {e}") @app.command() def run( persona: str = typer.Option(None, help="Specify persona"), human: str = typer.Option(None, help="Specify human"), - model: str = typer.Option( - constants.DEFAULT_MEMGPT_MODEL, help="Specify the LLM model" - ), - first: bool = typer.Option( - False, "--first", help="Use --first to send the first message in the sequence" - ), - debug: bool = typer.Option( - False, "--debug", help="Use --debug to enable debugging output" - ), - no_verify: bool = typer.Option( - False, "--no_verify", help="Bypass message verification" - ), + model: str = typer.Option(constants.DEFAULT_MEMGPT_MODEL, help="Specify the LLM model"), + first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"), + debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"), + no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"), archival_storage_faiss_path: str = typer.Option( "", "--archival_storage_faiss_path", @@ -201,9 +187,7 @@ async def main( else: azure_vars = get_set_azure_env_vars() if len(azure_vars) > 0: - print( - f"Error: Environment variables {', '.join([x[0] for x in azure_vars])} should not be set if --use_azure_openai is False" - ) + print(f"Error: Environment variables {', '.join([x[0] for x in azure_vars])} should not be set if --use_azure_openai is False") return if any( @@ -296,23 +280,17 @@ async def main( else: cfg = await Config.config_init() - memgpt.interface.important_message( - "Running... [exit by typing '/exit', list available commands with '/help']" - ) + memgpt.interface.important_message("Running... [exit by typing '/exit', list available commands with '/help']") if cfg.model != constants.DEFAULT_MEMGPT_MODEL: memgpt.interface.warning_message( f"⛔️ Warning - you are running MemGPT with {cfg.model}, which is not officially supported (yet). Expect bugs!" ) if cfg.index: - persistence_manager = InMemoryStateManagerWithFaiss( - cfg.index, cfg.archival_database - ) + persistence_manager = InMemoryStateManagerWithFaiss(cfg.index, cfg.archival_database) elif cfg.archival_storage_files: print(f"Preloaded {len(cfg.archival_database)} chunks into archival memory.") - persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory( - cfg.archival_database - ) + persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(cfg.archival_database) else: persistence_manager = InMemoryStateManager() @@ -356,9 +334,7 @@ async def main( print(f"Database loaded into archival memory.") if cfg.agent_save_file: - load_save_file = await questionary.confirm( - f"Load in saved agent '{cfg.agent_save_file}'?" - ).ask_async() + load_save_file = await questionary.confirm(f"Load in saved agent '{cfg.agent_save_file}'?").ask_async() if load_save_file: load(memgpt_agent, cfg.agent_save_file) @@ -367,9 +343,7 @@ async def main( return if not USER_GOES_FIRST: - console.input( - "[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]" - ) + console.input("[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]") clear_line() print() @@ -405,9 +379,7 @@ async def main( break elif user_input.lower() == "/savechat": - filename = ( - utils.get_local_time().replace(" ", "_").replace(":", "_") - ) + filename = utils.get_local_time().replace(" ", "_").replace(":", "_") filename = f"{filename}.pkl" directory = os.path.join(MEMGPT_DIR, "saved_chats") try: @@ -424,9 +396,7 @@ async def main( save(memgpt_agent=memgpt_agent, cfg=cfg) continue - elif user_input.lower() == "/load" or user_input.lower().startswith( - "/load " - ): + elif user_input.lower() == "/load" or user_input.lower().startswith("/load "): command = user_input.strip().split() filename = command[1] if len(command) > 1 else None load(memgpt_agent=memgpt_agent, filename=filename) @@ -459,16 +429,10 @@ async def main( print(f"Updated model to:\n{str(memgpt_agent.model)}") continue - elif user_input.lower() == "/pop" or user_input.lower().startswith( - "/pop " - ): + elif user_input.lower() == "/pop" or user_input.lower().startswith("/pop "): # Check if there's an additional argument that's an integer command = user_input.strip().split() - amount = ( - int(command[1]) - if len(command) > 1 and command[1].isdigit() - else 2 - ) + amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 2 print(f"Popping last {amount} messages from stack") for _ in range(min(amount, len(memgpt_agent.messages))): memgpt_agent.messages.pop() @@ -513,18 +477,14 @@ async def main( heartbeat_request, function_failed, token_warning, - ) = await memgpt_agent.step( - user_message, first_message=False, skip_verify=no_verify - ) + ) = await memgpt_agent.step(user_message, first_message=False, skip_verify=no_verify) # Skip user inputs if there's a memory warning, function execution failed, or the agent asked for control if token_warning: user_message = system.get_token_limit_warning() skip_next_user_input = True elif function_failed: - user_message = system.get_heartbeat( - constants.FUNC_FAILED_HEARTBEAT_MESSAGE - ) + user_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE) skip_next_user_input = True elif heartbeat_request: user_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE) diff --git a/memgpt/memory.py b/memgpt/memory.py index f4fa09ec..157255b7 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -40,20 +40,17 @@ class CoreMemory(object): self.archival_memory_exists = archival_memory_exists def __repr__(self) -> str: - return \ - f"\n### CORE MEMORY ###" + \ - f"\n=== Persona ===\n{self.persona}" + \ - f"\n\n=== Human ===\n{self.human}" + return f"\n### CORE MEMORY ###" + f"\n=== Persona ===\n{self.persona}" + f"\n\n=== Human ===\n{self.human}" def to_dict(self): return { - 'persona': self.persona, - 'human': self.human, + "persona": self.persona, + "human": self.human, } @classmethod def load(cls, state): - return cls(state['persona'], state['human']) + return cls(state["persona"], state["human"]) def edit_persona(self, new_persona): if self.persona_char_limit and len(new_persona) > self.persona_char_limit: @@ -76,53 +73,55 @@ class CoreMemory(object): return len(self.human) def edit(self, field, content): - if field == 'persona': + if field == "persona": return self.edit_persona(content) - elif field == 'human': + elif field == "human": return self.edit_human(content) else: raise KeyError - def edit_append(self, field, content, sep='\n'): - if field == 'persona': + def edit_append(self, field, content, sep="\n"): + if field == "persona": new_content = self.persona + sep + content return self.edit_persona(new_content) - elif field == 'human': + elif field == "human": new_content = self.human + sep + content return self.edit_human(new_content) else: raise KeyError def edit_replace(self, field, old_content, new_content): - if field == 'persona': + if field == "persona": if old_content in self.persona: new_persona = self.persona.replace(old_content, new_content) return self.edit_persona(new_persona) else: - raise ValueError('Content not found in persona (make sure to use exact string)') - elif field == 'human': + raise ValueError("Content not found in persona (make sure to use exact string)") + elif field == "human": if old_content in self.human: new_human = self.human.replace(old_content, new_content) return self.edit_human(new_human) else: - raise ValueError('Content not found in human (make sure to use exact string)') + raise ValueError("Content not found in human (make sure to use exact string)") else: raise KeyError async def summarize_messages( - model, - message_sequence_to_summarize, - ): + model, + message_sequence_to_summarize, +): """Summarize a message sequence using GPT""" summary_prompt = SUMMARY_PROMPT_SYSTEM summary_input = str(message_sequence_to_summarize) summary_input_tkns = count_tokens(summary_input, model) if summary_input_tkns > MESSAGE_SUMMARY_WARNING_TOKENS: - trunc_ratio = (MESSAGE_SUMMARY_WARNING_TOKENS / summary_input_tkns) * 0.8 # For good measure... + trunc_ratio = (MESSAGE_SUMMARY_WARNING_TOKENS / summary_input_tkns) * 0.8 # For good measure... cutoff = int(len(message_sequence_to_summarize) * trunc_ratio) - summary_input = str([await summarize_messages(model, message_sequence_to_summarize[:cutoff])] + message_sequence_to_summarize[cutoff:]) + summary_input = str( + [await summarize_messages(model, message_sequence_to_summarize[:cutoff])] + message_sequence_to_summarize[cutoff:] + ) message_sequence = [ {"role": "system", "content": summary_prompt}, {"role": "user", "content": summary_input}, @@ -139,10 +138,9 @@ async def summarize_messages( class ArchivalMemory(ABC): - @abstractmethod def insert(self, memory_string): - """ Insert new archival memory + """Insert new archival memory :param memory_string: Memory string to insert :type memory_string: str @@ -151,7 +149,7 @@ class ArchivalMemory(ABC): @abstractmethod def search(self, query_string, count=None, start=None) -> Tuple[List[str], int]: - """ Search archival memory + """Search archival memory :param query_string: Query string :type query_string: str @@ -159,7 +157,7 @@ class ArchivalMemory(ABC): :type count: Optional[int] :param start: Offset to start returning results from (None if 0) :type start: Optional[int] - + :return: Tuple of (list of results, total number of results) """ pass @@ -178,7 +176,7 @@ class DummyArchivalMemory(ArchivalMemory): """ def __init__(self, archival_memory_database=None): - self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts + self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts def __len__(self): return len(self._archive) @@ -187,31 +185,33 @@ class DummyArchivalMemory(ArchivalMemory): if len(self._archive) == 0: memory_str = "" else: - memory_str = "\n".join([d['content'] for d in self._archive]) - return \ - f"\n### ARCHIVAL MEMORY ###" + \ - f"\n{memory_str}" + memory_str = "\n".join([d["content"] for d in self._archive]) + return f"\n### ARCHIVAL MEMORY ###" + f"\n{memory_str}" async def insert(self, memory_string, embedding=None): if embedding is not None: - raise ValueError('Basic text-based archival memory does not support embeddings') - self._archive.append({ - # can eventually upgrade to adding semantic tags, etc - 'timestamp': get_local_time(), - 'content': memory_string, - }) + raise ValueError("Basic text-based archival memory does not support embeddings") + self._archive.append( + { + # can eventually upgrade to adding semantic tags, etc + "timestamp": get_local_time(), + "content": memory_string, + } + ) async def search(self, query_string, count=None, start=None): """Simple text-based search""" # in the dummy version, run an (inefficient) case-insensitive match search # printd(f"query_string: {query_string}") - matches = [s for s in self._archive if query_string.lower() in s['content'].lower()] + matches = [s for s in self._archive if query_string.lower() in s["content"].lower()] # printd(f"archive_memory.search (text-based): search for query '{query_string}' returned the following results (limit 5):\n{[str(d['content']) d in matches[:5]]}") - printd(f"archive_memory.search (text-based): search for query '{query_string}' returned the following results (limit 5):\n{[matches[start:count]]}") + printd( + f"archive_memory.search (text-based): search for query '{query_string}' returned the following results (limit 5):\n{[matches[start:count]]}" + ) # start/count support paging through results if start is not None and count is not None: - return matches[start:start+count], len(matches) + return matches[start : start + count], len(matches) elif start is None and count is not None: return matches[:count], len(matches) elif start is not None and count is None: @@ -223,8 +223,8 @@ class DummyArchivalMemory(ArchivalMemory): class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory): """Same as dummy in-memory archival memory, but with bare-bones embedding support""" - def __init__(self, archival_memory_database=None, embedding_model='text-embedding-ada-002'): - self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts + def __init__(self, archival_memory_database=None, embedding_model="text-embedding-ada-002"): + self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts self.embedding_model = embedding_model def __len__(self): @@ -234,15 +234,17 @@ class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory): # Get the embedding if embedding is None: embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model) - embedding_meta = {'model': self.embedding_model} + embedding_meta = {"model": self.embedding_model} printd(f"Got an embedding, type {type(embedding)}, len {len(embedding)}") - self._archive.append({ - 'timestamp': get_local_time(), - 'content': memory_string, - 'embedding': embedding, - 'embedding_metadata': embedding_meta, - }) + self._archive.append( + { + "timestamp": get_local_time(), + "content": memory_string, + "embedding": embedding, + "embedding_metadata": embedding_meta, + } + ) async def search(self, query_string, count=None, start=None): """Simple embedding-based search (inefficient, no caching)""" @@ -251,22 +253,24 @@ class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory): # query_embedding = get_embedding(query_string, model=self.embedding_model) # our wrapped version supports backoff/rate-limits query_embedding = await async_get_embedding_with_backoff(query_string, model=self.embedding_model) - similarity_scores = [cosine_similarity(memory['embedding'], query_embedding) for memory in self._archive] + similarity_scores = [cosine_similarity(memory["embedding"], query_embedding) for memory in self._archive] # Sort the archive based on similarity scores sorted_archive_with_scores = sorted( zip(self._archive, similarity_scores), key=lambda pair: pair[1], # Sort by the similarity score - reverse=True # We want the highest similarity first + reverse=True, # We want the highest similarity first + ) + printd( + f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results (limit 5) and scores:\n{str([str(t[0]['content']) + '- score ' + str(t[1]) for t in sorted_archive_with_scores[:5]])}" ) - printd(f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results (limit 5) and scores:\n{str([str(t[0]['content']) + '- score ' + str(t[1]) for t in sorted_archive_with_scores[:5]])}") # Extract the sorted archive without the scores matches = [item[0] for item in sorted_archive_with_scores] # start/count support paging through results if start is not None and count is not None: - return matches[start:start+count], len(matches) + return matches[start : start + count], len(matches) elif start is None and count is not None: return matches[:count], len(matches) elif start is not None and count is None: @@ -287,13 +291,13 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory): is essential enough not to be left only to the recall memory. """ - def __init__(self, index=None, archival_memory_database=None, embedding_model='text-embedding-ada-002', k=100): + def __init__(self, index=None, archival_memory_database=None, embedding_model="text-embedding-ada-002", k=100): if index is None: - self.index = faiss.IndexFlatL2(1536) # openai embedding vector size. + self.index = faiss.IndexFlatL2(1536) # openai embedding vector size. else: self.index = index self.k = k - self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts + self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts self.embedding_model = embedding_model self.embeddings_dict = {} self.search_results = {} @@ -307,12 +311,14 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory): embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model) print(f"Got an embedding, type {type(embedding)}, len {len(embedding)}") - self._archive.append({ - # can eventually upgrade to adding semantic tags, etc - 'timestamp': get_local_time(), - 'content': memory_string, - }) - embedding = np.array([embedding]).astype('float32') + self._archive.append( + { + # can eventually upgrade to adding semantic tags, etc + "timestamp": get_local_time(), + "content": memory_string, + } + ) + embedding = np.array([embedding]).astype("float32") self.index.add(embedding) async def search(self, query_string, count=None, start=None): @@ -332,20 +338,22 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory): self.search_results[query_string] = search_result if start is not None and count is not None: - toprint = search_result[start:start+count] + toprint = search_result[start : start + count] else: if len(search_result) >= 5: toprint = search_result[:5] else: toprint = search_result - printd(f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results ({start}--{start+5}/{len(search_result)}) and scores:\n{str([t[:60] if len(t) > 60 else t for t in toprint])}") + printd( + f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results ({start}--{start+5}/{len(search_result)}) and scores:\n{str([t[:60] if len(t) > 60 else t for t in toprint])}" + ) # Extract the sorted archive without the scores matches = search_result # start/count support paging through results if start is not None and count is not None: - return matches[start:start+count], len(matches) + return matches[start : start + count], len(matches) elif start is None and count is not None: return matches[:count], len(matches) elif start is not None and count is None: @@ -355,7 +363,6 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory): class RecallMemory(ABC): - @abstractmethod def text_search(self, query_string, count=None, start=None): pass @@ -393,42 +400,46 @@ class DummyRecallMemory(RecallMemory): # don't dump all the conversations, just statistics system_count = user_count = assistant_count = function_count = other_count = 0 for msg in self._message_logs: - role = msg['message']['role'] - if role == 'system': + role = msg["message"]["role"] + if role == "system": system_count += 1 - elif role == 'user': + elif role == "user": user_count += 1 - elif role == 'assistant': + elif role == "assistant": assistant_count += 1 - elif role == 'function': + elif role == "function": function_count += 1 else: other_count += 1 - memory_str = f"Statistics:" + \ - f"\n{len(self._message_logs)} total messages" + \ - f"\n{system_count} system" + \ - f"\n{user_count} user" + \ - f"\n{assistant_count} assistant" + \ - f"\n{function_count} function" + \ - f"\n{other_count} other" - return \ - f"\n### RECALL MEMORY ###" + \ - f"\n{memory_str}" + memory_str = ( + f"Statistics:" + + f"\n{len(self._message_logs)} total messages" + + f"\n{system_count} system" + + f"\n{user_count} user" + + f"\n{assistant_count} assistant" + + f"\n{function_count} function" + + f"\n{other_count} other" + ) + return f"\n### RECALL MEMORY ###" + f"\n{memory_str}" async def insert(self, message): - raise NotImplementedError('This should be handled by the PersistenceManager, recall memory is just a search layer on top') + raise NotImplementedError("This should be handled by the PersistenceManager, recall memory is just a search layer on top") async def text_search(self, query_string, count=None, start=None): # in the dummy version, run an (inefficient) case-insensitive match search - message_pool = [d for d in self._message_logs if d['message']['role'] not in ['system', 'function']] + message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]] - printd(f"recall_memory.text_search: searching for {query_string} (c={count}, s={start}) in {len(self._message_logs)} total messages") - matches = [d for d in message_pool if d['message']['content'] is not None and query_string.lower() in d['message']['content'].lower()] + printd( + f"recall_memory.text_search: searching for {query_string} (c={count}, s={start}) in {len(self._message_logs)} total messages" + ) + matches = [ + d for d in message_pool if d["message"]["content"] is not None and query_string.lower() in d["message"]["content"].lower() + ] printd(f"recall_memory - matches:\n{matches[start:start+count]}") # start/count support paging through results if start is not None and count is not None: - return matches[start:start+count], len(matches) + return matches[start : start + count], len(matches) elif start is None and count is not None: return matches[:count], len(matches) elif start is not None and count is None: @@ -439,7 +450,7 @@ class DummyRecallMemory(RecallMemory): def _validate_date_format(self, date_str): """Validate the given date string in the format 'YYYY-MM-DD'.""" try: - datetime.datetime.strptime(date_str, '%Y-%m-%d') + datetime.datetime.strptime(date_str, "%Y-%m-%d") return True except ValueError: return False @@ -451,25 +462,26 @@ class DummyRecallMemory(RecallMemory): return match.group(1) if match else None async def date_search(self, start_date, end_date, count=None, start=None): - message_pool = [d for d in self._message_logs if d['message']['role'] not in ['system', 'function']] + message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]] # First, validate the start_date and end_date format if not self._validate_date_format(start_date) or not self._validate_date_format(end_date): raise ValueError("Invalid date format. Expected format: YYYY-MM-DD") # Convert dates to datetime objects for comparison - start_date_dt = datetime.datetime.strptime(start_date, '%Y-%m-%d') - end_date_dt = datetime.datetime.strptime(end_date, '%Y-%m-%d') + start_date_dt = datetime.datetime.strptime(start_date, "%Y-%m-%d") + end_date_dt = datetime.datetime.strptime(end_date, "%Y-%m-%d") # Next, match items inside self._message_logs matches = [ - d for d in message_pool - if start_date_dt <= datetime.datetime.strptime(self._extract_date_from_timestamp(d['timestamp']), '%Y-%m-%d') <= end_date_dt + d + for d in message_pool + if start_date_dt <= datetime.datetime.strptime(self._extract_date_from_timestamp(d["timestamp"]), "%Y-%m-%d") <= end_date_dt ] # start/count support paging through results if start is not None and count is not None: - return matches[start:start+count], len(matches) + return matches[start : start + count], len(matches) elif start is None and count is not None: return matches[:count], len(matches) elif start is not None and count is None: @@ -484,17 +496,17 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.embeddings = dict() - self.embedding_model = 'text-embedding-ada-002' + self.embedding_model = "text-embedding-ada-002" self.only_use_preloaded_embeddings = False async def text_search(self, query_string, count=None, start=None): # in the dummy version, run an (inefficient) case-insensitive match search - message_pool = [d for d in self._message_logs if d['message']['role'] not in ['system', 'function']] + message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]] # first, go through and make sure we have all the embeddings we need message_pool_filtered = [] for d in message_pool: - message_str = d['message']['content'] + message_str = d["message"]["content"] if self.only_use_preloaded_embeddings: if message_str not in self.embeddings: printd(f"recall_memory.text_search -- '{message_str}' was not in embedding dict, skipping.") @@ -505,24 +517,26 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory): self.embeddings[message_str] = await async_get_embedding_with_backoff(message_str, model=self.embedding_model) message_pool_filtered.append(d) - # our wrapped version supports backoff/rate-limits + # our wrapped version supports backoff/rate-limits query_embedding = await async_get_embedding_with_backoff(query_string, model=self.embedding_model) - similarity_scores = [cosine_similarity(self.embeddings[d['message']['content']], query_embedding) for d in message_pool_filtered] + similarity_scores = [cosine_similarity(self.embeddings[d["message"]["content"]], query_embedding) for d in message_pool_filtered] # Sort the archive based on similarity scores sorted_archive_with_scores = sorted( zip(message_pool_filtered, similarity_scores), key=lambda pair: pair[1], # Sort by the similarity score - reverse=True # We want the highest similarity first + reverse=True, # We want the highest similarity first + ) + printd( + f"recall_memory.text_search (vector-based): search for query '{query_string}' returned the following results (limit 5) and scores:\n{str([str(t[0]['message']['content']) + '- score ' + str(t[1]) for t in sorted_archive_with_scores[:5]])}" ) - printd(f"recall_memory.text_search (vector-based): search for query '{query_string}' returned the following results (limit 5) and scores:\n{str([str(t[0]['message']['content']) + '- score ' + str(t[1]) for t in sorted_archive_with_scores[:5]])}") # Extract the sorted archive without the scores matches = [item[0] for item in sorted_archive_with_scores] # start/count support paging through results if start is not None and count is not None: - return matches[start:start+count], len(matches) + return matches[start : start + count], len(matches) elif start is None and count is not None: return matches[:count], len(matches) elif start is not None and count is None: @@ -531,55 +545,49 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory): return matches, len(matches) -class LocalArchivalMemory(ArchivalMemory): - """ Archival memory built on top of Llama Index """ +class LocalArchivalMemory(ArchivalMemory): + """Archival memory built on top of Llama Index""" - def __init__(self, archival_memory_database: Optional[str] = None, top_k: Optional[int] = 100): - """ Init function for archival memory + def __init__(self, archival_memory_database: Optional[str] = None, top_k: Optional[int] = 100): + """Init function for archival memory - :param archiva_memory_database: name of dataset to pre-fill archival with + :param archiva_memory_database: name of dataset to pre-fill archival with :type archival_memory_database: str """ - if archival_memory_database is not None: + if archival_memory_database is not None: # TODO: load form ~/.memgpt/archival directory = f"{MEMGPT_DIR}/archival/{archival_memory_database}" assert os.path.exists(directory), f"Archival memory database {archival_memory_database} does not exist" storage_context = StorageContext.from_defaults(persist_dir=directory) self.index = load_index_from_storage(storage_context) - else: + else: self.index = VectorIndex() self.top_k = top_k self.retriever = VectorIndexRetriever( - index=self.index, # does this get refreshed? + index=self.index, # does this get refreshed? similarity_top_k=self.top_k, ) # TODO: have some mechanism for cleanup otherwise will lead to OOM - self.cache = {} + self.cache = {} async def insert(self, memory_string): self.index.insert(memory_string) - - async def search(self, query_string, count=None, start=None): - start = start if start else 0 - count = count if count else self.top_k + async def search(self, query_string, count=None, start=None): + start = start if start else 0 + count = count if count else self.top_k count = min(count + start, self.top_k) if query_string not in self.cache: self.cache[query_string] = self.retriever.retrieve(query_string) - results = self.cache[query_string][start:start+count] - results = [ - {'timestamp': get_local_time(), 'content': node.node.text} - for node in results - ] - #from pprint import pprint - #pprint(results) + results = self.cache[query_string][start : start + count] + results = [{"timestamp": get_local_time(), "content": node.node.text} for node in results] + # from pprint import pprint + # pprint(results) return results, len(results) - + def __repr__(self) -> str: print(self.index.ref_doc_info) return "" - - diff --git a/memgpt/openai_tools.py b/memgpt/openai_tools.py index 20ad5f9a..6b97b06a 100644 --- a/memgpt/openai_tools.py +++ b/memgpt/openai_tools.py @@ -41,9 +41,7 @@ def retry_with_exponential_backoff( # Check if max retries has been reached if num_retries > max_retries: - raise Exception( - f"Maximum number of retries ({max_retries}) exceeded." - ) + raise Exception(f"Maximum number of retries ({max_retries}) exceeded.") # Increment the delay delay *= exponential_base * (1 + jitter * random.random()) @@ -91,9 +89,7 @@ def aretry_with_exponential_backoff( # Check if max retries has been reached if num_retries > max_retries: - raise Exception( - f"Maximum number of retries ({max_retries}) exceeded." - ) + raise Exception(f"Maximum number of retries ({max_retries}) exceeded.") # Increment the delay delay *= exponential_base * (1 + jitter * random.random()) @@ -184,9 +180,7 @@ def configure_azure_support(): azure_openai_endpoint, azure_openai_version, ]: - print( - f"Error: missing Azure OpenAI environment variables. Please see README section on Azure." - ) + print(f"Error: missing Azure OpenAI environment variables. Please see README section on Azure.") return openai.api_type = "azure" @@ -199,10 +193,7 @@ def configure_azure_support(): def check_azure_embeddings(): azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") azure_openai_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") - if ( - azure_openai_deployment is not None - and azure_openai_embedding_deployment is None - ): + if azure_openai_deployment is not None and azure_openai_embedding_deployment is None: raise ValueError( f"Error: It looks like you are using Azure deployment ids and computing embeddings, make sure you are setting one for embeddings as well. Please see README section on Azure" ) diff --git a/memgpt/persistence_manager.py b/memgpt/persistence_manager.py index 74f8d1d9..2c38d8c1 100644 --- a/memgpt/persistence_manager.py +++ b/memgpt/persistence_manager.py @@ -1,12 +1,18 @@ from abc import ABC, abstractmethod import pickle -from .memory import DummyRecallMemory, DummyRecallMemoryWithEmbeddings, DummyArchivalMemory, DummyArchivalMemoryWithEmbeddings, DummyArchivalMemoryWithFaiss, LocalArchivalMemory +from .memory import ( + DummyRecallMemory, + DummyRecallMemoryWithEmbeddings, + DummyArchivalMemory, + DummyArchivalMemoryWithEmbeddings, + DummyArchivalMemoryWithFaiss, + LocalArchivalMemory, +) from .utils import get_local_time, printd class PersistenceManager(ABC): - @abstractmethod def trim_messages(self, num): pass @@ -27,6 +33,7 @@ class PersistenceManager(ABC): def update_memory(self, new_memory): pass + class InMemoryStateManager(PersistenceManager): """In-memory state manager has nothing to manage, all agents are held in-memory""" @@ -41,17 +48,17 @@ class InMemoryStateManager(PersistenceManager): @staticmethod def load(filename): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: return pickle.load(f) def save(self, filename): - with open(filename, 'wb') as fh: + with open(filename, "wb") as fh: pickle.dump(self, fh, protocol=pickle.HIGHEST_PROTOCOL) def init(self, agent): printd(f"Initializing InMemoryStateManager with agent object") - self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()] - self.messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()] + self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()] + self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()] self.memory = agent.memory printd(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}") printd(f"InMemoryStateManager.messages.len = {len(self.messages)}") @@ -67,7 +74,7 @@ class InMemoryStateManager(PersistenceManager): def prepend_to_messages(self, added_messages): # first tag with timestamps - added_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in added_messages] + added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages] printd(f"InMemoryStateManager.prepend_to_message") self.messages = [self.messages[0]] + added_messages + self.messages[1:] @@ -75,7 +82,7 @@ class InMemoryStateManager(PersistenceManager): def append_to_messages(self, added_messages): # first tag with timestamps - added_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in added_messages] + added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages] printd(f"InMemoryStateManager.append_to_messages") self.messages = self.messages + added_messages @@ -83,7 +90,7 @@ class InMemoryStateManager(PersistenceManager): def swap_system_message(self, new_system_message): # first tag with timestamps - new_system_message = {'timestamp': get_local_time(), 'message': new_system_message} + new_system_message = {"timestamp": get_local_time(), "message": new_system_message} printd(f"InMemoryStateManager.swap_system_message") self.messages[0] = new_system_message @@ -93,11 +100,12 @@ class InMemoryStateManager(PersistenceManager): printd(f"InMemoryStateManager.update_memory") self.memory = new_memory + class LocalStateManager(PersistenceManager): """In-memory state manager has nothing to manage, all agents are held in-memory""" recall_memory_cls = DummyRecallMemory - archival_memory_cls = LocalArchivalMemory + archival_memory_cls = LocalArchivalMemory def __init__(self, archival_memory_db=None): # Memory held in-state useful for debugging stateful versions @@ -108,17 +116,17 @@ class LocalStateManager(PersistenceManager): @staticmethod def load(filename): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: return pickle.load(f) def save(self, filename): - with open(filename, 'wb') as fh: + with open(filename, "wb") as fh: pickle.dump(self, fh, protocol=pickle.HIGHEST_PROTOCOL) def init(self, agent): printd(f"Initializing InMemoryStateManager with agent object") - self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()] - self.messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()] + self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()] + self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()] self.memory = agent.memory printd(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}") printd(f"InMemoryStateManager.messages.len = {len(self.messages)}") @@ -126,7 +134,7 @@ class LocalStateManager(PersistenceManager): # Persistence manager also handles DB-related state self.recall_memory = self.recall_memory_cls(message_database=self.all_messages) - # TODO: init archival memory here? + # TODO: init archival memory here? def trim_messages(self, num): # printd(f"InMemoryStateManager.trim_messages") @@ -134,7 +142,7 @@ class LocalStateManager(PersistenceManager): def prepend_to_messages(self, added_messages): # first tag with timestamps - added_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in added_messages] + added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages] printd(f"InMemoryStateManager.prepend_to_message") self.messages = [self.messages[0]] + added_messages + self.messages[1:] @@ -142,7 +150,7 @@ class LocalStateManager(PersistenceManager): def append_to_messages(self, added_messages): # first tag with timestamps - added_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in added_messages] + added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages] printd(f"InMemoryStateManager.append_to_messages") self.messages = self.messages + added_messages @@ -150,7 +158,7 @@ class LocalStateManager(PersistenceManager): def swap_system_message(self, new_system_message): # first tag with timestamps - new_system_message = {'timestamp': get_local_time(), 'message': new_system_message} + new_system_message = {"timestamp": get_local_time(), "message": new_system_message} printd(f"InMemoryStateManager.swap_system_message") self.messages[0] = new_system_message @@ -170,8 +178,8 @@ class InMemoryStateManagerWithPreloadedArchivalMemory(InMemoryStateManager): def init(self, agent): print(f"Initializing InMemoryStateManager with agent object") - self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()] - self.messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()] + self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()] + self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()] self.memory = agent.memory print(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}") print(f"InMemoryStateManager.messages.len = {len(self.messages)}") @@ -199,13 +207,14 @@ class InMemoryStateManagerWithFaiss(InMemoryStateManager): def init(self, agent): print(f"Initializing InMemoryStateManager with agent object") - self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()] - self.messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()] + self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()] + self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()] self.memory = agent.memory print(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}") print(f"InMemoryStateManager.messages.len = {len(self.messages)}") # Persistence manager also handles DB-related state self.recall_memory = self.recall_memory_cls(message_database=self.all_messages) - self.archival_memory = self.archival_memory_cls(index=self.archival_index, archival_memory_database=self.archival_memory_db, k=self.a_k) - + self.archival_memory = self.archival_memory_cls( + index=self.archival_index, archival_memory_database=self.archival_memory_db, k=self.a_k + ) diff --git a/memgpt/personas/examples/docqa/build_index.py b/memgpt/personas/examples/docqa/build_index.py index 2dd94708..94802395 100644 --- a/memgpt/personas/examples/docqa/build_index.py +++ b/memgpt/personas/examples/docqa/build_index.py @@ -5,15 +5,14 @@ import numpy as np import argparse import json -def build_index(embedding_files: str, - index_name: str): +def build_index(embedding_files: str, index_name: str): index = faiss.IndexFlatL2(1536) file_list = sorted(glob(embedding_files)) for embedding_file in file_list: print(embedding_file) - with open(embedding_file, 'rt', encoding='utf-8') as file: + with open(embedding_file, "rt", encoding="utf-8") as file: embeddings = [] l = 0 for line in tqdm(file): @@ -21,7 +20,7 @@ def build_index(embedding_files: str, data = json.loads(line) embeddings.append(data) l += 1 - data = np.array(embeddings).astype('float32') + data = np.array(embeddings).astype("float32") print(data.shape) try: index.add(data) @@ -32,14 +31,11 @@ def build_index(embedding_files: str, faiss.write_index(index, index_name) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--embedding_files', type=str, help='embedding_filepaths glob expression') - parser.add_argument('--output_index_file', type=str, help='output filepath') - args = parser.parse_args() + parser.add_argument("--embedding_files", type=str, help="embedding_filepaths glob expression") + parser.add_argument("--output_index_file", type=str, help="output filepath") + args = parser.parse_args() - build_index( - embedding_files=args.embedding_files, - index_name=args.output_index_file - ) \ No newline at end of file + build_index(embedding_files=args.embedding_files, index_name=args.output_index_file) diff --git a/memgpt/personas/examples/docqa/generate_embeddings_for_docs.py b/memgpt/personas/examples/docqa/generate_embeddings_for_docs.py index f377ce27..0c2d9479 100644 --- a/memgpt/personas/examples/docqa/generate_embeddings_for_docs.py +++ b/memgpt/personas/examples/docqa/generate_embeddings_for_docs.py @@ -7,12 +7,14 @@ import argparse from tqdm import tqdm import openai + try: from dotenv import load_dotenv + load_dotenv() except ModuleNotFoundError: pass -openai.api_key = os.getenv('OPENAI_API_KEY') +openai.api_key = os.getenv("OPENAI_API_KEY") sys.path.append("../../../") from openai_tools import async_get_embedding_with_backoff @@ -24,8 +26,8 @@ from openai_parallel_request_processor import process_api_requests_from_file TPM_LIMIT = 1000000 RPM_LIMIT = 3000 -DEFAULT_FILE = 'iclr/data/qa_data/30_total_documents/nq-open-30_total_documents_gold_at_0.jsonl.gz' -EMBEDDING_MODEL = 'text-embedding-ada-002' +DEFAULT_FILE = "iclr/data/qa_data/30_total_documents/nq-open-30_total_documents_gold_at_0.jsonl.gz" +EMBEDDING_MODEL = "text-embedding-ada-002" async def generate_requests_file(filename): @@ -33,36 +35,33 @@ async def generate_requests_file(filename): base_name = os.path.splitext(filename)[0] requests_filename = f"{base_name}_embedding_requests.jsonl" - with open(filename, 'r') as f: + with open(filename, "r") as f: all_data = [json.loads(line) for line in f] - with open(requests_filename, 'w') as f: + with open(requests_filename, "w") as f: for data in all_data: documents = data for idx, doc in enumerate(documents): title = doc["title"] text = doc["text"] document_string = f"Document [{idx+1}] (Title: {title}) {text}" - request = { - "model": EMBEDDING_MODEL, - "input": document_string - } + request = {"model": EMBEDDING_MODEL, "input": document_string} json_string = json.dumps(request) f.write(json_string + "\n") # Run your parallel processing function input(f"Generated requests file ({requests_filename}), continue with embedding batch requests? (hit enter)") await process_api_requests_from_file( - requests_filepath=requests_filename, - save_filepath=f"{base_name}.embeddings.jsonl.gz", # Adjust as necessary - request_url="https://api.openai.com/v1/embeddings", - api_key=os.getenv('OPENAI_API_KEY'), - max_requests_per_minute=RPM_LIMIT, - max_tokens_per_minute=TPM_LIMIT, - token_encoding_name=EMBEDDING_MODEL, - max_attempts=5, - logging_level=logging.INFO, - ) + requests_filepath=requests_filename, + save_filepath=f"{base_name}.embeddings.jsonl.gz", # Adjust as necessary + request_url="https://api.openai.com/v1/embeddings", + api_key=os.getenv("OPENAI_API_KEY"), + max_requests_per_minute=RPM_LIMIT, + max_tokens_per_minute=TPM_LIMIT, + token_encoding_name=EMBEDDING_MODEL, + max_attempts=5, + logging_level=logging.INFO, + ) async def generate_embedding_file(filename, parallel_mode=False): @@ -72,7 +71,7 @@ async def generate_embedding_file(filename, parallel_mode=False): # Derive the sister filename # base_name = os.path.splitext(filename)[0] - base_name = filename.rsplit('.jsonl', 1)[0] + base_name = filename.rsplit(".jsonl", 1)[0] sister_filename = f"{base_name}.embeddings.jsonl" # Check if the sister file already exists @@ -80,7 +79,7 @@ async def generate_embedding_file(filename, parallel_mode=False): print(f"{sister_filename} already exists. Skipping embedding generation.") return - with open(filename, 'rt') as f: + with open(filename, "rt") as f: all_data = [json.loads(line) for line in f] embedding_data = [] @@ -90,7 +89,9 @@ async def generate_embedding_file(filename, parallel_mode=False): for i, data in enumerate(tqdm(all_data, desc="Processing data", total=len(all_data))): documents = data # Inner loop progress bar - for idx, doc in enumerate(tqdm(documents, desc=f"Embedding documents for data {i+1}/{len(all_data)}", total=len(documents), leave=False)): + for idx, doc in enumerate( + tqdm(documents, desc=f"Embedding documents for data {i+1}/{len(all_data)}", total=len(documents), leave=False) + ): title = doc["title"] text = doc["text"] document_string = f"[Title: {title}] {text}" @@ -103,10 +104,10 @@ async def generate_embedding_file(filename, parallel_mode=False): # Save the embeddings to the sister file # with gzip.open(sister_filename, 'wt') as f: - with open(sister_filename, 'wb') as f: + with open(sister_filename, "wb") as f: for embedding in embedding_data: # f.write(json.dumps(embedding) + '\n') - f.write((json.dumps(embedding) + '\n').encode('utf-8')) + f.write((json.dumps(embedding) + "\n").encode("utf-8")) print(f"Embeddings saved to {sister_filename}") @@ -118,6 +119,7 @@ async def main(): filename = DEFAULT_FILE await generate_embedding_file(filename) + async def main(): parser = argparse.ArgumentParser() parser.add_argument("filename", nargs="?", default=DEFAULT_FILE, help="Path to the input file") @@ -129,4 +131,4 @@ async def main(): if __name__ == "__main__": loop = asyncio.get_event_loop() - loop.run_until_complete(main()) \ No newline at end of file + loop.run_until_complete(main()) diff --git a/memgpt/personas/examples/docqa/openai_parallel_request_processor.py b/memgpt/personas/examples/docqa/openai_parallel_request_processor.py index 4b9a1aae..169bfd37 100644 --- a/memgpt/personas/examples/docqa/openai_parallel_request_processor.py +++ b/memgpt/personas/examples/docqa/openai_parallel_request_processor.py @@ -121,9 +121,7 @@ async def process_api_requests_from_file( """Processes API requests in parallel, throttling to stay under rate limits.""" # constants seconds_to_pause_after_rate_limit_error = 15 - seconds_to_sleep_each_loop = ( - 0.001 # 1 ms limits max throughput to 1,000 requests per second - ) + seconds_to_sleep_each_loop = 0.001 # 1 ms limits max throughput to 1,000 requests per second # initialize logging logging.basicConfig(level=logging_level) @@ -135,12 +133,8 @@ async def process_api_requests_from_file( # initialize trackers queue_of_requests_to_retry = asyncio.Queue() - task_id_generator = ( - task_id_generator_function() - ) # generates integer IDs of 1, 2, 3, ... - status_tracker = ( - StatusTracker() - ) # single instance to track a collection of variables + task_id_generator = task_id_generator_function() # generates integer IDs of 1, 2, 3, ... + status_tracker = StatusTracker() # single instance to track a collection of variables next_request = None # variable to hold the next request to call # initialize available capacity counts @@ -163,9 +157,7 @@ async def process_api_requests_from_file( if next_request is None: if not queue_of_requests_to_retry.empty(): next_request = queue_of_requests_to_retry.get_nowait() - logging.debug( - f"Retrying request {next_request.task_id}: {next_request}" - ) + logging.debug(f"Retrying request {next_request.task_id}: {next_request}") elif file_not_finished: try: # get new request @@ -173,17 +165,13 @@ async def process_api_requests_from_file( next_request = APIRequest( task_id=next(task_id_generator), request_json=request_json, - token_consumption=num_tokens_consumed_from_request( - request_json, api_endpoint, token_encoding_name - ), + token_consumption=num_tokens_consumed_from_request(request_json, api_endpoint, token_encoding_name), attempts_left=max_attempts, metadata=request_json.pop("metadata", None), ) status_tracker.num_tasks_started += 1 status_tracker.num_tasks_in_progress += 1 - logging.debug( - f"Reading request {next_request.task_id}: {next_request}" - ) + logging.debug(f"Reading request {next_request.task_id}: {next_request}") except StopIteration: # if file runs out, set flag to stop reading it logging.debug("Read file exhausted") @@ -193,13 +181,11 @@ async def process_api_requests_from_file( current_time = time.time() seconds_since_update = current_time - last_update_time available_request_capacity = min( - available_request_capacity - + max_requests_per_minute * seconds_since_update / 60.0, + available_request_capacity + max_requests_per_minute * seconds_since_update / 60.0, max_requests_per_minute, ) available_token_capacity = min( - available_token_capacity - + max_tokens_per_minute * seconds_since_update / 60.0, + available_token_capacity + max_tokens_per_minute * seconds_since_update / 60.0, max_tokens_per_minute, ) last_update_time = current_time @@ -207,10 +193,7 @@ async def process_api_requests_from_file( # if enough capacity available, call API if next_request: next_request_tokens = next_request.token_consumption - if ( - available_request_capacity >= 1 - and available_token_capacity >= next_request_tokens - ): + if available_request_capacity >= 1 and available_token_capacity >= next_request_tokens: # update counters available_request_capacity -= 1 available_token_capacity -= next_request_tokens @@ -237,17 +220,9 @@ async def process_api_requests_from_file( await asyncio.sleep(seconds_to_sleep_each_loop) # if a rate limit error was hit recently, pause to cool down - seconds_since_rate_limit_error = ( - time.time() - status_tracker.time_of_last_rate_limit_error - ) - if ( - seconds_since_rate_limit_error - < seconds_to_pause_after_rate_limit_error - ): - remaining_seconds_to_pause = ( - seconds_to_pause_after_rate_limit_error - - seconds_since_rate_limit_error - ) + seconds_since_rate_limit_error = time.time() - status_tracker.time_of_last_rate_limit_error + if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error: + remaining_seconds_to_pause = seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error await asyncio.sleep(remaining_seconds_to_pause) # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago logging.warn( @@ -255,17 +230,13 @@ async def process_api_requests_from_file( ) # after finishing, log final status - logging.info( - f"""Parallel processing complete. Results saved to {save_filepath}""" - ) + logging.info(f"""Parallel processing complete. Results saved to {save_filepath}""") if status_tracker.num_tasks_failed > 0: logging.warning( f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed. Errors logged to {save_filepath}." ) if status_tracker.num_rate_limit_errors > 0: - logging.warning( - f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate." - ) + logging.warning(f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate.") # dataclasses @@ -309,26 +280,18 @@ class APIRequest: logging.info(f"Starting request #{self.task_id}") error = None try: - async with session.post( - url=request_url, headers=request_header, json=self.request_json - ) as response: + async with session.post(url=request_url, headers=request_header, json=self.request_json) as response: response = await response.json() if "error" in response: - logging.warning( - f"Request {self.task_id} failed with error {response['error']}" - ) + logging.warning(f"Request {self.task_id} failed with error {response['error']}") status_tracker.num_api_errors += 1 error = response if "Rate limit" in response["error"].get("message", ""): status_tracker.time_of_last_rate_limit_error = time.time() status_tracker.num_rate_limit_errors += 1 - status_tracker.num_api_errors -= ( - 1 # rate limit errors are counted separately - ) + status_tracker.num_api_errors -= 1 # rate limit errors are counted separately - except ( - Exception - ) as e: # catching naked exceptions is bad practice, but in this case we'll log & save them + except Exception as e: # catching naked exceptions is bad practice, but in this case we'll log & save them logging.warning(f"Request {self.task_id} failed with Exception {e}") status_tracker.num_other_errors += 1 error = e @@ -337,9 +300,7 @@ class APIRequest: if self.attempts_left: retry_queue.put_nowait(self) else: - logging.error( - f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}" - ) + logging.error(f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}") data = ( [self.request_json, [str(e) for e in self.result], self.metadata] if self.metadata @@ -349,11 +310,7 @@ class APIRequest: status_tracker.num_tasks_in_progress -= 1 status_tracker.num_tasks_failed += 1 else: - data = ( - [self.request_json, response, self.metadata] - if self.metadata - else [self.request_json, response] - ) + data = [self.request_json, response, self.metadata] if self.metadata else [self.request_json, response] append_to_jsonl(data, save_filepath) status_tracker.num_tasks_in_progress -= 1 status_tracker.num_tasks_succeeded += 1 @@ -382,8 +339,8 @@ def num_tokens_consumed_from_request( token_encoding_name: str, ): """Count the number of tokens in the request. Only supports completion and embedding requests.""" - if token_encoding_name == 'text-embedding-ada-002': - encoding = tiktoken.get_encoding('cl100k_base') + if token_encoding_name == "text-embedding-ada-002": + encoding = tiktoken.get_encoding("cl100k_base") else: encoding = tiktoken.get_encoding(token_encoding_name) # if completions request, tokens = prompt + n * max_tokens @@ -415,9 +372,7 @@ def num_tokens_consumed_from_request( num_tokens = prompt_tokens + completion_tokens * len(prompt) return num_tokens else: - raise TypeError( - 'Expecting either string or list of strings for "prompt" field in completion request' - ) + raise TypeError('Expecting either string or list of strings for "prompt" field in completion request') # if embeddings request, tokens = input tokens elif api_endpoint == "embeddings": input = request_json["input"] @@ -428,14 +383,10 @@ def num_tokens_consumed_from_request( num_tokens = sum([len(encoding.encode(i)) for i in input]) return num_tokens else: - raise TypeError( - 'Expecting either string or list of strings for "inputs" field in embedding request' - ) + raise TypeError('Expecting either string or list of strings for "inputs" field in embedding request') # more logic needed to support other API calls (e.g., edits, inserts, DALL-E) else: - raise NotImplementedError( - f'API endpoint "{api_endpoint}" not implemented in this script' - ) + raise NotImplementedError(f'API endpoint "{api_endpoint}" not implemented in this script') def task_id_generator_function(): @@ -502,4 +453,4 @@ with open(filename, "w") as f: ``` As with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically). -""" \ No newline at end of file +""" diff --git a/memgpt/personas/examples/docqa/scrape_docs.py b/memgpt/personas/examples/docqa/scrape_docs.py index 66682694..f02df414 100644 --- a/memgpt/personas/examples/docqa/scrape_docs.py +++ b/memgpt/personas/examples/docqa/scrape_docs.py @@ -4,69 +4,65 @@ import tiktoken import json # Define the directory where the documentation resides -docs_dir = 'text' +docs_dir = "text" encoding = tiktoken.encoding_for_model("gpt-4") PASSAGE_TOKEN_LEN = 800 + def extract_text_from_sphinx_txt(file_path): lines = [] title = "" - with open(file_path, 'r', encoding='utf-8') as file: + with open(file_path, "r", encoding="utf-8") as file: for line in file: if not title: title = line.strip() continue - if line and re.match(r'^.*\S.*$', line) and not re.match(r'^[-=*]+$', line): + if line and re.match(r"^.*\S.*$", line) and not re.match(r"^[-=*]+$", line): lines.append(line) - passages = [] + passages = [] curr_passage = [] curr_token_ct = 0 for line in lines: try: - line_token_ct = len(encoding.encode(line, allowed_special={'<|endoftext|>'})) + line_token_ct = len(encoding.encode(line, allowed_special={"<|endoftext|>"})) except Exception as e: print("line", line) raise e if line_token_ct > PASSAGE_TOKEN_LEN: - passages.append({ - 'title': title, - 'text': line[:3200], - 'num_tokens': curr_token_ct, - }) + passages.append( + { + "title": title, + "text": line[:3200], + "num_tokens": curr_token_ct, + } + ) continue curr_token_ct += line_token_ct curr_passage.append(line) if curr_token_ct > PASSAGE_TOKEN_LEN: - passages.append({ - 'title': title, - 'text': ''.join(curr_passage), - 'num_tokens': curr_token_ct - }) + passages.append({"title": title, "text": "".join(curr_passage), "num_tokens": curr_token_ct}) curr_passage = [] curr_token_ct = 0 if len(curr_passage) > 0: - passages.append({ - 'title': title, - 'text': ''.join(curr_passage), - 'num_tokens': curr_token_ct - }) + passages.append({"title": title, "text": "".join(curr_passage), "num_tokens": curr_token_ct}) return passages + # Iterate over all files in the directory and its subdirectories passages = [] total_files = 0 for subdir, _, files in os.walk(docs_dir): for file in files: - if file.endswith('.txt'): + if file.endswith(".txt"): file_path = os.path.join(subdir, file) passages.append(extract_text_from_sphinx_txt(file_path)) total_files += 1 print("total .txt files:", total_files) # Save to a new text file or process as needed -with open('all_docs.jsonl', 'w', encoding='utf-8') as file: +with open("all_docs.jsonl", "w", encoding="utf-8") as file: for p in passages: file.write(json.dumps(p)) - file.write('\n') + file.write("\n") diff --git a/memgpt/presets.py b/memgpt/presets.py index 4fad1ed8..76ff8fae 100644 --- a/memgpt/presets.py +++ b/memgpt/presets.py @@ -1,30 +1,33 @@ - from .prompts import gpt_functions from .prompts import gpt_system from .agent import AgentAsync from .utils import printd -DEFAULT = 'memgpt_chat' +DEFAULT = "memgpt_chat" + def use_preset(preset_name, model, persona, human, interface, persistence_manager): """Storing combinations of SYSTEM + FUNCTION prompts""" - if preset_name == 'memgpt_chat': - + if preset_name == "memgpt_chat": functions = [ - 'send_message', 'pause_heartbeats', - 'core_memory_append', 'core_memory_replace', - 'conversation_search', 'conversation_search_date', - 'archival_memory_insert', 'archival_memory_search', + "send_message", + "pause_heartbeats", + "core_memory_append", + "core_memory_replace", + "conversation_search", + "conversation_search_date", + "archival_memory_insert", + "archival_memory_search", ] - available_functions = [v for k,v in gpt_functions.FUNCTIONS_CHAINING.items() if k in functions] - printd(f"Available functions:\n", [x['name'] for x in available_functions]) + available_functions = [v for k, v in gpt_functions.FUNCTIONS_CHAINING.items() if k in functions] + printd(f"Available functions:\n", [x["name"] for x in available_functions]) assert len(functions) == len(available_functions) - if 'gpt-3.5' in model: + if "gpt-3.5" in model: # use a different system message for gpt-3.5 - preset_name = 'memgpt_gpt35_extralong' + preset_name = "memgpt_gpt35_extralong" return AgentAsync( model=model, @@ -35,8 +38,8 @@ def use_preset(preset_name, model, persona, human, interface, persistence_manage persona_notes=persona, human_notes=human, # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now - first_message_verify_mono=True if 'gpt-4' in model else False, + first_message_verify_mono=True if "gpt-4" in model else False, ) else: - raise ValueError(preset_name) \ No newline at end of file + raise ValueError(preset_name) diff --git a/memgpt/prompts/gpt_functions.py b/memgpt/prompts/gpt_functions.py index a32a545e..060b50c7 100644 --- a/memgpt/prompts/gpt_functions.py +++ b/memgpt/prompts/gpt_functions.py @@ -2,9 +2,7 @@ from ..constants import FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT # FUNCTIONS_PROMPT_MULTISTEP_NO_HEARTBEATS = FUNCTIONS_PROMPT_MULTISTEP[:-1] FUNCTIONS_CHAINING = { - - 'send_message': - { + "send_message": { "name": "send_message", "description": "Sends a message to the human user", "parameters": { @@ -17,11 +15,9 @@ FUNCTIONS_CHAINING = { }, }, "required": ["message"], - } + }, }, - - 'pause_heartbeats': - { + "pause_heartbeats": { "name": "pause_heartbeats", "description": "Temporarily ignore timed heartbeats. You may still receive messages from manual heartbeats and other events.", "parameters": { @@ -34,11 +30,9 @@ FUNCTIONS_CHAINING = { }, }, "required": ["minutes"], - } + }, }, - - 'message_chatgpt': - { + "message_chatgpt": { "name": "message_chatgpt", "description": "Send a message to a more basic AI, ChatGPT. A useful resource for asking questions. ChatGPT does not retain memory of previous interactions.", "parameters": { @@ -55,11 +49,9 @@ FUNCTIONS_CHAINING = { }, }, "required": ["message", "request_heartbeat"], - } + }, }, - - 'core_memory_append': - { + "core_memory_append": { "name": "core_memory_append", "description": "Append to the contents of core memory.", "parameters": { @@ -79,11 +71,9 @@ FUNCTIONS_CHAINING = { }, }, "required": ["name", "content", "request_heartbeat"], - } + }, }, - - 'core_memory_replace': - { + "core_memory_replace": { "name": "core_memory_replace", "description": "Replace to the contents of core memory. To delete memories, use an empty string for new_content.", "parameters": { @@ -107,11 +97,9 @@ FUNCTIONS_CHAINING = { }, }, "required": ["name", "old_content", "new_content", "request_heartbeat"], - } + }, }, - - 'recall_memory_search': - { + "recall_memory_search": { "name": "recall_memory_search", "description": "Search prior conversation history using a string.", "parameters": { @@ -131,11 +119,9 @@ FUNCTIONS_CHAINING = { }, }, "required": ["name", "page", "request_heartbeat"], - } + }, }, - - 'conversation_search': - { + "conversation_search": { "name": "conversation_search", "description": "Search prior conversation history using case-insensitive string matching.", "parameters": { @@ -155,11 +141,9 @@ FUNCTIONS_CHAINING = { }, }, "required": ["name", "page", "request_heartbeat"], - } + }, }, - - 'recall_memory_search_date': - { + "recall_memory_search_date": { "name": "recall_memory_search_date", "description": "Search prior conversation history using a date range.", "parameters": { @@ -183,11 +167,9 @@ FUNCTIONS_CHAINING = { }, }, "required": ["name", "page", "request_heartbeat"], - } + }, }, - - 'conversation_search_date': - { + "conversation_search_date": { "name": "conversation_search_date", "description": "Search prior conversation history using a date range.", "parameters": { @@ -211,11 +193,9 @@ FUNCTIONS_CHAINING = { }, }, "required": ["name", "page", "request_heartbeat"], - } + }, }, - - 'archival_memory_insert': - { + "archival_memory_insert": { "name": "archival_memory_insert", "description": "Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.", "parameters": { @@ -231,11 +211,9 @@ FUNCTIONS_CHAINING = { }, }, "required": ["name", "content", "request_heartbeat"], - } + }, }, - - 'archival_memory_search': - { + "archival_memory_search": { "name": "archival_memory_search", "description": "Search archival memory using semantic (embedding-based) search.", "parameters": { @@ -255,7 +233,6 @@ FUNCTIONS_CHAINING = { }, }, "required": ["name", "query", "page", "request_heartbeat"], - } + }, }, - -} \ No newline at end of file +} diff --git a/memgpt/prompts/gpt_summarize.py b/memgpt/prompts/gpt_summarize.py index 619dbf83..95c0e199 100644 --- a/memgpt/prompts/gpt_summarize.py +++ b/memgpt/prompts/gpt_summarize.py @@ -1,6 +1,5 @@ WORD_LIMIT = 100 -SYSTEM = \ -f""" +SYSTEM = f""" Your job is to summarize a history of previous messages in a conversation between an AI persona and a human. The conversation you are given is a from a fixed context window and may not be complete. Messages sent by the AI are marked with the 'assistant' role. @@ -12,4 +11,4 @@ The 'user' role is also used for important system events, such as login events a Summarize what happened in the conversation from the perspective of the AI (use the first person). Keep your summary less than {WORD_LIMIT} words, do NOT exceed this word limit. Only output the summary, do NOT include anything else in your output. -""" \ No newline at end of file +""" diff --git a/memgpt/prompts/gpt_system.py b/memgpt/prompts/gpt_system.py index 2ee8edec..8100b6ee 100644 --- a/memgpt/prompts/gpt_system.py +++ b/memgpt/prompts/gpt_system.py @@ -2,11 +2,11 @@ import os def get_system_text(key): - filename = f'{key}.txt' - file_path = os.path.join(os.path.dirname(__file__), 'system', filename) + filename = f"{key}.txt" + file_path = os.path.join(os.path.dirname(__file__), "system", filename) if os.path.exists(file_path): - with open(file_path, 'r') as file: + with open(file_path, "r") as file: return file.read().strip() else: raise FileNotFoundError(f"No file found for key {key}, path={file_path}") diff --git a/memgpt/system.py b/memgpt/system.py index 116090a5..5993a007 100644 --- a/memgpt/system.py +++ b/memgpt/system.py @@ -1,18 +1,22 @@ import json from .utils import get_local_time -from .constants import INITIAL_BOOT_MESSAGE, INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT, INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG, MESSAGE_SUMMARY_WARNING_STR +from .constants import ( + INITIAL_BOOT_MESSAGE, + INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT, + INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG, + MESSAGE_SUMMARY_WARNING_STR, +) -def get_initial_boot_messages(version='startup'): - - if version == 'startup': +def get_initial_boot_messages(version="startup"): + if version == "startup": initial_boot_message = INITIAL_BOOT_MESSAGE messages = [ {"role": "assistant", "content": initial_boot_message}, ] - elif version == 'startup_with_send_message': + elif version == "startup_with_send_message": messages = [ # first message includes both inner monologue and function call to send_message { @@ -20,34 +24,23 @@ def get_initial_boot_messages(version='startup'): "content": INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT, "function_call": { "name": "send_message", - "arguments": "{\n \"message\": \"" + f"{INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG}" + "\"\n}" - } + "arguments": '{\n "message": "' + f"{INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG}" + '"\n}', + }, }, # obligatory function return message - { - "role": "function", - "name": "send_message", - "content": package_function_response(True, None) - } + {"role": "function", "name": "send_message", "content": package_function_response(True, None)}, ] - elif version == 'startup_with_send_message_gpt35': + elif version == "startup_with_send_message_gpt35": messages = [ # first message includes both inner monologue and function call to send_message { "role": "assistant", "content": "*inner thoughts* Still waiting on the user. Sending a message with function.", - "function_call": { - "name": "send_message", - "arguments": "{\n \"message\": \"" + f"Hi, is anyone there?" + "\"\n}" - } + "function_call": {"name": "send_message", "arguments": '{\n "message": "' + f"Hi, is anyone there?" + '"\n}'}, }, # obligatory function return message - { - "role": "function", - "name": "send_message", - "content": package_function_response(True, None) - } + {"role": "function", "name": "send_message", "content": package_function_response(True, None)}, ] else: @@ -56,12 +49,11 @@ def get_initial_boot_messages(version='startup'): return messages -def get_heartbeat(reason='Automated timer', include_location=False, location_name='San Francisco, CA, USA'): - +def get_heartbeat(reason="Automated timer", include_location=False, location_name="San Francisco, CA, USA"): # Package the message with time and location formatted_time = get_local_time() packaged_message = { - "type": 'heartbeat', + "type": "heartbeat", "reason": reason, "time": formatted_time, } @@ -72,12 +64,11 @@ def get_heartbeat(reason='Automated timer', include_location=False, location_nam return json.dumps(packaged_message) -def get_login_event(last_login='Never (first login)', include_location=False, location_name='San Francisco, CA, USA'): - +def get_login_event(last_login="Never (first login)", include_location=False, location_name="San Francisco, CA, USA"): # Package the message with time and location formatted_time = get_local_time() packaged_message = { - "type": 'login', + "type": "login", "last_login": last_login, "time": formatted_time, } @@ -88,12 +79,11 @@ def get_login_event(last_login='Never (first login)', include_location=False, lo return json.dumps(packaged_message) -def package_user_message(user_message, time=None, include_location=False, location_name='San Francisco, CA, USA'): - +def package_user_message(user_message, time=None, include_location=False, location_name="San Francisco, CA, USA"): # Package the message with time and location formatted_time = time if time else get_local_time() packaged_message = { - "type": 'user_message', + "type": "user_message", "message": user_message, "time": formatted_time, } @@ -103,11 +93,11 @@ def package_user_message(user_message, time=None, include_location=False, locati return json.dumps(packaged_message) -def package_function_response(was_success, response_string, timestamp=None): +def package_function_response(was_success, response_string, timestamp=None): formatted_time = get_local_time() if timestamp is None else timestamp packaged_message = { - "status": 'OK' if was_success else 'Failed', + "status": "OK" if was_success else "Failed", "message": response_string, "time": formatted_time, } @@ -116,14 +106,14 @@ def package_function_response(was_success, response_string, timestamp=None): def package_summarize_message(summary, summary_length, hidden_message_count, total_message_count, timestamp=None): - - context_message = \ - f"Note: prior messages ({hidden_message_count} of {total_message_count} total messages) have been hidden from view due to conversation memory constraints.\n" \ + context_message = ( + f"Note: prior messages ({hidden_message_count} of {total_message_count} total messages) have been hidden from view due to conversation memory constraints.\n" + f"The following is a summary of the previous {summary_length} messages:\n {summary}" + ) formatted_time = get_local_time() if timestamp is None else timestamp packaged_message = { - "type": 'system_alert', + "type": "system_alert", "message": context_message, "time": formatted_time, } @@ -136,10 +126,13 @@ def package_summarize_message_no_summary(hidden_message_count, timestamp=None, m # Package the message with time and location formatted_time = get_local_time() if timestamp is None else timestamp - context_message = message if message else \ - f"Note: {hidden_message_count} prior messages with the user have been hidden from view due to conversation memory constraints. Older messages are stored in Recall Memory and can be viewed using functions." + context_message = ( + message + if message + else f"Note: {hidden_message_count} prior messages with the user have been hidden from view due to conversation memory constraints. Older messages are stored in Recall Memory and can be viewed using functions." + ) packaged_message = { - "type": 'system_alert', + "type": "system_alert", "message": context_message, "time": formatted_time, } @@ -148,12 +141,11 @@ def package_summarize_message_no_summary(hidden_message_count, timestamp=None, m def get_token_limit_warning(): - formatted_time = get_local_time() packaged_message = { - "type": 'system_alert', + "type": "system_alert", "message": MESSAGE_SUMMARY_WARNING_STR, "time": formatted_time, } - return json.dumps(packaged_message) \ No newline at end of file + return json.dumps(packaged_message) diff --git a/memgpt/utils.py b/memgpt/utils.py index 560f544f..f9fa614a 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -19,6 +19,7 @@ from memgpt.constants import MEMGPT_DIR from llama_index import set_global_service_context, ServiceContext, VectorStoreIndex, load_index_from_storage, StorageContext from llama_index.embeddings import OpenAIEmbedding + def count_tokens(s: str, model: str = "gpt-4") -> int: encoding = tiktoken.encoding_for_model(model) return len(encoding.encode(s)) @@ -169,9 +170,7 @@ def chunk_file(file, tkns_per_chunk=300, model="gpt-4"): line_token_ct = len(encoding.encode(line)) except Exception as e: line_token_ct = len(line.split(" ")) / 0.75 - print( - f"Could not encode line {i}, estimating it to be {line_token_ct} tokens" - ) + print(f"Could not encode line {i}, estimating it to be {line_token_ct} tokens") print(e) if line_token_ct > tkns_per_chunk: if len(curr_chunk) > 0: @@ -195,9 +194,7 @@ def chunk_files(files, tkns_per_chunk=300, model="gpt-4"): archival_database = [] for file in files: timestamp = os.path.getmtime(file) - formatted_time = datetime.fromtimestamp(timestamp).strftime( - "%Y-%m-%d %I:%M:%S %p %Z%z" - ) + formatted_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %I:%M:%S %p %Z%z") file_stem = file.split("/")[-1] chunks = [c for c in chunk_file(file, tkns_per_chunk, model)] for i, chunk in enumerate(chunks): @@ -244,9 +241,7 @@ async def process_concurrently(archival_database, model, concurrency=10): # Create a list of tasks for chunks embedding_data = [0 for _ in archival_database] - tasks = [ - bounded_process_chunk(i, chunk) for i, chunk in enumerate(archival_database) - ] + tasks = [bounded_process_chunk(i, chunk) for i, chunk in enumerate(archival_database)] for future in tqdm( asyncio.as_completed(tasks), @@ -268,15 +263,12 @@ async def prepare_archival_index_from_files_compute_embeddings( files = sorted(glob.glob(glob_pattern)) save_dir = os.path.join( MEMGPT_DIR, - "archival_index_from_files_" - + get_local_time().replace(" ", "_").replace(":", "_"), + "archival_index_from_files_" + get_local_time().replace(" ", "_").replace(":", "_"), ) os.makedirs(save_dir, exist_ok=True) total_tokens = total_bytes(glob_pattern) / 3 price_estimate = total_tokens / 1000 * 0.0001 - confirm = input( - f"Computing embeddings over {len(files)} files. This will cost ~${price_estimate:.2f}. Continue? [y/n] " - ) + confirm = input(f"Computing embeddings over {len(files)} files. This will cost ~${price_estimate:.2f}. Continue? [y/n] ") if confirm != "y": raise Exception("embeddings were not computed") @@ -292,9 +284,7 @@ async def prepare_archival_index_from_files_compute_embeddings( archival_storage_file = os.path.join(save_dir, "all_docs.jsonl") chunks_by_file = chunk_files_for_jsonl(files, tkns_per_chunk, model) with open(archival_storage_file, "w") as f: - print( - f"Saving archival storage with preloaded files to {archival_storage_file}" - ) + print(f"Saving archival storage with preloaded files to {archival_storage_file}") for c in chunks_by_file: json.dump(c, f) f.write("\n") @@ -341,9 +331,8 @@ def read_database_as_list(database_name): return result_list - -def estimate_openai_cost(docs): - """ Estimate OpenAI embedding cost +def estimate_openai_cost(docs): + """Estimate OpenAI embedding cost :param docs: Documents to be embedded :type docs: List[Document] @@ -356,18 +345,11 @@ def estimate_openai_cost(docs): embed_model = MockEmbedding(embed_dim=1536) - token_counter = TokenCountingHandler( - tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode - ) + token_counter = TokenCountingHandler(tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode) callback_manager = CallbackManager([token_counter]) - set_global_service_context( - ServiceContext.from_defaults( - embed_model=embed_model, - callback_manager=callback_manager - ) - ) + set_global_service_context(ServiceContext.from_defaults(embed_model=embed_model, callback_manager=callback_manager)) index = VectorStoreIndex.from_documents(docs) # estimate cost @@ -377,8 +359,7 @@ def estimate_openai_cost(docs): def get_index(name, docs): - - """ Index documents + """Index documents :param docs: Documents to be embedded :type docs: List[Document] @@ -398,38 +379,40 @@ def get_index(name, docs): estimated_cost = estimate_openai_cost(docs) # TODO: prettier cost formatting - confirm = typer.confirm(typer.style(f"Open AI embedding cost will be approximately ${estimated_cost} - continue?", fg="yellow"), default=True) + confirm = typer.confirm( + typer.style(f"Open AI embedding cost will be approximately ${estimated_cost} - continue?", fg="yellow"), default=True + ) if not confirm: typer.secho("Aborting.", fg="red") exit() - + embed_model = OpenAIEmbedding() - service_context = ServiceContext.from_defaults(embed_model=embed_model, chunk_size = 300) + service_context = ServiceContext.from_defaults(embed_model=embed_model, chunk_size=300) set_global_service_context(service_context) # index documents index = VectorStoreIndex.from_documents(docs) return index -def save_index(index, name): - """ Save index to a specificed name in ~/.memgpt +def save_index(index, name): + """Save index to a specificed name in ~/.memgpt :param index: Index to save :type index: VectorStoreIndex :param name: Name of index :type name: str """ - # save - # TODO: load directory from config + # save + # TODO: load directory from config # TODO: save to vectordb/local depending on config dir = f"{MEMGPT_DIR}/archival/{name}" ## Avoid overwriting ## check if directory exists - #if os.path.exists(dir): + # if os.path.exists(dir): # confirm = typer.confirm(typer.style(f"Index with name {name} already exists -- overwrite?", fg="red"), default=False) # if not confirm: # typer.secho("Aborting.", fg="red") diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py index dc857372..95bec5ce 100644 --- a/tests/test_load_archival.py +++ b/tests/test_load_archival.py @@ -9,10 +9,7 @@ import memgpt.presets as presets import memgpt.constants as constants import memgpt.personas.personas as personas import memgpt.humans.humans as humans -from memgpt.persistence_manager import ( - InMemoryStateManager, - LocalStateManager -) +from memgpt.persistence_manager import InMemoryStateManager, LocalStateManager from memgpt.config import Config from memgpt.constants import MEMGPT_DIR, DEFAULT_MEMGPT_MODEL from memgpt.connectors import connector @@ -20,6 +17,7 @@ import memgpt.interface # for printing to terminal import asyncio from datasets import load_dataset + def test_load_directory(): # downloading hugging face dataset (if does not exist) dataset = load_dataset("MemGPT/example_short_stories") @@ -30,12 +28,12 @@ def test_load_directory(): # Construct the default path if the environment variable is not set. cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets") - # load directory + # load directory print("Loading dataset into index...") print(cache_dir) load_directory( name="tmp_hf_dataset", - input_dir=cache_dir, + input_dir=cache_dir, recursive=True, ) @@ -51,23 +49,25 @@ def test_load_directory(): memgpt.interface, persistence_manager, ) - def query(q): + + def query(q): res = asyncio.run(memgpt_agent.archival_memory_search(q)) return res results = query("cinderella be getting sick") assert "Cinderella" in results, f"Expected 'Cinderella' in results, but got {results}" -def test_load_webpage(): + +def test_load_webpage(): pass -def test_load_database(): +def test_load_database(): from sqlalchemy import create_engine, MetaData import pandas as pd db_path = "memgpt/personas/examples/sqldb/test.db" - engine = create_engine(f'sqlite:///{db_path}') + engine = create_engine(f"sqlite:///{db_path}") # Create a MetaData object and reflect the database to get table information. metadata = MetaData() @@ -87,7 +87,7 @@ def test_load_database(): load_database( name="tmp_db_dataset", - #engine=engine, + # engine=engine, dump_path=db_path, query=f"SELECT * FROM {list(table_names)[0]}", ) @@ -107,7 +107,5 @@ def test_load_database(): assert True - - -#test_load_directory() -test_load_database() \ No newline at end of file +# test_load_directory() +test_load_database()