add black to poetry and reformat

This commit is contained in:
Sarah Wooders
2023-10-26 15:33:50 -07:00
parent 86bfafdf31
commit 5c44790ad0
26 changed files with 668 additions and 766 deletions

View File

@@ -1,2 +1,3 @@
from .main import app
app()

View File

@@ -11,10 +11,15 @@ from .system import get_heartbeat, get_login_event, package_function_response, p
from .memory import CoreMemory as Memory, summarize_messages
from .openai_tools import acompletions_with_backoff as acreate
from .utils import get_local_time, parse_json, united_diff, printd, count_tokens
from .constants import \
FIRST_MESSAGE_ATTEMPTS, MAX_PAUSE_HEARTBEATS, \
MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MESSAGE_SUMMARY_WARNING_TOKENS, \
CORE_MEMORY_HUMAN_CHAR_LIMIT, CORE_MEMORY_PERSONA_CHAR_LIMIT
from .constants import (
FIRST_MESSAGE_ATTEMPTS,
MAX_PAUSE_HEARTBEATS,
MESSAGE_CHATGPT_FUNCTION_MODEL,
MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE,
MESSAGE_SUMMARY_WARNING_TOKENS,
CORE_MEMORY_HUMAN_CHAR_LIMIT,
CORE_MEMORY_PERSONA_CHAR_LIMIT,
)
def initialize_memory(ai_notes, human_notes):
@@ -28,52 +33,57 @@ def initialize_memory(ai_notes, human_notes):
return memory
def construct_system_with_memory(
system, memory, memory_edit_timestamp,
archival_memory=None, recall_memory=None
):
full_system_message = "\n".join([
system,
"\n",
f"### Memory [last modified: {memory_edit_timestamp}",
f"{len(recall_memory) if recall_memory else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
f"{len(archival_memory) if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)",
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
"<persona>",
memory.persona,
"</persona>",
"<human>",
memory.human,
"</human>",
])
def construct_system_with_memory(system, memory, memory_edit_timestamp, archival_memory=None, recall_memory=None):
full_system_message = "\n".join(
[
system,
"\n",
f"### Memory [last modified: {memory_edit_timestamp}",
f"{len(recall_memory) if recall_memory else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
f"{len(archival_memory) if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)",
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
"<persona>",
memory.persona,
"</persona>",
"<human>",
memory.human,
"</human>",
]
)
return full_system_message
def initialize_message_sequence(
model,
system,
memory,
archival_memory=None,
recall_memory=None,
memory_edit_timestamp=None,
include_initial_boot_message=True,
):
model,
system,
memory,
archival_memory=None,
recall_memory=None,
memory_edit_timestamp=None,
include_initial_boot_message=True,
):
if memory_edit_timestamp is None:
memory_edit_timestamp = get_local_time()
full_system_message = construct_system_with_memory(system, memory, memory_edit_timestamp, archival_memory=archival_memory, recall_memory=recall_memory)
full_system_message = construct_system_with_memory(
system, memory, memory_edit_timestamp, archival_memory=archival_memory, recall_memory=recall_memory
)
first_user_message = get_login_event() # event letting MemGPT know the user just logged in
if include_initial_boot_message:
if 'gpt-3.5' in model:
initial_boot_messages = get_initial_boot_messages('startup_with_send_message_gpt35')
if "gpt-3.5" in model:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35")
else:
initial_boot_messages = get_initial_boot_messages('startup_with_send_message')
messages = [
{"role": "system", "content": full_system_message},
] + initial_boot_messages + [
{"role": "user", "content": first_user_message},
]
initial_boot_messages = get_initial_boot_messages("startup_with_send_message")
messages = (
[
{"role": "system", "content": full_system_message},
]
+ initial_boot_messages
+ [
{"role": "user", "content": first_user_message},
]
)
else:
messages = [
@@ -85,11 +95,11 @@ def initialize_message_sequence(
async def get_ai_reply_async(
model,
message_sequence,
functions,
function_call="auto",
):
model,
message_sequence,
functions,
function_call="auto",
):
"""Base call to GPT API w/ functions"""
try:
@@ -101,11 +111,11 @@ async def get_ai_reply_async(
)
# special case for 'length'
if response.choices[0].finish_reason == 'length':
raise Exception('Finish reason was length (maximum context length)')
if response.choices[0].finish_reason == "length":
raise Exception("Finish reason was length (maximum context length)")
# catches for soft errors
if response.choices[0].finish_reason not in ['stop', 'function_call']:
if response.choices[0].finish_reason not in ["stop", "function_call"]:
raise Exception(f"API call finish with bad finish reason: {response}")
# unpack with response.choices[0].message.content
@@ -118,7 +128,19 @@ async def get_ai_reply_async(
class AgentAsync(object):
"""Core logic for a MemGPT agent"""
def __init__(self, model, system, functions, interface, persistence_manager, persona_notes, human_notes, messages_total=None, persistence_manager_init=True, first_message_verify_mono=True):
def __init__(
self,
model,
system,
functions,
interface,
persistence_manager,
persona_notes,
human_notes,
messages_total=None,
persistence_manager_init=True,
first_message_verify_mono=True,
):
# gpt-4, gpt-3.5-turbo
self.model = model
# Store the system instructions (used to rebuild memory)
@@ -173,7 +195,7 @@ class AgentAsync(object):
@messages.setter
def messages(self, value):
raise Exception('Modifying message list directly not allowed')
raise Exception("Modifying message list directly not allowed")
def trim_messages(self, num):
"""Trim messages from the front, not including the system message"""
@@ -196,16 +218,16 @@ class AgentAsync(object):
# strip extra metadata if it exists
for msg in added_messages:
msg.pop('api_response', None)
msg.pop('api_args', None)
msg.pop("api_response", None)
msg.pop("api_args", None)
new_messages = self.messages + added_messages # append
self._messages = new_messages
self.messages_total += len(added_messages)
def swap_system_message(self, new_system_message):
assert new_system_message['role'] == 'system', new_system_message
assert self.messages[0]['role'] == 'system', self.messages
assert new_system_message["role"] == "system", new_system_message
assert self.messages[0]["role"] == "system", self.messages
self.persistence_manager.swap_system_message(new_system_message)
@@ -223,7 +245,7 @@ class AgentAsync(object):
recall_memory=self.persistence_manager.recall_memory,
)[0]
diff = united_diff(curr_system_message['content'], new_system_message['content'])
diff = united_diff(curr_system_message["content"], new_system_message["content"])
printd(f"Rebuilding system with new memory...\nDiff:\n{diff}")
# Store the memory change (if stateful)
@@ -235,32 +257,32 @@ class AgentAsync(object):
### Local state management
def to_dict(self):
return {
'model': self.model,
'system': self.system,
'functions': self.functions,
'messages': self.messages,
'messages_total': self.messages_total,
'memory': self.memory.to_dict(),
"model": self.model,
"system": self.system,
"functions": self.functions,
"messages": self.messages,
"messages_total": self.messages_total,
"memory": self.memory.to_dict(),
}
def save_to_json_file(self, filename):
with open(filename, 'w') as file:
with open(filename, "w") as file:
json.dump(self.to_dict(), file)
@classmethod
def load(cls, state, interface, persistence_manager):
model = state['model']
system = state['system']
functions = state['functions']
messages = state['messages']
model = state["model"]
system = state["system"]
functions = state["functions"]
messages = state["messages"]
try:
messages_total = state['messages_total']
messages_total = state["messages_total"]
except KeyError:
messages_total = len(messages) - 1
# memory requires a nested load
memory_dict = state['memory']
persona_notes = memory_dict['persona']
human_notes = memory_dict['human']
memory_dict = state["memory"]
persona_notes = memory_dict["persona"]
human_notes = memory_dict["human"]
# Two-part load
new_agent = cls(
@@ -278,18 +300,18 @@ class AgentAsync(object):
return new_agent
def load_inplace(self, state):
self.model = state['model']
self.system = state['system']
self.functions = state['functions']
self.model = state["model"]
self.system = state["system"]
self.functions = state["functions"]
# memory requires a nested load
memory_dict = state['memory']
persona_notes = memory_dict['persona']
human_notes = memory_dict['human']
memory_dict = state["memory"]
persona_notes = memory_dict["persona"]
human_notes = memory_dict["human"]
self.memory = initialize_memory(persona_notes, human_notes)
# messages also
self._messages = state['messages']
self._messages = state["messages"]
try:
self.messages_total = state['messages_total']
self.messages_total = state["messages_total"]
except KeyError:
self.messages_total = len(self.messages) - 1 # -system
@@ -300,14 +322,14 @@ class AgentAsync(object):
@classmethod
def load_from_json_file(cls, json_file, interface, persistence_manager):
with open(json_file, 'r') as file:
with open(json_file, "r") as file:
state = json.load(file)
return cls.load(state, interface, persistence_manager)
def load_from_json_file_inplace(self, json_file):
# Load in-place
# No interface arg needed, we can use the current one
with open(json_file, 'r') as file:
with open(json_file, "r") as file:
state = json.load(file)
self.load_inplace(state)
@@ -317,7 +339,6 @@ class AgentAsync(object):
# Step 2: check if LLM wanted to call a function
if response_message.get("function_call"):
# The content if then internal monologue, not chat
await self.interface.internal_monologue(response_message.content)
messages.append(response_message) # extend conversation with assistant's reply
@@ -348,7 +369,7 @@ class AgentAsync(object):
try:
function_to_call = available_functions[function_name]
except KeyError as e:
error_msg = f'No function named {function_name}'
error_msg = f"No function named {function_name}"
function_response = package_function_response(False, error_msg)
messages.append(
{
@@ -357,7 +378,7 @@ class AgentAsync(object):
"content": function_response,
}
) # extend conversation with function response
await self.interface.function_message(f'Error: {error_msg}')
await self.interface.function_message(f"Error: {error_msg}")
return messages, None, True # force a heartbeat to allow agent to handle error
# Failure case 2: function name is OK, but function args are bad JSON
@@ -374,18 +395,20 @@ class AgentAsync(object):
"content": function_response,
}
) # extend conversation with function response
await self.interface.function_message(f'Error: {error_msg}')
await self.interface.function_message(f"Error: {error_msg}")
return messages, None, True # force a heartbeat to allow agent to handle error
# (Still parsing function args)
# Handle requests for immediate heartbeat
heartbeat_request = function_args.pop('request_heartbeat', None)
heartbeat_request = function_args.pop("request_heartbeat", None)
if not (isinstance(heartbeat_request, bool) or heartbeat_request is None):
printd(f"Warning: 'request_heartbeat' arg parsed was not a bool or None, type={type(heartbeat_request)}, value={heartbeat_request}")
printd(
f"Warning: 'request_heartbeat' arg parsed was not a bool or None, type={type(heartbeat_request)}, value={heartbeat_request}"
)
heartbeat_request = None
# Failure case 3: function failed during execution
await self.interface.function_message(f'Running {function_name}({function_args})')
await self.interface.function_message(f"Running {function_name}({function_args})")
try:
function_response_string = await function_to_call(**function_args)
function_response = package_function_response(True, function_response_string)
@@ -401,12 +424,12 @@ class AgentAsync(object):
"content": function_response,
}
) # extend conversation with function response
await self.interface.function_message(f'Error: {error_msg}')
await self.interface.function_message(f"Error: {error_msg}")
return messages, None, True # force a heartbeat to allow agent to handle error
# If no failures happened along the way: ...
# Step 4: send the info on the function call and function response to GPT
await self.interface.function_message(f'Success: {function_response_string}')
await self.interface.function_message(f"Success: {function_response_string}")
messages.append(
{
"role": "function",
@@ -434,25 +457,29 @@ class AgentAsync(object):
return False
function_name = response_message["function_call"]["name"]
if require_send_message and function_name != 'send_message':
if require_send_message and function_name != "send_message":
printd(f"First message function call wasn't send_message: {response_message}")
return False
if require_monologue and (not response_message.get("content") or response_message["content"] is None or response_message["content"] == ""):
if require_monologue and (
not response_message.get("content") or response_message["content"] is None or response_message["content"] == ""
):
printd(f"First message missing internal monologue: {response_message}")
return False
if response_message.get("content"):
### Extras
monologue = response_message.get("content")
def contains_special_characters(s):
special_characters = '(){}[]"'
return any(char in s for char in special_characters)
if contains_special_characters(monologue):
printd(f"First message internal monologue contained special characters: {response_message}")
return False
# if 'functions' in monologue or 'send_message' in monologue or 'inner thought' in monologue.lower():
if 'functions' in monologue or 'send_message' in monologue:
if "functions" in monologue or "send_message" in monologue:
# Sometimes the syntax won't be correct and internal syntax will leak into message.context
printd(f"First message internal monologue contained reserved words: {response_message}")
return False
@@ -466,12 +493,12 @@ class AgentAsync(object):
# Step 0: add user message
if user_message is not None:
await self.interface.user_message(user_message)
packed_user_message = {'role': 'user', 'content': user_message}
packed_user_message = {"role": "user", "content": user_message}
input_message_sequence = self.messages + [packed_user_message]
else:
input_message_sequence = self.messages
if len(input_message_sequence) > 1 and input_message_sequence[-1]['role'] != 'user':
if len(input_message_sequence) > 1 and input_message_sequence[-1]["role"] != "user":
printd(f"WARNING: attempting to run ChatCompletion without user as the last message in the queue")
# Step 1: send the conversation and available functions to GPT
@@ -479,14 +506,13 @@ class AgentAsync(object):
printd(f"This is the first message. Running extra verifier on AI response.")
counter = 0
while True:
response = await get_ai_reply_async(model=self.model, message_sequence=input_message_sequence, functions=self.functions)
if self.verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono):
break
counter += 1
if counter > first_message_retry_limit:
raise Exception(f'Hit first message retry limit ({first_message_retry_limit})')
raise Exception(f"Hit first message retry limit ({first_message_retry_limit})")
else:
response = await get_ai_reply_async(model=self.model, message_sequence=input_message_sequence, functions=self.functions)
@@ -500,13 +526,13 @@ class AgentAsync(object):
# Add the extra metadata to the assistant response
# (e.g. enough metadata to enable recreating the API call)
assert 'api_response' not in all_response_messages[0]
all_response_messages[0]['api_response'] = response_message_copy
assert 'api_args' not in all_response_messages[0]
all_response_messages[0]['api_args'] = {
'model': self.model,
'messages': input_message_sequence,
'functions': self.functions,
assert "api_response" not in all_response_messages[0]
all_response_messages[0]["api_response"] = response_message_copy
assert "api_args" not in all_response_messages[0]
all_response_messages[0]["api_args"] = {
"model": self.model,
"messages": input_message_sequence,
"functions": self.functions,
}
# Step 4: extend the message history
@@ -516,7 +542,7 @@ class AgentAsync(object):
all_new_messages = all_response_messages
# Check the memory pressure and potentially issue a memory pressure warning
current_total_tokens = response['usage']['total_tokens']
current_total_tokens = response["usage"]["total_tokens"]
active_memory_warning = False
if current_total_tokens > MESSAGE_SUMMARY_WARNING_TOKENS:
printd(f"WARNING: last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_TOKENS}")
@@ -534,7 +560,7 @@ class AgentAsync(object):
printd(f"step() failed\nuser_message = {user_message}\nerror = {e}")
# If we got a context alert, try trimming the messages length, then try again
if 'maximum context length' in str(e):
if "maximum context length" in str(e):
# A separate API call to run a summarizer
await self.summarize_messages_inplace()
@@ -546,21 +572,21 @@ class AgentAsync(object):
async def summarize_messages_inplace(self, cutoff=None):
if cutoff is None:
tokens_so_far = 0 # Smart cutoff -- just below the max.
tokens_so_far = 0 # Smart cutoff -- just below the max.
cutoff = len(self.messages) - 1
for m in reversed(self.messages):
tokens_so_far += count_tokens(str(m), self.model)
if tokens_so_far >= MESSAGE_SUMMARY_WARNING_TOKENS*0.2:
if tokens_so_far >= MESSAGE_SUMMARY_WARNING_TOKENS * 0.2:
break
cutoff -= 1
cutoff = min(len(self.messages) - 3, cutoff) # Always keep the last two messages too
cutoff = min(len(self.messages) - 3, cutoff) # Always keep the last two messages too
# Try to make an assistant message come after the cutoff
try:
printd(f"Selected cutoff {cutoff} was a 'user', shifting one...")
if self.messages[cutoff]['role'] == 'user':
if self.messages[cutoff]["role"] == "user":
new_cutoff = cutoff + 1
if self.messages[new_cutoff]['role'] == 'user':
if self.messages[new_cutoff]["role"] == "user":
printd(f"Shifted cutoff {new_cutoff} is still a 'user', ignoring...")
cutoff = new_cutoff
except IndexError:
@@ -600,11 +626,11 @@ class AgentAsync(object):
while limit is None or step_count < limit:
if function_failed:
user_message = get_heartbeat('Function call failed')
user_message = get_heartbeat("Function call failed")
new_messages, heartbeat_request, function_failed = await self.step(user_message)
step_count += 1
elif heartbeat_request:
user_message = get_heartbeat('AI requested')
user_message = get_heartbeat("AI requested")
new_messages, heartbeat_request, function_failed = await self.step(user_message)
step_count += 1
else:
@@ -638,7 +664,7 @@ class AgentAsync(object):
return None
async def recall_memory_search(self, query, count=5, page=0):
results, total = await self.persistence_manager.recall_memory.text_search(query, count=count, start=page*count)
results, total = await self.persistence_manager.recall_memory.text_search(query, count=count, start=page * count)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."
@@ -649,7 +675,7 @@ class AgentAsync(object):
return results_str
async def recall_memory_search_date(self, start_date, end_date, count=5, page=0):
results, total = await self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page*count)
results, total = await self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page * count)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."
@@ -664,7 +690,7 @@ class AgentAsync(object):
return None
async def archival_memory_search(self, query, count=5, page=0):
results, total = await self.persistence_manager.archival_memory.search(query, count=count, start=page*count)
results, total = await self.persistence_manager.archival_memory.search(query, count=count, start=page * count)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."
@@ -683,7 +709,7 @@ class AgentAsync(object):
# And record how long the pause should go for
self.pause_heartbeats_minutes = int(minutes)
return f'Pausing timed heartbeats for {minutes} min'
return f"Pausing timed heartbeats for {minutes} min"
def heartbeat_is_paused(self):
"""Check if there's a requested pause on timed heartbeats"""
@@ -700,8 +726,8 @@ class AgentAsync(object):
"""Base call to GPT API w/ functions"""
message_sequence = [
{'role': 'system', 'content': MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE},
{'role': 'user', 'content': str(message)},
{"role": "system", "content": MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE},
{"role": "user", "content": str(message)},
]
response = await acreate(
model=MESSAGE_CHATGPT_FUNCTION_MODEL,

View File

@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
class AgentAsyncBase(ABC):
@abstractmethod
async def step(self, user_message):
pass
pass

View File

@@ -68,41 +68,25 @@ class AutoGenInterface(object):
print(f"inner thoughts :: {msg}")
if not self.show_inner_thoughts:
return
message = (
f"\x1B[3m{Fore.LIGHTBLACK_EX}💭 {msg}{Style.RESET_ALL}"
if self.fancy
else f"[inner thoughts] {msg}"
)
message = f"\x1B[3m{Fore.LIGHTBLACK_EX}💭 {msg}{Style.RESET_ALL}" if self.fancy else f"[inner thoughts] {msg}"
self.message_list.append(message)
async def assistant_message(self, msg):
if self.debug:
print(f"assistant :: {msg}")
message = (
f"{Fore.YELLOW}{Style.BRIGHT}🤖 {Fore.YELLOW}{msg}{Style.RESET_ALL}"
if self.fancy
else msg
)
message = f"{Fore.YELLOW}{Style.BRIGHT}🤖 {Fore.YELLOW}{msg}{Style.RESET_ALL}" if self.fancy else msg
self.message_list.append(message)
async def memory_message(self, msg):
if self.debug:
print(f"memory :: {msg}")
message = (
f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{msg}{Style.RESET_ALL}"
if self.fancy
else f"[memory] {msg}"
)
message = f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{msg}{Style.RESET_ALL}" if self.fancy else f"[memory] {msg}"
self.message_list.append(message)
async def system_message(self, msg):
if self.debug:
print(f"system :: {msg}")
message = (
f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}"
if self.fancy
else f"[system] {msg}"
)
message = f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}" if self.fancy else f"[system] {msg}"
self.message_list.append(message)
async def user_message(self, msg, raw=False):
@@ -113,11 +97,7 @@ class AutoGenInterface(object):
if isinstance(msg, str):
if raw:
message = (
f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}"
if self.fancy
else f"[user] {msg}"
)
message = f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}" if self.fancy else f"[user] {msg}"
self.message_list.append(message)
return
else:
@@ -125,42 +105,24 @@ class AutoGenInterface(object):
msg_json = json.loads(msg)
except:
print(f"Warning: failed to parse user message into json")
message = (
f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}"
if self.fancy
else f"[user] {msg}"
)
message = f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}" if self.fancy else f"[user] {msg}"
self.message_list.append(message)
return
if msg_json["type"] == "user_message":
msg_json.pop("type")
message = (
f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}"
if self.fancy
else f"[user] {msg}"
)
message = f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}" if self.fancy else f"[user] {msg}"
elif msg_json["type"] == "heartbeat":
if True or DEBUG:
msg_json.pop("type")
message = (
f"{Fore.GREEN}{Style.BRIGHT}💓 {Fore.GREEN}{msg_json}{Style.RESET_ALL}"
if self.fancy
else f"[system heartbeat] {msg}"
f"{Fore.GREEN}{Style.BRIGHT}💓 {Fore.GREEN}{msg_json}{Style.RESET_ALL}" if self.fancy else f"[system heartbeat] {msg}"
)
elif msg_json["type"] == "system_message":
msg_json.pop("type")
message = (
f"{Fore.GREEN}{Style.BRIGHT}🖥️ {Fore.GREEN}{msg_json}{Style.RESET_ALL}"
if self.fancy
else f"[system] {msg}"
)
message = f"{Fore.GREEN}{Style.BRIGHT}🖥️ {Fore.GREEN}{msg_json}{Style.RESET_ALL}" if self.fancy else f"[system] {msg}"
else:
message = (
f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}"
if self.fancy
else f"[user] {msg}"
)
message = f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}" if self.fancy else f"[user] {msg}"
self.message_list.append(message)
@@ -171,31 +133,19 @@ class AutoGenInterface(object):
return
if isinstance(msg, dict):
message = (
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
)
message = f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
self.message_list.append(message)
return
if msg.startswith("Success: "):
message = (
f"{Fore.RED}{Style.BRIGHT}⚡🟢 [function] {Fore.RED}{msg}{Style.RESET_ALL}"
if self.fancy
else f"[function - OK] {msg}"
)
message = f"{Fore.RED}{Style.BRIGHT}⚡🟢 [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function - OK] {msg}"
elif msg.startswith("Error: "):
message = (
f"{Fore.RED}{Style.BRIGHT}⚡🔴 [function] {Fore.RED}{msg}{Style.RESET_ALL}"
if self.fancy
else f"[function - error] {msg}"
f"{Fore.RED}{Style.BRIGHT}⚡🔴 [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function - error] {msg}"
)
elif msg.startswith("Running "):
if DEBUG:
message = (
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
if self.fancy
else f"[function] {msg}"
)
message = f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}"
else:
if "memory" in msg:
match = re.search(r"Running (\w+)\((.*)\)", msg)
@@ -227,35 +177,25 @@ class AutoGenInterface(object):
else:
print(f"Warning: did not recognize function message")
message = (
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
if self.fancy
else f"[function] {msg}"
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}"
)
elif "send_message" in msg:
# ignore in debug mode
message = None
else:
message = (
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
if self.fancy
else f"[function] {msg}"
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}"
)
else:
try:
msg_dict = json.loads(msg)
if "status" in msg_dict and msg_dict["status"] == "OK":
message = (
f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}"
if self.fancy
else f"[function] {msg}"
f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}"
)
except Exception:
print(f"Warning: did not recognize function message {type(msg)} {msg}")
message = (
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
if self.fancy
else f"[function] {msg}"
)
message = f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}"
if message:
self.message_list.append(message)

