diff --git a/memgpt/__main__.py b/memgpt/__main__.py
index 2310408d..89f11424 100644
--- a/memgpt/__main__.py
+++ b/memgpt/__main__.py
@@ -1,2 +1,3 @@
from .main import app
+
app()
diff --git a/memgpt/agent.py b/memgpt/agent.py
index 6e5de9f7..85ce0df3 100644
--- a/memgpt/agent.py
+++ b/memgpt/agent.py
@@ -11,10 +11,15 @@ from .system import get_heartbeat, get_login_event, package_function_response, p
from .memory import CoreMemory as Memory, summarize_messages
from .openai_tools import acompletions_with_backoff as acreate
from .utils import get_local_time, parse_json, united_diff, printd, count_tokens
-from .constants import \
- FIRST_MESSAGE_ATTEMPTS, MAX_PAUSE_HEARTBEATS, \
- MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MESSAGE_SUMMARY_WARNING_TOKENS, \
- CORE_MEMORY_HUMAN_CHAR_LIMIT, CORE_MEMORY_PERSONA_CHAR_LIMIT
+from .constants import (
+ FIRST_MESSAGE_ATTEMPTS,
+ MAX_PAUSE_HEARTBEATS,
+ MESSAGE_CHATGPT_FUNCTION_MODEL,
+ MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE,
+ MESSAGE_SUMMARY_WARNING_TOKENS,
+ CORE_MEMORY_HUMAN_CHAR_LIMIT,
+ CORE_MEMORY_PERSONA_CHAR_LIMIT,
+)
def initialize_memory(ai_notes, human_notes):
@@ -28,52 +33,57 @@ def initialize_memory(ai_notes, human_notes):
return memory
-def construct_system_with_memory(
- system, memory, memory_edit_timestamp,
- archival_memory=None, recall_memory=None
- ):
- full_system_message = "\n".join([
- system,
- "\n",
- f"### Memory [last modified: {memory_edit_timestamp}",
- f"{len(recall_memory) if recall_memory else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
- f"{len(archival_memory) if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)",
- "\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
- "",
- memory.persona,
- "",
- "",
- memory.human,
- "",
- ])
+def construct_system_with_memory(system, memory, memory_edit_timestamp, archival_memory=None, recall_memory=None):
+ full_system_message = "\n".join(
+ [
+ system,
+ "\n",
+ f"### Memory [last modified: {memory_edit_timestamp}",
+ f"{len(recall_memory) if recall_memory else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
+ f"{len(archival_memory) if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)",
+ "\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
+ "",
+ memory.persona,
+ "",
+ "",
+ memory.human,
+ "",
+ ]
+ )
return full_system_message
def initialize_message_sequence(
- model,
- system,
- memory,
- archival_memory=None,
- recall_memory=None,
- memory_edit_timestamp=None,
- include_initial_boot_message=True,
- ):
+ model,
+ system,
+ memory,
+ archival_memory=None,
+ recall_memory=None,
+ memory_edit_timestamp=None,
+ include_initial_boot_message=True,
+):
if memory_edit_timestamp is None:
memory_edit_timestamp = get_local_time()
- full_system_message = construct_system_with_memory(system, memory, memory_edit_timestamp, archival_memory=archival_memory, recall_memory=recall_memory)
+ full_system_message = construct_system_with_memory(
+ system, memory, memory_edit_timestamp, archival_memory=archival_memory, recall_memory=recall_memory
+ )
first_user_message = get_login_event() # event letting MemGPT know the user just logged in
if include_initial_boot_message:
- if 'gpt-3.5' in model:
- initial_boot_messages = get_initial_boot_messages('startup_with_send_message_gpt35')
+ if "gpt-3.5" in model:
+ initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35")
else:
- initial_boot_messages = get_initial_boot_messages('startup_with_send_message')
- messages = [
- {"role": "system", "content": full_system_message},
- ] + initial_boot_messages + [
- {"role": "user", "content": first_user_message},
- ]
+ initial_boot_messages = get_initial_boot_messages("startup_with_send_message")
+ messages = (
+ [
+ {"role": "system", "content": full_system_message},
+ ]
+ + initial_boot_messages
+ + [
+ {"role": "user", "content": first_user_message},
+ ]
+ )
else:
messages = [
@@ -85,11 +95,11 @@ def initialize_message_sequence(
async def get_ai_reply_async(
- model,
- message_sequence,
- functions,
- function_call="auto",
- ):
+ model,
+ message_sequence,
+ functions,
+ function_call="auto",
+):
"""Base call to GPT API w/ functions"""
try:
@@ -101,11 +111,11 @@ async def get_ai_reply_async(
)
# special case for 'length'
- if response.choices[0].finish_reason == 'length':
- raise Exception('Finish reason was length (maximum context length)')
+ if response.choices[0].finish_reason == "length":
+ raise Exception("Finish reason was length (maximum context length)")
# catches for soft errors
- if response.choices[0].finish_reason not in ['stop', 'function_call']:
+ if response.choices[0].finish_reason not in ["stop", "function_call"]:
raise Exception(f"API call finish with bad finish reason: {response}")
# unpack with response.choices[0].message.content
@@ -118,7 +128,19 @@ async def get_ai_reply_async(
class AgentAsync(object):
"""Core logic for a MemGPT agent"""
- def __init__(self, model, system, functions, interface, persistence_manager, persona_notes, human_notes, messages_total=None, persistence_manager_init=True, first_message_verify_mono=True):
+ def __init__(
+ self,
+ model,
+ system,
+ functions,
+ interface,
+ persistence_manager,
+ persona_notes,
+ human_notes,
+ messages_total=None,
+ persistence_manager_init=True,
+ first_message_verify_mono=True,
+ ):
# gpt-4, gpt-3.5-turbo
self.model = model
# Store the system instructions (used to rebuild memory)
@@ -173,7 +195,7 @@ class AgentAsync(object):
@messages.setter
def messages(self, value):
- raise Exception('Modifying message list directly not allowed')
+ raise Exception("Modifying message list directly not allowed")
def trim_messages(self, num):
"""Trim messages from the front, not including the system message"""
@@ -196,16 +218,16 @@ class AgentAsync(object):
# strip extra metadata if it exists
for msg in added_messages:
- msg.pop('api_response', None)
- msg.pop('api_args', None)
+ msg.pop("api_response", None)
+ msg.pop("api_args", None)
new_messages = self.messages + added_messages # append
self._messages = new_messages
self.messages_total += len(added_messages)
def swap_system_message(self, new_system_message):
- assert new_system_message['role'] == 'system', new_system_message
- assert self.messages[0]['role'] == 'system', self.messages
+ assert new_system_message["role"] == "system", new_system_message
+ assert self.messages[0]["role"] == "system", self.messages
self.persistence_manager.swap_system_message(new_system_message)
@@ -223,7 +245,7 @@ class AgentAsync(object):
recall_memory=self.persistence_manager.recall_memory,
)[0]
- diff = united_diff(curr_system_message['content'], new_system_message['content'])
+ diff = united_diff(curr_system_message["content"], new_system_message["content"])
printd(f"Rebuilding system with new memory...\nDiff:\n{diff}")
# Store the memory change (if stateful)
@@ -235,32 +257,32 @@ class AgentAsync(object):
### Local state management
def to_dict(self):
return {
- 'model': self.model,
- 'system': self.system,
- 'functions': self.functions,
- 'messages': self.messages,
- 'messages_total': self.messages_total,
- 'memory': self.memory.to_dict(),
+ "model": self.model,
+ "system": self.system,
+ "functions": self.functions,
+ "messages": self.messages,
+ "messages_total": self.messages_total,
+ "memory": self.memory.to_dict(),
}
def save_to_json_file(self, filename):
- with open(filename, 'w') as file:
+ with open(filename, "w") as file:
json.dump(self.to_dict(), file)
@classmethod
def load(cls, state, interface, persistence_manager):
- model = state['model']
- system = state['system']
- functions = state['functions']
- messages = state['messages']
+ model = state["model"]
+ system = state["system"]
+ functions = state["functions"]
+ messages = state["messages"]
try:
- messages_total = state['messages_total']
+ messages_total = state["messages_total"]
except KeyError:
messages_total = len(messages) - 1
# memory requires a nested load
- memory_dict = state['memory']
- persona_notes = memory_dict['persona']
- human_notes = memory_dict['human']
+ memory_dict = state["memory"]
+ persona_notes = memory_dict["persona"]
+ human_notes = memory_dict["human"]
# Two-part load
new_agent = cls(
@@ -278,18 +300,18 @@ class AgentAsync(object):
return new_agent
def load_inplace(self, state):
- self.model = state['model']
- self.system = state['system']
- self.functions = state['functions']
+ self.model = state["model"]
+ self.system = state["system"]
+ self.functions = state["functions"]
# memory requires a nested load
- memory_dict = state['memory']
- persona_notes = memory_dict['persona']
- human_notes = memory_dict['human']
+ memory_dict = state["memory"]
+ persona_notes = memory_dict["persona"]
+ human_notes = memory_dict["human"]
self.memory = initialize_memory(persona_notes, human_notes)
# messages also
- self._messages = state['messages']
+ self._messages = state["messages"]
try:
- self.messages_total = state['messages_total']
+ self.messages_total = state["messages_total"]
except KeyError:
self.messages_total = len(self.messages) - 1 # -system
@@ -300,14 +322,14 @@ class AgentAsync(object):
@classmethod
def load_from_json_file(cls, json_file, interface, persistence_manager):
- with open(json_file, 'r') as file:
+ with open(json_file, "r") as file:
state = json.load(file)
return cls.load(state, interface, persistence_manager)
def load_from_json_file_inplace(self, json_file):
# Load in-place
# No interface arg needed, we can use the current one
- with open(json_file, 'r') as file:
+ with open(json_file, "r") as file:
state = json.load(file)
self.load_inplace(state)
@@ -317,7 +339,6 @@ class AgentAsync(object):
# Step 2: check if LLM wanted to call a function
if response_message.get("function_call"):
-
# The content if then internal monologue, not chat
await self.interface.internal_monologue(response_message.content)
messages.append(response_message) # extend conversation with assistant's reply
@@ -348,7 +369,7 @@ class AgentAsync(object):
try:
function_to_call = available_functions[function_name]
except KeyError as e:
- error_msg = f'No function named {function_name}'
+ error_msg = f"No function named {function_name}"
function_response = package_function_response(False, error_msg)
messages.append(
{
@@ -357,7 +378,7 @@ class AgentAsync(object):
"content": function_response,
}
) # extend conversation with function response
- await self.interface.function_message(f'Error: {error_msg}')
+ await self.interface.function_message(f"Error: {error_msg}")
return messages, None, True # force a heartbeat to allow agent to handle error
# Failure case 2: function name is OK, but function args are bad JSON
@@ -374,18 +395,20 @@ class AgentAsync(object):
"content": function_response,
}
) # extend conversation with function response
- await self.interface.function_message(f'Error: {error_msg}')
+ await self.interface.function_message(f"Error: {error_msg}")
return messages, None, True # force a heartbeat to allow agent to handle error
# (Still parsing function args)
# Handle requests for immediate heartbeat
- heartbeat_request = function_args.pop('request_heartbeat', None)
+ heartbeat_request = function_args.pop("request_heartbeat", None)
if not (isinstance(heartbeat_request, bool) or heartbeat_request is None):
- printd(f"Warning: 'request_heartbeat' arg parsed was not a bool or None, type={type(heartbeat_request)}, value={heartbeat_request}")
+ printd(
+ f"Warning: 'request_heartbeat' arg parsed was not a bool or None, type={type(heartbeat_request)}, value={heartbeat_request}"
+ )
heartbeat_request = None
# Failure case 3: function failed during execution
- await self.interface.function_message(f'Running {function_name}({function_args})')
+ await self.interface.function_message(f"Running {function_name}({function_args})")
try:
function_response_string = await function_to_call(**function_args)
function_response = package_function_response(True, function_response_string)
@@ -401,12 +424,12 @@ class AgentAsync(object):
"content": function_response,
}
) # extend conversation with function response
- await self.interface.function_message(f'Error: {error_msg}')
+ await self.interface.function_message(f"Error: {error_msg}")
return messages, None, True # force a heartbeat to allow agent to handle error
# If no failures happened along the way: ...
# Step 4: send the info on the function call and function response to GPT
- await self.interface.function_message(f'Success: {function_response_string}')
+ await self.interface.function_message(f"Success: {function_response_string}")
messages.append(
{
"role": "function",
@@ -434,25 +457,29 @@ class AgentAsync(object):
return False
function_name = response_message["function_call"]["name"]
- if require_send_message and function_name != 'send_message':
+ if require_send_message and function_name != "send_message":
printd(f"First message function call wasn't send_message: {response_message}")
return False
- if require_monologue and (not response_message.get("content") or response_message["content"] is None or response_message["content"] == ""):
+ if require_monologue and (
+ not response_message.get("content") or response_message["content"] is None or response_message["content"] == ""
+ ):
printd(f"First message missing internal monologue: {response_message}")
return False
if response_message.get("content"):
### Extras
monologue = response_message.get("content")
+
def contains_special_characters(s):
special_characters = '(){}[]"'
return any(char in s for char in special_characters)
+
if contains_special_characters(monologue):
printd(f"First message internal monologue contained special characters: {response_message}")
return False
# if 'functions' in monologue or 'send_message' in monologue or 'inner thought' in monologue.lower():
- if 'functions' in monologue or 'send_message' in monologue:
+ if "functions" in monologue or "send_message" in monologue:
# Sometimes the syntax won't be correct and internal syntax will leak into message.context
printd(f"First message internal monologue contained reserved words: {response_message}")
return False
@@ -466,12 +493,12 @@ class AgentAsync(object):
# Step 0: add user message
if user_message is not None:
await self.interface.user_message(user_message)
- packed_user_message = {'role': 'user', 'content': user_message}
+ packed_user_message = {"role": "user", "content": user_message}
input_message_sequence = self.messages + [packed_user_message]
else:
input_message_sequence = self.messages
- if len(input_message_sequence) > 1 and input_message_sequence[-1]['role'] != 'user':
+ if len(input_message_sequence) > 1 and input_message_sequence[-1]["role"] != "user":
printd(f"WARNING: attempting to run ChatCompletion without user as the last message in the queue")
# Step 1: send the conversation and available functions to GPT
@@ -479,14 +506,13 @@ class AgentAsync(object):
printd(f"This is the first message. Running extra verifier on AI response.")
counter = 0
while True:
-
response = await get_ai_reply_async(model=self.model, message_sequence=input_message_sequence, functions=self.functions)
if self.verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono):
break
counter += 1
if counter > first_message_retry_limit:
- raise Exception(f'Hit first message retry limit ({first_message_retry_limit})')
+ raise Exception(f"Hit first message retry limit ({first_message_retry_limit})")
else:
response = await get_ai_reply_async(model=self.model, message_sequence=input_message_sequence, functions=self.functions)
@@ -500,13 +526,13 @@ class AgentAsync(object):
# Add the extra metadata to the assistant response
# (e.g. enough metadata to enable recreating the API call)
- assert 'api_response' not in all_response_messages[0]
- all_response_messages[0]['api_response'] = response_message_copy
- assert 'api_args' not in all_response_messages[0]
- all_response_messages[0]['api_args'] = {
- 'model': self.model,
- 'messages': input_message_sequence,
- 'functions': self.functions,
+ assert "api_response" not in all_response_messages[0]
+ all_response_messages[0]["api_response"] = response_message_copy
+ assert "api_args" not in all_response_messages[0]
+ all_response_messages[0]["api_args"] = {
+ "model": self.model,
+ "messages": input_message_sequence,
+ "functions": self.functions,
}
# Step 4: extend the message history
@@ -516,7 +542,7 @@ class AgentAsync(object):
all_new_messages = all_response_messages
# Check the memory pressure and potentially issue a memory pressure warning
- current_total_tokens = response['usage']['total_tokens']
+ current_total_tokens = response["usage"]["total_tokens"]
active_memory_warning = False
if current_total_tokens > MESSAGE_SUMMARY_WARNING_TOKENS:
printd(f"WARNING: last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_TOKENS}")
@@ -534,7 +560,7 @@ class AgentAsync(object):
printd(f"step() failed\nuser_message = {user_message}\nerror = {e}")
# If we got a context alert, try trimming the messages length, then try again
- if 'maximum context length' in str(e):
+ if "maximum context length" in str(e):
# A separate API call to run a summarizer
await self.summarize_messages_inplace()
@@ -546,21 +572,21 @@ class AgentAsync(object):
async def summarize_messages_inplace(self, cutoff=None):
if cutoff is None:
- tokens_so_far = 0 # Smart cutoff -- just below the max.
+ tokens_so_far = 0 # Smart cutoff -- just below the max.
cutoff = len(self.messages) - 1
for m in reversed(self.messages):
tokens_so_far += count_tokens(str(m), self.model)
- if tokens_so_far >= MESSAGE_SUMMARY_WARNING_TOKENS*0.2:
+ if tokens_so_far >= MESSAGE_SUMMARY_WARNING_TOKENS * 0.2:
break
cutoff -= 1
- cutoff = min(len(self.messages) - 3, cutoff) # Always keep the last two messages too
+ cutoff = min(len(self.messages) - 3, cutoff) # Always keep the last two messages too
# Try to make an assistant message come after the cutoff
try:
printd(f"Selected cutoff {cutoff} was a 'user', shifting one...")
- if self.messages[cutoff]['role'] == 'user':
+ if self.messages[cutoff]["role"] == "user":
new_cutoff = cutoff + 1
- if self.messages[new_cutoff]['role'] == 'user':
+ if self.messages[new_cutoff]["role"] == "user":
printd(f"Shifted cutoff {new_cutoff} is still a 'user', ignoring...")
cutoff = new_cutoff
except IndexError:
@@ -600,11 +626,11 @@ class AgentAsync(object):
while limit is None or step_count < limit:
if function_failed:
- user_message = get_heartbeat('Function call failed')
+ user_message = get_heartbeat("Function call failed")
new_messages, heartbeat_request, function_failed = await self.step(user_message)
step_count += 1
elif heartbeat_request:
- user_message = get_heartbeat('AI requested')
+ user_message = get_heartbeat("AI requested")
new_messages, heartbeat_request, function_failed = await self.step(user_message)
step_count += 1
else:
@@ -638,7 +664,7 @@ class AgentAsync(object):
return None
async def recall_memory_search(self, query, count=5, page=0):
- results, total = await self.persistence_manager.recall_memory.text_search(query, count=count, start=page*count)
+ results, total = await self.persistence_manager.recall_memory.text_search(query, count=count, start=page * count)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."
@@ -649,7 +675,7 @@ class AgentAsync(object):
return results_str
async def recall_memory_search_date(self, start_date, end_date, count=5, page=0):
- results, total = await self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page*count)
+ results, total = await self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page * count)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."
@@ -664,7 +690,7 @@ class AgentAsync(object):
return None
async def archival_memory_search(self, query, count=5, page=0):
- results, total = await self.persistence_manager.archival_memory.search(query, count=count, start=page*count)
+ results, total = await self.persistence_manager.archival_memory.search(query, count=count, start=page * count)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."
@@ -683,7 +709,7 @@ class AgentAsync(object):
# And record how long the pause should go for
self.pause_heartbeats_minutes = int(minutes)
- return f'Pausing timed heartbeats for {minutes} min'
+ return f"Pausing timed heartbeats for {minutes} min"
def heartbeat_is_paused(self):
"""Check if there's a requested pause on timed heartbeats"""
@@ -700,8 +726,8 @@ class AgentAsync(object):
"""Base call to GPT API w/ functions"""
message_sequence = [
- {'role': 'system', 'content': MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE},
- {'role': 'user', 'content': str(message)},
+ {"role": "system", "content": MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE},
+ {"role": "user", "content": str(message)},
]
response = await acreate(
model=MESSAGE_CHATGPT_FUNCTION_MODEL,
diff --git a/memgpt/agent_base.py b/memgpt/agent_base.py
index 06442c92..7f132e49 100644
--- a/memgpt/agent_base.py
+++ b/memgpt/agent_base.py
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
class AgentAsyncBase(ABC):
-
@abstractmethod
async def step(self, user_message):
- pass
\ No newline at end of file
+ pass
diff --git a/memgpt/autogen/interface.py b/memgpt/autogen/interface.py
index 4f01fd7a..f3776790 100644
--- a/memgpt/autogen/interface.py
+++ b/memgpt/autogen/interface.py
@@ -68,41 +68,25 @@ class AutoGenInterface(object):
print(f"inner thoughts :: {msg}")
if not self.show_inner_thoughts:
return
- message = (
- f"\x1B[3m{Fore.LIGHTBLACK_EX}💭 {msg}{Style.RESET_ALL}"
- if self.fancy
- else f"[inner thoughts] {msg}"
- )
+ message = f"\x1B[3m{Fore.LIGHTBLACK_EX}💭 {msg}{Style.RESET_ALL}" if self.fancy else f"[inner thoughts] {msg}"
self.message_list.append(message)
async def assistant_message(self, msg):
if self.debug:
print(f"assistant :: {msg}")
- message = (
- f"{Fore.YELLOW}{Style.BRIGHT}🤖 {Fore.YELLOW}{msg}{Style.RESET_ALL}"
- if self.fancy
- else msg
- )
+ message = f"{Fore.YELLOW}{Style.BRIGHT}🤖 {Fore.YELLOW}{msg}{Style.RESET_ALL}" if self.fancy else msg
self.message_list.append(message)
async def memory_message(self, msg):
if self.debug:
print(f"memory :: {msg}")
- message = (
- f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{msg}{Style.RESET_ALL}"
- if self.fancy
- else f"[memory] {msg}"
- )
+ message = f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{msg}{Style.RESET_ALL}" if self.fancy else f"[memory] {msg}"
self.message_list.append(message)
async def system_message(self, msg):
if self.debug:
print(f"system :: {msg}")
- message = (
- f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}"
- if self.fancy
- else f"[system] {msg}"
- )
+ message = f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}" if self.fancy else f"[system] {msg}"
self.message_list.append(message)
async def user_message(self, msg, raw=False):
@@ -113,11 +97,7 @@ class AutoGenInterface(object):
if isinstance(msg, str):
if raw:
- message = (
- f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}"
- if self.fancy
- else f"[user] {msg}"
- )
+ message = f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}" if self.fancy else f"[user] {msg}"
self.message_list.append(message)
return
else:
@@ -125,42 +105,24 @@ class AutoGenInterface(object):
msg_json = json.loads(msg)
except:
print(f"Warning: failed to parse user message into json")
- message = (
- f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}"
- if self.fancy
- else f"[user] {msg}"
- )
+ message = f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}" if self.fancy else f"[user] {msg}"
self.message_list.append(message)
return
if msg_json["type"] == "user_message":
msg_json.pop("type")
- message = (
- f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}"
- if self.fancy
- else f"[user] {msg}"
- )
+ message = f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}" if self.fancy else f"[user] {msg}"
elif msg_json["type"] == "heartbeat":
if True or DEBUG:
msg_json.pop("type")
message = (
- f"{Fore.GREEN}{Style.BRIGHT}💓 {Fore.GREEN}{msg_json}{Style.RESET_ALL}"
- if self.fancy
- else f"[system heartbeat] {msg}"
+ f"{Fore.GREEN}{Style.BRIGHT}💓 {Fore.GREEN}{msg_json}{Style.RESET_ALL}" if self.fancy else f"[system heartbeat] {msg}"
)
elif msg_json["type"] == "system_message":
msg_json.pop("type")
- message = (
- f"{Fore.GREEN}{Style.BRIGHT}🖥️ {Fore.GREEN}{msg_json}{Style.RESET_ALL}"
- if self.fancy
- else f"[system] {msg}"
- )
+ message = f"{Fore.GREEN}{Style.BRIGHT}🖥️ {Fore.GREEN}{msg_json}{Style.RESET_ALL}" if self.fancy else f"[system] {msg}"
else:
- message = (
- f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}"
- if self.fancy
- else f"[user] {msg}"
- )
+ message = f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}" if self.fancy else f"[user] {msg}"
self.message_list.append(message)
@@ -171,31 +133,19 @@ class AutoGenInterface(object):
return
if isinstance(msg, dict):
- message = (
- f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
- )
+ message = f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
self.message_list.append(message)
return
if msg.startswith("Success: "):
- message = (
- f"{Fore.RED}{Style.BRIGHT}⚡🟢 [function] {Fore.RED}{msg}{Style.RESET_ALL}"
- if self.fancy
- else f"[function - OK] {msg}"
- )
+ message = f"{Fore.RED}{Style.BRIGHT}⚡🟢 [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function - OK] {msg}"
elif msg.startswith("Error: "):
message = (
- f"{Fore.RED}{Style.BRIGHT}⚡🔴 [function] {Fore.RED}{msg}{Style.RESET_ALL}"
- if self.fancy
- else f"[function - error] {msg}"
+ f"{Fore.RED}{Style.BRIGHT}⚡🔴 [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function - error] {msg}"
)
elif msg.startswith("Running "):
if DEBUG:
- message = (
- f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
- if self.fancy
- else f"[function] {msg}"
- )
+ message = f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}"
else:
if "memory" in msg:
match = re.search(r"Running (\w+)\((.*)\)", msg)
@@ -227,35 +177,25 @@ class AutoGenInterface(object):
else:
print(f"Warning: did not recognize function message")
message = (
- f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
- if self.fancy
- else f"[function] {msg}"
+ f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}"
)
elif "send_message" in msg:
# ignore in debug mode
message = None
else:
message = (
- f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
- if self.fancy
- else f"[function] {msg}"
+ f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}"
)
else:
try:
msg_dict = json.loads(msg)
if "status" in msg_dict and msg_dict["status"] == "OK":
message = (
- f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}"
- if self.fancy
- else f"[function] {msg}"
+ f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}"
)
except Exception:
print(f"Warning: did not recognize function message {type(msg)} {msg}")
- message = (
- f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
- if self.fancy
- else f"[function] {msg}"
- )
+ message = f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}"
if message:
self.message_list.append(message)
diff --git a/memgpt/autogen/memgpt_agent.py b/memgpt/autogen/memgpt_agent.py
index 91adf5d8..6e9db5a2 100644
--- a/memgpt/autogen/memgpt_agent.py
+++ b/memgpt/autogen/memgpt_agent.py
@@ -55,11 +55,7 @@ def create_autogen_memgpt_agent(
```
"""
interface = AutoGenInterface(**interface_kwargs) if interface is None else interface
- persistence_manager = (
- InMemoryStateManager(**persistence_manager_kwargs)
- if persistence_manager is None
- else persistence_manager
- )
+ persistence_manager = InMemoryStateManager(**persistence_manager_kwargs) if persistence_manager is None else persistence_manager
memgpt_agent = presets.use_preset(
preset,
@@ -89,9 +85,7 @@ class MemGPTAgent(ConversableAgent):
self.agent = agent
self.skip_verify = skip_verify
self.concat_other_agent_messages = concat_other_agent_messages
- self.register_reply(
- [Agent, None], MemGPTAgent._a_generate_reply_for_user_message
- )
+ self.register_reply([Agent, None], MemGPTAgent._a_generate_reply_for_user_message)
self.register_reply([Agent, None], MemGPTAgent._generate_reply_for_user_message)
self.messages_processed_up_to_idx = 0
@@ -119,11 +113,7 @@ class MemGPTAgent(ConversableAgent):
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
- return asyncio.run(
- self._a_generate_reply_for_user_message(
- messages=messages, sender=sender, config=config
- )
- )
+ return asyncio.run(self._a_generate_reply_for_user_message(messages=messages, sender=sender, config=config))
async def _a_generate_reply_for_user_message(
self,
@@ -137,9 +127,7 @@ class MemGPTAgent(ConversableAgent):
if len(new_messages) > 1:
if self.concat_other_agent_messages:
# Combine all the other messages into one message
- user_message = "\n".join(
- [self.format_other_agent_message(m) for m in new_messages]
- )
+ user_message = "\n".join([self.format_other_agent_message(m) for m in new_messages])
else:
# Extend the MemGPT message list with multiple 'user' messages, then push the last one with agent.step()
self.agent.messages.extend(new_messages[:-1])
@@ -157,16 +145,12 @@ class MemGPTAgent(ConversableAgent):
heartbeat_request,
function_failed,
token_warning,
- ) = await self.agent.step(
- user_message, first_message=False, skip_verify=self.skip_verify
- )
+ ) = await self.agent.step(user_message, first_message=False, skip_verify=self.skip_verify)
# Skip user inputs if there's a memory warning, function execution failed, or the agent asked for control
if token_warning:
user_message = system.get_token_limit_warning()
elif function_failed:
- user_message = system.get_heartbeat(
- constants.FUNC_FAILED_HEARTBEAT_MESSAGE
- )
+ user_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE)
elif heartbeat_request:
user_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE)
else:
diff --git a/memgpt/config.py b/memgpt/config.py
index d22ee281..d9a8aa93 100644
--- a/memgpt/config.py
+++ b/memgpt/config.py
@@ -24,6 +24,7 @@ model_choices = [
),
]
+
class Config:
personas_dir = os.path.join("memgpt", "personas", "examples")
custom_personas_dir = os.path.join(MEMGPT_DIR, "personas")
@@ -78,12 +79,8 @@ class Config:
cfg = Config.get_most_recent_config()
use_cfg = False
if cfg:
- print(
- f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Found saved config file.{Style.RESET_ALL}"
- )
- use_cfg = await questionary.confirm(
- f"Use most recent config file '{cfg}'?"
- ).ask_async()
+ print(f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Found saved config file.{Style.RESET_ALL}")
+ use_cfg = await questionary.confirm(f"Use most recent config file '{cfg}'?").ask_async()
if use_cfg:
self.config_file = cfg
@@ -104,9 +101,7 @@ class Config:
return self
# print("No settings file found, configuring MemGPT...")
- print(
- f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ No settings file found, configuring MemGPT...{Style.RESET_ALL}"
- )
+ print(f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ No settings file found, configuring MemGPT...{Style.RESET_ALL}")
self.model = await questionary.select(
"Which model would you like to use?",
@@ -126,9 +121,7 @@ class Config:
).ask_async()
self.archival_storage_index = None
- self.preload_archival = await questionary.confirm(
- "Would you like to preload anything into MemGPT's archival memory?"
- ).ask_async()
+ self.preload_archival = await questionary.confirm("Would you like to preload anything into MemGPT's archival memory?").ask_async()
if self.preload_archival:
self.load_type = await questionary.select(
"What would you like to load?",
@@ -139,19 +132,13 @@ class Config:
],
).ask_async()
if self.load_type == "folder" or self.load_type == "sql":
- archival_storage_path = await questionary.path(
- "Please enter the folder or file (tab for autocomplete):"
- ).ask_async()
+ archival_storage_path = await questionary.path("Please enter the folder or file (tab for autocomplete):").ask_async()
if os.path.isdir(archival_storage_path):
- self.archival_storage_files = os.path.join(
- archival_storage_path, "*"
- )
+ self.archival_storage_files = os.path.join(archival_storage_path, "*")
else:
self.archival_storage_files = archival_storage_path
else:
- self.archival_storage_files = await questionary.path(
- "Please enter the glob pattern (tab for autocomplete):"
- ).ask_async()
+ self.archival_storage_files = await questionary.path("Please enter the glob pattern (tab for autocomplete):").ask_async()
self.compute_embeddings = await questionary.confirm(
"Would you like to compute embeddings over these files to enable embeddings search?"
).ask_async()
@@ -167,19 +154,11 @@ class Config:
"⛔️ Embeddings on a non-OpenAI endpoint are not yet supported, falling back to substring matching search."
)
else:
- self.archival_storage_index = (
- await utils.prepare_archival_index_from_files_compute_embeddings(
- self.archival_storage_files
- )
- )
+ self.archival_storage_index = await utils.prepare_archival_index_from_files_compute_embeddings(self.archival_storage_files)
if self.compute_embeddings and self.archival_storage_index:
- self.index, self.archival_database = utils.prepare_archival_index(
- self.archival_storage_index
- )
+ self.index, self.archival_database = utils.prepare_archival_index(self.archival_storage_index)
else:
- self.archival_database = utils.prepare_archival_index_from_files(
- self.archival_storage_files
- )
+ self.archival_database = utils.prepare_archival_index_from_files(self.archival_storage_files)
def to_dict(self):
return {
@@ -216,15 +195,11 @@ class Config:
configs_dir = Config.configs_dir
os.makedirs(configs_dir, exist_ok=True)
if self.config_file is None:
- filename = os.path.join(
- configs_dir, utils.get_local_time().replace(" ", "_").replace(":", "_")
- )
+ filename = os.path.join(configs_dir, utils.get_local_time().replace(" ", "_").replace(":", "_"))
self.config_file = f"{filename}.json"
with open(self.config_file, "wt") as f:
json.dump(self.to_dict(), f, indent=4)
- print(
- f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Saved config file to {self.config_file}.{Style.RESET_ALL}"
- )
+ print(f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Saved config file to {self.config_file}.{Style.RESET_ALL}")
@staticmethod
def is_valid_config_file(file: str):
@@ -233,9 +208,7 @@ class Config:
cfg.load_config(file)
except Exception:
return False
- return (
- cfg.memgpt_persona is not None and cfg.human_persona is not None
- ) # TODO: more validation for configs
+ return cfg.memgpt_persona is not None and cfg.human_persona is not None # TODO: more validation for configs
@staticmethod
def get_memgpt_personas():
@@ -330,8 +303,7 @@ class Config:
files = [
os.path.join(configs_dir, f)
for f in os.listdir(configs_dir)
- if os.path.isfile(os.path.join(configs_dir, f))
- and Config.is_valid_config_file(os.path.join(configs_dir, f))
+ if os.path.isfile(os.path.join(configs_dir, f)) and Config.is_valid_config_file(os.path.join(configs_dir, f))
]
# Return the file with the most recent modification time
if len(files) == 0:
diff --git a/memgpt/connectors/connector.py b/memgpt/connectors/connector.py
index 549c2d7d..4b4c399a 100644
--- a/memgpt/connectors/connector.py
+++ b/memgpt/connectors/connector.py
@@ -1,7 +1,7 @@
-"""
+"""
This file contains functions for loading data into MemGPT's archival storage.
-Data can be loaded with the following command, once a load function is defined:
+Data can be loaded with the following command, once a load function is defined:
```
memgpt load --name [ADDITIONAL ARGS]
```
@@ -18,14 +18,13 @@ from memgpt.utils import estimate_openai_cost, get_index, save_index
app = typer.Typer()
-
@app.command("directory")
def load_directory(
name: str = typer.Option(help="Name of dataset to load."),
input_dir: str = typer.Option(None, help="Path to directory containing dataset."),
input_files: List[str] = typer.Option(None, help="List of paths to files containing dataset."),
recursive: bool = typer.Option(False, help="Recursively search for files in directory."),
-):
+):
from llama_index import SimpleDirectoryReader
if recursive:
@@ -35,34 +34,35 @@ def load_directory(
recursive=True,
)
else:
- reader = SimpleDirectoryReader(
- input_files=input_files
- )
+ reader = SimpleDirectoryReader(input_files=input_files)
# load docs
print("Loading data...")
docs = reader.load_data()
- # embed docs
+ # embed docs
print("Indexing documents...")
index = get_index(name, docs)
# save connector information into .memgpt metadata file
save_index(index, name)
+
@app.command("webpage")
def load_webpage(
name: str = typer.Option(help="Name of dataset to load."),
urls: List[str] = typer.Option(None, help="List of urls to load."),
-):
+):
from llama_index import SimpleWebPageReader
+
docs = SimpleWebPageReader(html_to_text=True).load_data(urls)
- # embed docs
+ # embed docs
print("Indexing documents...")
index = get_index(docs)
# save connector information into .memgpt metadata file
save_index(index, name)
+
@app.command("database")
def load_database(
name: str = typer.Option(help="Name of dataset to load."),
@@ -76,12 +76,14 @@ def load_database(
dbname: str = typer.Option(None, help="Database name."),
):
from llama_index.readers.database import DatabaseReader
+
print(dump_path, scheme)
- if dump_path is not None:
+ if dump_path is not None:
# read from database dump file
from sqlalchemy import create_engine, MetaData
- engine = create_engine(f'sqlite:///{dump_path}')
+
+ engine = create_engine(f"sqlite:///{dump_path}")
db = DatabaseReader(engine=engine)
else:
@@ -104,8 +106,6 @@ def load_database(
# load data
docs = db.load_data(query=query)
-
+
index = get_index(name, docs)
save_index(index, name)
-
-
diff --git a/memgpt/constants.py b/memgpt/constants.py
index bd83f7fc..aae904c4 100644
--- a/memgpt/constants.py
+++ b/memgpt/constants.py
@@ -7,9 +7,7 @@ DEFAULT_MEMGPT_MODEL = "gpt-4"
FIRST_MESSAGE_ATTEMPTS = 10
INITIAL_BOOT_MESSAGE = "Boot sequence complete. Persona activated."
-INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT = (
- "Bootup sequence complete. Persona activated. Testing messaging functionality."
-)
+INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT = "Bootup sequence complete. Persona activated. Testing messaging functionality."
STARTUP_QUOTES = [
"I think, therefore I am.",
"All those moments will be lost in time, like tears in rain.",
@@ -28,9 +26,7 @@ CORE_MEMORY_HUMAN_CHAR_LIMIT = 2000
MAX_PAUSE_HEARTBEATS = 360 # in min
MESSAGE_CHATGPT_FUNCTION_MODEL = "gpt-3.5-turbo"
-MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE = (
- "You are a helpful assistant. Keep your responses short and concise."
-)
+MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE = "You are a helpful assistant. Keep your responses short and concise."
#### Functions related
diff --git a/memgpt/interface.py b/memgpt/interface.py
index 0e66af08..b9b95be6 100644
--- a/memgpt/interface.py
+++ b/memgpt/interface.py
@@ -29,15 +29,11 @@ async def assistant_message(msg):
async def memory_message(msg):
- print(
- f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{msg}{Style.RESET_ALL}"
- )
+ print(f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{msg}{Style.RESET_ALL}")
async def system_message(msg):
- printd(
- f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}"
- )
+ printd(f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}")
async def user_message(msg, raw=False):
@@ -50,9 +46,7 @@ async def user_message(msg, raw=False):
msg_json = json.loads(msg)
except:
printd(f"Warning: failed to parse user message into json")
- printd(
- f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}"
- )
+ printd(f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}")
return
if msg_json["type"] == "user_message":
@@ -61,9 +55,7 @@ async def user_message(msg, raw=False):
elif msg_json["type"] == "heartbeat":
if DEBUG:
msg_json.pop("type")
- printd(
- f"{Fore.GREEN}{Style.BRIGHT}💓 {Fore.GREEN}{msg_json}{Style.RESET_ALL}"
- )
+ printd(f"{Fore.GREEN}{Style.BRIGHT}💓 {Fore.GREEN}{msg_json}{Style.RESET_ALL}")
elif msg_json["type"] == "system_message":
msg_json.pop("type")
printd(f"{Fore.GREEN}{Style.BRIGHT}🖥️ {Fore.GREEN}{msg_json}{Style.RESET_ALL}")
@@ -77,33 +69,23 @@ async def function_message(msg):
return
if msg.startswith("Success: "):
- printd(
- f"{Fore.RED}{Style.BRIGHT}⚡🟢 [function] {Fore.RED}{msg}{Style.RESET_ALL}"
- )
+ printd(f"{Fore.RED}{Style.BRIGHT}⚡🟢 [function] {Fore.RED}{msg}{Style.RESET_ALL}")
elif msg.startswith("Error: "):
- printd(
- f"{Fore.RED}{Style.BRIGHT}⚡🔴 [function] {Fore.RED}{msg}{Style.RESET_ALL}"
- )
+ printd(f"{Fore.RED}{Style.BRIGHT}⚡🔴 [function] {Fore.RED}{msg}{Style.RESET_ALL}")
elif msg.startswith("Running "):
if DEBUG:
- printd(
- f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
- )
+ printd(f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}")
else:
if "memory" in msg:
match = re.search(r"Running (\w+)\((.*)\)", msg)
if match:
function_name = match.group(1)
function_args = match.group(2)
- print(
- f"{Fore.RED}{Style.BRIGHT}⚡🧠 [function] {Fore.RED}updating memory with {function_name}{Style.RESET_ALL}:"
- )
+ print(f"{Fore.RED}{Style.BRIGHT}⚡🧠 [function] {Fore.RED}updating memory with {function_name}{Style.RESET_ALL}:")
try:
msg_dict = eval(function_args)
if function_name == "archival_memory_search":
- print(
- f'{Fore.RED}\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}'
- )
+ print(f'{Fore.RED}\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}')
else:
print(
f'{Fore.RED}{Style.BRIGHT}\t{Fore.RED} {msg_dict["old_content"]}\n\t{Fore.GREEN}→ {msg_dict["new_content"]}'
@@ -114,28 +96,20 @@ async def function_message(msg):
pass
else:
printd(f"Warning: did not recognize function message")
- printd(
- f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
- )
+ printd(f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}")
elif "send_message" in msg:
# ignore in debug mode
pass
else:
- printd(
- f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
- )
+ printd(f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}")
else:
try:
msg_dict = json.loads(msg)
if "status" in msg_dict and msg_dict["status"] == "OK":
- printd(
- f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}"
- )
+ printd(f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}")
except Exception:
printd(f"Warning: did not recognize function message {type(msg)} {msg}")
- printd(
- f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
- )
+ printd(f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}")
async def print_messages(message_sequence):
diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py
index 60f8ee6b..0b2100fd 100644
--- a/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py
+++ b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py
@@ -190,9 +190,7 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
function_parameters = function_json_output["params"]
if self.clean_func_args:
- function_name, function_parameters = self.clean_function_args(
- function_name, function_parameters
- )
+ function_name, function_parameters = self.clean_function_args(function_name, function_parameters)
message = {
"role": "assistant",
@@ -275,9 +273,7 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
func_str += f"\n description: {schema['description']}"
func_str += f"\n params:"
if add_inner_thoughts:
- func_str += (
- f"\n inner_thoughts: Deep inner monologue private to you only."
- )
+ func_str += f"\n inner_thoughts: Deep inner monologue private to you only."
for param_k, param_v in schema["parameters"]["properties"].items():
# TODO we're ignoring type
func_str += f"\n {param_k}: {param_v['description']}"
diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/dolphin.py b/memgpt/local_llm/llm_chat_completion_wrappers/dolphin.py
index 0ce5d4b1..40d98579 100644
--- a/memgpt/local_llm/llm_chat_completion_wrappers/dolphin.py
+++ b/memgpt/local_llm/llm_chat_completion_wrappers/dolphin.py
@@ -152,9 +152,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
try:
content_json = json.loads(message["content"])
content_simple = content_json["message"]
- prompt += (
- f"\n{IM_START_TOKEN}user\n{content_simple}{IM_END_TOKEN}"
- )
+ prompt += f"\n{IM_START_TOKEN}user\n{content_simple}{IM_END_TOKEN}"
# prompt += f"\nUSER: {content_simple}"
except:
prompt += f"\n{IM_START_TOKEN}user\n{message['content']}{IM_END_TOKEN}"
@@ -227,9 +225,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
function_parameters = function_json_output["params"]
if self.clean_func_args:
- function_name, function_parameters = self.clean_function_args(
- function_name, function_parameters
- )
+ function_name, function_parameters = self.clean_function_args(function_name, function_parameters)
message = {
"role": "assistant",
diff --git a/memgpt/main.py b/memgpt/main.py
index 93df2329..3df4ee0e 100644
--- a/memgpt/main.py
+++ b/memgpt/main.py
@@ -84,12 +84,8 @@ def load(memgpt_agent, filename):
print(f"Loading {filename} failed with: {e}")
else:
# Load the latest file
- print(
- f"/load warning: no checkpoint specified, loading most recent checkpoint instead"
- )
- json_files = glob.glob(
- "saved_state/*.json"
- ) # This will list all .json files in the current directory.
+ print(f"/load warning: no checkpoint specified, loading most recent checkpoint instead")
+ json_files = glob.glob("saved_state/*.json") # This will list all .json files in the current directory.
# Check if there are any json files.
if not json_files:
@@ -111,27 +107,17 @@ def load(memgpt_agent, filename):
) # TODO(fixme):for different types of persistence managers that require different load/save methods
print(f"Loaded persistence manager from {filename}")
except Exception as e:
- print(
- f"/load warning: loading persistence manager from {filename} failed with: {e}"
- )
+ print(f"/load warning: loading persistence manager from {filename} failed with: {e}")
@app.command()
def run(
persona: str = typer.Option(None, help="Specify persona"),
human: str = typer.Option(None, help="Specify human"),
- model: str = typer.Option(
- constants.DEFAULT_MEMGPT_MODEL, help="Specify the LLM model"
- ),
- first: bool = typer.Option(
- False, "--first", help="Use --first to send the first message in the sequence"
- ),
- debug: bool = typer.Option(
- False, "--debug", help="Use --debug to enable debugging output"
- ),
- no_verify: bool = typer.Option(
- False, "--no_verify", help="Bypass message verification"
- ),
+ model: str = typer.Option(constants.DEFAULT_MEMGPT_MODEL, help="Specify the LLM model"),
+ first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"),
+ debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"),
+ no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"),
archival_storage_faiss_path: str = typer.Option(
"",
"--archival_storage_faiss_path",
@@ -201,9 +187,7 @@ async def main(
else:
azure_vars = get_set_azure_env_vars()
if len(azure_vars) > 0:
- print(
- f"Error: Environment variables {', '.join([x[0] for x in azure_vars])} should not be set if --use_azure_openai is False"
- )
+ print(f"Error: Environment variables {', '.join([x[0] for x in azure_vars])} should not be set if --use_azure_openai is False")
return
if any(
@@ -296,23 +280,17 @@ async def main(
else:
cfg = await Config.config_init()
- memgpt.interface.important_message(
- "Running... [exit by typing '/exit', list available commands with '/help']"
- )
+ memgpt.interface.important_message("Running... [exit by typing '/exit', list available commands with '/help']")
if cfg.model != constants.DEFAULT_MEMGPT_MODEL:
memgpt.interface.warning_message(
f"⛔️ Warning - you are running MemGPT with {cfg.model}, which is not officially supported (yet). Expect bugs!"
)
if cfg.index:
- persistence_manager = InMemoryStateManagerWithFaiss(
- cfg.index, cfg.archival_database
- )
+ persistence_manager = InMemoryStateManagerWithFaiss(cfg.index, cfg.archival_database)
elif cfg.archival_storage_files:
print(f"Preloaded {len(cfg.archival_database)} chunks into archival memory.")
- persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(
- cfg.archival_database
- )
+ persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(cfg.archival_database)
else:
persistence_manager = InMemoryStateManager()
@@ -356,9 +334,7 @@ async def main(
print(f"Database loaded into archival memory.")
if cfg.agent_save_file:
- load_save_file = await questionary.confirm(
- f"Load in saved agent '{cfg.agent_save_file}'?"
- ).ask_async()
+ load_save_file = await questionary.confirm(f"Load in saved agent '{cfg.agent_save_file}'?").ask_async()
if load_save_file:
load(memgpt_agent, cfg.agent_save_file)
@@ -367,9 +343,7 @@ async def main(
return
if not USER_GOES_FIRST:
- console.input(
- "[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]"
- )
+ console.input("[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]")
clear_line()
print()
@@ -405,9 +379,7 @@ async def main(
break
elif user_input.lower() == "/savechat":
- filename = (
- utils.get_local_time().replace(" ", "_").replace(":", "_")
- )
+ filename = utils.get_local_time().replace(" ", "_").replace(":", "_")
filename = f"{filename}.pkl"
directory = os.path.join(MEMGPT_DIR, "saved_chats")
try:
@@ -424,9 +396,7 @@ async def main(
save(memgpt_agent=memgpt_agent, cfg=cfg)
continue
- elif user_input.lower() == "/load" or user_input.lower().startswith(
- "/load "
- ):
+ elif user_input.lower() == "/load" or user_input.lower().startswith("/load "):
command = user_input.strip().split()
filename = command[1] if len(command) > 1 else None
load(memgpt_agent=memgpt_agent, filename=filename)
@@ -459,16 +429,10 @@ async def main(
print(f"Updated model to:\n{str(memgpt_agent.model)}")
continue
- elif user_input.lower() == "/pop" or user_input.lower().startswith(
- "/pop "
- ):
+ elif user_input.lower() == "/pop" or user_input.lower().startswith("/pop "):
# Check if there's an additional argument that's an integer
command = user_input.strip().split()
- amount = (
- int(command[1])
- if len(command) > 1 and command[1].isdigit()
- else 2
- )
+ amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 2
print(f"Popping last {amount} messages from stack")
for _ in range(min(amount, len(memgpt_agent.messages))):
memgpt_agent.messages.pop()
@@ -513,18 +477,14 @@ async def main(
heartbeat_request,
function_failed,
token_warning,
- ) = await memgpt_agent.step(
- user_message, first_message=False, skip_verify=no_verify
- )
+ ) = await memgpt_agent.step(user_message, first_message=False, skip_verify=no_verify)
# Skip user inputs if there's a memory warning, function execution failed, or the agent asked for control
if token_warning:
user_message = system.get_token_limit_warning()
skip_next_user_input = True
elif function_failed:
- user_message = system.get_heartbeat(
- constants.FUNC_FAILED_HEARTBEAT_MESSAGE
- )
+ user_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE)
skip_next_user_input = True
elif heartbeat_request:
user_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE)
diff --git a/memgpt/memory.py b/memgpt/memory.py
index f4fa09ec..157255b7 100644
--- a/memgpt/memory.py
+++ b/memgpt/memory.py
@@ -40,20 +40,17 @@ class CoreMemory(object):
self.archival_memory_exists = archival_memory_exists
def __repr__(self) -> str:
- return \
- f"\n### CORE MEMORY ###" + \
- f"\n=== Persona ===\n{self.persona}" + \
- f"\n\n=== Human ===\n{self.human}"
+ return f"\n### CORE MEMORY ###" + f"\n=== Persona ===\n{self.persona}" + f"\n\n=== Human ===\n{self.human}"
def to_dict(self):
return {
- 'persona': self.persona,
- 'human': self.human,
+ "persona": self.persona,
+ "human": self.human,
}
@classmethod
def load(cls, state):
- return cls(state['persona'], state['human'])
+ return cls(state["persona"], state["human"])
def edit_persona(self, new_persona):
if self.persona_char_limit and len(new_persona) > self.persona_char_limit:
@@ -76,53 +73,55 @@ class CoreMemory(object):
return len(self.human)
def edit(self, field, content):
- if field == 'persona':
+ if field == "persona":
return self.edit_persona(content)
- elif field == 'human':
+ elif field == "human":
return self.edit_human(content)
else:
raise KeyError
- def edit_append(self, field, content, sep='\n'):
- if field == 'persona':
+ def edit_append(self, field, content, sep="\n"):
+ if field == "persona":
new_content = self.persona + sep + content
return self.edit_persona(new_content)
- elif field == 'human':
+ elif field == "human":
new_content = self.human + sep + content
return self.edit_human(new_content)
else:
raise KeyError
def edit_replace(self, field, old_content, new_content):
- if field == 'persona':
+ if field == "persona":
if old_content in self.persona:
new_persona = self.persona.replace(old_content, new_content)
return self.edit_persona(new_persona)
else:
- raise ValueError('Content not found in persona (make sure to use exact string)')
- elif field == 'human':
+ raise ValueError("Content not found in persona (make sure to use exact string)")
+ elif field == "human":
if old_content in self.human:
new_human = self.human.replace(old_content, new_content)
return self.edit_human(new_human)
else:
- raise ValueError('Content not found in human (make sure to use exact string)')
+ raise ValueError("Content not found in human (make sure to use exact string)")
else:
raise KeyError
async def summarize_messages(
- model,
- message_sequence_to_summarize,
- ):
+ model,
+ message_sequence_to_summarize,
+):
"""Summarize a message sequence using GPT"""
summary_prompt = SUMMARY_PROMPT_SYSTEM
summary_input = str(message_sequence_to_summarize)
summary_input_tkns = count_tokens(summary_input, model)
if summary_input_tkns > MESSAGE_SUMMARY_WARNING_TOKENS:
- trunc_ratio = (MESSAGE_SUMMARY_WARNING_TOKENS / summary_input_tkns) * 0.8 # For good measure...
+ trunc_ratio = (MESSAGE_SUMMARY_WARNING_TOKENS / summary_input_tkns) * 0.8 # For good measure...
cutoff = int(len(message_sequence_to_summarize) * trunc_ratio)
- summary_input = str([await summarize_messages(model, message_sequence_to_summarize[:cutoff])] + message_sequence_to_summarize[cutoff:])
+ summary_input = str(
+ [await summarize_messages(model, message_sequence_to_summarize[:cutoff])] + message_sequence_to_summarize[cutoff:]
+ )
message_sequence = [
{"role": "system", "content": summary_prompt},
{"role": "user", "content": summary_input},
@@ -139,10 +138,9 @@ async def summarize_messages(
class ArchivalMemory(ABC):
-
@abstractmethod
def insert(self, memory_string):
- """ Insert new archival memory
+ """Insert new archival memory
:param memory_string: Memory string to insert
:type memory_string: str
@@ -151,7 +149,7 @@ class ArchivalMemory(ABC):
@abstractmethod
def search(self, query_string, count=None, start=None) -> Tuple[List[str], int]:
- """ Search archival memory
+ """Search archival memory
:param query_string: Query string
:type query_string: str
@@ -159,7 +157,7 @@ class ArchivalMemory(ABC):
:type count: Optional[int]
:param start: Offset to start returning results from (None if 0)
:type start: Optional[int]
-
+
:return: Tuple of (list of results, total number of results)
"""
pass
@@ -178,7 +176,7 @@ class DummyArchivalMemory(ArchivalMemory):
"""
def __init__(self, archival_memory_database=None):
- self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
+ self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
def __len__(self):
return len(self._archive)
@@ -187,31 +185,33 @@ class DummyArchivalMemory(ArchivalMemory):
if len(self._archive) == 0:
memory_str = ""
else:
- memory_str = "\n".join([d['content'] for d in self._archive])
- return \
- f"\n### ARCHIVAL MEMORY ###" + \
- f"\n{memory_str}"
+ memory_str = "\n".join([d["content"] for d in self._archive])
+ return f"\n### ARCHIVAL MEMORY ###" + f"\n{memory_str}"
async def insert(self, memory_string, embedding=None):
if embedding is not None:
- raise ValueError('Basic text-based archival memory does not support embeddings')
- self._archive.append({
- # can eventually upgrade to adding semantic tags, etc
- 'timestamp': get_local_time(),
- 'content': memory_string,
- })
+ raise ValueError("Basic text-based archival memory does not support embeddings")
+ self._archive.append(
+ {
+ # can eventually upgrade to adding semantic tags, etc
+ "timestamp": get_local_time(),
+ "content": memory_string,
+ }
+ )
async def search(self, query_string, count=None, start=None):
"""Simple text-based search"""
# in the dummy version, run an (inefficient) case-insensitive match search
# printd(f"query_string: {query_string}")
- matches = [s for s in self._archive if query_string.lower() in s['content'].lower()]
+ matches = [s for s in self._archive if query_string.lower() in s["content"].lower()]
# printd(f"archive_memory.search (text-based): search for query '{query_string}' returned the following results (limit 5):\n{[str(d['content']) d in matches[:5]]}")
- printd(f"archive_memory.search (text-based): search for query '{query_string}' returned the following results (limit 5):\n{[matches[start:count]]}")
+ printd(
+ f"archive_memory.search (text-based): search for query '{query_string}' returned the following results (limit 5):\n{[matches[start:count]]}"
+ )
# start/count support paging through results
if start is not None and count is not None:
- return matches[start:start+count], len(matches)
+ return matches[start : start + count], len(matches)
elif start is None and count is not None:
return matches[:count], len(matches)
elif start is not None and count is None:
@@ -223,8 +223,8 @@ class DummyArchivalMemory(ArchivalMemory):
class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory):
"""Same as dummy in-memory archival memory, but with bare-bones embedding support"""
- def __init__(self, archival_memory_database=None, embedding_model='text-embedding-ada-002'):
- self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
+ def __init__(self, archival_memory_database=None, embedding_model="text-embedding-ada-002"):
+ self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
self.embedding_model = embedding_model
def __len__(self):
@@ -234,15 +234,17 @@ class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory):
# Get the embedding
if embedding is None:
embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
- embedding_meta = {'model': self.embedding_model}
+ embedding_meta = {"model": self.embedding_model}
printd(f"Got an embedding, type {type(embedding)}, len {len(embedding)}")
- self._archive.append({
- 'timestamp': get_local_time(),
- 'content': memory_string,
- 'embedding': embedding,
- 'embedding_metadata': embedding_meta,
- })
+ self._archive.append(
+ {
+ "timestamp": get_local_time(),
+ "content": memory_string,
+ "embedding": embedding,
+ "embedding_metadata": embedding_meta,
+ }
+ )
async def search(self, query_string, count=None, start=None):
"""Simple embedding-based search (inefficient, no caching)"""
@@ -251,22 +253,24 @@ class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory):
# query_embedding = get_embedding(query_string, model=self.embedding_model)
# our wrapped version supports backoff/rate-limits
query_embedding = await async_get_embedding_with_backoff(query_string, model=self.embedding_model)
- similarity_scores = [cosine_similarity(memory['embedding'], query_embedding) for memory in self._archive]
+ similarity_scores = [cosine_similarity(memory["embedding"], query_embedding) for memory in self._archive]
# Sort the archive based on similarity scores
sorted_archive_with_scores = sorted(
zip(self._archive, similarity_scores),
key=lambda pair: pair[1], # Sort by the similarity score
- reverse=True # We want the highest similarity first
+ reverse=True, # We want the highest similarity first
+ )
+ printd(
+ f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results (limit 5) and scores:\n{str([str(t[0]['content']) + '- score ' + str(t[1]) for t in sorted_archive_with_scores[:5]])}"
)
- printd(f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results (limit 5) and scores:\n{str([str(t[0]['content']) + '- score ' + str(t[1]) for t in sorted_archive_with_scores[:5]])}")
# Extract the sorted archive without the scores
matches = [item[0] for item in sorted_archive_with_scores]
# start/count support paging through results
if start is not None and count is not None:
- return matches[start:start+count], len(matches)
+ return matches[start : start + count], len(matches)
elif start is None and count is not None:
return matches[:count], len(matches)
elif start is not None and count is None:
@@ -287,13 +291,13 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
is essential enough not to be left only to the recall memory.
"""
- def __init__(self, index=None, archival_memory_database=None, embedding_model='text-embedding-ada-002', k=100):
+ def __init__(self, index=None, archival_memory_database=None, embedding_model="text-embedding-ada-002", k=100):
if index is None:
- self.index = faiss.IndexFlatL2(1536) # openai embedding vector size.
+ self.index = faiss.IndexFlatL2(1536) # openai embedding vector size.
else:
self.index = index
self.k = k
- self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
+ self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
self.embedding_model = embedding_model
self.embeddings_dict = {}
self.search_results = {}
@@ -307,12 +311,14 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
print(f"Got an embedding, type {type(embedding)}, len {len(embedding)}")
- self._archive.append({
- # can eventually upgrade to adding semantic tags, etc
- 'timestamp': get_local_time(),
- 'content': memory_string,
- })
- embedding = np.array([embedding]).astype('float32')
+ self._archive.append(
+ {
+ # can eventually upgrade to adding semantic tags, etc
+ "timestamp": get_local_time(),
+ "content": memory_string,
+ }
+ )
+ embedding = np.array([embedding]).astype("float32")
self.index.add(embedding)
async def search(self, query_string, count=None, start=None):
@@ -332,20 +338,22 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
self.search_results[query_string] = search_result
if start is not None and count is not None:
- toprint = search_result[start:start+count]
+ toprint = search_result[start : start + count]
else:
if len(search_result) >= 5:
toprint = search_result[:5]
else:
toprint = search_result
- printd(f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results ({start}--{start+5}/{len(search_result)}) and scores:\n{str([t[:60] if len(t) > 60 else t for t in toprint])}")
+ printd(
+ f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results ({start}--{start+5}/{len(search_result)}) and scores:\n{str([t[:60] if len(t) > 60 else t for t in toprint])}"
+ )
# Extract the sorted archive without the scores
matches = search_result
# start/count support paging through results
if start is not None and count is not None:
- return matches[start:start+count], len(matches)
+ return matches[start : start + count], len(matches)
elif start is None and count is not None:
return matches[:count], len(matches)
elif start is not None and count is None:
@@ -355,7 +363,6 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
class RecallMemory(ABC):
-
@abstractmethod
def text_search(self, query_string, count=None, start=None):
pass
@@ -393,42 +400,46 @@ class DummyRecallMemory(RecallMemory):
# don't dump all the conversations, just statistics
system_count = user_count = assistant_count = function_count = other_count = 0
for msg in self._message_logs:
- role = msg['message']['role']
- if role == 'system':
+ role = msg["message"]["role"]
+ if role == "system":
system_count += 1
- elif role == 'user':
+ elif role == "user":
user_count += 1
- elif role == 'assistant':
+ elif role == "assistant":
assistant_count += 1
- elif role == 'function':
+ elif role == "function":
function_count += 1
else:
other_count += 1
- memory_str = f"Statistics:" + \
- f"\n{len(self._message_logs)} total messages" + \
- f"\n{system_count} system" + \
- f"\n{user_count} user" + \
- f"\n{assistant_count} assistant" + \
- f"\n{function_count} function" + \
- f"\n{other_count} other"
- return \
- f"\n### RECALL MEMORY ###" + \
- f"\n{memory_str}"
+ memory_str = (
+ f"Statistics:"
+ + f"\n{len(self._message_logs)} total messages"
+ + f"\n{system_count} system"
+ + f"\n{user_count} user"
+ + f"\n{assistant_count} assistant"
+ + f"\n{function_count} function"
+ + f"\n{other_count} other"
+ )
+ return f"\n### RECALL MEMORY ###" + f"\n{memory_str}"
async def insert(self, message):
- raise NotImplementedError('This should be handled by the PersistenceManager, recall memory is just a search layer on top')
+ raise NotImplementedError("This should be handled by the PersistenceManager, recall memory is just a search layer on top")
async def text_search(self, query_string, count=None, start=None):
# in the dummy version, run an (inefficient) case-insensitive match search
- message_pool = [d for d in self._message_logs if d['message']['role'] not in ['system', 'function']]
+ message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
- printd(f"recall_memory.text_search: searching for {query_string} (c={count}, s={start}) in {len(self._message_logs)} total messages")
- matches = [d for d in message_pool if d['message']['content'] is not None and query_string.lower() in d['message']['content'].lower()]
+ printd(
+ f"recall_memory.text_search: searching for {query_string} (c={count}, s={start}) in {len(self._message_logs)} total messages"
+ )
+ matches = [
+ d for d in message_pool if d["message"]["content"] is not None and query_string.lower() in d["message"]["content"].lower()
+ ]
printd(f"recall_memory - matches:\n{matches[start:start+count]}")
# start/count support paging through results
if start is not None and count is not None:
- return matches[start:start+count], len(matches)
+ return matches[start : start + count], len(matches)
elif start is None and count is not None:
return matches[:count], len(matches)
elif start is not None and count is None:
@@ -439,7 +450,7 @@ class DummyRecallMemory(RecallMemory):
def _validate_date_format(self, date_str):
"""Validate the given date string in the format 'YYYY-MM-DD'."""
try:
- datetime.datetime.strptime(date_str, '%Y-%m-%d')
+ datetime.datetime.strptime(date_str, "%Y-%m-%d")
return True
except ValueError:
return False
@@ -451,25 +462,26 @@ class DummyRecallMemory(RecallMemory):
return match.group(1) if match else None
async def date_search(self, start_date, end_date, count=None, start=None):
- message_pool = [d for d in self._message_logs if d['message']['role'] not in ['system', 'function']]
+ message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
# First, validate the start_date and end_date format
if not self._validate_date_format(start_date) or not self._validate_date_format(end_date):
raise ValueError("Invalid date format. Expected format: YYYY-MM-DD")
# Convert dates to datetime objects for comparison
- start_date_dt = datetime.datetime.strptime(start_date, '%Y-%m-%d')
- end_date_dt = datetime.datetime.strptime(end_date, '%Y-%m-%d')
+ start_date_dt = datetime.datetime.strptime(start_date, "%Y-%m-%d")
+ end_date_dt = datetime.datetime.strptime(end_date, "%Y-%m-%d")
# Next, match items inside self._message_logs
matches = [
- d for d in message_pool
- if start_date_dt <= datetime.datetime.strptime(self._extract_date_from_timestamp(d['timestamp']), '%Y-%m-%d') <= end_date_dt
+ d
+ for d in message_pool
+ if start_date_dt <= datetime.datetime.strptime(self._extract_date_from_timestamp(d["timestamp"]), "%Y-%m-%d") <= end_date_dt
]
# start/count support paging through results
if start is not None and count is not None:
- return matches[start:start+count], len(matches)
+ return matches[start : start + count], len(matches)
elif start is None and count is not None:
return matches[:count], len(matches)
elif start is not None and count is None:
@@ -484,17 +496,17 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.embeddings = dict()
- self.embedding_model = 'text-embedding-ada-002'
+ self.embedding_model = "text-embedding-ada-002"
self.only_use_preloaded_embeddings = False
async def text_search(self, query_string, count=None, start=None):
# in the dummy version, run an (inefficient) case-insensitive match search
- message_pool = [d for d in self._message_logs if d['message']['role'] not in ['system', 'function']]
+ message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
# first, go through and make sure we have all the embeddings we need
message_pool_filtered = []
for d in message_pool:
- message_str = d['message']['content']
+ message_str = d["message"]["content"]
if self.only_use_preloaded_embeddings:
if message_str not in self.embeddings:
printd(f"recall_memory.text_search -- '{message_str}' was not in embedding dict, skipping.")
@@ -505,24 +517,26 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory):
self.embeddings[message_str] = await async_get_embedding_with_backoff(message_str, model=self.embedding_model)
message_pool_filtered.append(d)
- # our wrapped version supports backoff/rate-limits
+ # our wrapped version supports backoff/rate-limits
query_embedding = await async_get_embedding_with_backoff(query_string, model=self.embedding_model)
- similarity_scores = [cosine_similarity(self.embeddings[d['message']['content']], query_embedding) for d in message_pool_filtered]
+ similarity_scores = [cosine_similarity(self.embeddings[d["message"]["content"]], query_embedding) for d in message_pool_filtered]
# Sort the archive based on similarity scores
sorted_archive_with_scores = sorted(
zip(message_pool_filtered, similarity_scores),
key=lambda pair: pair[1], # Sort by the similarity score
- reverse=True # We want the highest similarity first
+ reverse=True, # We want the highest similarity first
+ )
+ printd(
+ f"recall_memory.text_search (vector-based): search for query '{query_string}' returned the following results (limit 5) and scores:\n{str([str(t[0]['message']['content']) + '- score ' + str(t[1]) for t in sorted_archive_with_scores[:5]])}"
)
- printd(f"recall_memory.text_search (vector-based): search for query '{query_string}' returned the following results (limit 5) and scores:\n{str([str(t[0]['message']['content']) + '- score ' + str(t[1]) for t in sorted_archive_with_scores[:5]])}")
# Extract the sorted archive without the scores
matches = [item[0] for item in sorted_archive_with_scores]
# start/count support paging through results
if start is not None and count is not None:
- return matches[start:start+count], len(matches)
+ return matches[start : start + count], len(matches)
elif start is None and count is not None:
return matches[:count], len(matches)
elif start is not None and count is None:
@@ -531,55 +545,49 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory):
return matches, len(matches)
-class LocalArchivalMemory(ArchivalMemory):
- """ Archival memory built on top of Llama Index """
+class LocalArchivalMemory(ArchivalMemory):
+ """Archival memory built on top of Llama Index"""
- def __init__(self, archival_memory_database: Optional[str] = None, top_k: Optional[int] = 100):
- """ Init function for archival memory
+ def __init__(self, archival_memory_database: Optional[str] = None, top_k: Optional[int] = 100):
+ """Init function for archival memory
- :param archiva_memory_database: name of dataset to pre-fill archival with
+ :param archiva_memory_database: name of dataset to pre-fill archival with
:type archival_memory_database: str
"""
- if archival_memory_database is not None:
+ if archival_memory_database is not None:
# TODO: load form ~/.memgpt/archival
directory = f"{MEMGPT_DIR}/archival/{archival_memory_database}"
assert os.path.exists(directory), f"Archival memory database {archival_memory_database} does not exist"
storage_context = StorageContext.from_defaults(persist_dir=directory)
self.index = load_index_from_storage(storage_context)
- else:
+ else:
self.index = VectorIndex()
self.top_k = top_k
self.retriever = VectorIndexRetriever(
- index=self.index, # does this get refreshed?
+ index=self.index, # does this get refreshed?
similarity_top_k=self.top_k,
)
# TODO: have some mechanism for cleanup otherwise will lead to OOM
- self.cache = {}
+ self.cache = {}
async def insert(self, memory_string):
self.index.insert(memory_string)
-
- async def search(self, query_string, count=None, start=None):
- start = start if start else 0
- count = count if count else self.top_k
+ async def search(self, query_string, count=None, start=None):
+ start = start if start else 0
+ count = count if count else self.top_k
count = min(count + start, self.top_k)
if query_string not in self.cache:
self.cache[query_string] = self.retriever.retrieve(query_string)
- results = self.cache[query_string][start:start+count]
- results = [
- {'timestamp': get_local_time(), 'content': node.node.text}
- for node in results
- ]
- #from pprint import pprint
- #pprint(results)
+ results = self.cache[query_string][start : start + count]
+ results = [{"timestamp": get_local_time(), "content": node.node.text} for node in results]
+ # from pprint import pprint
+ # pprint(results)
return results, len(results)
-
+
def __repr__(self) -> str:
print(self.index.ref_doc_info)
return ""
-
-
diff --git a/memgpt/openai_tools.py b/memgpt/openai_tools.py
index 20ad5f9a..6b97b06a 100644
--- a/memgpt/openai_tools.py
+++ b/memgpt/openai_tools.py
@@ -41,9 +41,7 @@ def retry_with_exponential_backoff(
# Check if max retries has been reached
if num_retries > max_retries:
- raise Exception(
- f"Maximum number of retries ({max_retries}) exceeded."
- )
+ raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")
# Increment the delay
delay *= exponential_base * (1 + jitter * random.random())
@@ -91,9 +89,7 @@ def aretry_with_exponential_backoff(
# Check if max retries has been reached
if num_retries > max_retries:
- raise Exception(
- f"Maximum number of retries ({max_retries}) exceeded."
- )
+ raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")
# Increment the delay
delay *= exponential_base * (1 + jitter * random.random())
@@ -184,9 +180,7 @@ def configure_azure_support():
azure_openai_endpoint,
azure_openai_version,
]:
- print(
- f"Error: missing Azure OpenAI environment variables. Please see README section on Azure."
- )
+ print(f"Error: missing Azure OpenAI environment variables. Please see README section on Azure.")
return
openai.api_type = "azure"
@@ -199,10 +193,7 @@ def configure_azure_support():
def check_azure_embeddings():
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
azure_openai_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
- if (
- azure_openai_deployment is not None
- and azure_openai_embedding_deployment is None
- ):
+ if azure_openai_deployment is not None and azure_openai_embedding_deployment is None:
raise ValueError(
f"Error: It looks like you are using Azure deployment ids and computing embeddings, make sure you are setting one for embeddings as well. Please see README section on Azure"
)
diff --git a/memgpt/persistence_manager.py b/memgpt/persistence_manager.py
index 74f8d1d9..2c38d8c1 100644
--- a/memgpt/persistence_manager.py
+++ b/memgpt/persistence_manager.py
@@ -1,12 +1,18 @@
from abc import ABC, abstractmethod
import pickle
-from .memory import DummyRecallMemory, DummyRecallMemoryWithEmbeddings, DummyArchivalMemory, DummyArchivalMemoryWithEmbeddings, DummyArchivalMemoryWithFaiss, LocalArchivalMemory
+from .memory import (
+ DummyRecallMemory,
+ DummyRecallMemoryWithEmbeddings,
+ DummyArchivalMemory,
+ DummyArchivalMemoryWithEmbeddings,
+ DummyArchivalMemoryWithFaiss,
+ LocalArchivalMemory,
+)
from .utils import get_local_time, printd
class PersistenceManager(ABC):
-
@abstractmethod
def trim_messages(self, num):
pass
@@ -27,6 +33,7 @@ class PersistenceManager(ABC):
def update_memory(self, new_memory):
pass
+
class InMemoryStateManager(PersistenceManager):
"""In-memory state manager has nothing to manage, all agents are held in-memory"""
@@ -41,17 +48,17 @@ class InMemoryStateManager(PersistenceManager):
@staticmethod
def load(filename):
- with open(filename, 'rb') as f:
+ with open(filename, "rb") as f:
return pickle.load(f)
def save(self, filename):
- with open(filename, 'wb') as fh:
+ with open(filename, "wb") as fh:
pickle.dump(self, fh, protocol=pickle.HIGHEST_PROTOCOL)
def init(self, agent):
printd(f"Initializing InMemoryStateManager with agent object")
- self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
- self.messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
+ self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
+ self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.memory = agent.memory
printd(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
printd(f"InMemoryStateManager.messages.len = {len(self.messages)}")
@@ -67,7 +74,7 @@ class InMemoryStateManager(PersistenceManager):
def prepend_to_messages(self, added_messages):
# first tag with timestamps
- added_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in added_messages]
+ added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
printd(f"InMemoryStateManager.prepend_to_message")
self.messages = [self.messages[0]] + added_messages + self.messages[1:]
@@ -75,7 +82,7 @@ class InMemoryStateManager(PersistenceManager):
def append_to_messages(self, added_messages):
# first tag with timestamps
- added_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in added_messages]
+ added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
printd(f"InMemoryStateManager.append_to_messages")
self.messages = self.messages + added_messages
@@ -83,7 +90,7 @@ class InMemoryStateManager(PersistenceManager):
def swap_system_message(self, new_system_message):
# first tag with timestamps
- new_system_message = {'timestamp': get_local_time(), 'message': new_system_message}
+ new_system_message = {"timestamp": get_local_time(), "message": new_system_message}
printd(f"InMemoryStateManager.swap_system_message")
self.messages[0] = new_system_message
@@ -93,11 +100,12 @@ class InMemoryStateManager(PersistenceManager):
printd(f"InMemoryStateManager.update_memory")
self.memory = new_memory
+
class LocalStateManager(PersistenceManager):
"""In-memory state manager has nothing to manage, all agents are held in-memory"""
recall_memory_cls = DummyRecallMemory
- archival_memory_cls = LocalArchivalMemory
+ archival_memory_cls = LocalArchivalMemory
def __init__(self, archival_memory_db=None):
# Memory held in-state useful for debugging stateful versions
@@ -108,17 +116,17 @@ class LocalStateManager(PersistenceManager):
@staticmethod
def load(filename):
- with open(filename, 'rb') as f:
+ with open(filename, "rb") as f:
return pickle.load(f)
def save(self, filename):
- with open(filename, 'wb') as fh:
+ with open(filename, "wb") as fh:
pickle.dump(self, fh, protocol=pickle.HIGHEST_PROTOCOL)
def init(self, agent):
printd(f"Initializing InMemoryStateManager with agent object")
- self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
- self.messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
+ self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
+ self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.memory = agent.memory
printd(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
printd(f"InMemoryStateManager.messages.len = {len(self.messages)}")
@@ -126,7 +134,7 @@ class LocalStateManager(PersistenceManager):
# Persistence manager also handles DB-related state
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
- # TODO: init archival memory here?
+ # TODO: init archival memory here?
def trim_messages(self, num):
# printd(f"InMemoryStateManager.trim_messages")
@@ -134,7 +142,7 @@ class LocalStateManager(PersistenceManager):
def prepend_to_messages(self, added_messages):
# first tag with timestamps
- added_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in added_messages]
+ added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
printd(f"InMemoryStateManager.prepend_to_message")
self.messages = [self.messages[0]] + added_messages + self.messages[1:]
@@ -142,7 +150,7 @@ class LocalStateManager(PersistenceManager):
def append_to_messages(self, added_messages):
# first tag with timestamps
- added_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in added_messages]
+ added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
printd(f"InMemoryStateManager.append_to_messages")
self.messages = self.messages + added_messages
@@ -150,7 +158,7 @@ class LocalStateManager(PersistenceManager):
def swap_system_message(self, new_system_message):
# first tag with timestamps
- new_system_message = {'timestamp': get_local_time(), 'message': new_system_message}
+ new_system_message = {"timestamp": get_local_time(), "message": new_system_message}
printd(f"InMemoryStateManager.swap_system_message")
self.messages[0] = new_system_message
@@ -170,8 +178,8 @@ class InMemoryStateManagerWithPreloadedArchivalMemory(InMemoryStateManager):
def init(self, agent):
print(f"Initializing InMemoryStateManager with agent object")
- self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
- self.messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
+ self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
+ self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.memory = agent.memory
print(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
print(f"InMemoryStateManager.messages.len = {len(self.messages)}")
@@ -199,13 +207,14 @@ class InMemoryStateManagerWithFaiss(InMemoryStateManager):
def init(self, agent):
print(f"Initializing InMemoryStateManager with agent object")
- self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
- self.messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
+ self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
+ self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.memory = agent.memory
print(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
print(f"InMemoryStateManager.messages.len = {len(self.messages)}")
# Persistence manager also handles DB-related state
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
- self.archival_memory = self.archival_memory_cls(index=self.archival_index, archival_memory_database=self.archival_memory_db, k=self.a_k)
-
+ self.archival_memory = self.archival_memory_cls(
+ index=self.archival_index, archival_memory_database=self.archival_memory_db, k=self.a_k
+ )
diff --git a/memgpt/personas/examples/docqa/build_index.py b/memgpt/personas/examples/docqa/build_index.py
index 2dd94708..94802395 100644
--- a/memgpt/personas/examples/docqa/build_index.py
+++ b/memgpt/personas/examples/docqa/build_index.py
@@ -5,15 +5,14 @@ import numpy as np
import argparse
import json
-def build_index(embedding_files: str,
- index_name: str):
+def build_index(embedding_files: str, index_name: str):
index = faiss.IndexFlatL2(1536)
file_list = sorted(glob(embedding_files))
for embedding_file in file_list:
print(embedding_file)
- with open(embedding_file, 'rt', encoding='utf-8') as file:
+ with open(embedding_file, "rt", encoding="utf-8") as file:
embeddings = []
l = 0
for line in tqdm(file):
@@ -21,7 +20,7 @@ def build_index(embedding_files: str,
data = json.loads(line)
embeddings.append(data)
l += 1
- data = np.array(embeddings).astype('float32')
+ data = np.array(embeddings).astype("float32")
print(data.shape)
try:
index.add(data)
@@ -32,14 +31,11 @@ def build_index(embedding_files: str,
faiss.write_index(index, index_name)
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--embedding_files', type=str, help='embedding_filepaths glob expression')
- parser.add_argument('--output_index_file', type=str, help='output filepath')
- args = parser.parse_args()
+ parser.add_argument("--embedding_files", type=str, help="embedding_filepaths glob expression")
+ parser.add_argument("--output_index_file", type=str, help="output filepath")
+ args = parser.parse_args()
- build_index(
- embedding_files=args.embedding_files,
- index_name=args.output_index_file
- )
\ No newline at end of file
+ build_index(embedding_files=args.embedding_files, index_name=args.output_index_file)
diff --git a/memgpt/personas/examples/docqa/generate_embeddings_for_docs.py b/memgpt/personas/examples/docqa/generate_embeddings_for_docs.py
index f377ce27..0c2d9479 100644
--- a/memgpt/personas/examples/docqa/generate_embeddings_for_docs.py
+++ b/memgpt/personas/examples/docqa/generate_embeddings_for_docs.py
@@ -7,12 +7,14 @@ import argparse
from tqdm import tqdm
import openai
+
try:
from dotenv import load_dotenv
+
load_dotenv()
except ModuleNotFoundError:
pass
-openai.api_key = os.getenv('OPENAI_API_KEY')
+openai.api_key = os.getenv("OPENAI_API_KEY")
sys.path.append("../../../")
from openai_tools import async_get_embedding_with_backoff
@@ -24,8 +26,8 @@ from openai_parallel_request_processor import process_api_requests_from_file
TPM_LIMIT = 1000000
RPM_LIMIT = 3000
-DEFAULT_FILE = 'iclr/data/qa_data/30_total_documents/nq-open-30_total_documents_gold_at_0.jsonl.gz'
-EMBEDDING_MODEL = 'text-embedding-ada-002'
+DEFAULT_FILE = "iclr/data/qa_data/30_total_documents/nq-open-30_total_documents_gold_at_0.jsonl.gz"
+EMBEDDING_MODEL = "text-embedding-ada-002"
async def generate_requests_file(filename):
@@ -33,36 +35,33 @@ async def generate_requests_file(filename):
base_name = os.path.splitext(filename)[0]
requests_filename = f"{base_name}_embedding_requests.jsonl"
- with open(filename, 'r') as f:
+ with open(filename, "r") as f:
all_data = [json.loads(line) for line in f]
- with open(requests_filename, 'w') as f:
+ with open(requests_filename, "w") as f:
for data in all_data:
documents = data
for idx, doc in enumerate(documents):
title = doc["title"]
text = doc["text"]
document_string = f"Document [{idx+1}] (Title: {title}) {text}"
- request = {
- "model": EMBEDDING_MODEL,
- "input": document_string
- }
+ request = {"model": EMBEDDING_MODEL, "input": document_string}
json_string = json.dumps(request)
f.write(json_string + "\n")
# Run your parallel processing function
input(f"Generated requests file ({requests_filename}), continue with embedding batch requests? (hit enter)")
await process_api_requests_from_file(
- requests_filepath=requests_filename,
- save_filepath=f"{base_name}.embeddings.jsonl.gz", # Adjust as necessary
- request_url="https://api.openai.com/v1/embeddings",
- api_key=os.getenv('OPENAI_API_KEY'),
- max_requests_per_minute=RPM_LIMIT,
- max_tokens_per_minute=TPM_LIMIT,
- token_encoding_name=EMBEDDING_MODEL,
- max_attempts=5,
- logging_level=logging.INFO,
- )
+ requests_filepath=requests_filename,
+ save_filepath=f"{base_name}.embeddings.jsonl.gz", # Adjust as necessary
+ request_url="https://api.openai.com/v1/embeddings",
+ api_key=os.getenv("OPENAI_API_KEY"),
+ max_requests_per_minute=RPM_LIMIT,
+ max_tokens_per_minute=TPM_LIMIT,
+ token_encoding_name=EMBEDDING_MODEL,
+ max_attempts=5,
+ logging_level=logging.INFO,
+ )
async def generate_embedding_file(filename, parallel_mode=False):
@@ -72,7 +71,7 @@ async def generate_embedding_file(filename, parallel_mode=False):
# Derive the sister filename
# base_name = os.path.splitext(filename)[0]
- base_name = filename.rsplit('.jsonl', 1)[0]
+ base_name = filename.rsplit(".jsonl", 1)[0]
sister_filename = f"{base_name}.embeddings.jsonl"
# Check if the sister file already exists
@@ -80,7 +79,7 @@ async def generate_embedding_file(filename, parallel_mode=False):
print(f"{sister_filename} already exists. Skipping embedding generation.")
return
- with open(filename, 'rt') as f:
+ with open(filename, "rt") as f:
all_data = [json.loads(line) for line in f]
embedding_data = []
@@ -90,7 +89,9 @@ async def generate_embedding_file(filename, parallel_mode=False):
for i, data in enumerate(tqdm(all_data, desc="Processing data", total=len(all_data))):
documents = data
# Inner loop progress bar
- for idx, doc in enumerate(tqdm(documents, desc=f"Embedding documents for data {i+1}/{len(all_data)}", total=len(documents), leave=False)):
+ for idx, doc in enumerate(
+ tqdm(documents, desc=f"Embedding documents for data {i+1}/{len(all_data)}", total=len(documents), leave=False)
+ ):
title = doc["title"]
text = doc["text"]
document_string = f"[Title: {title}] {text}"
@@ -103,10 +104,10 @@ async def generate_embedding_file(filename, parallel_mode=False):
# Save the embeddings to the sister file
# with gzip.open(sister_filename, 'wt') as f:
- with open(sister_filename, 'wb') as f:
+ with open(sister_filename, "wb") as f:
for embedding in embedding_data:
# f.write(json.dumps(embedding) + '\n')
- f.write((json.dumps(embedding) + '\n').encode('utf-8'))
+ f.write((json.dumps(embedding) + "\n").encode("utf-8"))
print(f"Embeddings saved to {sister_filename}")
@@ -118,6 +119,7 @@ async def main():
filename = DEFAULT_FILE
await generate_embedding_file(filename)
+
async def main():
parser = argparse.ArgumentParser()
parser.add_argument("filename", nargs="?", default=DEFAULT_FILE, help="Path to the input file")
@@ -129,4 +131,4 @@ async def main():
if __name__ == "__main__":
loop = asyncio.get_event_loop()
- loop.run_until_complete(main())
\ No newline at end of file
+ loop.run_until_complete(main())
diff --git a/memgpt/personas/examples/docqa/openai_parallel_request_processor.py b/memgpt/personas/examples/docqa/openai_parallel_request_processor.py
index 4b9a1aae..169bfd37 100644
--- a/memgpt/personas/examples/docqa/openai_parallel_request_processor.py
+++ b/memgpt/personas/examples/docqa/openai_parallel_request_processor.py
@@ -121,9 +121,7 @@ async def process_api_requests_from_file(
"""Processes API requests in parallel, throttling to stay under rate limits."""
# constants
seconds_to_pause_after_rate_limit_error = 15
- seconds_to_sleep_each_loop = (
- 0.001 # 1 ms limits max throughput to 1,000 requests per second
- )
+ seconds_to_sleep_each_loop = 0.001 # 1 ms limits max throughput to 1,000 requests per second
# initialize logging
logging.basicConfig(level=logging_level)
@@ -135,12 +133,8 @@ async def process_api_requests_from_file(
# initialize trackers
queue_of_requests_to_retry = asyncio.Queue()
- task_id_generator = (
- task_id_generator_function()
- ) # generates integer IDs of 1, 2, 3, ...
- status_tracker = (
- StatusTracker()
- ) # single instance to track a collection of variables
+ task_id_generator = task_id_generator_function() # generates integer IDs of 1, 2, 3, ...
+ status_tracker = StatusTracker() # single instance to track a collection of variables
next_request = None # variable to hold the next request to call
# initialize available capacity counts
@@ -163,9 +157,7 @@ async def process_api_requests_from_file(
if next_request is None:
if not queue_of_requests_to_retry.empty():
next_request = queue_of_requests_to_retry.get_nowait()
- logging.debug(
- f"Retrying request {next_request.task_id}: {next_request}"
- )
+ logging.debug(f"Retrying request {next_request.task_id}: {next_request}")
elif file_not_finished:
try:
# get new request
@@ -173,17 +165,13 @@ async def process_api_requests_from_file(
next_request = APIRequest(
task_id=next(task_id_generator),
request_json=request_json,
- token_consumption=num_tokens_consumed_from_request(
- request_json, api_endpoint, token_encoding_name
- ),
+ token_consumption=num_tokens_consumed_from_request(request_json, api_endpoint, token_encoding_name),
attempts_left=max_attempts,
metadata=request_json.pop("metadata", None),
)
status_tracker.num_tasks_started += 1
status_tracker.num_tasks_in_progress += 1
- logging.debug(
- f"Reading request {next_request.task_id}: {next_request}"
- )
+ logging.debug(f"Reading request {next_request.task_id}: {next_request}")
except StopIteration:
# if file runs out, set flag to stop reading it
logging.debug("Read file exhausted")
@@ -193,13 +181,11 @@ async def process_api_requests_from_file(
current_time = time.time()
seconds_since_update = current_time - last_update_time
available_request_capacity = min(
- available_request_capacity
- + max_requests_per_minute * seconds_since_update / 60.0,
+ available_request_capacity + max_requests_per_minute * seconds_since_update / 60.0,
max_requests_per_minute,
)
available_token_capacity = min(
- available_token_capacity
- + max_tokens_per_minute * seconds_since_update / 60.0,
+ available_token_capacity + max_tokens_per_minute * seconds_since_update / 60.0,
max_tokens_per_minute,
)
last_update_time = current_time
@@ -207,10 +193,7 @@ async def process_api_requests_from_file(
# if enough capacity available, call API
if next_request:
next_request_tokens = next_request.token_consumption
- if (
- available_request_capacity >= 1
- and available_token_capacity >= next_request_tokens
- ):
+ if available_request_capacity >= 1 and available_token_capacity >= next_request_tokens:
# update counters
available_request_capacity -= 1
available_token_capacity -= next_request_tokens
@@ -237,17 +220,9 @@ async def process_api_requests_from_file(
await asyncio.sleep(seconds_to_sleep_each_loop)
# if a rate limit error was hit recently, pause to cool down
- seconds_since_rate_limit_error = (
- time.time() - status_tracker.time_of_last_rate_limit_error
- )
- if (
- seconds_since_rate_limit_error
- < seconds_to_pause_after_rate_limit_error
- ):
- remaining_seconds_to_pause = (
- seconds_to_pause_after_rate_limit_error
- - seconds_since_rate_limit_error
- )
+ seconds_since_rate_limit_error = time.time() - status_tracker.time_of_last_rate_limit_error
+ if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error:
+ remaining_seconds_to_pause = seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error
await asyncio.sleep(remaining_seconds_to_pause)
# ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago
logging.warn(
@@ -255,17 +230,13 @@ async def process_api_requests_from_file(
)
# after finishing, log final status
- logging.info(
- f"""Parallel processing complete. Results saved to {save_filepath}"""
- )
+ logging.info(f"""Parallel processing complete. Results saved to {save_filepath}""")
if status_tracker.num_tasks_failed > 0:
logging.warning(
f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed. Errors logged to {save_filepath}."
)
if status_tracker.num_rate_limit_errors > 0:
- logging.warning(
- f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
- )
+ logging.warning(f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate.")
# dataclasses
@@ -309,26 +280,18 @@ class APIRequest:
logging.info(f"Starting request #{self.task_id}")
error = None
try:
- async with session.post(
- url=request_url, headers=request_header, json=self.request_json
- ) as response:
+ async with session.post(url=request_url, headers=request_header, json=self.request_json) as response:
response = await response.json()
if "error" in response:
- logging.warning(
- f"Request {self.task_id} failed with error {response['error']}"
- )
+ logging.warning(f"Request {self.task_id} failed with error {response['error']}")
status_tracker.num_api_errors += 1
error = response
if "Rate limit" in response["error"].get("message", ""):
status_tracker.time_of_last_rate_limit_error = time.time()
status_tracker.num_rate_limit_errors += 1
- status_tracker.num_api_errors -= (
- 1 # rate limit errors are counted separately
- )
+ status_tracker.num_api_errors -= 1 # rate limit errors are counted separately
- except (
- Exception
- ) as e: # catching naked exceptions is bad practice, but in this case we'll log & save them
+ except Exception as e: # catching naked exceptions is bad practice, but in this case we'll log & save them
logging.warning(f"Request {self.task_id} failed with Exception {e}")
status_tracker.num_other_errors += 1
error = e
@@ -337,9 +300,7 @@ class APIRequest:
if self.attempts_left:
retry_queue.put_nowait(self)
else:
- logging.error(
- f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}"
- )
+ logging.error(f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}")
data = (
[self.request_json, [str(e) for e in self.result], self.metadata]
if self.metadata
@@ -349,11 +310,7 @@ class APIRequest:
status_tracker.num_tasks_in_progress -= 1
status_tracker.num_tasks_failed += 1
else:
- data = (
- [self.request_json, response, self.metadata]
- if self.metadata
- else [self.request_json, response]
- )
+ data = [self.request_json, response, self.metadata] if self.metadata else [self.request_json, response]
append_to_jsonl(data, save_filepath)
status_tracker.num_tasks_in_progress -= 1
status_tracker.num_tasks_succeeded += 1
@@ -382,8 +339,8 @@ def num_tokens_consumed_from_request(
token_encoding_name: str,
):
"""Count the number of tokens in the request. Only supports completion and embedding requests."""
- if token_encoding_name == 'text-embedding-ada-002':
- encoding = tiktoken.get_encoding('cl100k_base')
+ if token_encoding_name == "text-embedding-ada-002":
+ encoding = tiktoken.get_encoding("cl100k_base")
else:
encoding = tiktoken.get_encoding(token_encoding_name)
# if completions request, tokens = prompt + n * max_tokens
@@ -415,9 +372,7 @@ def num_tokens_consumed_from_request(
num_tokens = prompt_tokens + completion_tokens * len(prompt)
return num_tokens
else:
- raise TypeError(
- 'Expecting either string or list of strings for "prompt" field in completion request'
- )
+ raise TypeError('Expecting either string or list of strings for "prompt" field in completion request')
# if embeddings request, tokens = input tokens
elif api_endpoint == "embeddings":
input = request_json["input"]
@@ -428,14 +383,10 @@ def num_tokens_consumed_from_request(
num_tokens = sum([len(encoding.encode(i)) for i in input])
return num_tokens
else:
- raise TypeError(
- 'Expecting either string or list of strings for "inputs" field in embedding request'
- )
+ raise TypeError('Expecting either string or list of strings for "inputs" field in embedding request')
# more logic needed to support other API calls (e.g., edits, inserts, DALL-E)
else:
- raise NotImplementedError(
- f'API endpoint "{api_endpoint}" not implemented in this script'
- )
+ raise NotImplementedError(f'API endpoint "{api_endpoint}" not implemented in this script')
def task_id_generator_function():
@@ -502,4 +453,4 @@ with open(filename, "w") as f:
```
As with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically).
-"""
\ No newline at end of file
+"""
diff --git a/memgpt/personas/examples/docqa/scrape_docs.py b/memgpt/personas/examples/docqa/scrape_docs.py
index 66682694..f02df414 100644
--- a/memgpt/personas/examples/docqa/scrape_docs.py
+++ b/memgpt/personas/examples/docqa/scrape_docs.py
@@ -4,69 +4,65 @@ import tiktoken
import json
# Define the directory where the documentation resides
-docs_dir = 'text'
+docs_dir = "text"
encoding = tiktoken.encoding_for_model("gpt-4")
PASSAGE_TOKEN_LEN = 800
+
def extract_text_from_sphinx_txt(file_path):
lines = []
title = ""
- with open(file_path, 'r', encoding='utf-8') as file:
+ with open(file_path, "r", encoding="utf-8") as file:
for line in file:
if not title:
title = line.strip()
continue
- if line and re.match(r'^.*\S.*$', line) and not re.match(r'^[-=*]+$', line):
+ if line and re.match(r"^.*\S.*$", line) and not re.match(r"^[-=*]+$", line):
lines.append(line)
- passages = []
+ passages = []
curr_passage = []
curr_token_ct = 0
for line in lines:
try:
- line_token_ct = len(encoding.encode(line, allowed_special={'<|endoftext|>'}))
+ line_token_ct = len(encoding.encode(line, allowed_special={"<|endoftext|>"}))
except Exception as e:
print("line", line)
raise e
if line_token_ct > PASSAGE_TOKEN_LEN:
- passages.append({
- 'title': title,
- 'text': line[:3200],
- 'num_tokens': curr_token_ct,
- })
+ passages.append(
+ {
+ "title": title,
+ "text": line[:3200],
+ "num_tokens": curr_token_ct,
+ }
+ )
continue
curr_token_ct += line_token_ct
curr_passage.append(line)
if curr_token_ct > PASSAGE_TOKEN_LEN:
- passages.append({
- 'title': title,
- 'text': ''.join(curr_passage),
- 'num_tokens': curr_token_ct
- })
+ passages.append({"title": title, "text": "".join(curr_passage), "num_tokens": curr_token_ct})
curr_passage = []
curr_token_ct = 0
if len(curr_passage) > 0:
- passages.append({
- 'title': title,
- 'text': ''.join(curr_passage),
- 'num_tokens': curr_token_ct
- })
+ passages.append({"title": title, "text": "".join(curr_passage), "num_tokens": curr_token_ct})
return passages
+
# Iterate over all files in the directory and its subdirectories
passages = []
total_files = 0
for subdir, _, files in os.walk(docs_dir):
for file in files:
- if file.endswith('.txt'):
+ if file.endswith(".txt"):
file_path = os.path.join(subdir, file)
passages.append(extract_text_from_sphinx_txt(file_path))
total_files += 1
print("total .txt files:", total_files)
# Save to a new text file or process as needed
-with open('all_docs.jsonl', 'w', encoding='utf-8') as file:
+with open("all_docs.jsonl", "w", encoding="utf-8") as file:
for p in passages:
file.write(json.dumps(p))
- file.write('\n')
+ file.write("\n")
diff --git a/memgpt/presets.py b/memgpt/presets.py
index 4fad1ed8..76ff8fae 100644
--- a/memgpt/presets.py
+++ b/memgpt/presets.py
@@ -1,30 +1,33 @@
-
from .prompts import gpt_functions
from .prompts import gpt_system
from .agent import AgentAsync
from .utils import printd
-DEFAULT = 'memgpt_chat'
+DEFAULT = "memgpt_chat"
+
def use_preset(preset_name, model, persona, human, interface, persistence_manager):
"""Storing combinations of SYSTEM + FUNCTION prompts"""
- if preset_name == 'memgpt_chat':
-
+ if preset_name == "memgpt_chat":
functions = [
- 'send_message', 'pause_heartbeats',
- 'core_memory_append', 'core_memory_replace',
- 'conversation_search', 'conversation_search_date',
- 'archival_memory_insert', 'archival_memory_search',
+ "send_message",
+ "pause_heartbeats",
+ "core_memory_append",
+ "core_memory_replace",
+ "conversation_search",
+ "conversation_search_date",
+ "archival_memory_insert",
+ "archival_memory_search",
]
- available_functions = [v for k,v in gpt_functions.FUNCTIONS_CHAINING.items() if k in functions]
- printd(f"Available functions:\n", [x['name'] for x in available_functions])
+ available_functions = [v for k, v in gpt_functions.FUNCTIONS_CHAINING.items() if k in functions]
+ printd(f"Available functions:\n", [x["name"] for x in available_functions])
assert len(functions) == len(available_functions)
- if 'gpt-3.5' in model:
+ if "gpt-3.5" in model:
# use a different system message for gpt-3.5
- preset_name = 'memgpt_gpt35_extralong'
+ preset_name = "memgpt_gpt35_extralong"
return AgentAsync(
model=model,
@@ -35,8 +38,8 @@ def use_preset(preset_name, model, persona, human, interface, persistence_manage
persona_notes=persona,
human_notes=human,
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
- first_message_verify_mono=True if 'gpt-4' in model else False,
+ first_message_verify_mono=True if "gpt-4" in model else False,
)
else:
- raise ValueError(preset_name)
\ No newline at end of file
+ raise ValueError(preset_name)
diff --git a/memgpt/prompts/gpt_functions.py b/memgpt/prompts/gpt_functions.py
index a32a545e..060b50c7 100644
--- a/memgpt/prompts/gpt_functions.py
+++ b/memgpt/prompts/gpt_functions.py
@@ -2,9 +2,7 @@ from ..constants import FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT
# FUNCTIONS_PROMPT_MULTISTEP_NO_HEARTBEATS = FUNCTIONS_PROMPT_MULTISTEP[:-1]
FUNCTIONS_CHAINING = {
-
- 'send_message':
- {
+ "send_message": {
"name": "send_message",
"description": "Sends a message to the human user",
"parameters": {
@@ -17,11 +15,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["message"],
- }
+ },
},
-
- 'pause_heartbeats':
- {
+ "pause_heartbeats": {
"name": "pause_heartbeats",
"description": "Temporarily ignore timed heartbeats. You may still receive messages from manual heartbeats and other events.",
"parameters": {
@@ -34,11 +30,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["minutes"],
- }
+ },
},
-
- 'message_chatgpt':
- {
+ "message_chatgpt": {
"name": "message_chatgpt",
"description": "Send a message to a more basic AI, ChatGPT. A useful resource for asking questions. ChatGPT does not retain memory of previous interactions.",
"parameters": {
@@ -55,11 +49,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["message", "request_heartbeat"],
- }
+ },
},
-
- 'core_memory_append':
- {
+ "core_memory_append": {
"name": "core_memory_append",
"description": "Append to the contents of core memory.",
"parameters": {
@@ -79,11 +71,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["name", "content", "request_heartbeat"],
- }
+ },
},
-
- 'core_memory_replace':
- {
+ "core_memory_replace": {
"name": "core_memory_replace",
"description": "Replace to the contents of core memory. To delete memories, use an empty string for new_content.",
"parameters": {
@@ -107,11 +97,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["name", "old_content", "new_content", "request_heartbeat"],
- }
+ },
},
-
- 'recall_memory_search':
- {
+ "recall_memory_search": {
"name": "recall_memory_search",
"description": "Search prior conversation history using a string.",
"parameters": {
@@ -131,11 +119,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["name", "page", "request_heartbeat"],
- }
+ },
},
-
- 'conversation_search':
- {
+ "conversation_search": {
"name": "conversation_search",
"description": "Search prior conversation history using case-insensitive string matching.",
"parameters": {
@@ -155,11 +141,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["name", "page", "request_heartbeat"],
- }
+ },
},
-
- 'recall_memory_search_date':
- {
+ "recall_memory_search_date": {
"name": "recall_memory_search_date",
"description": "Search prior conversation history using a date range.",
"parameters": {
@@ -183,11 +167,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["name", "page", "request_heartbeat"],
- }
+ },
},
-
- 'conversation_search_date':
- {
+ "conversation_search_date": {
"name": "conversation_search_date",
"description": "Search prior conversation history using a date range.",
"parameters": {
@@ -211,11 +193,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["name", "page", "request_heartbeat"],
- }
+ },
},
-
- 'archival_memory_insert':
- {
+ "archival_memory_insert": {
"name": "archival_memory_insert",
"description": "Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.",
"parameters": {
@@ -231,11 +211,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["name", "content", "request_heartbeat"],
- }
+ },
},
-
- 'archival_memory_search':
- {
+ "archival_memory_search": {
"name": "archival_memory_search",
"description": "Search archival memory using semantic (embedding-based) search.",
"parameters": {
@@ -255,7 +233,6 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["name", "query", "page", "request_heartbeat"],
- }
+ },
},
-
-}
\ No newline at end of file
+}
diff --git a/memgpt/prompts/gpt_summarize.py b/memgpt/prompts/gpt_summarize.py
index 619dbf83..95c0e199 100644
--- a/memgpt/prompts/gpt_summarize.py
+++ b/memgpt/prompts/gpt_summarize.py
@@ -1,6 +1,5 @@
WORD_LIMIT = 100
-SYSTEM = \
-f"""
+SYSTEM = f"""
Your job is to summarize a history of previous messages in a conversation between an AI persona and a human.
The conversation you are given is a from a fixed context window and may not be complete.
Messages sent by the AI are marked with the 'assistant' role.
@@ -12,4 +11,4 @@ The 'user' role is also used for important system events, such as login events a
Summarize what happened in the conversation from the perspective of the AI (use the first person).
Keep your summary less than {WORD_LIMIT} words, do NOT exceed this word limit.
Only output the summary, do NOT include anything else in your output.
-"""
\ No newline at end of file
+"""
diff --git a/memgpt/prompts/gpt_system.py b/memgpt/prompts/gpt_system.py
index 2ee8edec..8100b6ee 100644
--- a/memgpt/prompts/gpt_system.py
+++ b/memgpt/prompts/gpt_system.py
@@ -2,11 +2,11 @@ import os
def get_system_text(key):
- filename = f'{key}.txt'
- file_path = os.path.join(os.path.dirname(__file__), 'system', filename)
+ filename = f"{key}.txt"
+ file_path = os.path.join(os.path.dirname(__file__), "system", filename)
if os.path.exists(file_path):
- with open(file_path, 'r') as file:
+ with open(file_path, "r") as file:
return file.read().strip()
else:
raise FileNotFoundError(f"No file found for key {key}, path={file_path}")
diff --git a/memgpt/system.py b/memgpt/system.py
index 116090a5..5993a007 100644
--- a/memgpt/system.py
+++ b/memgpt/system.py
@@ -1,18 +1,22 @@
import json
from .utils import get_local_time
-from .constants import INITIAL_BOOT_MESSAGE, INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT, INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG, MESSAGE_SUMMARY_WARNING_STR
+from .constants import (
+ INITIAL_BOOT_MESSAGE,
+ INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT,
+ INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG,
+ MESSAGE_SUMMARY_WARNING_STR,
+)
-def get_initial_boot_messages(version='startup'):
-
- if version == 'startup':
+def get_initial_boot_messages(version="startup"):
+ if version == "startup":
initial_boot_message = INITIAL_BOOT_MESSAGE
messages = [
{"role": "assistant", "content": initial_boot_message},
]
- elif version == 'startup_with_send_message':
+ elif version == "startup_with_send_message":
messages = [
# first message includes both inner monologue and function call to send_message
{
@@ -20,34 +24,23 @@ def get_initial_boot_messages(version='startup'):
"content": INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT,
"function_call": {
"name": "send_message",
- "arguments": "{\n \"message\": \"" + f"{INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG}" + "\"\n}"
- }
+ "arguments": '{\n "message": "' + f"{INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG}" + '"\n}',
+ },
},
# obligatory function return message
- {
- "role": "function",
- "name": "send_message",
- "content": package_function_response(True, None)
- }
+ {"role": "function", "name": "send_message", "content": package_function_response(True, None)},
]
- elif version == 'startup_with_send_message_gpt35':
+ elif version == "startup_with_send_message_gpt35":
messages = [
# first message includes both inner monologue and function call to send_message
{
"role": "assistant",
"content": "*inner thoughts* Still waiting on the user. Sending a message with function.",
- "function_call": {
- "name": "send_message",
- "arguments": "{\n \"message\": \"" + f"Hi, is anyone there?" + "\"\n}"
- }
+ "function_call": {"name": "send_message", "arguments": '{\n "message": "' + f"Hi, is anyone there?" + '"\n}'},
},
# obligatory function return message
- {
- "role": "function",
- "name": "send_message",
- "content": package_function_response(True, None)
- }
+ {"role": "function", "name": "send_message", "content": package_function_response(True, None)},
]
else:
@@ -56,12 +49,11 @@ def get_initial_boot_messages(version='startup'):
return messages
-def get_heartbeat(reason='Automated timer', include_location=False, location_name='San Francisco, CA, USA'):
-
+def get_heartbeat(reason="Automated timer", include_location=False, location_name="San Francisco, CA, USA"):
# Package the message with time and location
formatted_time = get_local_time()
packaged_message = {
- "type": 'heartbeat',
+ "type": "heartbeat",
"reason": reason,
"time": formatted_time,
}
@@ -72,12 +64,11 @@ def get_heartbeat(reason='Automated timer', include_location=False, location_nam
return json.dumps(packaged_message)
-def get_login_event(last_login='Never (first login)', include_location=False, location_name='San Francisco, CA, USA'):
-
+def get_login_event(last_login="Never (first login)", include_location=False, location_name="San Francisco, CA, USA"):
# Package the message with time and location
formatted_time = get_local_time()
packaged_message = {
- "type": 'login',
+ "type": "login",
"last_login": last_login,
"time": formatted_time,
}
@@ -88,12 +79,11 @@ def get_login_event(last_login='Never (first login)', include_location=False, lo
return json.dumps(packaged_message)
-def package_user_message(user_message, time=None, include_location=False, location_name='San Francisco, CA, USA'):
-
+def package_user_message(user_message, time=None, include_location=False, location_name="San Francisco, CA, USA"):
# Package the message with time and location
formatted_time = time if time else get_local_time()
packaged_message = {
- "type": 'user_message',
+ "type": "user_message",
"message": user_message,
"time": formatted_time,
}
@@ -103,11 +93,11 @@ def package_user_message(user_message, time=None, include_location=False, locati
return json.dumps(packaged_message)
-def package_function_response(was_success, response_string, timestamp=None):
+def package_function_response(was_success, response_string, timestamp=None):
formatted_time = get_local_time() if timestamp is None else timestamp
packaged_message = {
- "status": 'OK' if was_success else 'Failed',
+ "status": "OK" if was_success else "Failed",
"message": response_string,
"time": formatted_time,
}
@@ -116,14 +106,14 @@ def package_function_response(was_success, response_string, timestamp=None):
def package_summarize_message(summary, summary_length, hidden_message_count, total_message_count, timestamp=None):
-
- context_message = \
- f"Note: prior messages ({hidden_message_count} of {total_message_count} total messages) have been hidden from view due to conversation memory constraints.\n" \
+ context_message = (
+ f"Note: prior messages ({hidden_message_count} of {total_message_count} total messages) have been hidden from view due to conversation memory constraints.\n"
+ f"The following is a summary of the previous {summary_length} messages:\n {summary}"
+ )
formatted_time = get_local_time() if timestamp is None else timestamp
packaged_message = {
- "type": 'system_alert',
+ "type": "system_alert",
"message": context_message,
"time": formatted_time,
}
@@ -136,10 +126,13 @@ def package_summarize_message_no_summary(hidden_message_count, timestamp=None, m
# Package the message with time and location
formatted_time = get_local_time() if timestamp is None else timestamp
- context_message = message if message else \
- f"Note: {hidden_message_count} prior messages with the user have been hidden from view due to conversation memory constraints. Older messages are stored in Recall Memory and can be viewed using functions."
+ context_message = (
+ message
+ if message
+ else f"Note: {hidden_message_count} prior messages with the user have been hidden from view due to conversation memory constraints. Older messages are stored in Recall Memory and can be viewed using functions."
+ )
packaged_message = {
- "type": 'system_alert',
+ "type": "system_alert",
"message": context_message,
"time": formatted_time,
}
@@ -148,12 +141,11 @@ def package_summarize_message_no_summary(hidden_message_count, timestamp=None, m
def get_token_limit_warning():
-
formatted_time = get_local_time()
packaged_message = {
- "type": 'system_alert',
+ "type": "system_alert",
"message": MESSAGE_SUMMARY_WARNING_STR,
"time": formatted_time,
}
- return json.dumps(packaged_message)
\ No newline at end of file
+ return json.dumps(packaged_message)
diff --git a/memgpt/utils.py b/memgpt/utils.py
index 560f544f..f9fa614a 100644
--- a/memgpt/utils.py
+++ b/memgpt/utils.py
@@ -19,6 +19,7 @@ from memgpt.constants import MEMGPT_DIR
from llama_index import set_global_service_context, ServiceContext, VectorStoreIndex, load_index_from_storage, StorageContext
from llama_index.embeddings import OpenAIEmbedding
+
def count_tokens(s: str, model: str = "gpt-4") -> int:
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(s))
@@ -169,9 +170,7 @@ def chunk_file(file, tkns_per_chunk=300, model="gpt-4"):
line_token_ct = len(encoding.encode(line))
except Exception as e:
line_token_ct = len(line.split(" ")) / 0.75
- print(
- f"Could not encode line {i}, estimating it to be {line_token_ct} tokens"
- )
+ print(f"Could not encode line {i}, estimating it to be {line_token_ct} tokens")
print(e)
if line_token_ct > tkns_per_chunk:
if len(curr_chunk) > 0:
@@ -195,9 +194,7 @@ def chunk_files(files, tkns_per_chunk=300, model="gpt-4"):
archival_database = []
for file in files:
timestamp = os.path.getmtime(file)
- formatted_time = datetime.fromtimestamp(timestamp).strftime(
- "%Y-%m-%d %I:%M:%S %p %Z%z"
- )
+ formatted_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
file_stem = file.split("/")[-1]
chunks = [c for c in chunk_file(file, tkns_per_chunk, model)]
for i, chunk in enumerate(chunks):
@@ -244,9 +241,7 @@ async def process_concurrently(archival_database, model, concurrency=10):
# Create a list of tasks for chunks
embedding_data = [0 for _ in archival_database]
- tasks = [
- bounded_process_chunk(i, chunk) for i, chunk in enumerate(archival_database)
- ]
+ tasks = [bounded_process_chunk(i, chunk) for i, chunk in enumerate(archival_database)]
for future in tqdm(
asyncio.as_completed(tasks),
@@ -268,15 +263,12 @@ async def prepare_archival_index_from_files_compute_embeddings(
files = sorted(glob.glob(glob_pattern))
save_dir = os.path.join(
MEMGPT_DIR,
- "archival_index_from_files_"
- + get_local_time().replace(" ", "_").replace(":", "_"),
+ "archival_index_from_files_" + get_local_time().replace(" ", "_").replace(":", "_"),
)
os.makedirs(save_dir, exist_ok=True)
total_tokens = total_bytes(glob_pattern) / 3
price_estimate = total_tokens / 1000 * 0.0001
- confirm = input(
- f"Computing embeddings over {len(files)} files. This will cost ~${price_estimate:.2f}. Continue? [y/n] "
- )
+ confirm = input(f"Computing embeddings over {len(files)} files. This will cost ~${price_estimate:.2f}. Continue? [y/n] ")
if confirm != "y":
raise Exception("embeddings were not computed")
@@ -292,9 +284,7 @@ async def prepare_archival_index_from_files_compute_embeddings(
archival_storage_file = os.path.join(save_dir, "all_docs.jsonl")
chunks_by_file = chunk_files_for_jsonl(files, tkns_per_chunk, model)
with open(archival_storage_file, "w") as f:
- print(
- f"Saving archival storage with preloaded files to {archival_storage_file}"
- )
+ print(f"Saving archival storage with preloaded files to {archival_storage_file}")
for c in chunks_by_file:
json.dump(c, f)
f.write("\n")
@@ -341,9 +331,8 @@ def read_database_as_list(database_name):
return result_list
-
-def estimate_openai_cost(docs):
- """ Estimate OpenAI embedding cost
+def estimate_openai_cost(docs):
+ """Estimate OpenAI embedding cost
:param docs: Documents to be embedded
:type docs: List[Document]
@@ -356,18 +345,11 @@ def estimate_openai_cost(docs):
embed_model = MockEmbedding(embed_dim=1536)
- token_counter = TokenCountingHandler(
- tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode
- )
+ token_counter = TokenCountingHandler(tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode)
callback_manager = CallbackManager([token_counter])
- set_global_service_context(
- ServiceContext.from_defaults(
- embed_model=embed_model,
- callback_manager=callback_manager
- )
- )
+ set_global_service_context(ServiceContext.from_defaults(embed_model=embed_model, callback_manager=callback_manager))
index = VectorStoreIndex.from_documents(docs)
# estimate cost
@@ -377,8 +359,7 @@ def estimate_openai_cost(docs):
def get_index(name, docs):
-
- """ Index documents
+ """Index documents
:param docs: Documents to be embedded
:type docs: List[Document]
@@ -398,38 +379,40 @@ def get_index(name, docs):
estimated_cost = estimate_openai_cost(docs)
# TODO: prettier cost formatting
- confirm = typer.confirm(typer.style(f"Open AI embedding cost will be approximately ${estimated_cost} - continue?", fg="yellow"), default=True)
+ confirm = typer.confirm(
+ typer.style(f"Open AI embedding cost will be approximately ${estimated_cost} - continue?", fg="yellow"), default=True
+ )
if not confirm:
typer.secho("Aborting.", fg="red")
exit()
-
+
embed_model = OpenAIEmbedding()
- service_context = ServiceContext.from_defaults(embed_model=embed_model, chunk_size = 300)
+ service_context = ServiceContext.from_defaults(embed_model=embed_model, chunk_size=300)
set_global_service_context(service_context)
# index documents
index = VectorStoreIndex.from_documents(docs)
return index
-def save_index(index, name):
- """ Save index to a specificed name in ~/.memgpt
+def save_index(index, name):
+ """Save index to a specificed name in ~/.memgpt
:param index: Index to save
:type index: VectorStoreIndex
:param name: Name of index
:type name: str
"""
- # save
- # TODO: load directory from config
+ # save
+ # TODO: load directory from config
# TODO: save to vectordb/local depending on config
dir = f"{MEMGPT_DIR}/archival/{name}"
## Avoid overwriting
## check if directory exists
- #if os.path.exists(dir):
+ # if os.path.exists(dir):
# confirm = typer.confirm(typer.style(f"Index with name {name} already exists -- overwrite?", fg="red"), default=False)
# if not confirm:
# typer.secho("Aborting.", fg="red")
diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py
index dc857372..95bec5ce 100644
--- a/tests/test_load_archival.py
+++ b/tests/test_load_archival.py
@@ -9,10 +9,7 @@ import memgpt.presets as presets
import memgpt.constants as constants
import memgpt.personas.personas as personas
import memgpt.humans.humans as humans
-from memgpt.persistence_manager import (
- InMemoryStateManager,
- LocalStateManager
-)
+from memgpt.persistence_manager import InMemoryStateManager, LocalStateManager
from memgpt.config import Config
from memgpt.constants import MEMGPT_DIR, DEFAULT_MEMGPT_MODEL
from memgpt.connectors import connector
@@ -20,6 +17,7 @@ import memgpt.interface # for printing to terminal
import asyncio
from datasets import load_dataset
+
def test_load_directory():
# downloading hugging face dataset (if does not exist)
dataset = load_dataset("MemGPT/example_short_stories")
@@ -30,12 +28,12 @@ def test_load_directory():
# Construct the default path if the environment variable is not set.
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets")
- # load directory
+ # load directory
print("Loading dataset into index...")
print(cache_dir)
load_directory(
name="tmp_hf_dataset",
- input_dir=cache_dir,
+ input_dir=cache_dir,
recursive=True,
)
@@ -51,23 +49,25 @@ def test_load_directory():
memgpt.interface,
persistence_manager,
)
- def query(q):
+
+ def query(q):
res = asyncio.run(memgpt_agent.archival_memory_search(q))
return res
results = query("cinderella be getting sick")
assert "Cinderella" in results, f"Expected 'Cinderella' in results, but got {results}"
-def test_load_webpage():
+
+def test_load_webpage():
pass
-def test_load_database():
+def test_load_database():
from sqlalchemy import create_engine, MetaData
import pandas as pd
db_path = "memgpt/personas/examples/sqldb/test.db"
- engine = create_engine(f'sqlite:///{db_path}')
+ engine = create_engine(f"sqlite:///{db_path}")
# Create a MetaData object and reflect the database to get table information.
metadata = MetaData()
@@ -87,7 +87,7 @@ def test_load_database():
load_database(
name="tmp_db_dataset",
- #engine=engine,
+ # engine=engine,
dump_path=db_path,
query=f"SELECT * FROM {list(table_names)[0]}",
)
@@ -107,7 +107,5 @@ def test_load_database():
assert True
-
-
-#test_load_directory()
-test_load_database()
\ No newline at end of file
+# test_load_directory()
+test_load_database()