View File

@@ -55,11 +55,7 @@ def create_autogen_memgpt_agent(
```
"""
interface = AutoGenInterface(**interface_kwargs) if interface is None else interface
persistence_manager = (
InMemoryStateManager(**persistence_manager_kwargs)
if persistence_manager is None
else persistence_manager
)
persistence_manager = InMemoryStateManager(**persistence_manager_kwargs) if persistence_manager is None else persistence_manager
memgpt_agent = presets.use_preset(
preset,
@@ -89,9 +85,7 @@ class MemGPTAgent(ConversableAgent):
self.agent = agent
self.skip_verify = skip_verify
self.concat_other_agent_messages = concat_other_agent_messages
self.register_reply(
[Agent, None], MemGPTAgent._a_generate_reply_for_user_message
)
self.register_reply([Agent, None], MemGPTAgent._a_generate_reply_for_user_message)
self.register_reply([Agent, None], MemGPTAgent._generate_reply_for_user_message)
self.messages_processed_up_to_idx = 0
@@ -119,11 +113,7 @@ class MemGPTAgent(ConversableAgent):
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
return asyncio.run(
self._a_generate_reply_for_user_message(
messages=messages, sender=sender, config=config
)
)
return asyncio.run(self._a_generate_reply_for_user_message(messages=messages, sender=sender, config=config))
async def _a_generate_reply_for_user_message(
self,
@@ -137,9 +127,7 @@ class MemGPTAgent(ConversableAgent):
if len(new_messages) > 1:
if self.concat_other_agent_messages:
# Combine all the other messages into one message
user_message = "\n".join(
[self.format_other_agent_message(m) for m in new_messages]
)
user_message = "\n".join([self.format_other_agent_message(m) for m in new_messages])
else:
# Extend the MemGPT message list with multiple 'user' messages, then push the last one with agent.step()
self.agent.messages.extend(new_messages[:-1])
@@ -157,16 +145,12 @@ class MemGPTAgent(ConversableAgent):
heartbeat_request,
function_failed,
token_warning,
) = await self.agent.step(
user_message, first_message=False, skip_verify=self.skip_verify
)
) = await self.agent.step(user_message, first_message=False, skip_verify=self.skip_verify)
# Skip user inputs if there's a memory warning, function execution failed, or the agent asked for control
if token_warning:
user_message = system.get_token_limit_warning()
elif function_failed:
user_message = system.get_heartbeat(
constants.FUNC_FAILED_HEARTBEAT_MESSAGE
)
user_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE)
elif heartbeat_request:
user_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE)
else:

View File

@@ -79,12 +79,8 @@ class Config:
cfg = Config.get_most_recent_config()
use_cfg = False
if cfg:
print(
f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Found saved config file.{Style.RESET_ALL}"
)
use_cfg = await questionary.confirm(
f"Use most recent config file '{cfg}'?"
).ask_async()
print(f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Found saved config file.{Style.RESET_ALL}")
use_cfg = await questionary.confirm(f"Use most recent config file '{cfg}'?").ask_async()
if use_cfg:
self.config_file = cfg
@@ -105,9 +101,7 @@ class Config:
return self
# print("No settings file found, configuring MemGPT...")
print(
f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ No settings file found, configuring MemGPT...{Style.RESET_ALL}"
)
print(f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ No settings file found, configuring MemGPT...{Style.RESET_ALL}")
self.model = await questionary.select(
"Which model would you like to use?",
@@ -127,9 +121,7 @@ class Config:
).ask_async()
self.archival_storage_index = None
self.preload_archival = await questionary.confirm(
"Would you like to preload anything into MemGPT's archival memory?"
).ask_async()
self.preload_archival = await questionary.confirm("Would you like to preload anything into MemGPT's archival memory?").ask_async()
if self.preload_archival:
self.load_type = await questionary.select(
"What would you like to load?",
@@ -140,19 +132,13 @@ class Config:
],
).ask_async()
if self.load_type == "folder" or self.load_type == "sql":
archival_storage_path = await questionary.path(
"Please enter the folder or file (tab for autocomplete):"
).ask_async()
archival_storage_path = await questionary.path("Please enter the folder or file (tab for autocomplete):").ask_async()
if os.path.isdir(archival_storage_path):
self.archival_storage_files = os.path.join(
archival_storage_path, "*"
)
self.archival_storage_files = os.path.join(archival_storage_path, "*")
else:
self.archival_storage_files = archival_storage_path
else:
self.archival_storage_files = await questionary.path(
"Please enter the glob pattern (tab for autocomplete):"
).ask_async()
self.archival_storage_files = await questionary.path("Please enter the glob pattern (tab for autocomplete):").ask_async()
self.compute_embeddings = await questionary.confirm(
"Would you like to compute embeddings over these files to enable embeddings search?"
).ask_async()
@@ -168,19 +154,11 @@ class Config:
"⛔️ Embeddings on a non-OpenAI endpoint are not yet supported, falling back to substring matching search."
)
else:
self.archival_storage_index = (
await utils.prepare_archival_index_from_files_compute_embeddings(
self.archival_storage_files
)
)
self.archival_storage_index = await utils.prepare_archival_index_from_files_compute_embeddings(self.archival_storage_files)
if self.compute_embeddings and self.archival_storage_index:
self.index, self.archival_database = utils.prepare_archival_index(
self.archival_storage_index
)
self.index, self.archival_database = utils.prepare_archival_index(self.archival_storage_index)
else:
self.archival_database = utils.prepare_archival_index_from_files(
self.archival_storage_files
)
self.archival_database = utils.prepare_archival_index_from_files(self.archival_storage_files)
def to_dict(self):
return {
@@ -217,15 +195,11 @@ class Config:
configs_dir = Config.configs_dir
os.makedirs(configs_dir, exist_ok=True)
if self.config_file is None:
filename = os.path.join(
configs_dir, utils.get_local_time().replace(" ", "_").replace(":", "_")
)
filename = os.path.join(configs_dir, utils.get_local_time().replace(" ", "_").replace(":", "_"))
self.config_file = f"{filename}.json"
with open(self.config_file, "wt") as f:
json.dump(self.to_dict(), f, indent=4)
print(
f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Saved config file to {self.config_file}.{Style.RESET_ALL}"
)
print(f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Saved config file to {self.config_file}.{Style.RESET_ALL}")
@staticmethod
def is_valid_config_file(file: str):
@@ -234,9 +208,7 @@ class Config:
cfg.load_config(file)
except Exception:
return False
return (
cfg.memgpt_persona is not None and cfg.human_persona is not None
) # TODO: more validation for configs
return cfg.memgpt_persona is not None and cfg.human_persona is not None # TODO: more validation for configs
@staticmethod
def get_memgpt_personas():
@@ -331,8 +303,7 @@ class Config:
files = [
os.path.join(configs_dir, f)
for f in os.listdir(configs_dir)
if os.path.isfile(os.path.join(configs_dir, f))
and Config.is_valid_config_file(os.path.join(configs_dir, f))
if os.path.isfile(os.path.join(configs_dir, f)) and Config.is_valid_config_file(os.path.join(configs_dir, f))
]
# Return the file with the most recent modification time
if len(files) == 0:

View File

@@ -7,9 +7,7 @@ DEFAULT_MEMGPT_MODEL = "gpt-4"
FIRST_MESSAGE_ATTEMPTS = 10
INITIAL_BOOT_MESSAGE = "Boot sequence complete. Persona activated."
INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT = (
"Bootup sequence complete. Persona activated. Testing messaging functionality."
)
INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT = "Bootup sequence complete. Persona activated. Testing messaging functionality."
STARTUP_QUOTES = [
"I think, therefore I am.",
"All those moments will be lost in time, like tears in rain.",
@@ -28,9 +26,7 @@ CORE_MEMORY_HUMAN_CHAR_LIMIT = 2000
MAX_PAUSE_HEARTBEATS = 360 # in min
MESSAGE_CHATGPT_FUNCTION_MODEL = "gpt-3.5-turbo"
MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE = (
"You are a helpful assistant. Keep your responses short and concise."
)
MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE = "You are a helpful assistant. Keep your responses short and concise."
#### Functions related

View File

@@ -29,15 +29,11 @@ async def assistant_message(msg):
async def memory_message(msg):
print(
f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{msg}{Style.RESET_ALL}"
)
print(f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{msg}{Style.RESET_ALL}")
async def system_message(msg):
printd(
f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}"
)
printd(f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}")
async def user_message(msg, raw=False):
@@ -50,9 +46,7 @@ async def user_message(msg, raw=False):
msg_json = json.loads(msg)
except:
printd(f"Warning: failed to parse user message into json")
printd(
f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}"
)
printd(f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}")
return
if msg_json["type"] == "user_message":
@@ -61,9 +55,7 @@ async def user_message(msg, raw=False):
elif msg_json["type"] == "heartbeat":
if DEBUG:
msg_json.pop("type")
printd(
f"{Fore.GREEN}{Style.BRIGHT}💓 {Fore.GREEN}{msg_json}{Style.RESET_ALL}"
)
printd(f"{Fore.GREEN}{Style.BRIGHT}💓 {Fore.GREEN}{msg_json}{Style.RESET_ALL}")
elif msg_json["type"] == "system_message":
msg_json.pop("type")
printd(f"{Fore.GREEN}{Style.BRIGHT}🖥️ {Fore.GREEN}{msg_json}{Style.RESET_ALL}")
@@ -77,33 +69,23 @@ async def function_message(msg):
return
if msg.startswith("Success: "):
printd(
f"{Fore.RED}{Style.BRIGHT}⚡🟢 [function] {Fore.RED}{msg}{Style.RESET_ALL}"
)
printd(f"{Fore.RED}{Style.BRIGHT}⚡🟢 [function] {Fore.RED}{msg}{Style.RESET_ALL}")
elif msg.startswith("Error: "):
printd(
f"{Fore.RED}{Style.BRIGHT}⚡🔴 [function] {Fore.RED}{msg}{Style.RESET_ALL}"
)
printd(f"{Fore.RED}{Style.BRIGHT}⚡🔴 [function] {Fore.RED}{msg}{Style.RESET_ALL}")
elif msg.startswith("Running "):
if DEBUG:
printd(
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
)
printd(f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}")
else:
if "memory" in msg:
match = re.search(r"Running (\w+)\((.*)\)", msg)
if match:
function_name = match.group(1)
function_args = match.group(2)
print(
f"{Fore.RED}{Style.BRIGHT}⚡🧠 [function] {Fore.RED}updating memory with {function_name}{Style.RESET_ALL}:"
)
print(f"{Fore.RED}{Style.BRIGHT}⚡🧠 [function] {Fore.RED}updating memory with {function_name}{Style.RESET_ALL}:")
try:
msg_dict = eval(function_args)
if function_name == "archival_memory_search":
print(
f'{Fore.RED}\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}'
)
print(f'{Fore.RED}\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}')
else:
print(
f'{Fore.RED}{Style.BRIGHT}\t{Fore.RED} {msg_dict["old_content"]}\n\t{Fore.GREEN}{msg_dict["new_content"]}'
@@ -114,28 +96,20 @@ async def function_message(msg):
pass
else:
printd(f"Warning: did not recognize function message")
printd(
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
)
printd(f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}")
elif "send_message" in msg:
# ignore in debug mode
pass
else:
printd(
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
)
printd(f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}")
else:
try:
msg_dict = json.loads(msg)
if "status" in msg_dict and msg_dict["status"] == "OK":
printd(
f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}"
)
printd(f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}")
except Exception:
printd(f"Warning: did not recognize function message {type(msg)} {msg}")
printd(
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
)
printd(f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}")
async def print_messages(message_sequence):

View File

@@ -190,9 +190,7 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
function_parameters = function_json_output["params"]
if self.clean_func_args:
function_name, function_parameters = self.clean_function_args(
function_name, function_parameters
)
function_name, function_parameters = self.clean_function_args(function_name, function_parameters)
message = {
"role": "assistant",
@@ -275,9 +273,7 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
func_str += f"\n description: {schema['description']}"
func_str += f"\n params:"
if add_inner_thoughts:
func_str += (
f"\n inner_thoughts: Deep inner monologue private to you only."
)
func_str += f"\n inner_thoughts: Deep inner monologue private to you only."
for param_k, param_v in schema["parameters"]["properties"].items():
# TODO we're ignoring type
func_str += f"\n {param_k}: {param_v['description']}"

View File

@@ -152,9 +152,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
try:
content_json = json.loads(message["content"])
content_simple = content_json["message"]
prompt += (
f"\n{IM_START_TOKEN}user\n{content_simple}{IM_END_TOKEN}"
)
prompt += f"\n{IM_START_TOKEN}user\n{content_simple}{IM_END_TOKEN}"
# prompt += f"\nUSER: {content_simple}"
except:
prompt += f"\n{IM_START_TOKEN}user\n{message['content']}{IM_END_TOKEN}"
@@ -227,9 +225,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
function_parameters = function_json_output["params"]
if self.clean_func_args:
function_name, function_parameters = self.clean_function_args(
function_name, function_parameters
)
function_name, function_parameters = self.clean_function_args(function_name, function_parameters)
message = {
"role": "assistant",

View File

@@ -83,12 +83,8 @@ def load(memgpt_agent, filename):
print(f"Loading {filename} failed with: {e}")
else:
# Load the latest file
print(
f"/load warning: no checkpoint specified, loading most recent checkpoint instead"
)
json_files = glob.glob(
"saved_state/*.json"
) # This will list all .json files in the current directory.
print(f"/load warning: no checkpoint specified, loading most recent checkpoint instead")
json_files = glob.glob("saved_state/*.json") # This will list all .json files in the current directory.
# Check if there are any json files.
if not json_files:
@@ -110,27 +106,17 @@ def load(memgpt_agent, filename):
) # TODO(fixme):for different types of persistence managers that require different load/save methods
print(f"Loaded persistence manager from {filename}")
except Exception as e:
print(
f"/load warning: loading persistence manager from {filename} failed with: {e}"
)
print(f"/load warning: loading persistence manager from {filename} failed with: {e}")
@app.command()
def run(
persona: str = typer.Option(None, help="Specify persona"),
human: str = typer.Option(None, help="Specify human"),
model: str = typer.Option(
constants.DEFAULT_MEMGPT_MODEL, help="Specify the LLM model"
),
first: bool = typer.Option(
False, "--first", help="Use --first to send the first message in the sequence"
),
debug: bool = typer.Option(
False, "--debug", help="Use --debug to enable debugging output"
),
no_verify: bool = typer.Option(
False, "--no_verify", help="Bypass message verification"
),
model: str = typer.Option(constants.DEFAULT_MEMGPT_MODEL, help="Specify the LLM model"),
first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"),
debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"),
no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"),
archival_storage_faiss_path: str = typer.Option(
"",
"--archival_storage_faiss_path",
@@ -200,9 +186,7 @@ async def main(
else:
azure_vars = get_set_azure_env_vars()
if len(azure_vars) > 0:
print(
f"Error: Environment variables {', '.join([x[0] for x in azure_vars])} should not be set if --use_azure_openai is False"
)
print(f"Error: Environment variables {', '.join([x[0] for x in azure_vars])} should not be set if --use_azure_openai is False")
return
if any(
@@ -295,23 +279,17 @@ async def main(
else:
cfg = await Config.config_init()
memgpt.interface.important_message(
"Running... [exit by typing '/exit', list available commands with '/help']"
)
memgpt.interface.important_message("Running... [exit by typing '/exit', list available commands with '/help']")
if cfg.model != constants.DEFAULT_MEMGPT_MODEL:
memgpt.interface.warning_message(
f"⛔️ Warning - you are running MemGPT with {cfg.model}, which is not officially supported (yet). Expect bugs!"
)
if cfg.index:
persistence_manager = InMemoryStateManagerWithFaiss(
cfg.index, cfg.archival_database
)
persistence_manager = InMemoryStateManagerWithFaiss(cfg.index, cfg.archival_database)
elif cfg.archival_storage_files:
print(f"Preloaded {len(cfg.archival_database)} chunks into archival memory.")
persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(
cfg.archival_database
)
persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(cfg.archival_database)
else:
persistence_manager = InMemoryStateManager()
@@ -355,9 +333,7 @@ async def main(
print(f"Database loaded into archival memory.")
if cfg.agent_save_file:
load_save_file = await questionary.confirm(
f"Load in saved agent '{cfg.agent_save_file}'?"
).ask_async()
load_save_file = await questionary.confirm(f"Load in saved agent '{cfg.agent_save_file}'?").ask_async()
if load_save_file:
load(memgpt_agent, cfg.agent_save_file)
@@ -366,9 +342,7 @@ async def main(
return
if not USER_GOES_FIRST:
console.input(
"[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]"
)
console.input("[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]")
clear_line()
print()
@@ -404,9 +378,7 @@ async def main(
break
elif user_input.lower() == "/savechat":
filename = (
utils.get_local_time().replace(" ", "_").replace(":", "_")
)
filename = utils.get_local_time().replace(" ", "_").replace(":", "_")
filename = f"{filename}.pkl"
directory = os.path.join(MEMGPT_DIR, "saved_chats")
try:
@@ -423,9 +395,7 @@ async def main(
save(memgpt_agent=memgpt_agent, cfg=cfg)
continue
elif user_input.lower() == "/load" or user_input.lower().startswith(
"/load "
):
elif user_input.lower() == "/load" or user_input.lower().startswith("/load "):
command = user_input.strip().split()
filename = command[1] if len(command) > 1 else None
load(memgpt_agent=memgpt_agent, filename=filename)
@@ -458,16 +428,10 @@ async def main(
print(f"Updated model to:\n{str(memgpt_agent.model)}")
continue
elif user_input.lower() == "/pop" or user_input.lower().startswith(
"/pop "
):
elif user_input.lower() == "/pop" or user_input.lower().startswith("/pop "):
# Check if there's an additional argument that's an integer
command = user_input.strip().split()
amount = (
int(command[1])
if len(command) > 1 and command[1].isdigit()
else 2
)
amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 2
print(f"Popping last {amount} messages from stack")
for _ in range(min(amount, len(memgpt_agent.messages))):
memgpt_agent.messages.pop()
@@ -512,18 +476,14 @@ async def main(
heartbeat_request,
function_failed,
token_warning,
) = await memgpt_agent.step(
user_message, first_message=False, skip_verify=no_verify
)
) = await memgpt_agent.step(user_message, first_message=False, skip_verify=no_verify)
# Skip user inputs if there's a memory warning, function execution failed, or the agent asked for control
if token_warning:
user_message = system.get_token_limit_warning()
skip_next_user_input = True
elif function_failed:
user_message = system.get_heartbeat(
constants.FUNC_FAILED_HEARTBEAT_MESSAGE
)
user_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE)
skip_next_user_input = True
elif heartbeat_request:
user_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE)

View File

@@ -28,20 +28,17 @@ class CoreMemory(object):
self.archival_memory_exists = archival_memory_exists
def __repr__(self) -> str:
return \
f"\n### CORE MEMORY ###" + \
f"\n=== Persona ===\n{self.persona}" + \
f"\n\n=== Human ===\n{self.human}"
return f"\n### CORE MEMORY ###" + f"\n=== Persona ===\n{self.persona}" + f"\n\n=== Human ===\n{self.human}"
def to_dict(self):
return {
'persona': self.persona,
'human': self.human,
"persona": self.persona,
"human": self.human,
}
@classmethod
def load(cls, state):
return cls(state['persona'], state['human'])
return cls(state["persona"], state["human"])
def edit_persona(self, new_persona):
if self.persona_char_limit and len(new_persona) > self.persona_char_limit:
@@ -64,53 +61,55 @@ class CoreMemory(object):
return len(self.human)
def edit(self, field, content):
if field == 'persona':
if field == "persona":
return self.edit_persona(content)
elif field == 'human':
elif field == "human":
return self.edit_human(content)
else:
raise KeyError
def edit_append(self, field, content, sep='\n'):
if field == 'persona':
def edit_append(self, field, content, sep="\n"):
if field == "persona":
new_content = self.persona + sep + content
return self.edit_persona(new_content)
elif field == 'human':
elif field == "human":
new_content = self.human + sep + content
return self.edit_human(new_content)
else:
raise KeyError
def edit_replace(self, field, old_content, new_content):
if field == 'persona':
if field == "persona":
if old_content in self.persona:
new_persona = self.persona.replace(old_content, new_content)
return self.edit_persona(new_persona)
else:
raise ValueError('Content not found in persona (make sure to use exact string)')
elif field == 'human':
raise ValueError("Content not found in persona (make sure to use exact string)")
elif field == "human":
if old_content in self.human:
new_human = self.human.replace(old_content, new_content)
return self.edit_human(new_human)
else:
raise ValueError('Content not found in human (make sure to use exact string)')
raise ValueError("Content not found in human (make sure to use exact string)")
else:
raise KeyError
async def summarize_messages(
model,
message_sequence_to_summarize,
):
model,
message_sequence_to_summarize,
):
"""Summarize a message sequence using GPT"""
summary_prompt = SUMMARY_PROMPT_SYSTEM
summary_input = str(message_sequence_to_summarize)
summary_input_tkns = count_tokens(summary_input, model)
if summary_input_tkns > MESSAGE_SUMMARY_WARNING_TOKENS:
trunc_ratio = (MESSAGE_SUMMARY_WARNING_TOKENS / summary_input_tkns) * 0.8 # For good measure...
trunc_ratio = (MESSAGE_SUMMARY_WARNING_TOKENS / summary_input_tkns) * 0.8 # For good measure...
cutoff = int(len(message_sequence_to_summarize) * trunc_ratio)
summary_input = str([await summarize_messages(model, message_sequence_to_summarize[:cutoff])] + message_sequence_to_summarize[cutoff:])
summary_input = str(
[await summarize_messages(model, message_sequence_to_summarize[:cutoff])] + message_sequence_to_summarize[cutoff:]
)
message_sequence = [
{"role": "system", "content": summary_prompt},
{"role": "user", "content": summary_input},
@@ -127,7 +126,6 @@ async def summarize_messages(
class ArchivalMemory(ABC):
@abstractmethod
def insert(self, memory_string):
pass
@@ -150,7 +148,7 @@ class DummyArchivalMemory(ArchivalMemory):
"""
def __init__(self, archival_memory_database=None):
self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
def __len__(self):
return len(self._archive)
@@ -159,31 +157,33 @@ class DummyArchivalMemory(ArchivalMemory):
if len(self._archive) == 0:
memory_str = "<empty>"
else:
memory_str = "\n".join([d['content'] for d in self._archive])
return \
f"\n### ARCHIVAL MEMORY ###" + \
f"\n{memory_str}"
memory_str = "\n".join([d["content"] for d in self._archive])
return f"\n### ARCHIVAL MEMORY ###" + f"\n{memory_str}"
async def insert(self, memory_string, embedding=None):
if embedding is not None:
raise ValueError('Basic text-based archival memory does not support embeddings')
self._archive.append({
# can eventually upgrade to adding semantic tags, etc
'timestamp': get_local_time(),
'content': memory_string,
})
raise ValueError("Basic text-based archival memory does not support embeddings")
self._archive.append(
{
# can eventually upgrade to adding semantic tags, etc
"timestamp": get_local_time(),
"content": memory_string,
}
)
async def search(self, query_string, count=None, start=None):
"""Simple text-based search"""
# in the dummy version, run an (inefficient) case-insensitive match search
# printd(f"query_string: {query_string}")
matches = [s for s in self._archive if query_string.lower() in s['content'].lower()]
matches = [s for s in self._archive if query_string.lower() in s["content"].lower()]
# printd(f"archive_memory.search (text-based): search for query '{query_string}' returned the following results (limit 5):\n{[str(d['content']) d in matches[:5]]}")
printd(f"archive_memory.search (text-based): search for query '{query_string}' returned the following results (limit 5):\n{[matches[start:count]]}")
printd(
f"archive_memory.search (text-based): search for query '{query_string}' returned the following results (limit 5):\n{[matches[start:count]]}"
)
# start/count support paging through results
if start is not None and count is not None:
return matches[start:start+count], len(matches)
return matches[start : start + count], len(matches)
elif start is None and count is not None:
return matches[:count], len(matches)
elif start is not None and count is None:
@@ -195,8 +195,8 @@ class DummyArchivalMemory(ArchivalMemory):
class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory):
"""Same as dummy in-memory archival memory, but with bare-bones embedding support"""
def __init__(self, archival_memory_database=None, embedding_model='text-embedding-ada-002'):
self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
def __init__(self, archival_memory_database=None, embedding_model="text-embedding-ada-002"):
self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
self.embedding_model = embedding_model
def __len__(self):
@@ -206,15 +206,17 @@ class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory):
# Get the embedding
if embedding is None:
embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
embedding_meta = {'model': self.embedding_model}
embedding_meta = {"model": self.embedding_model}
printd(f"Got an embedding, type {type(embedding)}, len {len(embedding)}")
self._archive.append({
'timestamp': get_local_time(),
'content': memory_string,
'embedding': embedding,
'embedding_metadata': embedding_meta,
})
self._archive.append(
{
"timestamp": get_local_time(),
"content": memory_string,
"embedding": embedding,
"embedding_metadata": embedding_meta,
}
)
async def search(self, query_string, count=None, start=None):
"""Simple embedding-based search (inefficient, no caching)"""
@@ -223,22 +225,24 @@ class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory):
# query_embedding = get_embedding(query_string, model=self.embedding_model)
# our wrapped version supports backoff/rate-limits
query_embedding = await async_get_embedding_with_backoff(query_string, model=self.embedding_model)
similarity_scores = [cosine_similarity(memory['embedding'], query_embedding) for memory in self._archive]
similarity_scores = [cosine_similarity(memory["embedding"], query_embedding) for memory in self._archive]
# Sort the archive based on similarity scores
sorted_archive_with_scores = sorted(
zip(self._archive, similarity_scores),
key=lambda pair: pair[1], # Sort by the similarity score
reverse=True # We want the highest similarity first
reverse=True, # We want the highest similarity first
)
printd(
f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results (limit 5) and scores:\n{str([str(t[0]['content']) + '- score ' + str(t[1]) for t in sorted_archive_with_scores[:5]])}"
)
printd(f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results (limit 5) and scores:\n{str([str(t[0]['content']) + '- score ' + str(t[1]) for t in sorted_archive_with_scores[:5]])}")
# Extract the sorted archive without the scores
matches = [item[0] for item in sorted_archive_with_scores]
# start/count support paging through results
if start is not None and count is not None:
return matches[start:start+count], len(matches)
return matches[start : start + count], len(matches)
elif start is None and count is not None:
return matches[:count], len(matches)
elif start is not None and count is None:
@@ -259,13 +263,13 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
is essential enough not to be left only to the recall memory.
"""
def __init__(self, index=None, archival_memory_database=None, embedding_model='text-embedding-ada-002', k=100):
def __init__(self, index=None, archival_memory_database=None, embedding_model="text-embedding-ada-002", k=100):
if index is None:
self.index = faiss.IndexFlatL2(1536) # openai embedding vector size.
self.index = faiss.IndexFlatL2(1536) # openai embedding vector size.
else:
self.index = index
self.k = k
self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
self.embedding_model = embedding_model
self.embeddings_dict = {}
self.search_results = {}
@@ -279,12 +283,14 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
print(f"Got an embedding, type {type(embedding)}, len {len(embedding)}")
self._archive.append({
# can eventually upgrade to adding semantic tags, etc
'timestamp': get_local_time(),
'content': memory_string,
})
embedding = np.array([embedding]).astype('float32')
self._archive.append(
{
# can eventually upgrade to adding semantic tags, etc
"timestamp": get_local_time(),
"content": memory_string,
}
)
embedding = np.array([embedding]).astype("float32")
self.index.add(embedding)
async def search(self, query_string, count=None, start=None):
@@ -304,20 +310,22 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
self.search_results[query_string] = search_result
if start is not None and count is not None:
toprint = search_result[start:start+count]
toprint = search_result[start : start + count]
else:
if len(search_result) >= 5:
toprint = search_result[:5]
else:
toprint = search_result
printd(f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results ({start}--{start+5}/{len(search_result)}) and scores:\n{str([t[:60] if len(t) > 60 else t for t in toprint])}")
printd(
f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results ({start}--{start+5}/{len(search_result)}) and scores:\n{str([t[:60] if len(t) > 60 else t for t in toprint])}"
)
# Extract the sorted archive without the scores
matches = search_result
# start/count support paging through results
if start is not None and count is not None:
return matches[start:start+count], len(matches)
return matches[start : start + count], len(matches)
elif start is None and count is not None:
return matches[:count], len(matches)
elif start is not None and count is None:
@@ -327,7 +335,6 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
class RecallMemory(ABC):
@abstractmethod
def text_search(self, query_string, count=None, start=None):
pass
@@ -365,42 +372,46 @@ class DummyRecallMemory(RecallMemory):
# don't dump all the conversations, just statistics
system_count = user_count = assistant_count = function_count = other_count = 0
for msg in self._message_logs:
role = msg['message']['role']
if role == 'system':
role = msg["message"]["role"]
if role == "system":
system_count += 1
elif role == 'user':
elif role == "user":
user_count += 1
elif role == 'assistant':
elif role == "assistant":
assistant_count += 1
elif role == 'function':
elif role == "function":
function_count += 1
else:
other_count += 1
memory_str = f"Statistics:" + \
f"\n{len(self._message_logs)} total messages" + \
f"\n{system_count} system" + \
f"\n{user_count} user" + \
f"\n{assistant_count} assistant" + \
f"\n{function_count} function" + \
f"\n{other_count} other"
return \
f"\n### RECALL MEMORY ###" + \
f"\n{memory_str}"
memory_str = (
f"Statistics:"
+ f"\n{len(self._message_logs)} total messages"
+ f"\n{system_count} system"
+ f"\n{user_count} user"
+ f"\n{assistant_count} assistant"
+ f"\n{function_count} function"
+ f"\n{other_count} other"
)
return f"\n### RECALL MEMORY ###" + f"\n{memory_str}"
async def insert(self, message):
raise NotImplementedError('This should be handled by the PersistenceManager, recall memory is just a search layer on top')
raise NotImplementedError("This should be handled by the PersistenceManager, recall memory is just a search layer on top")
async def text_search(self, query_string, count=None, start=None):
# in the dummy version, run an (inefficient) case-insensitive match search
message_pool = [d for d in self._message_logs if d['message']['role'] not in ['system', 'function']]
message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
printd(f"recall_memory.text_search: searching for {query_string} (c={count}, s={start}) in {len(self._message_logs)} total messages")
matches = [d for d in message_pool if d['message']['content'] is not None and query_string.lower() in d['message']['content'].lower()]
printd(
f"recall_memory.text_search: searching for {query_string} (c={count}, s={start}) in {len(self._message_logs)} total messages"
)
matches = [
d for d in message_pool if d["message"]["content"] is not None and query_string.lower() in d["message"]["content"].lower()
]
printd(f"recall_memory - matches:\n{matches[start:start+count]}")
# start/count support paging through results
if start is not None and count is not None:
return matches[start:start+count], len(matches)
return matches[start : start + count], len(matches)
elif start is None and count is not None:
return matches[:count], len(matches)
elif start is not None and count is None:
@@ -411,7 +422,7 @@ class DummyRecallMemory(RecallMemory):
def _validate_date_format(self, date_str):
"""Validate the given date string in the format 'YYYY-MM-DD'."""
try:
datetime.datetime.strptime(date_str, '%Y-%m-%d')
datetime.datetime.strptime(date_str, "%Y-%m-%d")
return True
except ValueError:
return False
@@ -423,25 +434,26 @@ class DummyRecallMemory(RecallMemory):
return match.group(1) if match else None
async def date_search(self, start_date, end_date, count=None, start=None):
message_pool = [d for d in self._message_logs if d['message']['role'] not in ['system', 'function']]
message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
# First, validate the start_date and end_date format
if not self._validate_date_format(start_date) or not self._validate_date_format(end_date):
raise ValueError("Invalid date format. Expected format: YYYY-MM-DD")
# Convert dates to datetime objects for comparison
start_date_dt = datetime.datetime.strptime(start_date, '%Y-%m-%d')
end_date_dt = datetime.datetime.strptime(end_date, '%Y-%m-%d')
start_date_dt = datetime.datetime.strptime(start_date, "%Y-%m-%d")
end_date_dt = datetime.datetime.strptime(end_date, "%Y-%m-%d")
# Next, match items inside self._message_logs
matches = [
d for d in message_pool
if start_date_dt <= datetime.datetime.strptime(self._extract_date_from_timestamp(d['timestamp']), '%Y-%m-%d') <= end_date_dt
d
for d in message_pool
if start_date_dt <= datetime.datetime.strptime(self._extract_date_from_timestamp(d["timestamp"]), "%Y-%m-%d") <= end_date_dt
]
# start/count support paging through results
if start is not None and count is not None:
return matches[start:start+count], len(matches)
return matches[start : start + count], len(matches)
elif start is None and count is not None:
return matches[:count], len(matches)
elif start is not None and count is None:
@@ -456,17 +468,17 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.embeddings = dict()
self.embedding_model = 'text-embedding-ada-002'
self.embedding_model = "text-embedding-ada-002"
self.only_use_preloaded_embeddings = False
async def text_search(self, query_string, count=None, start=None):
# in the dummy version, run an (inefficient) case-insensitive match search
message_pool = [d for d in self._message_logs if d['message']['role'] not in ['system', 'function']]
message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
# first, go through and make sure we have all the embeddings we need
message_pool_filtered = []
for d in message_pool:
message_str = d['message']['content']
message_str = d["message"]["content"]
if self.only_use_preloaded_embeddings:
if message_str not in self.embeddings:
printd(f"recall_memory.text_search -- '{message_str}' was not in embedding dict, skipping.")
@@ -477,24 +489,26 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory):
self.embeddings[message_str] = await async_get_embedding_with_backoff(message_str, model=self.embedding_model)
message_pool_filtered.append(d)
# our wrapped version supports backoff/rate-limits
# our wrapped version supports backoff/rate-limits
query_embedding = await async_get_embedding_with_backoff(query_string, model=self.embedding_model)
similarity_scores = [cosine_similarity(self.embeddings[d['message']['content']], query_embedding) for d in message_pool_filtered]
similarity_scores = [cosine_similarity(self.embeddings[d["message"]["content"]], query_embedding) for d in message_pool_filtered]
# Sort the archive based on similarity scores
sorted_archive_with_scores = sorted(
zip(message_pool_filtered, similarity_scores),
key=lambda pair: pair[1], # Sort by the similarity score
reverse=True # We want the highest similarity first
reverse=True, # We want the highest similarity first
)
printd(
f"recall_memory.text_search (vector-based): search for query '{query_string}' returned the following results (limit 5) and scores:\n{str([str(t[0]['message']['content']) + '- score ' + str(t[1]) for t in sorted_archive_with_scores[:5]])}"
)
printd(f"recall_memory.text_search (vector-based): search for query '{query_string}' returned the following results (limit 5) and scores:\n{str([str(t[0]['message']['content']) + '- score ' + str(t[1]) for t in sorted_archive_with_scores[:5]])}")
# Extract the sorted archive without the scores
matches = [item[0] for item in sorted_archive_with_scores]
# start/count support paging through results
if start is not None and count is not None:
return matches[start:start+count], len(matches)
return matches[start : start + count], len(matches)
elif start is None and count is not None:
return matches[:count], len(matches)
elif start is not None and count is None:

View File

@@ -41,9 +41,7 @@ def retry_with_exponential_backoff(
# Check if max retries has been reached
if num_retries > max_retries:
raise Exception(
f"Maximum number of retries ({max_retries}) exceeded."
)
raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")
# Increment the delay
delay *= exponential_base * (1 + jitter * random.random())
@@ -91,9 +89,7 @@ def aretry_with_exponential_backoff(
# Check if max retries has been reached
if num_retries > max_retries:
raise Exception(
f"Maximum number of retries ({max_retries}) exceeded."
)
raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")
# Increment the delay
delay *= exponential_base * (1 + jitter * random.random())
@@ -184,9 +180,7 @@ def configure_azure_support():
azure_openai_endpoint,
azure_openai_version,
]:
print(
f"Error: missing Azure OpenAI environment variables. Please see README section on Azure."
)
print(f"Error: missing Azure OpenAI environment variables. Please see README section on Azure.")
return
openai.api_type = "azure"
@@ -199,10 +193,7 @@ def configure_azure_support():
def check_azure_embeddings():
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
azure_openai_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
if (
azure_openai_deployment is not None
and azure_openai_embedding_deployment is None
):
if azure_openai_deployment is not None and azure_openai_embedding_deployment is None:
raise ValueError(
f"Error: It looks like you are using Azure deployment ids and computing embeddings, make sure you are setting one for embeddings as well. Please see README section on Azure"
)

View File

@@ -1,12 +1,17 @@
from abc import ABC, abstractmethod
import pickle
from .memory import DummyRecallMemory, DummyRecallMemoryWithEmbeddings, DummyArchivalMemory, DummyArchivalMemoryWithEmbeddings, DummyArchivalMemoryWithFaiss
from .memory import (
DummyRecallMemory,
DummyRecallMemoryWithEmbeddings,
DummyArchivalMemory,
DummyArchivalMemoryWithEmbeddings,
DummyArchivalMemoryWithFaiss,
)
from .utils import get_local_time, printd
class PersistenceManager(ABC):
@abstractmethod
def trim_messages(self, num):
pass
@@ -42,17 +47,17 @@ class InMemoryStateManager(PersistenceManager):
@staticmethod
def load(filename):
with open(filename, 'rb') as f:
with open(filename, "rb") as f:
return pickle.load(f)
def save(self, filename):
with open(filename, 'wb') as fh:
with open(filename, "wb") as fh:
pickle.dump(self, fh, protocol=pickle.HIGHEST_PROTOCOL)
def init(self, agent):
printd(f"Initializing InMemoryStateManager with agent object")
self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
self.messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.memory = agent.memory
printd(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
printd(f"InMemoryStateManager.messages.len = {len(self.messages)}")
@@ -68,7 +73,7 @@ class InMemoryStateManager(PersistenceManager):
def prepend_to_messages(self, added_messages):
# first tag with timestamps
added_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in added_messages]
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
printd(f"InMemoryStateManager.prepend_to_message")
self.messages = [self.messages[0]] + added_messages + self.messages[1:]
@@ -76,7 +81,7 @@ class InMemoryStateManager(PersistenceManager):
def append_to_messages(self, added_messages):
# first tag with timestamps
added_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in added_messages]
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
printd(f"InMemoryStateManager.append_to_messages")
self.messages = self.messages + added_messages
@@ -84,7 +89,7 @@ class InMemoryStateManager(PersistenceManager):
def swap_system_message(self, new_system_message):
# first tag with timestamps
new_system_message = {'timestamp': get_local_time(), 'message': new_system_message}
new_system_message = {"timestamp": get_local_time(), "message": new_system_message}
printd(f"InMemoryStateManager.swap_system_message")
self.messages[0] = new_system_message
@@ -104,8 +109,8 @@ class InMemoryStateManagerWithPreloadedArchivalMemory(InMemoryStateManager):
def init(self, agent):
print(f"Initializing InMemoryStateManager with agent object")
self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
self.messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.memory = agent.memory
print(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
print(f"InMemoryStateManager.messages.len = {len(self.messages)}")
@@ -133,12 +138,14 @@ class InMemoryStateManagerWithFaiss(InMemoryStateManager):
def init(self, agent):
print(f"Initializing InMemoryStateManager with agent object")
self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
self.messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.memory = agent.memory
print(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
print(f"InMemoryStateManager.messages.len = {len(self.messages)}")
# Persistence manager also handles DB-related state
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
self.archival_memory = self.archival_memory_cls(index=self.archival_index, archival_memory_database=self.archival_memory_db, k=self.a_k)
self.archival_memory = self.archival_memory_cls(
index=self.archival_index, archival_memory_database=self.archival_memory_db, k=self.a_k
)

View File

@@ -5,15 +5,14 @@ import numpy as np
import argparse
import json
def build_index(embedding_files: str,
index_name: str):
def build_index(embedding_files: str, index_name: str):
index = faiss.IndexFlatL2(1536)
file_list = sorted(glob(embedding_files))
for embedding_file in file_list:
print(embedding_file)
with open(embedding_file, 'rt', encoding='utf-8') as file:
with open(embedding_file, "rt", encoding="utf-8") as file:
embeddings = []
l = 0
for line in tqdm(file):
@@ -21,7 +20,7 @@ def build_index(embedding_files: str,
data = json.loads(line)
embeddings.append(data)
l += 1
data = np.array(embeddings).astype('float32')
data = np.array(embeddings).astype("float32")
print(data.shape)
try:
index.add(data)
@@ -32,14 +31,11 @@ def build_index(embedding_files: str,
faiss.write_index(index, index_name)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--embedding_files', type=str, help='embedding_filepaths glob expression')
parser.add_argument('--output_index_file', type=str, help='output filepath')
args = parser.parse_args()
parser.add_argument("--embedding_files", type=str, help="embedding_filepaths glob expression")
parser.add_argument("--output_index_file", type=str, help="output filepath")
args = parser.parse_args()
build_index(
embedding_files=args.embedding_files,
index_name=args.output_index_file
)
build_index(embedding_files=args.embedding_files, index_name=args.output_index_file)

View File

@@ -7,12 +7,14 @@ import argparse
from tqdm import tqdm
import openai
try:
from dotenv import load_dotenv
load_dotenv()
except ModuleNotFoundError:
pass
openai.api_key = os.getenv('OPENAI_API_KEY')
openai.api_key = os.getenv("OPENAI_API_KEY")
sys.path.append("../../../")
from openai_tools import async_get_embedding_with_backoff
@@ -24,8 +26,8 @@ from openai_parallel_request_processor import process_api_requests_from_file
TPM_LIMIT = 1000000
RPM_LIMIT = 3000
DEFAULT_FILE = 'iclr/data/qa_data/30_total_documents/nq-open-30_total_documents_gold_at_0.jsonl.gz'
EMBEDDING_MODEL = 'text-embedding-ada-002'
DEFAULT_FILE = "iclr/data/qa_data/30_total_documents/nq-open-30_total_documents_gold_at_0.jsonl.gz"
EMBEDDING_MODEL = "text-embedding-ada-002"
async def generate_requests_file(filename):
@@ -33,36 +35,33 @@ async def generate_requests_file(filename):
base_name = os.path.splitext(filename)[0]
requests_filename = f"{base_name}_embedding_requests.jsonl"
with open(filename, 'r') as f:
with open(filename, "r") as f:
all_data = [json.loads(line) for line in f]
with open(requests_filename, 'w') as f:
with open(requests_filename, "w") as f:
for data in all_data:
documents = data
for idx, doc in enumerate(documents):
title = doc["title"]
text = doc["text"]
document_string = f"Document [{idx+1}] (Title: {title}) {text}"
request = {
"model": EMBEDDING_MODEL,
"input": document_string
}
request = {"model": EMBEDDING_MODEL, "input": document_string}
json_string = json.dumps(request)
f.write(json_string + "\n")
# Run your parallel processing function
input(f"Generated requests file ({requests_filename}), continue with embedding batch requests? (hit enter)")
await process_api_requests_from_file(
requests_filepath=requests_filename,
save_filepath=f"{base_name}.embeddings.jsonl.gz", # Adjust as necessary
request_url="https://api.openai.com/v1/embeddings",
api_key=os.getenv('OPENAI_API_KEY'),
max_requests_per_minute=RPM_LIMIT,
max_tokens_per_minute=TPM_LIMIT,
token_encoding_name=EMBEDDING_MODEL,
max_attempts=5,
logging_level=logging.INFO,
)
requests_filepath=requests_filename,
save_filepath=f"{base_name}.embeddings.jsonl.gz", # Adjust as necessary
request_url="https://api.openai.com/v1/embeddings",
api_key=os.getenv("OPENAI_API_KEY"),
max_requests_per_minute=RPM_LIMIT,
max_tokens_per_minute=TPM_LIMIT,
token_encoding_name=EMBEDDING_MODEL,
max_attempts=5,
logging_level=logging.INFO,
)
async def generate_embedding_file(filename, parallel_mode=False):
@@ -72,7 +71,7 @@ async def generate_embedding_file(filename, parallel_mode=False):
# Derive the sister filename
# base_name = os.path.splitext(filename)[0]
base_name = filename.rsplit('.jsonl', 1)[0]
base_name = filename.rsplit(".jsonl", 1)[0]
sister_filename = f"{base_name}.embeddings.jsonl"
# Check if the sister file already exists
@@ -80,7 +79,7 @@ async def generate_embedding_file(filename, parallel_mode=False):
print(f"{sister_filename} already exists. Skipping embedding generation.")
return
with open(filename, 'rt') as f:
with open(filename, "rt") as f:
all_data = [json.loads(line) for line in f]
embedding_data = []
@@ -90,7 +89,9 @@ async def generate_embedding_file(filename, parallel_mode=False):
for i, data in enumerate(tqdm(all_data, desc="Processing data", total=len(all_data))):
documents = data
# Inner loop progress bar
for idx, doc in enumerate(tqdm(documents, desc=f"Embedding documents for data {i+1}/{len(all_data)}", total=len(documents), leave=False)):
for idx, doc in enumerate(
tqdm(documents, desc=f"Embedding documents for data {i+1}/{len(all_data)}", total=len(documents), leave=False)
):
title = doc["title"]
text = doc["text"]
document_string = f"[Title: {title}] {text}"
@@ -103,10 +104,10 @@ async def generate_embedding_file(filename, parallel_mode=False):
# Save the embeddings to the sister file
# with gzip.open(sister_filename, 'wt') as f:
with open(sister_filename, 'wb') as f:
with open(sister_filename, "wb") as f:
for embedding in embedding_data:
# f.write(json.dumps(embedding) + '\n')
f.write((json.dumps(embedding) + '\n').encode('utf-8'))
f.write((json.dumps(embedding) + "\n").encode("utf-8"))
print(f"Embeddings saved to {sister_filename}")
@@ -118,6 +119,7 @@ async def main():
filename = DEFAULT_FILE
await generate_embedding_file(filename)
async def main():
parser = argparse.ArgumentParser()
parser.add_argument("filename", nargs="?", default=DEFAULT_FILE, help="Path to the input file")
@@ -129,4 +131,4 @@ async def main():
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
loop.run_until_complete(main())

View File

@@ -121,9 +121,7 @@ async def process_api_requests_from_file(
"""Processes API requests in parallel, throttling to stay under rate limits."""
# constants
seconds_to_pause_after_rate_limit_error = 15
seconds_to_sleep_each_loop = (
0.001 # 1 ms limits max throughput to 1,000 requests per second
)
seconds_to_sleep_each_loop = 0.001 # 1 ms limits max throughput to 1,000 requests per second
# initialize logging
logging.basicConfig(level=logging_level)
@@ -135,12 +133,8 @@ async def process_api_requests_from_file(
# initialize trackers
queue_of_requests_to_retry = asyncio.Queue()
task_id_generator = (
task_id_generator_function()
) # generates integer IDs of 1, 2, 3, ...
status_tracker = (
StatusTracker()
) # single instance to track a collection of variables
task_id_generator = task_id_generator_function() # generates integer IDs of 1, 2, 3, ...
status_tracker = StatusTracker() # single instance to track a collection of variables
next_request = None # variable to hold the next request to call
# initialize available capacity counts
@@ -163,9 +157,7 @@ async def process_api_requests_from_file(
if next_request is None:
if not queue_of_requests_to_retry.empty():
next_request = queue_of_requests_to_retry.get_nowait()
logging.debug(
f"Retrying request {next_request.task_id}: {next_request}"
)
logging.debug(f"Retrying request {next_request.task_id}: {next_request}")
elif file_not_finished:
try:
# get new request
@@ -173,17 +165,13 @@ async def process_api_requests_from_file(
next_request = APIRequest(
task_id=next(task_id_generator),
request_json=request_json,
token_consumption=num_tokens_consumed_from_request(
request_json, api_endpoint, token_encoding_name
),
token_consumption=num_tokens_consumed_from_request(request_json, api_endpoint, token_encoding_name),
attempts_left=max_attempts,
metadata=request_json.pop("metadata", None),
)
status_tracker.num_tasks_started += 1
status_tracker.num_tasks_in_progress += 1
logging.debug(
f"Reading request {next_request.task_id}: {next_request}"
)
logging.debug(f"Reading request {next_request.task_id}: {next_request}")
except StopIteration:
# if file runs out, set flag to stop reading it
logging.debug("Read file exhausted")
@@ -193,13 +181,11 @@ async def process_api_requests_from_file(
current_time = time.time()
seconds_since_update = current_time - last_update_time
available_request_capacity = min(
available_request_capacity
+ max_requests_per_minute * seconds_since_update / 60.0,
available_request_capacity + max_requests_per_minute * seconds_since_update / 60.0,
max_requests_per_minute,
)
available_token_capacity = min(
available_token_capacity
+ max_tokens_per_minute * seconds_since_update / 60.0,
available_token_capacity + max_tokens_per_minute * seconds_since_update / 60.0,
max_tokens_per_minute,
)
last_update_time = current_time
@@ -207,10 +193,7 @@ async def process_api_requests_from_file(
# if enough capacity available, call API
if next_request:
next_request_tokens = next_request.token_consumption
if (
available_request_capacity >= 1
and available_token_capacity >= next_request_tokens
):
if available_request_capacity >= 1 and available_token_capacity >= next_request_tokens:
# update counters
available_request_capacity -= 1
available_token_capacity -= next_request_tokens
@@ -237,17 +220,9 @@ async def process_api_requests_from_file(
await asyncio.sleep(seconds_to_sleep_each_loop)
# if a rate limit error was hit recently, pause to cool down
seconds_since_rate_limit_error = (
time.time() - status_tracker.time_of_last_rate_limit_error
)
if (
seconds_since_rate_limit_error
< seconds_to_pause_after_rate_limit_error
):
remaining_seconds_to_pause = (
seconds_to_pause_after_rate_limit_error
- seconds_since_rate_limit_error
)
seconds_since_rate_limit_error = time.time() - status_tracker.time_of_last_rate_limit_error
if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error:
remaining_seconds_to_pause = seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error
await asyncio.sleep(remaining_seconds_to_pause)
# ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago
logging.warn(
@@ -255,17 +230,13 @@ async def process_api_requests_from_file(
)
# after finishing, log final status
logging.info(
f"""Parallel processing complete. Results saved to {save_filepath}"""
)
logging.info(f"""Parallel processing complete. Results saved to {save_filepath}""")
if status_tracker.num_tasks_failed > 0:
logging.warning(
f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed. Errors logged to {save_filepath}."
)
if status_tracker.num_rate_limit_errors > 0:
logging.warning(
f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
)
logging.warning(f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate.")
# dataclasses
@@ -309,26 +280,18 @@ class APIRequest:
logging.info(f"Starting request #{self.task_id}")
error = None
try:
async with session.post(
url=request_url, headers=request_header, json=self.request_json
) as response:
async with session.post(url=request_url, headers=request_header, json=self.request_json) as response:
response = await response.json()
if "error" in response:
logging.warning(
f"Request {self.task_id} failed with error {response['error']}"
)
logging.warning(f"Request {self.task_id} failed with error {response['error']}")
status_tracker.num_api_errors += 1
error = response
if "Rate limit" in response["error"].get("message", ""):
status_tracker.time_of_last_rate_limit_error = time.time()
status_tracker.num_rate_limit_errors += 1
status_tracker.num_api_errors -= (
1 # rate limit errors are counted separately
)
status_tracker.num_api_errors -= 1 # rate limit errors are counted separately
except (
Exception
) as e: # catching naked exceptions is bad practice, but in this case we'll log & save them
except Exception as e: # catching naked exceptions is bad practice, but in this case we'll log & save them
logging.warning(f"Request {self.task_id} failed with Exception {e}")
status_tracker.num_other_errors += 1
error = e
@@ -337,9 +300,7 @@ class APIRequest:
if self.attempts_left:
retry_queue.put_nowait(self)
else:
logging.error(
f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}"
)
logging.error(f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}")
data = (
[self.request_json, [str(e) for e in self.result], self.metadata]
if self.metadata
@@ -349,11 +310,7 @@ class APIRequest:
status_tracker.num_tasks_in_progress -= 1
status_tracker.num_tasks_failed += 1
else:
data = (
[self.request_json, response, self.metadata]
if self.metadata
else [self.request_json, response]
)
data = [self.request_json, response, self.metadata] if self.metadata else [self.request_json, response]
append_to_jsonl(data, save_filepath)
status_tracker.num_tasks_in_progress -= 1
status_tracker.num_tasks_succeeded += 1
@@ -382,8 +339,8 @@ def num_tokens_consumed_from_request(
token_encoding_name: str,
):
"""Count the number of tokens in the request. Only supports completion and embedding requests."""
if token_encoding_name == 'text-embedding-ada-002':
encoding = tiktoken.get_encoding('cl100k_base')
if token_encoding_name == "text-embedding-ada-002":
encoding = tiktoken.get_encoding("cl100k_base")
else:
encoding = tiktoken.get_encoding(token_encoding_name)
# if completions request, tokens = prompt + n * max_tokens
@@ -415,9 +372,7 @@ def num_tokens_consumed_from_request(
num_tokens = prompt_tokens + completion_tokens * len(prompt)
return num_tokens
else:
raise TypeError(
'Expecting either string or list of strings for "prompt" field in completion request'
)
raise TypeError('Expecting either string or list of strings for "prompt" field in completion request')
# if embeddings request, tokens = input tokens
elif api_endpoint == "embeddings":
input = request_json["input"]
@@ -428,14 +383,10 @@ def num_tokens_consumed_from_request(
num_tokens = sum([len(encoding.encode(i)) for i in input])
return num_tokens
else:
raise TypeError(
'Expecting either string or list of strings for "inputs" field in embedding request'
)
raise TypeError('Expecting either string or list of strings for "inputs" field in embedding request')
# more logic needed to support other API calls (e.g., edits, inserts, DALL-E)
else:
raise NotImplementedError(
f'API endpoint "{api_endpoint}" not implemented in this script'
)
raise NotImplementedError(f'API endpoint "{api_endpoint}" not implemented in this script')
def task_id_generator_function():
@@ -502,4 +453,4 @@ with open(filename, "w") as f:
```
As with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically).
"""
"""

View File

@@ -4,69 +4,65 @@ import tiktoken
import json
# Define the directory where the documentation resides
docs_dir = 'text'
docs_dir = "text"
encoding = tiktoken.encoding_for_model("gpt-4")
PASSAGE_TOKEN_LEN = 800
def extract_text_from_sphinx_txt(file_path):
lines = []
title = ""
with open(file_path, 'r', encoding='utf-8') as file:
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
if not title:
title = line.strip()
continue
if line and re.match(r'^.*\S.*$', line) and not re.match(r'^[-=*]+$', line):
if line and re.match(r"^.*\S.*$", line) and not re.match(r"^[-=*]+$", line):
lines.append(line)
passages = []
passages = []
curr_passage = []
curr_token_ct = 0
for line in lines:
try:
line_token_ct = len(encoding.encode(line, allowed_special={'<|endoftext|>'}))
line_token_ct = len(encoding.encode(line, allowed_special={"<|endoftext|>"}))
except Exception as e:
print("line", line)
raise e
if line_token_ct > PASSAGE_TOKEN_LEN:
passages.append({
'title': title,
'text': line[:3200],
'num_tokens': curr_token_ct,
})
passages.append(
{
"title": title,
"text": line[:3200],
"num_tokens": curr_token_ct,
}
)
continue
curr_token_ct += line_token_ct
curr_passage.append(line)
if curr_token_ct > PASSAGE_TOKEN_LEN:
passages.append({
'title': title,
'text': ''.join(curr_passage),
'num_tokens': curr_token_ct
})
passages.append({"title": title, "text": "".join(curr_passage), "num_tokens": curr_token_ct})
curr_passage = []
curr_token_ct = 0
if len(curr_passage) > 0:
passages.append({
'title': title,
'text': ''.join(curr_passage),
'num_tokens': curr_token_ct
})
passages.append({"title": title, "text": "".join(curr_passage), "num_tokens": curr_token_ct})
return passages
# Iterate over all files in the directory and its subdirectories
passages = []
total_files = 0
for subdir, _, files in os.walk(docs_dir):
for file in files:
if file.endswith('.txt'):
if file.endswith(".txt"):
file_path = os.path.join(subdir, file)
passages.append(extract_text_from_sphinx_txt(file_path))
total_files += 1
print("total .txt files:", total_files)
# Save to a new text file or process as needed
with open('all_docs.jsonl', 'w', encoding='utf-8') as file:
with open("all_docs.jsonl", "w", encoding="utf-8") as file:
for p in passages:
file.write(json.dumps(p))
file.write('\n')
file.write("\n")

View File

@@ -1,30 +1,33 @@
from .prompts import gpt_functions
from .prompts import gpt_system
from .agent import AgentAsync
from .utils import printd
DEFAULT = 'memgpt_chat'
DEFAULT = "memgpt_chat"
def use_preset(preset_name, model, persona, human, interface, persistence_manager):
"""Storing combinations of SYSTEM + FUNCTION prompts"""
if preset_name == 'memgpt_chat':
if preset_name == "memgpt_chat":
functions = [
'send_message', 'pause_heartbeats',
'core_memory_append', 'core_memory_replace',
'conversation_search', 'conversation_search_date',
'archival_memory_insert', 'archival_memory_search',
"send_message",
"pause_heartbeats",
"core_memory_append",
"core_memory_replace",
"conversation_search",
"conversation_search_date",
"archival_memory_insert",
"archival_memory_search",
]
available_functions = [v for k,v in gpt_functions.FUNCTIONS_CHAINING.items() if k in functions]
printd(f"Available functions:\n", [x['name'] for x in available_functions])
available_functions = [v for k, v in gpt_functions.FUNCTIONS_CHAINING.items() if k in functions]
printd(f"Available functions:\n", [x["name"] for x in available_functions])
assert len(functions) == len(available_functions)
if 'gpt-3.5' in model:
if "gpt-3.5" in model:
# use a different system message for gpt-3.5
preset_name = 'memgpt_gpt35_extralong'
preset_name = "memgpt_gpt35_extralong"
return AgentAsync(
model=model,
@@ -35,8 +38,8 @@ def use_preset(preset_name, model, persona, human, interface, persistence_manage
persona_notes=persona,
human_notes=human,
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
first_message_verify_mono=True if 'gpt-4' in model else False,
first_message_verify_mono=True if "gpt-4" in model else False,
)
else:
raise ValueError(preset_name)
raise ValueError(preset_name)

View File

@@ -2,9 +2,7 @@ from ..constants import FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT
# FUNCTIONS_PROMPT_MULTISTEP_NO_HEARTBEATS = FUNCTIONS_PROMPT_MULTISTEP[:-1]
FUNCTIONS_CHAINING = {
'send_message':
{
"send_message": {
"name": "send_message",
"description": "Sends a message to the human user",
"parameters": {
@@ -17,11 +15,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["message"],
}
},
},
'pause_heartbeats':
{
"pause_heartbeats": {
"name": "pause_heartbeats",
"description": "Temporarily ignore timed heartbeats. You may still receive messages from manual heartbeats and other events.",
"parameters": {
@@ -34,11 +30,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["minutes"],
}
},
},
'message_chatgpt':
{
"message_chatgpt": {
"name": "message_chatgpt",
"description": "Send a message to a more basic AI, ChatGPT. A useful resource for asking questions. ChatGPT does not retain memory of previous interactions.",
"parameters": {
@@ -55,11 +49,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["message", "request_heartbeat"],
}
},
},
'core_memory_append':
{
"core_memory_append": {
"name": "core_memory_append",
"description": "Append to the contents of core memory.",
"parameters": {
@@ -79,11 +71,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["name", "content", "request_heartbeat"],
}
},
},
'core_memory_replace':
{
"core_memory_replace": {
"name": "core_memory_replace",
"description": "Replace to the contents of core memory. To delete memories, use an empty string for new_content.",
"parameters": {
@@ -107,11 +97,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["name", "old_content", "new_content", "request_heartbeat"],
}
},
},
'recall_memory_search':
{
"recall_memory_search": {
"name": "recall_memory_search",
"description": "Search prior conversation history using a string.",
"parameters": {
@@ -131,11 +119,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["name", "page", "request_heartbeat"],
}
},
},
'conversation_search':
{
"conversation_search": {
"name": "conversation_search",
"description": "Search prior conversation history using case-insensitive string matching.",
"parameters": {
@@ -155,11 +141,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["name", "page", "request_heartbeat"],
}
},
},
'recall_memory_search_date':
{
"recall_memory_search_date": {
"name": "recall_memory_search_date",
"description": "Search prior conversation history using a date range.",
"parameters": {
@@ -183,11 +167,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["name", "page", "request_heartbeat"],
}
},
},
'conversation_search_date':
{
"conversation_search_date": {
"name": "conversation_search_date",
"description": "Search prior conversation history using a date range.",
"parameters": {
@@ -211,11 +193,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["name", "page", "request_heartbeat"],
}
},
},
'archival_memory_insert':
{
"archival_memory_insert": {
"name": "archival_memory_insert",
"description": "Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.",
"parameters": {
@@ -231,11 +211,9 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["name", "content", "request_heartbeat"],
}
},
},
'archival_memory_search':
{
"archival_memory_search": {
"name": "archival_memory_search",
"description": "Search archival memory using semantic (embedding-based) search.",
"parameters": {
@@ -255,7 +233,6 @@ FUNCTIONS_CHAINING = {
},
},
"required": ["name", "query", "page", "request_heartbeat"],
}
},
},
}
}

View File

@@ -1,6 +1,5 @@
WORD_LIMIT = 100
SYSTEM = \
f"""
SYSTEM = f"""
Your job is to summarize a history of previous messages in a conversation between an AI persona and a human.
The conversation you are given is a from a fixed context window and may not be complete.
Messages sent by the AI are marked with the 'assistant' role.
@@ -12,4 +11,4 @@ The 'user' role is also used for important system events, such as login events a
Summarize what happened in the conversation from the perspective of the AI (use the first person).
Keep your summary less than {WORD_LIMIT} words, do NOT exceed this word limit.
Only output the summary, do NOT include anything else in your output.
"""
"""

View File

@@ -2,11 +2,11 @@ import os
def get_system_text(key):
filename = f'{key}.txt'
file_path = os.path.join(os.path.dirname(__file__), 'system', filename)
filename = f"{key}.txt"
file_path = os.path.join(os.path.dirname(__file__), "system", filename)
if os.path.exists(file_path):
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
return file.read().strip()
else:
raise FileNotFoundError(f"No file found for key {key}, path={file_path}")

View File

@@ -1,18 +1,22 @@
import json
from .utils import get_local_time
from .constants import INITIAL_BOOT_MESSAGE, INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT, INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG, MESSAGE_SUMMARY_WARNING_STR
from .constants import (
INITIAL_BOOT_MESSAGE,
INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT,
INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG,
MESSAGE_SUMMARY_WARNING_STR,
)
def get_initial_boot_messages(version='startup'):
if version == 'startup':
def get_initial_boot_messages(version="startup"):
if version == "startup":
initial_boot_message = INITIAL_BOOT_MESSAGE
messages = [
{"role": "assistant", "content": initial_boot_message},
]
elif version == 'startup_with_send_message':
elif version == "startup_with_send_message":
messages = [
# first message includes both inner monologue and function call to send_message
{
@@ -20,34 +24,23 @@ def get_initial_boot_messages(version='startup'):
"content": INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT,
"function_call": {
"name": "send_message",
"arguments": "{\n \"message\": \"" + f"{INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG}" + "\"\n}"
}
"arguments": '{\n "message": "' + f"{INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG}" + '"\n}',
},
},
# obligatory function return message
{
"role": "function",
"name": "send_message",
"content": package_function_response(True, None)
}
{"role": "function", "name": "send_message", "content": package_function_response(True, None)},
]
elif version == 'startup_with_send_message_gpt35':
elif version == "startup_with_send_message_gpt35":
messages = [
# first message includes both inner monologue and function call to send_message
{
"role": "assistant",
"content": "*inner thoughts* Still waiting on the user. Sending a message with function.",
"function_call": {
"name": "send_message",
"arguments": "{\n \"message\": \"" + f"Hi, is anyone there?" + "\"\n}"
}
"function_call": {"name": "send_message", "arguments": '{\n "message": "' + f"Hi, is anyone there?" + '"\n}'},
},
# obligatory function return message
{
"role": "function",
"name": "send_message",
"content": package_function_response(True, None)
}
{"role": "function", "name": "send_message", "content": package_function_response(True, None)},
]
else:
@@ -56,12 +49,11 @@ def get_initial_boot_messages(version='startup'):
return messages
def get_heartbeat(reason='Automated timer', include_location=False, location_name='San Francisco, CA, USA'):
def get_heartbeat(reason="Automated timer", include_location=False, location_name="San Francisco, CA, USA"):
# Package the message with time and location
formatted_time = get_local_time()
packaged_message = {
"type": 'heartbeat',
"type": "heartbeat",
"reason": reason,
"time": formatted_time,
}
@@ -72,12 +64,11 @@ def get_heartbeat(reason='Automated timer', include_location=False, location_nam
return json.dumps(packaged_message)
def get_login_event(last_login='Never (first login)', include_location=False, location_name='San Francisco, CA, USA'):
def get_login_event(last_login="Never (first login)", include_location=False, location_name="San Francisco, CA, USA"):
# Package the message with time and location
formatted_time = get_local_time()
packaged_message = {
"type": 'login',
"type": "login",
"last_login": last_login,
"time": formatted_time,
}
@@ -88,12 +79,11 @@ def get_login_event(last_login='Never (first login)', include_location=False, lo
return json.dumps(packaged_message)
def package_user_message(user_message, time=None, include_location=False, location_name='San Francisco, CA, USA'):
def package_user_message(user_message, time=None, include_location=False, location_name="San Francisco, CA, USA"):
# Package the message with time and location
formatted_time = time if time else get_local_time()
packaged_message = {
"type": 'user_message',
"type": "user_message",
"message": user_message,
"time": formatted_time,
}
@@ -103,11 +93,11 @@ def package_user_message(user_message, time=None, include_location=False, locati
return json.dumps(packaged_message)
def package_function_response(was_success, response_string, timestamp=None):
def package_function_response(was_success, response_string, timestamp=None):
formatted_time = get_local_time() if timestamp is None else timestamp
packaged_message = {
"status": 'OK' if was_success else 'Failed',
"status": "OK" if was_success else "Failed",
"message": response_string,
"time": formatted_time,
}
@@ -116,14 +106,14 @@ def package_function_response(was_success, response_string, timestamp=None):
def package_summarize_message(summary, summary_length, hidden_message_count, total_message_count, timestamp=None):
context_message = \
f"Note: prior messages ({hidden_message_count} of {total_message_count} total messages) have been hidden from view due to conversation memory constraints.\n" \
context_message = (
f"Note: prior messages ({hidden_message_count} of {total_message_count} total messages) have been hidden from view due to conversation memory constraints.\n"
+ f"The following is a summary of the previous {summary_length} messages:\n {summary}"
)
formatted_time = get_local_time() if timestamp is None else timestamp
packaged_message = {
"type": 'system_alert',
"type": "system_alert",
"message": context_message,
"time": formatted_time,
}
@@ -136,10 +126,13 @@ def package_summarize_message_no_summary(hidden_message_count, timestamp=None, m
# Package the message with time and location
formatted_time = get_local_time() if timestamp is None else timestamp
context_message = message if message else \
f"Note: {hidden_message_count} prior messages with the user have been hidden from view due to conversation memory constraints. Older messages are stored in Recall Memory and can be viewed using functions."
context_message = (
message
if message
else f"Note: {hidden_message_count} prior messages with the user have been hidden from view due to conversation memory constraints. Older messages are stored in Recall Memory and can be viewed using functions."
)
packaged_message = {
"type": 'system_alert',
"type": "system_alert",
"message": context_message,
"time": formatted_time,
}
@@ -148,12 +141,11 @@ def package_summarize_message_no_summary(hidden_message_count, timestamp=None, m
def get_token_limit_warning():
formatted_time = get_local_time()
packaged_message = {
"type": 'system_alert',
"type": "system_alert",
"message": MESSAGE_SUMMARY_WARNING_STR,
"time": formatted_time,
}
return json.dumps(packaged_message)
return json.dumps(packaged_message)

View File

@@ -168,9 +168,7 @@ def chunk_file(file, tkns_per_chunk=300, model="gpt-4"):
line_token_ct = len(encoding.encode(line))
except Exception as e:
line_token_ct = len(line.split(" ")) / 0.75
print(
f"Could not encode line {i}, estimating it to be {line_token_ct} tokens"
)
print(f"Could not encode line {i}, estimating it to be {line_token_ct} tokens")
print(e)
if line_token_ct > tkns_per_chunk:
if len(curr_chunk) > 0:
@@ -194,9 +192,7 @@ def chunk_files(files, tkns_per_chunk=300, model="gpt-4"):
archival_database = []
for file in files:
timestamp = os.path.getmtime(file)
formatted_time = datetime.fromtimestamp(timestamp).strftime(
"%Y-%m-%d %I:%M:%S %p %Z%z"
)
formatted_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
file_stem = file.split("/")[-1]
chunks = [c for c in chunk_file(file, tkns_per_chunk, model)]
for i, chunk in enumerate(chunks):
@@ -243,9 +239,7 @@ async def process_concurrently(archival_database, model, concurrency=10):
# Create a list of tasks for chunks
embedding_data = [0 for _ in archival_database]
tasks = [
bounded_process_chunk(i, chunk) for i, chunk in enumerate(archival_database)
]
tasks = [bounded_process_chunk(i, chunk) for i, chunk in enumerate(archival_database)]
for future in tqdm(
asyncio.as_completed(tasks),
@@ -267,15 +261,12 @@ async def prepare_archival_index_from_files_compute_embeddings(
files = sorted(glob.glob(glob_pattern))
save_dir = os.path.join(
MEMGPT_DIR,
"archival_index_from_files_"
+ get_local_time().replace(" ", "_").replace(":", "_"),
"archival_index_from_files_" + get_local_time().replace(" ", "_").replace(":", "_"),
)
os.makedirs(save_dir, exist_ok=True)
total_tokens = total_bytes(glob_pattern) / 3
price_estimate = total_tokens / 1000 * 0.0001
confirm = input(
f"Computing embeddings over {len(files)} files. This will cost ~${price_estimate:.2f}. Continue? [y/n] "
)
confirm = input(f"Computing embeddings over {len(files)} files. This will cost ~${price_estimate:.2f}. Continue? [y/n] ")
if confirm != "y":
raise Exception("embeddings were not computed")
@@ -291,9 +282,7 @@ async def prepare_archival_index_from_files_compute_embeddings(
archival_storage_file = os.path.join(save_dir, "all_docs.jsonl")
chunks_by_file = chunk_files_for_jsonl(files, tkns_per_chunk, model)
with open(archival_storage_file, "w") as f:
print(
f"Saving archival storage with preloaded files to {archival_storage_file}"
)
print(f"Saving archival storage with preloaded files to {archival_storage_file}")
for c in chunks_by_file:
json.dump(c, f)
f.write("\n")

147
poetry.lock generated
View File

@@ -1,9 +1,10 @@
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand.
[[package]]
name = "aiohttp"
version = "3.8.6"
description = "Async http client/server framework (asyncio)"
category = "main"
optional = false
python-versions = ">=3.6"
files = [
@@ -112,6 +113,7 @@ speedups = ["Brotli", "aiodns", "cchardet"]
name = "aiosignal"
version = "1.3.1"
description = "aiosignal: a list of registered asynchronous callbacks"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@@ -126,6 +128,7 @@ frozenlist = ">=1.1.0"
name = "async-timeout"
version = "4.0.3"
description = "Timeout context manager for asyncio programs"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@@ -137,6 +140,7 @@ files = [
name = "attrs"
version = "23.1.0"
description = "Classes Without Boilerplate"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@@ -151,10 +155,54 @@ docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-
tests = ["attrs[tests-no-zope]", "zope-interface"]
tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
[[package]]
name = "black"
version = "23.10.1"
description = "The uncompromising code formatter."
category = "main"
optional = false
python-versions = ">=3.8"
files = [
{file = "black-23.10.1-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:ec3f8e6234c4e46ff9e16d9ae96f4ef69fa328bb4ad08198c8cee45bb1f08c69"},
{file = "black-23.10.1-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:1b917a2aa020ca600483a7b340c165970b26e9029067f019e3755b56e8dd5916"},
{file = "black-23.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c74de4c77b849e6359c6f01987e94873c707098322b91490d24296f66d067dc"},
{file = "black-23.10.1-cp310-cp310-win_amd64.whl", hash = "sha256:7b4d10b0f016616a0d93d24a448100adf1699712fb7a4efd0e2c32bbb219b173"},
{file = "black-23.10.1-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b15b75fc53a2fbcac8a87d3e20f69874d161beef13954747e053bca7a1ce53a0"},
{file = "black-23.10.1-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:e293e4c2f4a992b980032bbd62df07c1bcff82d6964d6c9496f2cd726e246ace"},
{file = "black-23.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d56124b7a61d092cb52cce34182a5280e160e6aff3137172a68c2c2c4b76bcb"},
{file = "black-23.10.1-cp311-cp311-win_amd64.whl", hash = "sha256:3f157a8945a7b2d424da3335f7ace89c14a3b0625e6593d21139c2d8214d55ce"},
{file = "black-23.10.1-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:cfcce6f0a384d0da692119f2d72d79ed07c7159879d0bb1bb32d2e443382bf3a"},
{file = "black-23.10.1-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:33d40f5b06be80c1bbce17b173cda17994fbad096ce60eb22054da021bf933d1"},
{file = "black-23.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:840015166dbdfbc47992871325799fd2dc0dcf9395e401ada6d88fe11498abad"},
{file = "black-23.10.1-cp38-cp38-win_amd64.whl", hash = "sha256:037e9b4664cafda5f025a1728c50a9e9aedb99a759c89f760bd83730e76ba884"},
{file = "black-23.10.1-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:7cb5936e686e782fddb1c73f8aa6f459e1ad38a6a7b0e54b403f1f05a1507ee9"},
{file = "black-23.10.1-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:7670242e90dc129c539e9ca17665e39a146a761e681805c54fbd86015c7c84f7"},
{file = "black-23.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ed45ac9a613fb52dad3b61c8dea2ec9510bf3108d4db88422bacc7d1ba1243d"},
{file = "black-23.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:6d23d7822140e3fef190734216cefb262521789367fbdc0b3f22af6744058982"},
{file = "black-23.10.1-py3-none-any.whl", hash = "sha256:d431e6739f727bb2e0495df64a6c7a5310758e87505f5f8cde9ff6c0f2d7e4fe"},
{file = "black-23.10.1.tar.gz", hash = "sha256:1f8ce316753428ff68749c65a5f7844631aa18c8679dfd3ca9dc1a289979c258"},
]
[package.dependencies]
click = ">=8.0.0"
mypy-extensions = ">=0.4.3"
packaging = ">=22.0"
pathspec = ">=0.9.0"
platformdirs = ">=2"
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""}
[package.extras]
colorama = ["colorama (>=0.4.3)"]
d = ["aiohttp (>=3.7.4)"]
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
uvloop = ["uvloop (>=0.15.2)"]
[[package]]
name = "certifi"
version = "2023.7.22"
description = "Python package for providing Mozilla's CA Bundle."
category = "main"
optional = false
python-versions = ">=3.6"
files = [
@@ -166,6 +214,7 @@ files = [
name = "charset-normalizer"
version = "3.3.1"
description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
category = "main"
optional = false
python-versions = ">=3.7.0"
files = [
@@ -265,6 +314,7 @@ files = [
name = "click"
version = "8.1.7"
description = "Composable command line interface toolkit"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@@ -279,6 +329,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""}
name = "colorama"
version = "0.4.6"
description = "Cross-platform colored terminal text."
category = "main"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
files = [
@@ -290,6 +341,7 @@ files = [
name = "demjson3"
version = "3.0.6"
description = "encoder, decoder, and lint/validator for JSON (JavaScript Object Notation) compliant with RFC 7159"
category = "main"
optional = false
python-versions = "*"
files = [
@@ -300,6 +352,7 @@ files = [
name = "faiss-cpu"
version = "1.7.4"
description = "A library for efficient similarity search and clustering of dense vectors."
category = "main"
optional = false
python-versions = "*"
files = [
@@ -334,6 +387,7 @@ files = [
name = "frozenlist"
version = "1.4.0"
description = "A list-like structure which implements collections.abc.MutableSequence"
category = "main"
optional = false
python-versions = ">=3.8"
files = [
@@ -404,6 +458,7 @@ files = [
name = "idna"
version = "3.4"
description = "Internationalized Domain Names in Applications (IDNA)"
category = "main"
optional = false
python-versions = ">=3.5"
files = [
@@ -415,6 +470,7 @@ files = [
name = "markdown-it-py"
version = "3.0.0"
description = "Python port of markdown-it. Markdown parsing, done right!"
category = "main"
optional = false
python-versions = ">=3.8"
files = [
@@ -439,6 +495,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
name = "mdurl"
version = "0.1.2"
description = "Markdown URL utilities"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@@ -450,6 +507,7 @@ files = [
name = "multidict"
version = "6.0.4"
description = "multidict implementation"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@@ -529,10 +587,23 @@ files = [
{file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"},
]
[[package]]
name = "mypy-extensions"
version = "1.0.0"
description = "Type system extensions for programs checked with the mypy type checker."
category = "main"
optional = false
python-versions = ">=3.5"
files = [
{file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"},
{file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"},
]
[[package]]
name = "numpy"
version = "1.26.1"
description = "Fundamental package for array computing in Python"
category = "main"
optional = false
python-versions = "<3.13,>=3.9"
files = [
@@ -574,6 +645,7 @@ files = [
name = "openai"
version = "0.28.1"
description = "Python client library for the OpenAI API"
category = "main"
optional = false
python-versions = ">=3.7.1"
files = [
@@ -588,14 +660,55 @@ tqdm = "*"
[package.extras]
datalib = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
dev = ["black (>=21.6b0,<22.0)", "pytest (==6.*)", "pytest-asyncio", "pytest-mock"]
dev = ["black (>=21.6b0,<22.0)", "pytest (>=6.0.0,<7.0.0)", "pytest-asyncio", "pytest-mock"]
embeddings = ["matplotlib", "numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "plotly", "scikit-learn (>=1.0.2)", "scipy", "tenacity (>=8.0.1)"]
wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "wandb"]
[[package]]
name = "packaging"
version = "23.2"
description = "Core utilities for Python packages"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"},
{file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"},
]
[[package]]
name = "pathspec"
version = "0.11.2"
description = "Utility library for gitignore style pattern matching of file paths."
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "pathspec-0.11.2-py3-none-any.whl", hash = "sha256:1d6ed233af05e679efb96b1851550ea95bbb64b7c490b0f5aa52996c11e92a20"},
{file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"},
]
[[package]]
name = "platformdirs"
version = "3.11.0"
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "platformdirs-3.11.0-py3-none-any.whl", hash = "sha256:e9d171d00af68be50e9202731309c4e658fd8bc76f55c11c7dd760d023bda68e"},
{file = "platformdirs-3.11.0.tar.gz", hash = "sha256:cf8ee52a3afdb965072dcc652433e0c7e3e40cf5ea1477cd4b3b1d2eb75495b3"},
]
[package.extras]
docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-autodoc-typehints (>=1.24)"]
test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"]
[[package]]
name = "prompt-toolkit"
version = "3.0.36"
description = "Library for building powerful interactive command lines in Python"
category = "main"
optional = false
python-versions = ">=3.6.2"
files = [
@@ -610,6 +723,7 @@ wcwidth = "*"
name = "pygments"
version = "2.16.1"
description = "Pygments is a syntax highlighting package written in Python."
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@@ -624,6 +738,7 @@ plugins = ["importlib-metadata"]
name = "pymupdf"
version = "1.23.5"
description = "A high performance Python library for data extraction, analysis, conversion & manipulation of PDF (and other) documents."
category = "main"
optional = false
python-versions = ">=3.8"
files = [
@@ -667,6 +782,7 @@ PyMuPDFb = "1.23.5"
name = "pymupdfb"
version = "1.23.5"
description = "MuPDF shared libraries for PyMuPDF."
category = "main"
optional = false
python-versions = ">=3.8"
files = [
@@ -682,6 +798,7 @@ files = [
name = "pytz"
version = "2023.3.post1"
description = "World timezone definitions, modern and historical"
category = "main"
optional = false
python-versions = "*"
files = [
@@ -693,6 +810,7 @@ files = [
name = "questionary"
version = "2.0.1"
description = "Python library to build pretty command line user prompts ⭐️"
category = "main"
optional = false
python-versions = ">=3.8"
files = [
@@ -707,6 +825,7 @@ prompt_toolkit = ">=2.0,<=3.0.36"
name = "regex"
version = "2023.10.3"
description = "Alternative regular expression module, to replace re."
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@@ -804,6 +923,7 @@ files = [
name = "requests"
version = "2.31.0"
description = "Python HTTP for Humans."
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@@ -825,6 +945,7 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
name = "rich"
version = "13.6.0"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
category = "main"
optional = false
python-versions = ">=3.7.0"
files = [
@@ -843,6 +964,7 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"]
name = "shellingham"
version = "1.5.3"
description = "Tool to Detect Surrounding Shell"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@@ -854,6 +976,7 @@ files = [
name = "tiktoken"
version = "0.5.1"
description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models"
category = "main"
optional = false
python-versions = ">=3.8"
files = [
@@ -895,10 +1018,23 @@ requests = ">=2.26.0"
[package.extras]
blobfile = ["blobfile (>=2)"]
[[package]]
name = "tomli"
version = "2.0.1"
description = "A lil' TOML parser"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
]
[[package]]
name = "tqdm"
version = "4.66.1"
description = "Fast, Extensible Progress Meter"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@@ -919,6 +1055,7 @@ telegram = ["requests"]
name = "typer"
version = "0.9.0"
description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
category = "main"
optional = false
python-versions = ">=3.6"
files = [
@@ -943,6 +1080,7 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.
name = "typing-extensions"
version = "4.8.0"
description = "Backported and Experimental Type Hints for Python 3.8+"
category = "main"
optional = false
python-versions = ">=3.8"
files = [
@@ -954,6 +1092,7 @@ files = [
name = "urllib3"
version = "2.0.7"
description = "HTTP library with thread-safe connection pooling, file post, and more."
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@@ -971,6 +1110,7 @@ zstd = ["zstandard (>=0.18.0)"]
name = "wcwidth"
version = "0.2.8"
description = "Measures the displayed width of unicode strings in a terminal"
category = "main"
optional = false
python-versions = "*"
files = [
@@ -982,6 +1122,7 @@ files = [
name = "yarl"
version = "1.9.2"
description = "Yet another URL library"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@@ -1068,4 +1209,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "<3.13,>=3.9"
content-hash = "d5c3f3a0c8149f3ba08cca638c1a5154f2d7d42ddeb4a7ed2d87bc9545841f23"
content-hash = "9c19a9cd0487a85fa947ec3f53e765b47a03b2a1c6ae1a46de95b25a893690b2"

View File

@@ -30,6 +30,7 @@ tiktoken = "^0.5.1"
pymupdf = "^1.23.5"
tqdm = "^4.66.1"
openai = "^0.28.1"
black = "^23.10.1"
[build-system]