add black to poetry and reformat
This commit is contained in:
@@ -1,2 +1,3 @@
|
||||
from .main import app
|
||||
|
||||
app()
|
||||
|
||||
260
memgpt/agent.py
260
memgpt/agent.py
@@ -11,10 +11,15 @@ from .system import get_heartbeat, get_login_event, package_function_response, p
|
||||
from .memory import CoreMemory as Memory, summarize_messages
|
||||
from .openai_tools import acompletions_with_backoff as acreate
|
||||
from .utils import get_local_time, parse_json, united_diff, printd, count_tokens
|
||||
from .constants import \
|
||||
FIRST_MESSAGE_ATTEMPTS, MAX_PAUSE_HEARTBEATS, \
|
||||
MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MESSAGE_SUMMARY_WARNING_TOKENS, \
|
||||
CORE_MEMORY_HUMAN_CHAR_LIMIT, CORE_MEMORY_PERSONA_CHAR_LIMIT
|
||||
from .constants import (
|
||||
FIRST_MESSAGE_ATTEMPTS,
|
||||
MAX_PAUSE_HEARTBEATS,
|
||||
MESSAGE_CHATGPT_FUNCTION_MODEL,
|
||||
MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE,
|
||||
MESSAGE_SUMMARY_WARNING_TOKENS,
|
||||
CORE_MEMORY_HUMAN_CHAR_LIMIT,
|
||||
CORE_MEMORY_PERSONA_CHAR_LIMIT,
|
||||
)
|
||||
|
||||
|
||||
def initialize_memory(ai_notes, human_notes):
|
||||
@@ -28,52 +33,57 @@ def initialize_memory(ai_notes, human_notes):
|
||||
return memory
|
||||
|
||||
|
||||
def construct_system_with_memory(
|
||||
system, memory, memory_edit_timestamp,
|
||||
archival_memory=None, recall_memory=None
|
||||
):
|
||||
full_system_message = "\n".join([
|
||||
system,
|
||||
"\n",
|
||||
f"### Memory [last modified: {memory_edit_timestamp}",
|
||||
f"{len(recall_memory) if recall_memory else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
|
||||
f"{len(archival_memory) if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)",
|
||||
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
|
||||
"<persona>",
|
||||
memory.persona,
|
||||
"</persona>",
|
||||
"<human>",
|
||||
memory.human,
|
||||
"</human>",
|
||||
])
|
||||
def construct_system_with_memory(system, memory, memory_edit_timestamp, archival_memory=None, recall_memory=None):
|
||||
full_system_message = "\n".join(
|
||||
[
|
||||
system,
|
||||
"\n",
|
||||
f"### Memory [last modified: {memory_edit_timestamp}",
|
||||
f"{len(recall_memory) if recall_memory else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
|
||||
f"{len(archival_memory) if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)",
|
||||
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
|
||||
"<persona>",
|
||||
memory.persona,
|
||||
"</persona>",
|
||||
"<human>",
|
||||
memory.human,
|
||||
"</human>",
|
||||
]
|
||||
)
|
||||
return full_system_message
|
||||
|
||||
|
||||
def initialize_message_sequence(
|
||||
model,
|
||||
system,
|
||||
memory,
|
||||
archival_memory=None,
|
||||
recall_memory=None,
|
||||
memory_edit_timestamp=None,
|
||||
include_initial_boot_message=True,
|
||||
):
|
||||
model,
|
||||
system,
|
||||
memory,
|
||||
archival_memory=None,
|
||||
recall_memory=None,
|
||||
memory_edit_timestamp=None,
|
||||
include_initial_boot_message=True,
|
||||
):
|
||||
if memory_edit_timestamp is None:
|
||||
memory_edit_timestamp = get_local_time()
|
||||
|
||||
full_system_message = construct_system_with_memory(system, memory, memory_edit_timestamp, archival_memory=archival_memory, recall_memory=recall_memory)
|
||||
full_system_message = construct_system_with_memory(
|
||||
system, memory, memory_edit_timestamp, archival_memory=archival_memory, recall_memory=recall_memory
|
||||
)
|
||||
first_user_message = get_login_event() # event letting MemGPT know the user just logged in
|
||||
|
||||
if include_initial_boot_message:
|
||||
if 'gpt-3.5' in model:
|
||||
initial_boot_messages = get_initial_boot_messages('startup_with_send_message_gpt35')
|
||||
if "gpt-3.5" in model:
|
||||
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35")
|
||||
else:
|
||||
initial_boot_messages = get_initial_boot_messages('startup_with_send_message')
|
||||
messages = [
|
||||
{"role": "system", "content": full_system_message},
|
||||
] + initial_boot_messages + [
|
||||
{"role": "user", "content": first_user_message},
|
||||
]
|
||||
initial_boot_messages = get_initial_boot_messages("startup_with_send_message")
|
||||
messages = (
|
||||
[
|
||||
{"role": "system", "content": full_system_message},
|
||||
]
|
||||
+ initial_boot_messages
|
||||
+ [
|
||||
{"role": "user", "content": first_user_message},
|
||||
]
|
||||
)
|
||||
|
||||
else:
|
||||
messages = [
|
||||
@@ -85,11 +95,11 @@ def initialize_message_sequence(
|
||||
|
||||
|
||||
async def get_ai_reply_async(
|
||||
model,
|
||||
message_sequence,
|
||||
functions,
|
||||
function_call="auto",
|
||||
):
|
||||
model,
|
||||
message_sequence,
|
||||
functions,
|
||||
function_call="auto",
|
||||
):
|
||||
"""Base call to GPT API w/ functions"""
|
||||
|
||||
try:
|
||||
@@ -101,11 +111,11 @@ async def get_ai_reply_async(
|
||||
)
|
||||
|
||||
# special case for 'length'
|
||||
if response.choices[0].finish_reason == 'length':
|
||||
raise Exception('Finish reason was length (maximum context length)')
|
||||
if response.choices[0].finish_reason == "length":
|
||||
raise Exception("Finish reason was length (maximum context length)")
|
||||
|
||||
# catches for soft errors
|
||||
if response.choices[0].finish_reason not in ['stop', 'function_call']:
|
||||
if response.choices[0].finish_reason not in ["stop", "function_call"]:
|
||||
raise Exception(f"API call finish with bad finish reason: {response}")
|
||||
|
||||
# unpack with response.choices[0].message.content
|
||||
@@ -118,7 +128,19 @@ async def get_ai_reply_async(
|
||||
class AgentAsync(object):
|
||||
"""Core logic for a MemGPT agent"""
|
||||
|
||||
def __init__(self, model, system, functions, interface, persistence_manager, persona_notes, human_notes, messages_total=None, persistence_manager_init=True, first_message_verify_mono=True):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
system,
|
||||
functions,
|
||||
interface,
|
||||
persistence_manager,
|
||||
persona_notes,
|
||||
human_notes,
|
||||
messages_total=None,
|
||||
persistence_manager_init=True,
|
||||
first_message_verify_mono=True,
|
||||
):
|
||||
# gpt-4, gpt-3.5-turbo
|
||||
self.model = model
|
||||
# Store the system instructions (used to rebuild memory)
|
||||
@@ -173,7 +195,7 @@ class AgentAsync(object):
|
||||
|
||||
@messages.setter
|
||||
def messages(self, value):
|
||||
raise Exception('Modifying message list directly not allowed')
|
||||
raise Exception("Modifying message list directly not allowed")
|
||||
|
||||
def trim_messages(self, num):
|
||||
"""Trim messages from the front, not including the system message"""
|
||||
@@ -196,16 +218,16 @@ class AgentAsync(object):
|
||||
|
||||
# strip extra metadata if it exists
|
||||
for msg in added_messages:
|
||||
msg.pop('api_response', None)
|
||||
msg.pop('api_args', None)
|
||||
msg.pop("api_response", None)
|
||||
msg.pop("api_args", None)
|
||||
new_messages = self.messages + added_messages # append
|
||||
|
||||
self._messages = new_messages
|
||||
self.messages_total += len(added_messages)
|
||||
|
||||
def swap_system_message(self, new_system_message):
|
||||
assert new_system_message['role'] == 'system', new_system_message
|
||||
assert self.messages[0]['role'] == 'system', self.messages
|
||||
assert new_system_message["role"] == "system", new_system_message
|
||||
assert self.messages[0]["role"] == "system", self.messages
|
||||
|
||||
self.persistence_manager.swap_system_message(new_system_message)
|
||||
|
||||
@@ -223,7 +245,7 @@ class AgentAsync(object):
|
||||
recall_memory=self.persistence_manager.recall_memory,
|
||||
)[0]
|
||||
|
||||
diff = united_diff(curr_system_message['content'], new_system_message['content'])
|
||||
diff = united_diff(curr_system_message["content"], new_system_message["content"])
|
||||
printd(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
|
||||
# Store the memory change (if stateful)
|
||||
@@ -235,32 +257,32 @@ class AgentAsync(object):
|
||||
### Local state management
|
||||
def to_dict(self):
|
||||
return {
|
||||
'model': self.model,
|
||||
'system': self.system,
|
||||
'functions': self.functions,
|
||||
'messages': self.messages,
|
||||
'messages_total': self.messages_total,
|
||||
'memory': self.memory.to_dict(),
|
||||
"model": self.model,
|
||||
"system": self.system,
|
||||
"functions": self.functions,
|
||||
"messages": self.messages,
|
||||
"messages_total": self.messages_total,
|
||||
"memory": self.memory.to_dict(),
|
||||
}
|
||||
|
||||
def save_to_json_file(self, filename):
|
||||
with open(filename, 'w') as file:
|
||||
with open(filename, "w") as file:
|
||||
json.dump(self.to_dict(), file)
|
||||
|
||||
@classmethod
|
||||
def load(cls, state, interface, persistence_manager):
|
||||
model = state['model']
|
||||
system = state['system']
|
||||
functions = state['functions']
|
||||
messages = state['messages']
|
||||
model = state["model"]
|
||||
system = state["system"]
|
||||
functions = state["functions"]
|
||||
messages = state["messages"]
|
||||
try:
|
||||
messages_total = state['messages_total']
|
||||
messages_total = state["messages_total"]
|
||||
except KeyError:
|
||||
messages_total = len(messages) - 1
|
||||
# memory requires a nested load
|
||||
memory_dict = state['memory']
|
||||
persona_notes = memory_dict['persona']
|
||||
human_notes = memory_dict['human']
|
||||
memory_dict = state["memory"]
|
||||
persona_notes = memory_dict["persona"]
|
||||
human_notes = memory_dict["human"]
|
||||
|
||||
# Two-part load
|
||||
new_agent = cls(
|
||||
@@ -278,18 +300,18 @@ class AgentAsync(object):
|
||||
return new_agent
|
||||
|
||||
def load_inplace(self, state):
|
||||
self.model = state['model']
|
||||
self.system = state['system']
|
||||
self.functions = state['functions']
|
||||
self.model = state["model"]
|
||||
self.system = state["system"]
|
||||
self.functions = state["functions"]
|
||||
# memory requires a nested load
|
||||
memory_dict = state['memory']
|
||||
persona_notes = memory_dict['persona']
|
||||
human_notes = memory_dict['human']
|
||||
memory_dict = state["memory"]
|
||||
persona_notes = memory_dict["persona"]
|
||||
human_notes = memory_dict["human"]
|
||||
self.memory = initialize_memory(persona_notes, human_notes)
|
||||
# messages also
|
||||
self._messages = state['messages']
|
||||
self._messages = state["messages"]
|
||||
try:
|
||||
self.messages_total = state['messages_total']
|
||||
self.messages_total = state["messages_total"]
|
||||
except KeyError:
|
||||
self.messages_total = len(self.messages) - 1 # -system
|
||||
|
||||
@@ -300,14 +322,14 @@ class AgentAsync(object):
|
||||
|
||||
@classmethod
|
||||
def load_from_json_file(cls, json_file, interface, persistence_manager):
|
||||
with open(json_file, 'r') as file:
|
||||
with open(json_file, "r") as file:
|
||||
state = json.load(file)
|
||||
return cls.load(state, interface, persistence_manager)
|
||||
|
||||
def load_from_json_file_inplace(self, json_file):
|
||||
# Load in-place
|
||||
# No interface arg needed, we can use the current one
|
||||
with open(json_file, 'r') as file:
|
||||
with open(json_file, "r") as file:
|
||||
state = json.load(file)
|
||||
self.load_inplace(state)
|
||||
|
||||
@@ -317,7 +339,6 @@ class AgentAsync(object):
|
||||
|
||||
# Step 2: check if LLM wanted to call a function
|
||||
if response_message.get("function_call"):
|
||||
|
||||
# The content if then internal monologue, not chat
|
||||
await self.interface.internal_monologue(response_message.content)
|
||||
messages.append(response_message) # extend conversation with assistant's reply
|
||||
@@ -348,7 +369,7 @@ class AgentAsync(object):
|
||||
try:
|
||||
function_to_call = available_functions[function_name]
|
||||
except KeyError as e:
|
||||
error_msg = f'No function named {function_name}'
|
||||
error_msg = f"No function named {function_name}"
|
||||
function_response = package_function_response(False, error_msg)
|
||||
messages.append(
|
||||
{
|
||||
@@ -357,7 +378,7 @@ class AgentAsync(object):
|
||||
"content": function_response,
|
||||
}
|
||||
) # extend conversation with function response
|
||||
await self.interface.function_message(f'Error: {error_msg}')
|
||||
await self.interface.function_message(f"Error: {error_msg}")
|
||||
return messages, None, True # force a heartbeat to allow agent to handle error
|
||||
|
||||
# Failure case 2: function name is OK, but function args are bad JSON
|
||||
@@ -374,18 +395,20 @@ class AgentAsync(object):
|
||||
"content": function_response,
|
||||
}
|
||||
) # extend conversation with function response
|
||||
await self.interface.function_message(f'Error: {error_msg}')
|
||||
await self.interface.function_message(f"Error: {error_msg}")
|
||||
return messages, None, True # force a heartbeat to allow agent to handle error
|
||||
|
||||
# (Still parsing function args)
|
||||
# Handle requests for immediate heartbeat
|
||||
heartbeat_request = function_args.pop('request_heartbeat', None)
|
||||
heartbeat_request = function_args.pop("request_heartbeat", None)
|
||||
if not (isinstance(heartbeat_request, bool) or heartbeat_request is None):
|
||||
printd(f"Warning: 'request_heartbeat' arg parsed was not a bool or None, type={type(heartbeat_request)}, value={heartbeat_request}")
|
||||
printd(
|
||||
f"Warning: 'request_heartbeat' arg parsed was not a bool or None, type={type(heartbeat_request)}, value={heartbeat_request}"
|
||||
)
|
||||
heartbeat_request = None
|
||||
|
||||
# Failure case 3: function failed during execution
|
||||
await self.interface.function_message(f'Running {function_name}({function_args})')
|
||||
await self.interface.function_message(f"Running {function_name}({function_args})")
|
||||
try:
|
||||
function_response_string = await function_to_call(**function_args)
|
||||
function_response = package_function_response(True, function_response_string)
|
||||
@@ -401,12 +424,12 @@ class AgentAsync(object):
|
||||
"content": function_response,
|
||||
}
|
||||
) # extend conversation with function response
|
||||
await self.interface.function_message(f'Error: {error_msg}')
|
||||
await self.interface.function_message(f"Error: {error_msg}")
|
||||
return messages, None, True # force a heartbeat to allow agent to handle error
|
||||
|
||||
# If no failures happened along the way: ...
|
||||
# Step 4: send the info on the function call and function response to GPT
|
||||
await self.interface.function_message(f'Success: {function_response_string}')
|
||||
await self.interface.function_message(f"Success: {function_response_string}")
|
||||
messages.append(
|
||||
{
|
||||
"role": "function",
|
||||
@@ -434,25 +457,29 @@ class AgentAsync(object):
|
||||
return False
|
||||
|
||||
function_name = response_message["function_call"]["name"]
|
||||
if require_send_message and function_name != 'send_message':
|
||||
if require_send_message and function_name != "send_message":
|
||||
printd(f"First message function call wasn't send_message: {response_message}")
|
||||
return False
|
||||
|
||||
if require_monologue and (not response_message.get("content") or response_message["content"] is None or response_message["content"] == ""):
|
||||
if require_monologue and (
|
||||
not response_message.get("content") or response_message["content"] is None or response_message["content"] == ""
|
||||
):
|
||||
printd(f"First message missing internal monologue: {response_message}")
|
||||
return False
|
||||
|
||||
if response_message.get("content"):
|
||||
### Extras
|
||||
monologue = response_message.get("content")
|
||||
|
||||
def contains_special_characters(s):
|
||||
special_characters = '(){}[]"'
|
||||
return any(char in s for char in special_characters)
|
||||
|
||||
if contains_special_characters(monologue):
|
||||
printd(f"First message internal monologue contained special characters: {response_message}")
|
||||
return False
|
||||
# if 'functions' in monologue or 'send_message' in monologue or 'inner thought' in monologue.lower():
|
||||
if 'functions' in monologue or 'send_message' in monologue:
|
||||
if "functions" in monologue or "send_message" in monologue:
|
||||
# Sometimes the syntax won't be correct and internal syntax will leak into message.context
|
||||
printd(f"First message internal monologue contained reserved words: {response_message}")
|
||||
return False
|
||||
@@ -466,12 +493,12 @@ class AgentAsync(object):
|
||||
# Step 0: add user message
|
||||
if user_message is not None:
|
||||
await self.interface.user_message(user_message)
|
||||
packed_user_message = {'role': 'user', 'content': user_message}
|
||||
packed_user_message = {"role": "user", "content": user_message}
|
||||
input_message_sequence = self.messages + [packed_user_message]
|
||||
else:
|
||||
input_message_sequence = self.messages
|
||||
|
||||
if len(input_message_sequence) > 1 and input_message_sequence[-1]['role'] != 'user':
|
||||
if len(input_message_sequence) > 1 and input_message_sequence[-1]["role"] != "user":
|
||||
printd(f"WARNING: attempting to run ChatCompletion without user as the last message in the queue")
|
||||
|
||||
# Step 1: send the conversation and available functions to GPT
|
||||
@@ -479,14 +506,13 @@ class AgentAsync(object):
|
||||
printd(f"This is the first message. Running extra verifier on AI response.")
|
||||
counter = 0
|
||||
while True:
|
||||
|
||||
response = await get_ai_reply_async(model=self.model, message_sequence=input_message_sequence, functions=self.functions)
|
||||
if self.verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono):
|
||||
break
|
||||
|
||||
counter += 1
|
||||
if counter > first_message_retry_limit:
|
||||
raise Exception(f'Hit first message retry limit ({first_message_retry_limit})')
|
||||
raise Exception(f"Hit first message retry limit ({first_message_retry_limit})")
|
||||
|
||||
else:
|
||||
response = await get_ai_reply_async(model=self.model, message_sequence=input_message_sequence, functions=self.functions)
|
||||
@@ -500,13 +526,13 @@ class AgentAsync(object):
|
||||
|
||||
# Add the extra metadata to the assistant response
|
||||
# (e.g. enough metadata to enable recreating the API call)
|
||||
assert 'api_response' not in all_response_messages[0]
|
||||
all_response_messages[0]['api_response'] = response_message_copy
|
||||
assert 'api_args' not in all_response_messages[0]
|
||||
all_response_messages[0]['api_args'] = {
|
||||
'model': self.model,
|
||||
'messages': input_message_sequence,
|
||||
'functions': self.functions,
|
||||
assert "api_response" not in all_response_messages[0]
|
||||
all_response_messages[0]["api_response"] = response_message_copy
|
||||
assert "api_args" not in all_response_messages[0]
|
||||
all_response_messages[0]["api_args"] = {
|
||||
"model": self.model,
|
||||
"messages": input_message_sequence,
|
||||
"functions": self.functions,
|
||||
}
|
||||
|
||||
# Step 4: extend the message history
|
||||
@@ -516,7 +542,7 @@ class AgentAsync(object):
|
||||
all_new_messages = all_response_messages
|
||||
|
||||
# Check the memory pressure and potentially issue a memory pressure warning
|
||||
current_total_tokens = response['usage']['total_tokens']
|
||||
current_total_tokens = response["usage"]["total_tokens"]
|
||||
active_memory_warning = False
|
||||
if current_total_tokens > MESSAGE_SUMMARY_WARNING_TOKENS:
|
||||
printd(f"WARNING: last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_TOKENS}")
|
||||
@@ -534,7 +560,7 @@ class AgentAsync(object):
|
||||
printd(f"step() failed\nuser_message = {user_message}\nerror = {e}")
|
||||
|
||||
# If we got a context alert, try trimming the messages length, then try again
|
||||
if 'maximum context length' in str(e):
|
||||
if "maximum context length" in str(e):
|
||||
# A separate API call to run a summarizer
|
||||
await self.summarize_messages_inplace()
|
||||
|
||||
@@ -546,21 +572,21 @@ class AgentAsync(object):
|
||||
|
||||
async def summarize_messages_inplace(self, cutoff=None):
|
||||
if cutoff is None:
|
||||
tokens_so_far = 0 # Smart cutoff -- just below the max.
|
||||
tokens_so_far = 0 # Smart cutoff -- just below the max.
|
||||
cutoff = len(self.messages) - 1
|
||||
for m in reversed(self.messages):
|
||||
tokens_so_far += count_tokens(str(m), self.model)
|
||||
if tokens_so_far >= MESSAGE_SUMMARY_WARNING_TOKENS*0.2:
|
||||
if tokens_so_far >= MESSAGE_SUMMARY_WARNING_TOKENS * 0.2:
|
||||
break
|
||||
cutoff -= 1
|
||||
cutoff = min(len(self.messages) - 3, cutoff) # Always keep the last two messages too
|
||||
cutoff = min(len(self.messages) - 3, cutoff) # Always keep the last two messages too
|
||||
|
||||
# Try to make an assistant message come after the cutoff
|
||||
try:
|
||||
printd(f"Selected cutoff {cutoff} was a 'user', shifting one...")
|
||||
if self.messages[cutoff]['role'] == 'user':
|
||||
if self.messages[cutoff]["role"] == "user":
|
||||
new_cutoff = cutoff + 1
|
||||
if self.messages[new_cutoff]['role'] == 'user':
|
||||
if self.messages[new_cutoff]["role"] == "user":
|
||||
printd(f"Shifted cutoff {new_cutoff} is still a 'user', ignoring...")
|
||||
cutoff = new_cutoff
|
||||
except IndexError:
|
||||
@@ -600,11 +626,11 @@ class AgentAsync(object):
|
||||
|
||||
while limit is None or step_count < limit:
|
||||
if function_failed:
|
||||
user_message = get_heartbeat('Function call failed')
|
||||
user_message = get_heartbeat("Function call failed")
|
||||
new_messages, heartbeat_request, function_failed = await self.step(user_message)
|
||||
step_count += 1
|
||||
elif heartbeat_request:
|
||||
user_message = get_heartbeat('AI requested')
|
||||
user_message = get_heartbeat("AI requested")
|
||||
new_messages, heartbeat_request, function_failed = await self.step(user_message)
|
||||
step_count += 1
|
||||
else:
|
||||
@@ -638,7 +664,7 @@ class AgentAsync(object):
|
||||
return None
|
||||
|
||||
async def recall_memory_search(self, query, count=5, page=0):
|
||||
results, total = await self.persistence_manager.recall_memory.text_search(query, count=count, start=page*count)
|
||||
results, total = await self.persistence_manager.recall_memory.text_search(query, count=count, start=page * count)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
@@ -649,7 +675,7 @@ class AgentAsync(object):
|
||||
return results_str
|
||||
|
||||
async def recall_memory_search_date(self, start_date, end_date, count=5, page=0):
|
||||
results, total = await self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page*count)
|
||||
results, total = await self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page * count)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
@@ -664,7 +690,7 @@ class AgentAsync(object):
|
||||
return None
|
||||
|
||||
async def archival_memory_search(self, query, count=5, page=0):
|
||||
results, total = await self.persistence_manager.archival_memory.search(query, count=count, start=page*count)
|
||||
results, total = await self.persistence_manager.archival_memory.search(query, count=count, start=page * count)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
@@ -683,7 +709,7 @@ class AgentAsync(object):
|
||||
# And record how long the pause should go for
|
||||
self.pause_heartbeats_minutes = int(minutes)
|
||||
|
||||
return f'Pausing timed heartbeats for {minutes} min'
|
||||
return f"Pausing timed heartbeats for {minutes} min"
|
||||
|
||||
def heartbeat_is_paused(self):
|
||||
"""Check if there's a requested pause on timed heartbeats"""
|
||||
@@ -700,8 +726,8 @@ class AgentAsync(object):
|
||||
"""Base call to GPT API w/ functions"""
|
||||
|
||||
message_sequence = [
|
||||
{'role': 'system', 'content': MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE},
|
||||
{'role': 'user', 'content': str(message)},
|
||||
{"role": "system", "content": MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE},
|
||||
{"role": "user", "content": str(message)},
|
||||
]
|
||||
response = await acreate(
|
||||
model=MESSAGE_CHATGPT_FUNCTION_MODEL,
|
||||
|
||||
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class AgentAsyncBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def step(self, user_message):
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -68,41 +68,25 @@ class AutoGenInterface(object):
|
||||
print(f"inner thoughts :: {msg}")
|
||||
if not self.show_inner_thoughts:
|
||||
return
|
||||
message = (
|
||||
f"\x1B[3m{Fore.LIGHTBLACK_EX}💭 {msg}{Style.RESET_ALL}"
|
||||
if self.fancy
|
||||
else f"[inner thoughts] {msg}"
|
||||
)
|
||||
message = f"\x1B[3m{Fore.LIGHTBLACK_EX}💭 {msg}{Style.RESET_ALL}" if self.fancy else f"[inner thoughts] {msg}"
|
||||
self.message_list.append(message)
|
||||
|
||||
async def assistant_message(self, msg):
|
||||
if self.debug:
|
||||
print(f"assistant :: {msg}")
|
||||
message = (
|
||||
f"{Fore.YELLOW}{Style.BRIGHT}🤖 {Fore.YELLOW}{msg}{Style.RESET_ALL}"
|
||||
if self.fancy
|
||||
else msg
|
||||
)
|
||||
message = f"{Fore.YELLOW}{Style.BRIGHT}🤖 {Fore.YELLOW}{msg}{Style.RESET_ALL}" if self.fancy else msg
|
||||
self.message_list.append(message)
|
||||
|
||||
async def memory_message(self, msg):
|
||||
if self.debug:
|
||||
print(f"memory :: {msg}")
|
||||
message = (
|
||||
f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{msg}{Style.RESET_ALL}"
|
||||
if self.fancy
|
||||
else f"[memory] {msg}"
|
||||
)
|
||||
message = f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{msg}{Style.RESET_ALL}" if self.fancy else f"[memory] {msg}"
|
||||
self.message_list.append(message)
|
||||
|
||||
async def system_message(self, msg):
|
||||
if self.debug:
|
||||
print(f"system :: {msg}")
|
||||
message = (
|
||||
f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}"
|
||||
if self.fancy
|
||||
else f"[system] {msg}"
|
||||
)
|
||||
message = f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}" if self.fancy else f"[system] {msg}"
|
||||
self.message_list.append(message)
|
||||
|
||||
async def user_message(self, msg, raw=False):
|
||||
@@ -113,11 +97,7 @@ class AutoGenInterface(object):
|
||||
|
||||
if isinstance(msg, str):
|
||||
if raw:
|
||||
message = (
|
||||
f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}"
|
||||
if self.fancy
|
||||
else f"[user] {msg}"
|
||||
)
|
||||
message = f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}" if self.fancy else f"[user] {msg}"
|
||||
self.message_list.append(message)
|
||||
return
|
||||
else:
|
||||
@@ -125,42 +105,24 @@ class AutoGenInterface(object):
|
||||
msg_json = json.loads(msg)
|
||||
except:
|
||||
print(f"Warning: failed to parse user message into json")
|
||||
message = (
|
||||
f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}"
|
||||
if self.fancy
|
||||
else f"[user] {msg}"
|
||||
)
|
||||
message = f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}" if self.fancy else f"[user] {msg}"
|
||||
self.message_list.append(message)
|
||||
return
|
||||
|
||||
if msg_json["type"] == "user_message":
|
||||
msg_json.pop("type")
|
||||
message = (
|
||||
f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}"
|
||||
if self.fancy
|
||||
else f"[user] {msg}"
|
||||
)
|
||||
message = f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}" if self.fancy else f"[user] {msg}"
|
||||
elif msg_json["type"] == "heartbeat":
|
||||
if True or DEBUG:
|
||||
msg_json.pop("type")
|
||||
message = (
|
||||
f"{Fore.GREEN}{Style.BRIGHT}💓 {Fore.GREEN}{msg_json}{Style.RESET_ALL}"
|
||||
if self.fancy
|
||||
else f"[system heartbeat] {msg}"
|
||||
f"{Fore.GREEN}{Style.BRIGHT}💓 {Fore.GREEN}{msg_json}{Style.RESET_ALL}" if self.fancy else f"[system heartbeat] {msg}"
|
||||
)
|
||||
elif msg_json["type"] == "system_message":
|
||||
msg_json.pop("type")
|
||||
message = (
|
||||
f"{Fore.GREEN}{Style.BRIGHT}🖥️ {Fore.GREEN}{msg_json}{Style.RESET_ALL}"
|
||||
if self.fancy
|
||||
else f"[system] {msg}"
|
||||
)
|
||||
message = f"{Fore.GREEN}{Style.BRIGHT}🖥️ {Fore.GREEN}{msg_json}{Style.RESET_ALL}" if self.fancy else f"[system] {msg}"
|
||||
else:
|
||||
message = (
|
||||
f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}"
|
||||
if self.fancy
|
||||
else f"[user] {msg}"
|
||||
)
|
||||
message = f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}" if self.fancy else f"[user] {msg}"
|
||||
|
||||
self.message_list.append(message)
|
||||
|
||||
@@ -171,31 +133,19 @@ class AutoGenInterface(object):
|
||||
return
|
||||
|
||||
if isinstance(msg, dict):
|
||||
message = (
|
||||
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
|
||||
)
|
||||
message = f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
|
||||
self.message_list.append(message)
|
||||
return
|
||||
|
||||
if msg.startswith("Success: "):
|
||||
message = (
|
||||
f"{Fore.RED}{Style.BRIGHT}⚡🟢 [function] {Fore.RED}{msg}{Style.RESET_ALL}"
|
||||
if self.fancy
|
||||
else f"[function - OK] {msg}"
|
||||
)
|
||||
message = f"{Fore.RED}{Style.BRIGHT}⚡🟢 [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function - OK] {msg}"
|
||||
elif msg.startswith("Error: "):
|
||||
message = (
|
||||
f"{Fore.RED}{Style.BRIGHT}⚡🔴 [function] {Fore.RED}{msg}{Style.RESET_ALL}"
|
||||
if self.fancy
|
||||
else f"[function - error] {msg}"
|
||||
f"{Fore.RED}{Style.BRIGHT}⚡🔴 [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function - error] {msg}"
|
||||
)
|
||||
elif msg.startswith("Running "):
|
||||
if DEBUG:
|
||||
message = (
|
||||
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
|
||||
if self.fancy
|
||||
else f"[function] {msg}"
|
||||
)
|
||||
message = f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}"
|
||||
else:
|
||||
if "memory" in msg:
|
||||
match = re.search(r"Running (\w+)\((.*)\)", msg)
|
||||
@@ -227,35 +177,25 @@ class AutoGenInterface(object):
|
||||
else:
|
||||
print(f"Warning: did not recognize function message")
|
||||
message = (
|
||||
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
|
||||
if self.fancy
|
||||
else f"[function] {msg}"
|
||||
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}"
|
||||
)
|
||||
elif "send_message" in msg:
|
||||
# ignore in debug mode
|
||||
message = None
|
||||
else:
|
||||
message = (
|
||||
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
|
||||
if self.fancy
|
||||
else f"[function] {msg}"
|
||||
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
msg_dict = json.loads(msg)
|
||||
if "status" in msg_dict and msg_dict["status"] == "OK":
|
||||
message = (
|
||||
f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}"
|
||||
if self.fancy
|
||||
else f"[function] {msg}"
|
||||
f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}"
|
||||
)
|
||||
except Exception:
|
||||
print(f"Warning: did not recognize function message {type(msg)} {msg}")
|
||||
message = (
|
||||
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
|
||||
if self.fancy
|
||||
else f"[function] {msg}"
|
||||
)
|
||||
message = f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}"
|
||||
|
||||
if message:
|
||||
self.message_list.append(message)
|
||||
|
||||
@@ -55,11 +55,7 @@ def create_autogen_memgpt_agent(
|
||||
```
|
||||
"""
|
||||
interface = AutoGenInterface(**interface_kwargs) if interface is None else interface
|
||||
persistence_manager = (
|
||||
InMemoryStateManager(**persistence_manager_kwargs)
|
||||
if persistence_manager is None
|
||||
else persistence_manager
|
||||
)
|
||||
persistence_manager = InMemoryStateManager(**persistence_manager_kwargs) if persistence_manager is None else persistence_manager
|
||||
|
||||
memgpt_agent = presets.use_preset(
|
||||
preset,
|
||||
@@ -89,9 +85,7 @@ class MemGPTAgent(ConversableAgent):
|
||||
self.agent = agent
|
||||
self.skip_verify = skip_verify
|
||||
self.concat_other_agent_messages = concat_other_agent_messages
|
||||
self.register_reply(
|
||||
[Agent, None], MemGPTAgent._a_generate_reply_for_user_message
|
||||
)
|
||||
self.register_reply([Agent, None], MemGPTAgent._a_generate_reply_for_user_message)
|
||||
self.register_reply([Agent, None], MemGPTAgent._generate_reply_for_user_message)
|
||||
self.messages_processed_up_to_idx = 0
|
||||
|
||||
@@ -119,11 +113,7 @@ class MemGPTAgent(ConversableAgent):
|
||||
sender: Optional[Agent] = None,
|
||||
config: Optional[Any] = None,
|
||||
) -> Tuple[bool, Union[str, Dict, None]]:
|
||||
return asyncio.run(
|
||||
self._a_generate_reply_for_user_message(
|
||||
messages=messages, sender=sender, config=config
|
||||
)
|
||||
)
|
||||
return asyncio.run(self._a_generate_reply_for_user_message(messages=messages, sender=sender, config=config))
|
||||
|
||||
async def _a_generate_reply_for_user_message(
|
||||
self,
|
||||
@@ -137,9 +127,7 @@ class MemGPTAgent(ConversableAgent):
|
||||
if len(new_messages) > 1:
|
||||
if self.concat_other_agent_messages:
|
||||
# Combine all the other messages into one message
|
||||
user_message = "\n".join(
|
||||
[self.format_other_agent_message(m) for m in new_messages]
|
||||
)
|
||||
user_message = "\n".join([self.format_other_agent_message(m) for m in new_messages])
|
||||
else:
|
||||
# Extend the MemGPT message list with multiple 'user' messages, then push the last one with agent.step()
|
||||
self.agent.messages.extend(new_messages[:-1])
|
||||
@@ -157,16 +145,12 @@ class MemGPTAgent(ConversableAgent):
|
||||
heartbeat_request,
|
||||
function_failed,
|
||||
token_warning,
|
||||
) = await self.agent.step(
|
||||
user_message, first_message=False, skip_verify=self.skip_verify
|
||||
)
|
||||
) = await self.agent.step(user_message, first_message=False, skip_verify=self.skip_verify)
|
||||
# Skip user inputs if there's a memory warning, function execution failed, or the agent asked for control
|
||||
if token_warning:
|
||||
user_message = system.get_token_limit_warning()
|
||||
elif function_failed:
|
||||
user_message = system.get_heartbeat(
|
||||
constants.FUNC_FAILED_HEARTBEAT_MESSAGE
|
||||
)
|
||||
user_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE)
|
||||
elif heartbeat_request:
|
||||
user_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE)
|
||||
else:
|
||||
|
||||
@@ -79,12 +79,8 @@ class Config:
|
||||
cfg = Config.get_most_recent_config()
|
||||
use_cfg = False
|
||||
if cfg:
|
||||
print(
|
||||
f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Found saved config file.{Style.RESET_ALL}"
|
||||
)
|
||||
use_cfg = await questionary.confirm(
|
||||
f"Use most recent config file '{cfg}'?"
|
||||
).ask_async()
|
||||
print(f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Found saved config file.{Style.RESET_ALL}")
|
||||
use_cfg = await questionary.confirm(f"Use most recent config file '{cfg}'?").ask_async()
|
||||
if use_cfg:
|
||||
self.config_file = cfg
|
||||
|
||||
@@ -105,9 +101,7 @@ class Config:
|
||||
return self
|
||||
|
||||
# print("No settings file found, configuring MemGPT...")
|
||||
print(
|
||||
f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ No settings file found, configuring MemGPT...{Style.RESET_ALL}"
|
||||
)
|
||||
print(f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ No settings file found, configuring MemGPT...{Style.RESET_ALL}")
|
||||
|
||||
self.model = await questionary.select(
|
||||
"Which model would you like to use?",
|
||||
@@ -127,9 +121,7 @@ class Config:
|
||||
).ask_async()
|
||||
|
||||
self.archival_storage_index = None
|
||||
self.preload_archival = await questionary.confirm(
|
||||
"Would you like to preload anything into MemGPT's archival memory?"
|
||||
).ask_async()
|
||||
self.preload_archival = await questionary.confirm("Would you like to preload anything into MemGPT's archival memory?").ask_async()
|
||||
if self.preload_archival:
|
||||
self.load_type = await questionary.select(
|
||||
"What would you like to load?",
|
||||
@@ -140,19 +132,13 @@ class Config:
|
||||
],
|
||||
).ask_async()
|
||||
if self.load_type == "folder" or self.load_type == "sql":
|
||||
archival_storage_path = await questionary.path(
|
||||
"Please enter the folder or file (tab for autocomplete):"
|
||||
).ask_async()
|
||||
archival_storage_path = await questionary.path("Please enter the folder or file (tab for autocomplete):").ask_async()
|
||||
if os.path.isdir(archival_storage_path):
|
||||
self.archival_storage_files = os.path.join(
|
||||
archival_storage_path, "*"
|
||||
)
|
||||
self.archival_storage_files = os.path.join(archival_storage_path, "*")
|
||||
else:
|
||||
self.archival_storage_files = archival_storage_path
|
||||
else:
|
||||
self.archival_storage_files = await questionary.path(
|
||||
"Please enter the glob pattern (tab for autocomplete):"
|
||||
).ask_async()
|
||||
self.archival_storage_files = await questionary.path("Please enter the glob pattern (tab for autocomplete):").ask_async()
|
||||
self.compute_embeddings = await questionary.confirm(
|
||||
"Would you like to compute embeddings over these files to enable embeddings search?"
|
||||
).ask_async()
|
||||
@@ -168,19 +154,11 @@ class Config:
|
||||
"⛔️ Embeddings on a non-OpenAI endpoint are not yet supported, falling back to substring matching search."
|
||||
)
|
||||
else:
|
||||
self.archival_storage_index = (
|
||||
await utils.prepare_archival_index_from_files_compute_embeddings(
|
||||
self.archival_storage_files
|
||||
)
|
||||
)
|
||||
self.archival_storage_index = await utils.prepare_archival_index_from_files_compute_embeddings(self.archival_storage_files)
|
||||
if self.compute_embeddings and self.archival_storage_index:
|
||||
self.index, self.archival_database = utils.prepare_archival_index(
|
||||
self.archival_storage_index
|
||||
)
|
||||
self.index, self.archival_database = utils.prepare_archival_index(self.archival_storage_index)
|
||||
else:
|
||||
self.archival_database = utils.prepare_archival_index_from_files(
|
||||
self.archival_storage_files
|
||||
)
|
||||
self.archival_database = utils.prepare_archival_index_from_files(self.archival_storage_files)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
@@ -217,15 +195,11 @@ class Config:
|
||||
configs_dir = Config.configs_dir
|
||||
os.makedirs(configs_dir, exist_ok=True)
|
||||
if self.config_file is None:
|
||||
filename = os.path.join(
|
||||
configs_dir, utils.get_local_time().replace(" ", "_").replace(":", "_")
|
||||
)
|
||||
filename = os.path.join(configs_dir, utils.get_local_time().replace(" ", "_").replace(":", "_"))
|
||||
self.config_file = f"{filename}.json"
|
||||
with open(self.config_file, "wt") as f:
|
||||
json.dump(self.to_dict(), f, indent=4)
|
||||
print(
|
||||
f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Saved config file to {self.config_file}.{Style.RESET_ALL}"
|
||||
)
|
||||
print(f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Saved config file to {self.config_file}.{Style.RESET_ALL}")
|
||||
|
||||
@staticmethod
|
||||
def is_valid_config_file(file: str):
|
||||
@@ -234,9 +208,7 @@ class Config:
|
||||
cfg.load_config(file)
|
||||
except Exception:
|
||||
return False
|
||||
return (
|
||||
cfg.memgpt_persona is not None and cfg.human_persona is not None
|
||||
) # TODO: more validation for configs
|
||||
return cfg.memgpt_persona is not None and cfg.human_persona is not None # TODO: more validation for configs
|
||||
|
||||
@staticmethod
|
||||
def get_memgpt_personas():
|
||||
@@ -331,8 +303,7 @@ class Config:
|
||||
files = [
|
||||
os.path.join(configs_dir, f)
|
||||
for f in os.listdir(configs_dir)
|
||||
if os.path.isfile(os.path.join(configs_dir, f))
|
||||
and Config.is_valid_config_file(os.path.join(configs_dir, f))
|
||||
if os.path.isfile(os.path.join(configs_dir, f)) and Config.is_valid_config_file(os.path.join(configs_dir, f))
|
||||
]
|
||||
# Return the file with the most recent modification time
|
||||
if len(files) == 0:
|
||||
|
||||
@@ -7,9 +7,7 @@ DEFAULT_MEMGPT_MODEL = "gpt-4"
|
||||
FIRST_MESSAGE_ATTEMPTS = 10
|
||||
|
||||
INITIAL_BOOT_MESSAGE = "Boot sequence complete. Persona activated."
|
||||
INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT = (
|
||||
"Bootup sequence complete. Persona activated. Testing messaging functionality."
|
||||
)
|
||||
INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT = "Bootup sequence complete. Persona activated. Testing messaging functionality."
|
||||
STARTUP_QUOTES = [
|
||||
"I think, therefore I am.",
|
||||
"All those moments will be lost in time, like tears in rain.",
|
||||
@@ -28,9 +26,7 @@ CORE_MEMORY_HUMAN_CHAR_LIMIT = 2000
|
||||
MAX_PAUSE_HEARTBEATS = 360 # in min
|
||||
|
||||
MESSAGE_CHATGPT_FUNCTION_MODEL = "gpt-3.5-turbo"
|
||||
MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE = (
|
||||
"You are a helpful assistant. Keep your responses short and concise."
|
||||
)
|
||||
MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE = "You are a helpful assistant. Keep your responses short and concise."
|
||||
|
||||
#### Functions related
|
||||
|
||||
|
||||
@@ -29,15 +29,11 @@ async def assistant_message(msg):
|
||||
|
||||
|
||||
async def memory_message(msg):
|
||||
print(
|
||||
f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{msg}{Style.RESET_ALL}"
|
||||
)
|
||||
print(f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{msg}{Style.RESET_ALL}")
|
||||
|
||||
|
||||
async def system_message(msg):
|
||||
printd(
|
||||
f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}"
|
||||
)
|
||||
printd(f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}")
|
||||
|
||||
|
||||
async def user_message(msg, raw=False):
|
||||
@@ -50,9 +46,7 @@ async def user_message(msg, raw=False):
|
||||
msg_json = json.loads(msg)
|
||||
except:
|
||||
printd(f"Warning: failed to parse user message into json")
|
||||
printd(
|
||||
f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}"
|
||||
)
|
||||
printd(f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}")
|
||||
return
|
||||
|
||||
if msg_json["type"] == "user_message":
|
||||
@@ -61,9 +55,7 @@ async def user_message(msg, raw=False):
|
||||
elif msg_json["type"] == "heartbeat":
|
||||
if DEBUG:
|
||||
msg_json.pop("type")
|
||||
printd(
|
||||
f"{Fore.GREEN}{Style.BRIGHT}💓 {Fore.GREEN}{msg_json}{Style.RESET_ALL}"
|
||||
)
|
||||
printd(f"{Fore.GREEN}{Style.BRIGHT}💓 {Fore.GREEN}{msg_json}{Style.RESET_ALL}")
|
||||
elif msg_json["type"] == "system_message":
|
||||
msg_json.pop("type")
|
||||
printd(f"{Fore.GREEN}{Style.BRIGHT}🖥️ {Fore.GREEN}{msg_json}{Style.RESET_ALL}")
|
||||
@@ -77,33 +69,23 @@ async def function_message(msg):
|
||||
return
|
||||
|
||||
if msg.startswith("Success: "):
|
||||
printd(
|
||||
f"{Fore.RED}{Style.BRIGHT}⚡🟢 [function] {Fore.RED}{msg}{Style.RESET_ALL}"
|
||||
)
|
||||
printd(f"{Fore.RED}{Style.BRIGHT}⚡🟢 [function] {Fore.RED}{msg}{Style.RESET_ALL}")
|
||||
elif msg.startswith("Error: "):
|
||||
printd(
|
||||
f"{Fore.RED}{Style.BRIGHT}⚡🔴 [function] {Fore.RED}{msg}{Style.RESET_ALL}"
|
||||
)
|
||||
printd(f"{Fore.RED}{Style.BRIGHT}⚡🔴 [function] {Fore.RED}{msg}{Style.RESET_ALL}")
|
||||
elif msg.startswith("Running "):
|
||||
if DEBUG:
|
||||
printd(
|
||||
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
|
||||
)
|
||||
printd(f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}")
|
||||
else:
|
||||
if "memory" in msg:
|
||||
match = re.search(r"Running (\w+)\((.*)\)", msg)
|
||||
if match:
|
||||
function_name = match.group(1)
|
||||
function_args = match.group(2)
|
||||
print(
|
||||
f"{Fore.RED}{Style.BRIGHT}⚡🧠 [function] {Fore.RED}updating memory with {function_name}{Style.RESET_ALL}:"
|
||||
)
|
||||
print(f"{Fore.RED}{Style.BRIGHT}⚡🧠 [function] {Fore.RED}updating memory with {function_name}{Style.RESET_ALL}:")
|
||||
try:
|
||||
msg_dict = eval(function_args)
|
||||
if function_name == "archival_memory_search":
|
||||
print(
|
||||
f'{Fore.RED}\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}'
|
||||
)
|
||||
print(f'{Fore.RED}\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}')
|
||||
else:
|
||||
print(
|
||||
f'{Fore.RED}{Style.BRIGHT}\t{Fore.RED} {msg_dict["old_content"]}\n\t{Fore.GREEN}→ {msg_dict["new_content"]}'
|
||||
@@ -114,28 +96,20 @@ async def function_message(msg):
|
||||
pass
|
||||
else:
|
||||
printd(f"Warning: did not recognize function message")
|
||||
printd(
|
||||
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
|
||||
)
|
||||
printd(f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}")
|
||||
elif "send_message" in msg:
|
||||
# ignore in debug mode
|
||||
pass
|
||||
else:
|
||||
printd(
|
||||
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
|
||||
)
|
||||
printd(f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}")
|
||||
else:
|
||||
try:
|
||||
msg_dict = json.loads(msg)
|
||||
if "status" in msg_dict and msg_dict["status"] == "OK":
|
||||
printd(
|
||||
f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}"
|
||||
)
|
||||
printd(f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}")
|
||||
except Exception:
|
||||
printd(f"Warning: did not recognize function message {type(msg)} {msg}")
|
||||
printd(
|
||||
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
|
||||
)
|
||||
printd(f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}")
|
||||
|
||||
|
||||
async def print_messages(message_sequence):
|
||||
|
||||
@@ -190,9 +190,7 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
|
||||
function_parameters = function_json_output["params"]
|
||||
|
||||
if self.clean_func_args:
|
||||
function_name, function_parameters = self.clean_function_args(
|
||||
function_name, function_parameters
|
||||
)
|
||||
function_name, function_parameters = self.clean_function_args(function_name, function_parameters)
|
||||
|
||||
message = {
|
||||
"role": "assistant",
|
||||
@@ -275,9 +273,7 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
|
||||
func_str += f"\n description: {schema['description']}"
|
||||
func_str += f"\n params:"
|
||||
if add_inner_thoughts:
|
||||
func_str += (
|
||||
f"\n inner_thoughts: Deep inner monologue private to you only."
|
||||
)
|
||||
func_str += f"\n inner_thoughts: Deep inner monologue private to you only."
|
||||
for param_k, param_v in schema["parameters"]["properties"].items():
|
||||
# TODO we're ignoring type
|
||||
func_str += f"\n {param_k}: {param_v['description']}"
|
||||
|
||||
@@ -152,9 +152,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
|
||||
try:
|
||||
content_json = json.loads(message["content"])
|
||||
content_simple = content_json["message"]
|
||||
prompt += (
|
||||
f"\n{IM_START_TOKEN}user\n{content_simple}{IM_END_TOKEN}"
|
||||
)
|
||||
prompt += f"\n{IM_START_TOKEN}user\n{content_simple}{IM_END_TOKEN}"
|
||||
# prompt += f"\nUSER: {content_simple}"
|
||||
except:
|
||||
prompt += f"\n{IM_START_TOKEN}user\n{message['content']}{IM_END_TOKEN}"
|
||||
@@ -227,9 +225,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
|
||||
function_parameters = function_json_output["params"]
|
||||
|
||||
if self.clean_func_args:
|
||||
function_name, function_parameters = self.clean_function_args(
|
||||
function_name, function_parameters
|
||||
)
|
||||
function_name, function_parameters = self.clean_function_args(function_name, function_parameters)
|
||||
|
||||
message = {
|
||||
"role": "assistant",
|
||||
|
||||
@@ -83,12 +83,8 @@ def load(memgpt_agent, filename):
|
||||
print(f"Loading {filename} failed with: {e}")
|
||||
else:
|
||||
# Load the latest file
|
||||
print(
|
||||
f"/load warning: no checkpoint specified, loading most recent checkpoint instead"
|
||||
)
|
||||
json_files = glob.glob(
|
||||
"saved_state/*.json"
|
||||
) # This will list all .json files in the current directory.
|
||||
print(f"/load warning: no checkpoint specified, loading most recent checkpoint instead")
|
||||
json_files = glob.glob("saved_state/*.json") # This will list all .json files in the current directory.
|
||||
|
||||
# Check if there are any json files.
|
||||
if not json_files:
|
||||
@@ -110,27 +106,17 @@ def load(memgpt_agent, filename):
|
||||
) # TODO(fixme):for different types of persistence managers that require different load/save methods
|
||||
print(f"Loaded persistence manager from {filename}")
|
||||
except Exception as e:
|
||||
print(
|
||||
f"/load warning: loading persistence manager from {filename} failed with: {e}"
|
||||
)
|
||||
print(f"/load warning: loading persistence manager from {filename} failed with: {e}")
|
||||
|
||||
|
||||
@app.command()
|
||||
def run(
|
||||
persona: str = typer.Option(None, help="Specify persona"),
|
||||
human: str = typer.Option(None, help="Specify human"),
|
||||
model: str = typer.Option(
|
||||
constants.DEFAULT_MEMGPT_MODEL, help="Specify the LLM model"
|
||||
),
|
||||
first: bool = typer.Option(
|
||||
False, "--first", help="Use --first to send the first message in the sequence"
|
||||
),
|
||||
debug: bool = typer.Option(
|
||||
False, "--debug", help="Use --debug to enable debugging output"
|
||||
),
|
||||
no_verify: bool = typer.Option(
|
||||
False, "--no_verify", help="Bypass message verification"
|
||||
),
|
||||
model: str = typer.Option(constants.DEFAULT_MEMGPT_MODEL, help="Specify the LLM model"),
|
||||
first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"),
|
||||
debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"),
|
||||
no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"),
|
||||
archival_storage_faiss_path: str = typer.Option(
|
||||
"",
|
||||
"--archival_storage_faiss_path",
|
||||
@@ -200,9 +186,7 @@ async def main(
|
||||
else:
|
||||
azure_vars = get_set_azure_env_vars()
|
||||
if len(azure_vars) > 0:
|
||||
print(
|
||||
f"Error: Environment variables {', '.join([x[0] for x in azure_vars])} should not be set if --use_azure_openai is False"
|
||||
)
|
||||
print(f"Error: Environment variables {', '.join([x[0] for x in azure_vars])} should not be set if --use_azure_openai is False")
|
||||
return
|
||||
|
||||
if any(
|
||||
@@ -295,23 +279,17 @@ async def main(
|
||||
else:
|
||||
cfg = await Config.config_init()
|
||||
|
||||
memgpt.interface.important_message(
|
||||
"Running... [exit by typing '/exit', list available commands with '/help']"
|
||||
)
|
||||
memgpt.interface.important_message("Running... [exit by typing '/exit', list available commands with '/help']")
|
||||
if cfg.model != constants.DEFAULT_MEMGPT_MODEL:
|
||||
memgpt.interface.warning_message(
|
||||
f"⛔️ Warning - you are running MemGPT with {cfg.model}, which is not officially supported (yet). Expect bugs!"
|
||||
)
|
||||
|
||||
if cfg.index:
|
||||
persistence_manager = InMemoryStateManagerWithFaiss(
|
||||
cfg.index, cfg.archival_database
|
||||
)
|
||||
persistence_manager = InMemoryStateManagerWithFaiss(cfg.index, cfg.archival_database)
|
||||
elif cfg.archival_storage_files:
|
||||
print(f"Preloaded {len(cfg.archival_database)} chunks into archival memory.")
|
||||
persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(
|
||||
cfg.archival_database
|
||||
)
|
||||
persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(cfg.archival_database)
|
||||
else:
|
||||
persistence_manager = InMemoryStateManager()
|
||||
|
||||
@@ -355,9 +333,7 @@ async def main(
|
||||
print(f"Database loaded into archival memory.")
|
||||
|
||||
if cfg.agent_save_file:
|
||||
load_save_file = await questionary.confirm(
|
||||
f"Load in saved agent '{cfg.agent_save_file}'?"
|
||||
).ask_async()
|
||||
load_save_file = await questionary.confirm(f"Load in saved agent '{cfg.agent_save_file}'?").ask_async()
|
||||
if load_save_file:
|
||||
load(memgpt_agent, cfg.agent_save_file)
|
||||
|
||||
@@ -366,9 +342,7 @@ async def main(
|
||||
return
|
||||
|
||||
if not USER_GOES_FIRST:
|
||||
console.input(
|
||||
"[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]"
|
||||
)
|
||||
console.input("[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]")
|
||||
clear_line()
|
||||
print()
|
||||
|
||||
@@ -404,9 +378,7 @@ async def main(
|
||||
break
|
||||
|
||||
elif user_input.lower() == "/savechat":
|
||||
filename = (
|
||||
utils.get_local_time().replace(" ", "_").replace(":", "_")
|
||||
)
|
||||
filename = utils.get_local_time().replace(" ", "_").replace(":", "_")
|
||||
filename = f"{filename}.pkl"
|
||||
directory = os.path.join(MEMGPT_DIR, "saved_chats")
|
||||
try:
|
||||
@@ -423,9 +395,7 @@ async def main(
|
||||
save(memgpt_agent=memgpt_agent, cfg=cfg)
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/load" or user_input.lower().startswith(
|
||||
"/load "
|
||||
):
|
||||
elif user_input.lower() == "/load" or user_input.lower().startswith("/load "):
|
||||
command = user_input.strip().split()
|
||||
filename = command[1] if len(command) > 1 else None
|
||||
load(memgpt_agent=memgpt_agent, filename=filename)
|
||||
@@ -458,16 +428,10 @@ async def main(
|
||||
print(f"Updated model to:\n{str(memgpt_agent.model)}")
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/pop" or user_input.lower().startswith(
|
||||
"/pop "
|
||||
):
|
||||
elif user_input.lower() == "/pop" or user_input.lower().startswith("/pop "):
|
||||
# Check if there's an additional argument that's an integer
|
||||
command = user_input.strip().split()
|
||||
amount = (
|
||||
int(command[1])
|
||||
if len(command) > 1 and command[1].isdigit()
|
||||
else 2
|
||||
)
|
||||
amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 2
|
||||
print(f"Popping last {amount} messages from stack")
|
||||
for _ in range(min(amount, len(memgpt_agent.messages))):
|
||||
memgpt_agent.messages.pop()
|
||||
@@ -512,18 +476,14 @@ async def main(
|
||||
heartbeat_request,
|
||||
function_failed,
|
||||
token_warning,
|
||||
) = await memgpt_agent.step(
|
||||
user_message, first_message=False, skip_verify=no_verify
|
||||
)
|
||||
) = await memgpt_agent.step(user_message, first_message=False, skip_verify=no_verify)
|
||||
|
||||
# Skip user inputs if there's a memory warning, function execution failed, or the agent asked for control
|
||||
if token_warning:
|
||||
user_message = system.get_token_limit_warning()
|
||||
skip_next_user_input = True
|
||||
elif function_failed:
|
||||
user_message = system.get_heartbeat(
|
||||
constants.FUNC_FAILED_HEARTBEAT_MESSAGE
|
||||
)
|
||||
user_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE)
|
||||
skip_next_user_input = True
|
||||
elif heartbeat_request:
|
||||
user_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE)
|
||||
|
||||
208
memgpt/memory.py
208
memgpt/memory.py
@@ -28,20 +28,17 @@ class CoreMemory(object):
|
||||
self.archival_memory_exists = archival_memory_exists
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return \
|
||||
f"\n### CORE MEMORY ###" + \
|
||||
f"\n=== Persona ===\n{self.persona}" + \
|
||||
f"\n\n=== Human ===\n{self.human}"
|
||||
return f"\n### CORE MEMORY ###" + f"\n=== Persona ===\n{self.persona}" + f"\n\n=== Human ===\n{self.human}"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'persona': self.persona,
|
||||
'human': self.human,
|
||||
"persona": self.persona,
|
||||
"human": self.human,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def load(cls, state):
|
||||
return cls(state['persona'], state['human'])
|
||||
return cls(state["persona"], state["human"])
|
||||
|
||||
def edit_persona(self, new_persona):
|
||||
if self.persona_char_limit and len(new_persona) > self.persona_char_limit:
|
||||
@@ -64,53 +61,55 @@ class CoreMemory(object):
|
||||
return len(self.human)
|
||||
|
||||
def edit(self, field, content):
|
||||
if field == 'persona':
|
||||
if field == "persona":
|
||||
return self.edit_persona(content)
|
||||
elif field == 'human':
|
||||
elif field == "human":
|
||||
return self.edit_human(content)
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
def edit_append(self, field, content, sep='\n'):
|
||||
if field == 'persona':
|
||||
def edit_append(self, field, content, sep="\n"):
|
||||
if field == "persona":
|
||||
new_content = self.persona + sep + content
|
||||
return self.edit_persona(new_content)
|
||||
elif field == 'human':
|
||||
elif field == "human":
|
||||
new_content = self.human + sep + content
|
||||
return self.edit_human(new_content)
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
def edit_replace(self, field, old_content, new_content):
|
||||
if field == 'persona':
|
||||
if field == "persona":
|
||||
if old_content in self.persona:
|
||||
new_persona = self.persona.replace(old_content, new_content)
|
||||
return self.edit_persona(new_persona)
|
||||
else:
|
||||
raise ValueError('Content not found in persona (make sure to use exact string)')
|
||||
elif field == 'human':
|
||||
raise ValueError("Content not found in persona (make sure to use exact string)")
|
||||
elif field == "human":
|
||||
if old_content in self.human:
|
||||
new_human = self.human.replace(old_content, new_content)
|
||||
return self.edit_human(new_human)
|
||||
else:
|
||||
raise ValueError('Content not found in human (make sure to use exact string)')
|
||||
raise ValueError("Content not found in human (make sure to use exact string)")
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
|
||||
async def summarize_messages(
|
||||
model,
|
||||
message_sequence_to_summarize,
|
||||
):
|
||||
model,
|
||||
message_sequence_to_summarize,
|
||||
):
|
||||
"""Summarize a message sequence using GPT"""
|
||||
|
||||
summary_prompt = SUMMARY_PROMPT_SYSTEM
|
||||
summary_input = str(message_sequence_to_summarize)
|
||||
summary_input_tkns = count_tokens(summary_input, model)
|
||||
if summary_input_tkns > MESSAGE_SUMMARY_WARNING_TOKENS:
|
||||
trunc_ratio = (MESSAGE_SUMMARY_WARNING_TOKENS / summary_input_tkns) * 0.8 # For good measure...
|
||||
trunc_ratio = (MESSAGE_SUMMARY_WARNING_TOKENS / summary_input_tkns) * 0.8 # For good measure...
|
||||
cutoff = int(len(message_sequence_to_summarize) * trunc_ratio)
|
||||
summary_input = str([await summarize_messages(model, message_sequence_to_summarize[:cutoff])] + message_sequence_to_summarize[cutoff:])
|
||||
summary_input = str(
|
||||
[await summarize_messages(model, message_sequence_to_summarize[:cutoff])] + message_sequence_to_summarize[cutoff:]
|
||||
)
|
||||
message_sequence = [
|
||||
{"role": "system", "content": summary_prompt},
|
||||
{"role": "user", "content": summary_input},
|
||||
@@ -127,7 +126,6 @@ async def summarize_messages(
|
||||
|
||||
|
||||
class ArchivalMemory(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, memory_string):
|
||||
pass
|
||||
@@ -150,7 +148,7 @@ class DummyArchivalMemory(ArchivalMemory):
|
||||
"""
|
||||
|
||||
def __init__(self, archival_memory_database=None):
|
||||
self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
|
||||
self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
|
||||
|
||||
def __len__(self):
|
||||
return len(self._archive)
|
||||
@@ -159,31 +157,33 @@ class DummyArchivalMemory(ArchivalMemory):
|
||||
if len(self._archive) == 0:
|
||||
memory_str = "<empty>"
|
||||
else:
|
||||
memory_str = "\n".join([d['content'] for d in self._archive])
|
||||
return \
|
||||
f"\n### ARCHIVAL MEMORY ###" + \
|
||||
f"\n{memory_str}"
|
||||
memory_str = "\n".join([d["content"] for d in self._archive])
|
||||
return f"\n### ARCHIVAL MEMORY ###" + f"\n{memory_str}"
|
||||
|
||||
async def insert(self, memory_string, embedding=None):
|
||||
if embedding is not None:
|
||||
raise ValueError('Basic text-based archival memory does not support embeddings')
|
||||
self._archive.append({
|
||||
# can eventually upgrade to adding semantic tags, etc
|
||||
'timestamp': get_local_time(),
|
||||
'content': memory_string,
|
||||
})
|
||||
raise ValueError("Basic text-based archival memory does not support embeddings")
|
||||
self._archive.append(
|
||||
{
|
||||
# can eventually upgrade to adding semantic tags, etc
|
||||
"timestamp": get_local_time(),
|
||||
"content": memory_string,
|
||||
}
|
||||
)
|
||||
|
||||
async def search(self, query_string, count=None, start=None):
|
||||
"""Simple text-based search"""
|
||||
# in the dummy version, run an (inefficient) case-insensitive match search
|
||||
# printd(f"query_string: {query_string}")
|
||||
matches = [s for s in self._archive if query_string.lower() in s['content'].lower()]
|
||||
matches = [s for s in self._archive if query_string.lower() in s["content"].lower()]
|
||||
# printd(f"archive_memory.search (text-based): search for query '{query_string}' returned the following results (limit 5):\n{[str(d['content']) d in matches[:5]]}")
|
||||
printd(f"archive_memory.search (text-based): search for query '{query_string}' returned the following results (limit 5):\n{[matches[start:count]]}")
|
||||
printd(
|
||||
f"archive_memory.search (text-based): search for query '{query_string}' returned the following results (limit 5):\n{[matches[start:count]]}"
|
||||
)
|
||||
|
||||
# start/count support paging through results
|
||||
if start is not None and count is not None:
|
||||
return matches[start:start+count], len(matches)
|
||||
return matches[start : start + count], len(matches)
|
||||
elif start is None and count is not None:
|
||||
return matches[:count], len(matches)
|
||||
elif start is not None and count is None:
|
||||
@@ -195,8 +195,8 @@ class DummyArchivalMemory(ArchivalMemory):
|
||||
class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory):
|
||||
"""Same as dummy in-memory archival memory, but with bare-bones embedding support"""
|
||||
|
||||
def __init__(self, archival_memory_database=None, embedding_model='text-embedding-ada-002'):
|
||||
self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
|
||||
def __init__(self, archival_memory_database=None, embedding_model="text-embedding-ada-002"):
|
||||
self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
def __len__(self):
|
||||
@@ -206,15 +206,17 @@ class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory):
|
||||
# Get the embedding
|
||||
if embedding is None:
|
||||
embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
|
||||
embedding_meta = {'model': self.embedding_model}
|
||||
embedding_meta = {"model": self.embedding_model}
|
||||
printd(f"Got an embedding, type {type(embedding)}, len {len(embedding)}")
|
||||
|
||||
self._archive.append({
|
||||
'timestamp': get_local_time(),
|
||||
'content': memory_string,
|
||||
'embedding': embedding,
|
||||
'embedding_metadata': embedding_meta,
|
||||
})
|
||||
self._archive.append(
|
||||
{
|
||||
"timestamp": get_local_time(),
|
||||
"content": memory_string,
|
||||
"embedding": embedding,
|
||||
"embedding_metadata": embedding_meta,
|
||||
}
|
||||
)
|
||||
|
||||
async def search(self, query_string, count=None, start=None):
|
||||
"""Simple embedding-based search (inefficient, no caching)"""
|
||||
@@ -223,22 +225,24 @@ class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory):
|
||||
# query_embedding = get_embedding(query_string, model=self.embedding_model)
|
||||
# our wrapped version supports backoff/rate-limits
|
||||
query_embedding = await async_get_embedding_with_backoff(query_string, model=self.embedding_model)
|
||||
similarity_scores = [cosine_similarity(memory['embedding'], query_embedding) for memory in self._archive]
|
||||
similarity_scores = [cosine_similarity(memory["embedding"], query_embedding) for memory in self._archive]
|
||||
|
||||
# Sort the archive based on similarity scores
|
||||
sorted_archive_with_scores = sorted(
|
||||
zip(self._archive, similarity_scores),
|
||||
key=lambda pair: pair[1], # Sort by the similarity score
|
||||
reverse=True # We want the highest similarity first
|
||||
reverse=True, # We want the highest similarity first
|
||||
)
|
||||
printd(
|
||||
f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results (limit 5) and scores:\n{str([str(t[0]['content']) + '- score ' + str(t[1]) for t in sorted_archive_with_scores[:5]])}"
|
||||
)
|
||||
printd(f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results (limit 5) and scores:\n{str([str(t[0]['content']) + '- score ' + str(t[1]) for t in sorted_archive_with_scores[:5]])}")
|
||||
|
||||
# Extract the sorted archive without the scores
|
||||
matches = [item[0] for item in sorted_archive_with_scores]
|
||||
|
||||
# start/count support paging through results
|
||||
if start is not None and count is not None:
|
||||
return matches[start:start+count], len(matches)
|
||||
return matches[start : start + count], len(matches)
|
||||
elif start is None and count is not None:
|
||||
return matches[:count], len(matches)
|
||||
elif start is not None and count is None:
|
||||
@@ -259,13 +263,13 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
|
||||
is essential enough not to be left only to the recall memory.
|
||||
"""
|
||||
|
||||
def __init__(self, index=None, archival_memory_database=None, embedding_model='text-embedding-ada-002', k=100):
|
||||
def __init__(self, index=None, archival_memory_database=None, embedding_model="text-embedding-ada-002", k=100):
|
||||
if index is None:
|
||||
self.index = faiss.IndexFlatL2(1536) # openai embedding vector size.
|
||||
self.index = faiss.IndexFlatL2(1536) # openai embedding vector size.
|
||||
else:
|
||||
self.index = index
|
||||
self.k = k
|
||||
self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
|
||||
self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts
|
||||
self.embedding_model = embedding_model
|
||||
self.embeddings_dict = {}
|
||||
self.search_results = {}
|
||||
@@ -279,12 +283,14 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
|
||||
embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
|
||||
print(f"Got an embedding, type {type(embedding)}, len {len(embedding)}")
|
||||
|
||||
self._archive.append({
|
||||
# can eventually upgrade to adding semantic tags, etc
|
||||
'timestamp': get_local_time(),
|
||||
'content': memory_string,
|
||||
})
|
||||
embedding = np.array([embedding]).astype('float32')
|
||||
self._archive.append(
|
||||
{
|
||||
# can eventually upgrade to adding semantic tags, etc
|
||||
"timestamp": get_local_time(),
|
||||
"content": memory_string,
|
||||
}
|
||||
)
|
||||
embedding = np.array([embedding]).astype("float32")
|
||||
self.index.add(embedding)
|
||||
|
||||
async def search(self, query_string, count=None, start=None):
|
||||
@@ -304,20 +310,22 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
|
||||
self.search_results[query_string] = search_result
|
||||
|
||||
if start is not None and count is not None:
|
||||
toprint = search_result[start:start+count]
|
||||
toprint = search_result[start : start + count]
|
||||
else:
|
||||
if len(search_result) >= 5:
|
||||
toprint = search_result[:5]
|
||||
else:
|
||||
toprint = search_result
|
||||
printd(f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results ({start}--{start+5}/{len(search_result)}) and scores:\n{str([t[:60] if len(t) > 60 else t for t in toprint])}")
|
||||
printd(
|
||||
f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results ({start}--{start+5}/{len(search_result)}) and scores:\n{str([t[:60] if len(t) > 60 else t for t in toprint])}"
|
||||
)
|
||||
|
||||
# Extract the sorted archive without the scores
|
||||
matches = search_result
|
||||
|
||||
# start/count support paging through results
|
||||
if start is not None and count is not None:
|
||||
return matches[start:start+count], len(matches)
|
||||
return matches[start : start + count], len(matches)
|
||||
elif start is None and count is not None:
|
||||
return matches[:count], len(matches)
|
||||
elif start is not None and count is None:
|
||||
@@ -327,7 +335,6 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
|
||||
|
||||
|
||||
class RecallMemory(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def text_search(self, query_string, count=None, start=None):
|
||||
pass
|
||||
@@ -365,42 +372,46 @@ class DummyRecallMemory(RecallMemory):
|
||||
# don't dump all the conversations, just statistics
|
||||
system_count = user_count = assistant_count = function_count = other_count = 0
|
||||
for msg in self._message_logs:
|
||||
role = msg['message']['role']
|
||||
if role == 'system':
|
||||
role = msg["message"]["role"]
|
||||
if role == "system":
|
||||
system_count += 1
|
||||
elif role == 'user':
|
||||
elif role == "user":
|
||||
user_count += 1
|
||||
elif role == 'assistant':
|
||||
elif role == "assistant":
|
||||
assistant_count += 1
|
||||
elif role == 'function':
|
||||
elif role == "function":
|
||||
function_count += 1
|
||||
else:
|
||||
other_count += 1
|
||||
memory_str = f"Statistics:" + \
|
||||
f"\n{len(self._message_logs)} total messages" + \
|
||||
f"\n{system_count} system" + \
|
||||
f"\n{user_count} user" + \
|
||||
f"\n{assistant_count} assistant" + \
|
||||
f"\n{function_count} function" + \
|
||||
f"\n{other_count} other"
|
||||
return \
|
||||
f"\n### RECALL MEMORY ###" + \
|
||||
f"\n{memory_str}"
|
||||
memory_str = (
|
||||
f"Statistics:"
|
||||
+ f"\n{len(self._message_logs)} total messages"
|
||||
+ f"\n{system_count} system"
|
||||
+ f"\n{user_count} user"
|
||||
+ f"\n{assistant_count} assistant"
|
||||
+ f"\n{function_count} function"
|
||||
+ f"\n{other_count} other"
|
||||
)
|
||||
return f"\n### RECALL MEMORY ###" + f"\n{memory_str}"
|
||||
|
||||
async def insert(self, message):
|
||||
raise NotImplementedError('This should be handled by the PersistenceManager, recall memory is just a search layer on top')
|
||||
raise NotImplementedError("This should be handled by the PersistenceManager, recall memory is just a search layer on top")
|
||||
|
||||
async def text_search(self, query_string, count=None, start=None):
|
||||
# in the dummy version, run an (inefficient) case-insensitive match search
|
||||
message_pool = [d for d in self._message_logs if d['message']['role'] not in ['system', 'function']]
|
||||
message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
|
||||
|
||||
printd(f"recall_memory.text_search: searching for {query_string} (c={count}, s={start}) in {len(self._message_logs)} total messages")
|
||||
matches = [d for d in message_pool if d['message']['content'] is not None and query_string.lower() in d['message']['content'].lower()]
|
||||
printd(
|
||||
f"recall_memory.text_search: searching for {query_string} (c={count}, s={start}) in {len(self._message_logs)} total messages"
|
||||
)
|
||||
matches = [
|
||||
d for d in message_pool if d["message"]["content"] is not None and query_string.lower() in d["message"]["content"].lower()
|
||||
]
|
||||
printd(f"recall_memory - matches:\n{matches[start:start+count]}")
|
||||
|
||||
# start/count support paging through results
|
||||
if start is not None and count is not None:
|
||||
return matches[start:start+count], len(matches)
|
||||
return matches[start : start + count], len(matches)
|
||||
elif start is None and count is not None:
|
||||
return matches[:count], len(matches)
|
||||
elif start is not None and count is None:
|
||||
@@ -411,7 +422,7 @@ class DummyRecallMemory(RecallMemory):
|
||||
def _validate_date_format(self, date_str):
|
||||
"""Validate the given date string in the format 'YYYY-MM-DD'."""
|
||||
try:
|
||||
datetime.datetime.strptime(date_str, '%Y-%m-%d')
|
||||
datetime.datetime.strptime(date_str, "%Y-%m-%d")
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
@@ -423,25 +434,26 @@ class DummyRecallMemory(RecallMemory):
|
||||
return match.group(1) if match else None
|
||||
|
||||
async def date_search(self, start_date, end_date, count=None, start=None):
|
||||
message_pool = [d for d in self._message_logs if d['message']['role'] not in ['system', 'function']]
|
||||
message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
|
||||
|
||||
# First, validate the start_date and end_date format
|
||||
if not self._validate_date_format(start_date) or not self._validate_date_format(end_date):
|
||||
raise ValueError("Invalid date format. Expected format: YYYY-MM-DD")
|
||||
|
||||
# Convert dates to datetime objects for comparison
|
||||
start_date_dt = datetime.datetime.strptime(start_date, '%Y-%m-%d')
|
||||
end_date_dt = datetime.datetime.strptime(end_date, '%Y-%m-%d')
|
||||
start_date_dt = datetime.datetime.strptime(start_date, "%Y-%m-%d")
|
||||
end_date_dt = datetime.datetime.strptime(end_date, "%Y-%m-%d")
|
||||
|
||||
# Next, match items inside self._message_logs
|
||||
matches = [
|
||||
d for d in message_pool
|
||||
if start_date_dt <= datetime.datetime.strptime(self._extract_date_from_timestamp(d['timestamp']), '%Y-%m-%d') <= end_date_dt
|
||||
d
|
||||
for d in message_pool
|
||||
if start_date_dt <= datetime.datetime.strptime(self._extract_date_from_timestamp(d["timestamp"]), "%Y-%m-%d") <= end_date_dt
|
||||
]
|
||||
|
||||
# start/count support paging through results
|
||||
if start is not None and count is not None:
|
||||
return matches[start:start+count], len(matches)
|
||||
return matches[start : start + count], len(matches)
|
||||
elif start is None and count is not None:
|
||||
return matches[:count], len(matches)
|
||||
elif start is not None and count is None:
|
||||
@@ -456,17 +468,17 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.embeddings = dict()
|
||||
self.embedding_model = 'text-embedding-ada-002'
|
||||
self.embedding_model = "text-embedding-ada-002"
|
||||
self.only_use_preloaded_embeddings = False
|
||||
|
||||
async def text_search(self, query_string, count=None, start=None):
|
||||
# in the dummy version, run an (inefficient) case-insensitive match search
|
||||
message_pool = [d for d in self._message_logs if d['message']['role'] not in ['system', 'function']]
|
||||
message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
|
||||
|
||||
# first, go through and make sure we have all the embeddings we need
|
||||
message_pool_filtered = []
|
||||
for d in message_pool:
|
||||
message_str = d['message']['content']
|
||||
message_str = d["message"]["content"]
|
||||
if self.only_use_preloaded_embeddings:
|
||||
if message_str not in self.embeddings:
|
||||
printd(f"recall_memory.text_search -- '{message_str}' was not in embedding dict, skipping.")
|
||||
@@ -477,24 +489,26 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory):
|
||||
self.embeddings[message_str] = await async_get_embedding_with_backoff(message_str, model=self.embedding_model)
|
||||
message_pool_filtered.append(d)
|
||||
|
||||
# our wrapped version supports backoff/rate-limits
|
||||
# our wrapped version supports backoff/rate-limits
|
||||
query_embedding = await async_get_embedding_with_backoff(query_string, model=self.embedding_model)
|
||||
similarity_scores = [cosine_similarity(self.embeddings[d['message']['content']], query_embedding) for d in message_pool_filtered]
|
||||
similarity_scores = [cosine_similarity(self.embeddings[d["message"]["content"]], query_embedding) for d in message_pool_filtered]
|
||||
|
||||
# Sort the archive based on similarity scores
|
||||
sorted_archive_with_scores = sorted(
|
||||
zip(message_pool_filtered, similarity_scores),
|
||||
key=lambda pair: pair[1], # Sort by the similarity score
|
||||
reverse=True # We want the highest similarity first
|
||||
reverse=True, # We want the highest similarity first
|
||||
)
|
||||
printd(
|
||||
f"recall_memory.text_search (vector-based): search for query '{query_string}' returned the following results (limit 5) and scores:\n{str([str(t[0]['message']['content']) + '- score ' + str(t[1]) for t in sorted_archive_with_scores[:5]])}"
|
||||
)
|
||||
printd(f"recall_memory.text_search (vector-based): search for query '{query_string}' returned the following results (limit 5) and scores:\n{str([str(t[0]['message']['content']) + '- score ' + str(t[1]) for t in sorted_archive_with_scores[:5]])}")
|
||||
|
||||
# Extract the sorted archive without the scores
|
||||
matches = [item[0] for item in sorted_archive_with_scores]
|
||||
|
||||
# start/count support paging through results
|
||||
if start is not None and count is not None:
|
||||
return matches[start:start+count], len(matches)
|
||||
return matches[start : start + count], len(matches)
|
||||
elif start is None and count is not None:
|
||||
return matches[:count], len(matches)
|
||||
elif start is not None and count is None:
|
||||
|
||||
@@ -41,9 +41,7 @@ def retry_with_exponential_backoff(
|
||||
|
||||
# Check if max retries has been reached
|
||||
if num_retries > max_retries:
|
||||
raise Exception(
|
||||
f"Maximum number of retries ({max_retries}) exceeded."
|
||||
)
|
||||
raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")
|
||||
|
||||
# Increment the delay
|
||||
delay *= exponential_base * (1 + jitter * random.random())
|
||||
@@ -91,9 +89,7 @@ def aretry_with_exponential_backoff(
|
||||
|
||||
# Check if max retries has been reached
|
||||
if num_retries > max_retries:
|
||||
raise Exception(
|
||||
f"Maximum number of retries ({max_retries}) exceeded."
|
||||
)
|
||||
raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")
|
||||
|
||||
# Increment the delay
|
||||
delay *= exponential_base * (1 + jitter * random.random())
|
||||
@@ -184,9 +180,7 @@ def configure_azure_support():
|
||||
azure_openai_endpoint,
|
||||
azure_openai_version,
|
||||
]:
|
||||
print(
|
||||
f"Error: missing Azure OpenAI environment variables. Please see README section on Azure."
|
||||
)
|
||||
print(f"Error: missing Azure OpenAI environment variables. Please see README section on Azure.")
|
||||
return
|
||||
|
||||
openai.api_type = "azure"
|
||||
@@ -199,10 +193,7 @@ def configure_azure_support():
|
||||
def check_azure_embeddings():
|
||||
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
||||
azure_openai_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
|
||||
if (
|
||||
azure_openai_deployment is not None
|
||||
and azure_openai_embedding_deployment is None
|
||||
):
|
||||
if azure_openai_deployment is not None and azure_openai_embedding_deployment is None:
|
||||
raise ValueError(
|
||||
f"Error: It looks like you are using Azure deployment ids and computing embeddings, make sure you are setting one for embeddings as well. Please see README section on Azure"
|
||||
)
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import pickle
|
||||
|
||||
from .memory import DummyRecallMemory, DummyRecallMemoryWithEmbeddings, DummyArchivalMemory, DummyArchivalMemoryWithEmbeddings, DummyArchivalMemoryWithFaiss
|
||||
from .memory import (
|
||||
DummyRecallMemory,
|
||||
DummyRecallMemoryWithEmbeddings,
|
||||
DummyArchivalMemory,
|
||||
DummyArchivalMemoryWithEmbeddings,
|
||||
DummyArchivalMemoryWithFaiss,
|
||||
)
|
||||
from .utils import get_local_time, printd
|
||||
|
||||
|
||||
class PersistenceManager(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def trim_messages(self, num):
|
||||
pass
|
||||
@@ -42,17 +47,17 @@ class InMemoryStateManager(PersistenceManager):
|
||||
|
||||
@staticmethod
|
||||
def load(filename):
|
||||
with open(filename, 'rb') as f:
|
||||
with open(filename, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def save(self, filename):
|
||||
with open(filename, 'wb') as fh:
|
||||
with open(filename, "wb") as fh:
|
||||
pickle.dump(self, fh, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def init(self, agent):
|
||||
printd(f"Initializing InMemoryStateManager with agent object")
|
||||
self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
|
||||
self.messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
|
||||
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
self.memory = agent.memory
|
||||
printd(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
|
||||
printd(f"InMemoryStateManager.messages.len = {len(self.messages)}")
|
||||
@@ -68,7 +73,7 @@ class InMemoryStateManager(PersistenceManager):
|
||||
|
||||
def prepend_to_messages(self, added_messages):
|
||||
# first tag with timestamps
|
||||
added_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in added_messages]
|
||||
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
|
||||
|
||||
printd(f"InMemoryStateManager.prepend_to_message")
|
||||
self.messages = [self.messages[0]] + added_messages + self.messages[1:]
|
||||
@@ -76,7 +81,7 @@ class InMemoryStateManager(PersistenceManager):
|
||||
|
||||
def append_to_messages(self, added_messages):
|
||||
# first tag with timestamps
|
||||
added_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in added_messages]
|
||||
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
|
||||
|
||||
printd(f"InMemoryStateManager.append_to_messages")
|
||||
self.messages = self.messages + added_messages
|
||||
@@ -84,7 +89,7 @@ class InMemoryStateManager(PersistenceManager):
|
||||
|
||||
def swap_system_message(self, new_system_message):
|
||||
# first tag with timestamps
|
||||
new_system_message = {'timestamp': get_local_time(), 'message': new_system_message}
|
||||
new_system_message = {"timestamp": get_local_time(), "message": new_system_message}
|
||||
|
||||
printd(f"InMemoryStateManager.swap_system_message")
|
||||
self.messages[0] = new_system_message
|
||||
@@ -104,8 +109,8 @@ class InMemoryStateManagerWithPreloadedArchivalMemory(InMemoryStateManager):
|
||||
|
||||
def init(self, agent):
|
||||
print(f"Initializing InMemoryStateManager with agent object")
|
||||
self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
|
||||
self.messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
|
||||
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
self.memory = agent.memory
|
||||
print(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
|
||||
print(f"InMemoryStateManager.messages.len = {len(self.messages)}")
|
||||
@@ -133,12 +138,14 @@ class InMemoryStateManagerWithFaiss(InMemoryStateManager):
|
||||
|
||||
def init(self, agent):
|
||||
print(f"Initializing InMemoryStateManager with agent object")
|
||||
self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
|
||||
self.messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
|
||||
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
self.memory = agent.memory
|
||||
print(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
|
||||
print(f"InMemoryStateManager.messages.len = {len(self.messages)}")
|
||||
|
||||
# Persistence manager also handles DB-related state
|
||||
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
|
||||
self.archival_memory = self.archival_memory_cls(index=self.archival_index, archival_memory_database=self.archival_memory_db, k=self.a_k)
|
||||
self.archival_memory = self.archival_memory_cls(
|
||||
index=self.archival_index, archival_memory_database=self.archival_memory_db, k=self.a_k
|
||||
)
|
||||
|
||||
@@ -5,15 +5,14 @@ import numpy as np
|
||||
import argparse
|
||||
import json
|
||||
|
||||
def build_index(embedding_files: str,
|
||||
index_name: str):
|
||||
|
||||
def build_index(embedding_files: str, index_name: str):
|
||||
index = faiss.IndexFlatL2(1536)
|
||||
file_list = sorted(glob(embedding_files))
|
||||
|
||||
for embedding_file in file_list:
|
||||
print(embedding_file)
|
||||
with open(embedding_file, 'rt', encoding='utf-8') as file:
|
||||
with open(embedding_file, "rt", encoding="utf-8") as file:
|
||||
embeddings = []
|
||||
l = 0
|
||||
for line in tqdm(file):
|
||||
@@ -21,7 +20,7 @@ def build_index(embedding_files: str,
|
||||
data = json.loads(line)
|
||||
embeddings.append(data)
|
||||
l += 1
|
||||
data = np.array(embeddings).astype('float32')
|
||||
data = np.array(embeddings).astype("float32")
|
||||
print(data.shape)
|
||||
try:
|
||||
index.add(data)
|
||||
@@ -32,14 +31,11 @@ def build_index(embedding_files: str,
|
||||
faiss.write_index(index, index_name)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--embedding_files', type=str, help='embedding_filepaths glob expression')
|
||||
parser.add_argument('--output_index_file', type=str, help='output filepath')
|
||||
args = parser.parse_args()
|
||||
parser.add_argument("--embedding_files", type=str, help="embedding_filepaths glob expression")
|
||||
parser.add_argument("--output_index_file", type=str, help="output filepath")
|
||||
args = parser.parse_args()
|
||||
|
||||
build_index(
|
||||
embedding_files=args.embedding_files,
|
||||
index_name=args.output_index_file
|
||||
)
|
||||
build_index(embedding_files=args.embedding_files, index_name=args.output_index_file)
|
||||
|
||||
@@ -7,12 +7,14 @@ import argparse
|
||||
|
||||
from tqdm import tqdm
|
||||
import openai
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
openai.api_key = os.getenv('OPENAI_API_KEY')
|
||||
openai.api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
sys.path.append("../../../")
|
||||
from openai_tools import async_get_embedding_with_backoff
|
||||
@@ -24,8 +26,8 @@ from openai_parallel_request_processor import process_api_requests_from_file
|
||||
TPM_LIMIT = 1000000
|
||||
RPM_LIMIT = 3000
|
||||
|
||||
DEFAULT_FILE = 'iclr/data/qa_data/30_total_documents/nq-open-30_total_documents_gold_at_0.jsonl.gz'
|
||||
EMBEDDING_MODEL = 'text-embedding-ada-002'
|
||||
DEFAULT_FILE = "iclr/data/qa_data/30_total_documents/nq-open-30_total_documents_gold_at_0.jsonl.gz"
|
||||
EMBEDDING_MODEL = "text-embedding-ada-002"
|
||||
|
||||
|
||||
async def generate_requests_file(filename):
|
||||
@@ -33,36 +35,33 @@ async def generate_requests_file(filename):
|
||||
base_name = os.path.splitext(filename)[0]
|
||||
requests_filename = f"{base_name}_embedding_requests.jsonl"
|
||||
|
||||
with open(filename, 'r') as f:
|
||||
with open(filename, "r") as f:
|
||||
all_data = [json.loads(line) for line in f]
|
||||
|
||||
with open(requests_filename, 'w') as f:
|
||||
with open(requests_filename, "w") as f:
|
||||
for data in all_data:
|
||||
documents = data
|
||||
for idx, doc in enumerate(documents):
|
||||
title = doc["title"]
|
||||
text = doc["text"]
|
||||
document_string = f"Document [{idx+1}] (Title: {title}) {text}"
|
||||
request = {
|
||||
"model": EMBEDDING_MODEL,
|
||||
"input": document_string
|
||||
}
|
||||
request = {"model": EMBEDDING_MODEL, "input": document_string}
|
||||
json_string = json.dumps(request)
|
||||
f.write(json_string + "\n")
|
||||
|
||||
# Run your parallel processing function
|
||||
input(f"Generated requests file ({requests_filename}), continue with embedding batch requests? (hit enter)")
|
||||
await process_api_requests_from_file(
|
||||
requests_filepath=requests_filename,
|
||||
save_filepath=f"{base_name}.embeddings.jsonl.gz", # Adjust as necessary
|
||||
request_url="https://api.openai.com/v1/embeddings",
|
||||
api_key=os.getenv('OPENAI_API_KEY'),
|
||||
max_requests_per_minute=RPM_LIMIT,
|
||||
max_tokens_per_minute=TPM_LIMIT,
|
||||
token_encoding_name=EMBEDDING_MODEL,
|
||||
max_attempts=5,
|
||||
logging_level=logging.INFO,
|
||||
)
|
||||
requests_filepath=requests_filename,
|
||||
save_filepath=f"{base_name}.embeddings.jsonl.gz", # Adjust as necessary
|
||||
request_url="https://api.openai.com/v1/embeddings",
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
max_requests_per_minute=RPM_LIMIT,
|
||||
max_tokens_per_minute=TPM_LIMIT,
|
||||
token_encoding_name=EMBEDDING_MODEL,
|
||||
max_attempts=5,
|
||||
logging_level=logging.INFO,
|
||||
)
|
||||
|
||||
|
||||
async def generate_embedding_file(filename, parallel_mode=False):
|
||||
@@ -72,7 +71,7 @@ async def generate_embedding_file(filename, parallel_mode=False):
|
||||
|
||||
# Derive the sister filename
|
||||
# base_name = os.path.splitext(filename)[0]
|
||||
base_name = filename.rsplit('.jsonl', 1)[0]
|
||||
base_name = filename.rsplit(".jsonl", 1)[0]
|
||||
sister_filename = f"{base_name}.embeddings.jsonl"
|
||||
|
||||
# Check if the sister file already exists
|
||||
@@ -80,7 +79,7 @@ async def generate_embedding_file(filename, parallel_mode=False):
|
||||
print(f"{sister_filename} already exists. Skipping embedding generation.")
|
||||
return
|
||||
|
||||
with open(filename, 'rt') as f:
|
||||
with open(filename, "rt") as f:
|
||||
all_data = [json.loads(line) for line in f]
|
||||
|
||||
embedding_data = []
|
||||
@@ -90,7 +89,9 @@ async def generate_embedding_file(filename, parallel_mode=False):
|
||||
for i, data in enumerate(tqdm(all_data, desc="Processing data", total=len(all_data))):
|
||||
documents = data
|
||||
# Inner loop progress bar
|
||||
for idx, doc in enumerate(tqdm(documents, desc=f"Embedding documents for data {i+1}/{len(all_data)}", total=len(documents), leave=False)):
|
||||
for idx, doc in enumerate(
|
||||
tqdm(documents, desc=f"Embedding documents for data {i+1}/{len(all_data)}", total=len(documents), leave=False)
|
||||
):
|
||||
title = doc["title"]
|
||||
text = doc["text"]
|
||||
document_string = f"[Title: {title}] {text}"
|
||||
@@ -103,10 +104,10 @@ async def generate_embedding_file(filename, parallel_mode=False):
|
||||
|
||||
# Save the embeddings to the sister file
|
||||
# with gzip.open(sister_filename, 'wt') as f:
|
||||
with open(sister_filename, 'wb') as f:
|
||||
with open(sister_filename, "wb") as f:
|
||||
for embedding in embedding_data:
|
||||
# f.write(json.dumps(embedding) + '\n')
|
||||
f.write((json.dumps(embedding) + '\n').encode('utf-8'))
|
||||
f.write((json.dumps(embedding) + "\n").encode("utf-8"))
|
||||
|
||||
print(f"Embeddings saved to {sister_filename}")
|
||||
|
||||
@@ -118,6 +119,7 @@ async def main():
|
||||
filename = DEFAULT_FILE
|
||||
await generate_embedding_file(filename)
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("filename", nargs="?", default=DEFAULT_FILE, help="Path to the input file")
|
||||
@@ -129,4 +131,4 @@ async def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(main())
|
||||
loop.run_until_complete(main())
|
||||
|
||||
@@ -121,9 +121,7 @@ async def process_api_requests_from_file(
|
||||
"""Processes API requests in parallel, throttling to stay under rate limits."""
|
||||
# constants
|
||||
seconds_to_pause_after_rate_limit_error = 15
|
||||
seconds_to_sleep_each_loop = (
|
||||
0.001 # 1 ms limits max throughput to 1,000 requests per second
|
||||
)
|
||||
seconds_to_sleep_each_loop = 0.001 # 1 ms limits max throughput to 1,000 requests per second
|
||||
|
||||
# initialize logging
|
||||
logging.basicConfig(level=logging_level)
|
||||
@@ -135,12 +133,8 @@ async def process_api_requests_from_file(
|
||||
|
||||
# initialize trackers
|
||||
queue_of_requests_to_retry = asyncio.Queue()
|
||||
task_id_generator = (
|
||||
task_id_generator_function()
|
||||
) # generates integer IDs of 1, 2, 3, ...
|
||||
status_tracker = (
|
||||
StatusTracker()
|
||||
) # single instance to track a collection of variables
|
||||
task_id_generator = task_id_generator_function() # generates integer IDs of 1, 2, 3, ...
|
||||
status_tracker = StatusTracker() # single instance to track a collection of variables
|
||||
next_request = None # variable to hold the next request to call
|
||||
|
||||
# initialize available capacity counts
|
||||
@@ -163,9 +157,7 @@ async def process_api_requests_from_file(
|
||||
if next_request is None:
|
||||
if not queue_of_requests_to_retry.empty():
|
||||
next_request = queue_of_requests_to_retry.get_nowait()
|
||||
logging.debug(
|
||||
f"Retrying request {next_request.task_id}: {next_request}"
|
||||
)
|
||||
logging.debug(f"Retrying request {next_request.task_id}: {next_request}")
|
||||
elif file_not_finished:
|
||||
try:
|
||||
# get new request
|
||||
@@ -173,17 +165,13 @@ async def process_api_requests_from_file(
|
||||
next_request = APIRequest(
|
||||
task_id=next(task_id_generator),
|
||||
request_json=request_json,
|
||||
token_consumption=num_tokens_consumed_from_request(
|
||||
request_json, api_endpoint, token_encoding_name
|
||||
),
|
||||
token_consumption=num_tokens_consumed_from_request(request_json, api_endpoint, token_encoding_name),
|
||||
attempts_left=max_attempts,
|
||||
metadata=request_json.pop("metadata", None),
|
||||
)
|
||||
status_tracker.num_tasks_started += 1
|
||||
status_tracker.num_tasks_in_progress += 1
|
||||
logging.debug(
|
||||
f"Reading request {next_request.task_id}: {next_request}"
|
||||
)
|
||||
logging.debug(f"Reading request {next_request.task_id}: {next_request}")
|
||||
except StopIteration:
|
||||
# if file runs out, set flag to stop reading it
|
||||
logging.debug("Read file exhausted")
|
||||
@@ -193,13 +181,11 @@ async def process_api_requests_from_file(
|
||||
current_time = time.time()
|
||||
seconds_since_update = current_time - last_update_time
|
||||
available_request_capacity = min(
|
||||
available_request_capacity
|
||||
+ max_requests_per_minute * seconds_since_update / 60.0,
|
||||
available_request_capacity + max_requests_per_minute * seconds_since_update / 60.0,
|
||||
max_requests_per_minute,
|
||||
)
|
||||
available_token_capacity = min(
|
||||
available_token_capacity
|
||||
+ max_tokens_per_minute * seconds_since_update / 60.0,
|
||||
available_token_capacity + max_tokens_per_minute * seconds_since_update / 60.0,
|
||||
max_tokens_per_minute,
|
||||
)
|
||||
last_update_time = current_time
|
||||
@@ -207,10 +193,7 @@ async def process_api_requests_from_file(
|
||||
# if enough capacity available, call API
|
||||
if next_request:
|
||||
next_request_tokens = next_request.token_consumption
|
||||
if (
|
||||
available_request_capacity >= 1
|
||||
and available_token_capacity >= next_request_tokens
|
||||
):
|
||||
if available_request_capacity >= 1 and available_token_capacity >= next_request_tokens:
|
||||
# update counters
|
||||
available_request_capacity -= 1
|
||||
available_token_capacity -= next_request_tokens
|
||||
@@ -237,17 +220,9 @@ async def process_api_requests_from_file(
|
||||
await asyncio.sleep(seconds_to_sleep_each_loop)
|
||||
|
||||
# if a rate limit error was hit recently, pause to cool down
|
||||
seconds_since_rate_limit_error = (
|
||||
time.time() - status_tracker.time_of_last_rate_limit_error
|
||||
)
|
||||
if (
|
||||
seconds_since_rate_limit_error
|
||||
< seconds_to_pause_after_rate_limit_error
|
||||
):
|
||||
remaining_seconds_to_pause = (
|
||||
seconds_to_pause_after_rate_limit_error
|
||||
- seconds_since_rate_limit_error
|
||||
)
|
||||
seconds_since_rate_limit_error = time.time() - status_tracker.time_of_last_rate_limit_error
|
||||
if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error:
|
||||
remaining_seconds_to_pause = seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error
|
||||
await asyncio.sleep(remaining_seconds_to_pause)
|
||||
# ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago
|
||||
logging.warn(
|
||||
@@ -255,17 +230,13 @@ async def process_api_requests_from_file(
|
||||
)
|
||||
|
||||
# after finishing, log final status
|
||||
logging.info(
|
||||
f"""Parallel processing complete. Results saved to {save_filepath}"""
|
||||
)
|
||||
logging.info(f"""Parallel processing complete. Results saved to {save_filepath}""")
|
||||
if status_tracker.num_tasks_failed > 0:
|
||||
logging.warning(
|
||||
f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed. Errors logged to {save_filepath}."
|
||||
)
|
||||
if status_tracker.num_rate_limit_errors > 0:
|
||||
logging.warning(
|
||||
f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
|
||||
)
|
||||
logging.warning(f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate.")
|
||||
|
||||
|
||||
# dataclasses
|
||||
@@ -309,26 +280,18 @@ class APIRequest:
|
||||
logging.info(f"Starting request #{self.task_id}")
|
||||
error = None
|
||||
try:
|
||||
async with session.post(
|
||||
url=request_url, headers=request_header, json=self.request_json
|
||||
) as response:
|
||||
async with session.post(url=request_url, headers=request_header, json=self.request_json) as response:
|
||||
response = await response.json()
|
||||
if "error" in response:
|
||||
logging.warning(
|
||||
f"Request {self.task_id} failed with error {response['error']}"
|
||||
)
|
||||
logging.warning(f"Request {self.task_id} failed with error {response['error']}")
|
||||
status_tracker.num_api_errors += 1
|
||||
error = response
|
||||
if "Rate limit" in response["error"].get("message", ""):
|
||||
status_tracker.time_of_last_rate_limit_error = time.time()
|
||||
status_tracker.num_rate_limit_errors += 1
|
||||
status_tracker.num_api_errors -= (
|
||||
1 # rate limit errors are counted separately
|
||||
)
|
||||
status_tracker.num_api_errors -= 1 # rate limit errors are counted separately
|
||||
|
||||
except (
|
||||
Exception
|
||||
) as e: # catching naked exceptions is bad practice, but in this case we'll log & save them
|
||||
except Exception as e: # catching naked exceptions is bad practice, but in this case we'll log & save them
|
||||
logging.warning(f"Request {self.task_id} failed with Exception {e}")
|
||||
status_tracker.num_other_errors += 1
|
||||
error = e
|
||||
@@ -337,9 +300,7 @@ class APIRequest:
|
||||
if self.attempts_left:
|
||||
retry_queue.put_nowait(self)
|
||||
else:
|
||||
logging.error(
|
||||
f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}"
|
||||
)
|
||||
logging.error(f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}")
|
||||
data = (
|
||||
[self.request_json, [str(e) for e in self.result], self.metadata]
|
||||
if self.metadata
|
||||
@@ -349,11 +310,7 @@ class APIRequest:
|
||||
status_tracker.num_tasks_in_progress -= 1
|
||||
status_tracker.num_tasks_failed += 1
|
||||
else:
|
||||
data = (
|
||||
[self.request_json, response, self.metadata]
|
||||
if self.metadata
|
||||
else [self.request_json, response]
|
||||
)
|
||||
data = [self.request_json, response, self.metadata] if self.metadata else [self.request_json, response]
|
||||
append_to_jsonl(data, save_filepath)
|
||||
status_tracker.num_tasks_in_progress -= 1
|
||||
status_tracker.num_tasks_succeeded += 1
|
||||
@@ -382,8 +339,8 @@ def num_tokens_consumed_from_request(
|
||||
token_encoding_name: str,
|
||||
):
|
||||
"""Count the number of tokens in the request. Only supports completion and embedding requests."""
|
||||
if token_encoding_name == 'text-embedding-ada-002':
|
||||
encoding = tiktoken.get_encoding('cl100k_base')
|
||||
if token_encoding_name == "text-embedding-ada-002":
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
else:
|
||||
encoding = tiktoken.get_encoding(token_encoding_name)
|
||||
# if completions request, tokens = prompt + n * max_tokens
|
||||
@@ -415,9 +372,7 @@ def num_tokens_consumed_from_request(
|
||||
num_tokens = prompt_tokens + completion_tokens * len(prompt)
|
||||
return num_tokens
|
||||
else:
|
||||
raise TypeError(
|
||||
'Expecting either string or list of strings for "prompt" field in completion request'
|
||||
)
|
||||
raise TypeError('Expecting either string or list of strings for "prompt" field in completion request')
|
||||
# if embeddings request, tokens = input tokens
|
||||
elif api_endpoint == "embeddings":
|
||||
input = request_json["input"]
|
||||
@@ -428,14 +383,10 @@ def num_tokens_consumed_from_request(
|
||||
num_tokens = sum([len(encoding.encode(i)) for i in input])
|
||||
return num_tokens
|
||||
else:
|
||||
raise TypeError(
|
||||
'Expecting either string or list of strings for "inputs" field in embedding request'
|
||||
)
|
||||
raise TypeError('Expecting either string or list of strings for "inputs" field in embedding request')
|
||||
# more logic needed to support other API calls (e.g., edits, inserts, DALL-E)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'API endpoint "{api_endpoint}" not implemented in this script'
|
||||
)
|
||||
raise NotImplementedError(f'API endpoint "{api_endpoint}" not implemented in this script')
|
||||
|
||||
|
||||
def task_id_generator_function():
|
||||
@@ -502,4 +453,4 @@ with open(filename, "w") as f:
|
||||
```
|
||||
|
||||
As with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically).
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -4,69 +4,65 @@ import tiktoken
|
||||
import json
|
||||
|
||||
# Define the directory where the documentation resides
|
||||
docs_dir = 'text'
|
||||
docs_dir = "text"
|
||||
|
||||
encoding = tiktoken.encoding_for_model("gpt-4")
|
||||
PASSAGE_TOKEN_LEN = 800
|
||||
|
||||
|
||||
def extract_text_from_sphinx_txt(file_path):
|
||||
lines = []
|
||||
title = ""
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
for line in file:
|
||||
if not title:
|
||||
title = line.strip()
|
||||
continue
|
||||
if line and re.match(r'^.*\S.*$', line) and not re.match(r'^[-=*]+$', line):
|
||||
if line and re.match(r"^.*\S.*$", line) and not re.match(r"^[-=*]+$", line):
|
||||
lines.append(line)
|
||||
passages = []
|
||||
passages = []
|
||||
curr_passage = []
|
||||
curr_token_ct = 0
|
||||
for line in lines:
|
||||
try:
|
||||
line_token_ct = len(encoding.encode(line, allowed_special={'<|endoftext|>'}))
|
||||
line_token_ct = len(encoding.encode(line, allowed_special={"<|endoftext|>"}))
|
||||
except Exception as e:
|
||||
print("line", line)
|
||||
raise e
|
||||
if line_token_ct > PASSAGE_TOKEN_LEN:
|
||||
passages.append({
|
||||
'title': title,
|
||||
'text': line[:3200],
|
||||
'num_tokens': curr_token_ct,
|
||||
})
|
||||
passages.append(
|
||||
{
|
||||
"title": title,
|
||||
"text": line[:3200],
|
||||
"num_tokens": curr_token_ct,
|
||||
}
|
||||
)
|
||||
continue
|
||||
curr_token_ct += line_token_ct
|
||||
curr_passage.append(line)
|
||||
if curr_token_ct > PASSAGE_TOKEN_LEN:
|
||||
passages.append({
|
||||
'title': title,
|
||||
'text': ''.join(curr_passage),
|
||||
'num_tokens': curr_token_ct
|
||||
})
|
||||
passages.append({"title": title, "text": "".join(curr_passage), "num_tokens": curr_token_ct})
|
||||
curr_passage = []
|
||||
curr_token_ct = 0
|
||||
|
||||
if len(curr_passage) > 0:
|
||||
passages.append({
|
||||
'title': title,
|
||||
'text': ''.join(curr_passage),
|
||||
'num_tokens': curr_token_ct
|
||||
})
|
||||
passages.append({"title": title, "text": "".join(curr_passage), "num_tokens": curr_token_ct})
|
||||
return passages
|
||||
|
||||
|
||||
# Iterate over all files in the directory and its subdirectories
|
||||
passages = []
|
||||
total_files = 0
|
||||
for subdir, _, files in os.walk(docs_dir):
|
||||
for file in files:
|
||||
if file.endswith('.txt'):
|
||||
if file.endswith(".txt"):
|
||||
file_path = os.path.join(subdir, file)
|
||||
passages.append(extract_text_from_sphinx_txt(file_path))
|
||||
total_files += 1
|
||||
print("total .txt files:", total_files)
|
||||
|
||||
# Save to a new text file or process as needed
|
||||
with open('all_docs.jsonl', 'w', encoding='utf-8') as file:
|
||||
with open("all_docs.jsonl", "w", encoding="utf-8") as file:
|
||||
for p in passages:
|
||||
file.write(json.dumps(p))
|
||||
file.write('\n')
|
||||
file.write("\n")
|
||||
|
||||
@@ -1,30 +1,33 @@
|
||||
|
||||
from .prompts import gpt_functions
|
||||
from .prompts import gpt_system
|
||||
from .agent import AgentAsync
|
||||
from .utils import printd
|
||||
|
||||
|
||||
DEFAULT = 'memgpt_chat'
|
||||
DEFAULT = "memgpt_chat"
|
||||
|
||||
|
||||
def use_preset(preset_name, model, persona, human, interface, persistence_manager):
|
||||
"""Storing combinations of SYSTEM + FUNCTION prompts"""
|
||||
|
||||
if preset_name == 'memgpt_chat':
|
||||
|
||||
if preset_name == "memgpt_chat":
|
||||
functions = [
|
||||
'send_message', 'pause_heartbeats',
|
||||
'core_memory_append', 'core_memory_replace',
|
||||
'conversation_search', 'conversation_search_date',
|
||||
'archival_memory_insert', 'archival_memory_search',
|
||||
"send_message",
|
||||
"pause_heartbeats",
|
||||
"core_memory_append",
|
||||
"core_memory_replace",
|
||||
"conversation_search",
|
||||
"conversation_search_date",
|
||||
"archival_memory_insert",
|
||||
"archival_memory_search",
|
||||
]
|
||||
available_functions = [v for k,v in gpt_functions.FUNCTIONS_CHAINING.items() if k in functions]
|
||||
printd(f"Available functions:\n", [x['name'] for x in available_functions])
|
||||
available_functions = [v for k, v in gpt_functions.FUNCTIONS_CHAINING.items() if k in functions]
|
||||
printd(f"Available functions:\n", [x["name"] for x in available_functions])
|
||||
assert len(functions) == len(available_functions)
|
||||
|
||||
if 'gpt-3.5' in model:
|
||||
if "gpt-3.5" in model:
|
||||
# use a different system message for gpt-3.5
|
||||
preset_name = 'memgpt_gpt35_extralong'
|
||||
preset_name = "memgpt_gpt35_extralong"
|
||||
|
||||
return AgentAsync(
|
||||
model=model,
|
||||
@@ -35,8 +38,8 @@ def use_preset(preset_name, model, persona, human, interface, persistence_manage
|
||||
persona_notes=persona,
|
||||
human_notes=human,
|
||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||
first_message_verify_mono=True if 'gpt-4' in model else False,
|
||||
first_message_verify_mono=True if "gpt-4" in model else False,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(preset_name)
|
||||
raise ValueError(preset_name)
|
||||
|
||||
@@ -2,9 +2,7 @@ from ..constants import FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT
|
||||
|
||||
# FUNCTIONS_PROMPT_MULTISTEP_NO_HEARTBEATS = FUNCTIONS_PROMPT_MULTISTEP[:-1]
|
||||
FUNCTIONS_CHAINING = {
|
||||
|
||||
'send_message':
|
||||
{
|
||||
"send_message": {
|
||||
"name": "send_message",
|
||||
"description": "Sends a message to the human user",
|
||||
"parameters": {
|
||||
@@ -17,11 +15,9 @@ FUNCTIONS_CHAINING = {
|
||||
},
|
||||
},
|
||||
"required": ["message"],
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
'pause_heartbeats':
|
||||
{
|
||||
"pause_heartbeats": {
|
||||
"name": "pause_heartbeats",
|
||||
"description": "Temporarily ignore timed heartbeats. You may still receive messages from manual heartbeats and other events.",
|
||||
"parameters": {
|
||||
@@ -34,11 +30,9 @@ FUNCTIONS_CHAINING = {
|
||||
},
|
||||
},
|
||||
"required": ["minutes"],
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
'message_chatgpt':
|
||||
{
|
||||
"message_chatgpt": {
|
||||
"name": "message_chatgpt",
|
||||
"description": "Send a message to a more basic AI, ChatGPT. A useful resource for asking questions. ChatGPT does not retain memory of previous interactions.",
|
||||
"parameters": {
|
||||
@@ -55,11 +49,9 @@ FUNCTIONS_CHAINING = {
|
||||
},
|
||||
},
|
||||
"required": ["message", "request_heartbeat"],
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
'core_memory_append':
|
||||
{
|
||||
"core_memory_append": {
|
||||
"name": "core_memory_append",
|
||||
"description": "Append to the contents of core memory.",
|
||||
"parameters": {
|
||||
@@ -79,11 +71,9 @@ FUNCTIONS_CHAINING = {
|
||||
},
|
||||
},
|
||||
"required": ["name", "content", "request_heartbeat"],
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
'core_memory_replace':
|
||||
{
|
||||
"core_memory_replace": {
|
||||
"name": "core_memory_replace",
|
||||
"description": "Replace to the contents of core memory. To delete memories, use an empty string for new_content.",
|
||||
"parameters": {
|
||||
@@ -107,11 +97,9 @@ FUNCTIONS_CHAINING = {
|
||||
},
|
||||
},
|
||||
"required": ["name", "old_content", "new_content", "request_heartbeat"],
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
'recall_memory_search':
|
||||
{
|
||||
"recall_memory_search": {
|
||||
"name": "recall_memory_search",
|
||||
"description": "Search prior conversation history using a string.",
|
||||
"parameters": {
|
||||
@@ -131,11 +119,9 @@ FUNCTIONS_CHAINING = {
|
||||
},
|
||||
},
|
||||
"required": ["name", "page", "request_heartbeat"],
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
'conversation_search':
|
||||
{
|
||||
"conversation_search": {
|
||||
"name": "conversation_search",
|
||||
"description": "Search prior conversation history using case-insensitive string matching.",
|
||||
"parameters": {
|
||||
@@ -155,11 +141,9 @@ FUNCTIONS_CHAINING = {
|
||||
},
|
||||
},
|
||||
"required": ["name", "page", "request_heartbeat"],
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
'recall_memory_search_date':
|
||||
{
|
||||
"recall_memory_search_date": {
|
||||
"name": "recall_memory_search_date",
|
||||
"description": "Search prior conversation history using a date range.",
|
||||
"parameters": {
|
||||
@@ -183,11 +167,9 @@ FUNCTIONS_CHAINING = {
|
||||
},
|
||||
},
|
||||
"required": ["name", "page", "request_heartbeat"],
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
'conversation_search_date':
|
||||
{
|
||||
"conversation_search_date": {
|
||||
"name": "conversation_search_date",
|
||||
"description": "Search prior conversation history using a date range.",
|
||||
"parameters": {
|
||||
@@ -211,11 +193,9 @@ FUNCTIONS_CHAINING = {
|
||||
},
|
||||
},
|
||||
"required": ["name", "page", "request_heartbeat"],
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
'archival_memory_insert':
|
||||
{
|
||||
"archival_memory_insert": {
|
||||
"name": "archival_memory_insert",
|
||||
"description": "Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.",
|
||||
"parameters": {
|
||||
@@ -231,11 +211,9 @@ FUNCTIONS_CHAINING = {
|
||||
},
|
||||
},
|
||||
"required": ["name", "content", "request_heartbeat"],
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
'archival_memory_search':
|
||||
{
|
||||
"archival_memory_search": {
|
||||
"name": "archival_memory_search",
|
||||
"description": "Search archival memory using semantic (embedding-based) search.",
|
||||
"parameters": {
|
||||
@@ -255,7 +233,6 @@ FUNCTIONS_CHAINING = {
|
||||
},
|
||||
},
|
||||
"required": ["name", "query", "page", "request_heartbeat"],
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
WORD_LIMIT = 100
|
||||
SYSTEM = \
|
||||
f"""
|
||||
SYSTEM = f"""
|
||||
Your job is to summarize a history of previous messages in a conversation between an AI persona and a human.
|
||||
The conversation you are given is a from a fixed context window and may not be complete.
|
||||
Messages sent by the AI are marked with the 'assistant' role.
|
||||
@@ -12,4 +11,4 @@ The 'user' role is also used for important system events, such as login events a
|
||||
Summarize what happened in the conversation from the perspective of the AI (use the first person).
|
||||
Keep your summary less than {WORD_LIMIT} words, do NOT exceed this word limit.
|
||||
Only output the summary, do NOT include anything else in your output.
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -2,11 +2,11 @@ import os
|
||||
|
||||
|
||||
def get_system_text(key):
|
||||
filename = f'{key}.txt'
|
||||
file_path = os.path.join(os.path.dirname(__file__), 'system', filename)
|
||||
filename = f"{key}.txt"
|
||||
file_path = os.path.join(os.path.dirname(__file__), "system", filename)
|
||||
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path, 'r') as file:
|
||||
with open(file_path, "r") as file:
|
||||
return file.read().strip()
|
||||
else:
|
||||
raise FileNotFoundError(f"No file found for key {key}, path={file_path}")
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
import json
|
||||
|
||||
from .utils import get_local_time
|
||||
from .constants import INITIAL_BOOT_MESSAGE, INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT, INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG, MESSAGE_SUMMARY_WARNING_STR
|
||||
from .constants import (
|
||||
INITIAL_BOOT_MESSAGE,
|
||||
INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT,
|
||||
INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG,
|
||||
MESSAGE_SUMMARY_WARNING_STR,
|
||||
)
|
||||
|
||||
|
||||
def get_initial_boot_messages(version='startup'):
|
||||
|
||||
if version == 'startup':
|
||||
def get_initial_boot_messages(version="startup"):
|
||||
if version == "startup":
|
||||
initial_boot_message = INITIAL_BOOT_MESSAGE
|
||||
messages = [
|
||||
{"role": "assistant", "content": initial_boot_message},
|
||||
]
|
||||
|
||||
elif version == 'startup_with_send_message':
|
||||
elif version == "startup_with_send_message":
|
||||
messages = [
|
||||
# first message includes both inner monologue and function call to send_message
|
||||
{
|
||||
@@ -20,34 +24,23 @@ def get_initial_boot_messages(version='startup'):
|
||||
"content": INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT,
|
||||
"function_call": {
|
||||
"name": "send_message",
|
||||
"arguments": "{\n \"message\": \"" + f"{INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG}" + "\"\n}"
|
||||
}
|
||||
"arguments": '{\n "message": "' + f"{INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG}" + '"\n}',
|
||||
},
|
||||
},
|
||||
# obligatory function return message
|
||||
{
|
||||
"role": "function",
|
||||
"name": "send_message",
|
||||
"content": package_function_response(True, None)
|
||||
}
|
||||
{"role": "function", "name": "send_message", "content": package_function_response(True, None)},
|
||||
]
|
||||
|
||||
elif version == 'startup_with_send_message_gpt35':
|
||||
elif version == "startup_with_send_message_gpt35":
|
||||
messages = [
|
||||
# first message includes both inner monologue and function call to send_message
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "*inner thoughts* Still waiting on the user. Sending a message with function.",
|
||||
"function_call": {
|
||||
"name": "send_message",
|
||||
"arguments": "{\n \"message\": \"" + f"Hi, is anyone there?" + "\"\n}"
|
||||
}
|
||||
"function_call": {"name": "send_message", "arguments": '{\n "message": "' + f"Hi, is anyone there?" + '"\n}'},
|
||||
},
|
||||
# obligatory function return message
|
||||
{
|
||||
"role": "function",
|
||||
"name": "send_message",
|
||||
"content": package_function_response(True, None)
|
||||
}
|
||||
{"role": "function", "name": "send_message", "content": package_function_response(True, None)},
|
||||
]
|
||||
|
||||
else:
|
||||
@@ -56,12 +49,11 @@ def get_initial_boot_messages(version='startup'):
|
||||
return messages
|
||||
|
||||
|
||||
def get_heartbeat(reason='Automated timer', include_location=False, location_name='San Francisco, CA, USA'):
|
||||
|
||||
def get_heartbeat(reason="Automated timer", include_location=False, location_name="San Francisco, CA, USA"):
|
||||
# Package the message with time and location
|
||||
formatted_time = get_local_time()
|
||||
packaged_message = {
|
||||
"type": 'heartbeat',
|
||||
"type": "heartbeat",
|
||||
"reason": reason,
|
||||
"time": formatted_time,
|
||||
}
|
||||
@@ -72,12 +64,11 @@ def get_heartbeat(reason='Automated timer', include_location=False, location_nam
|
||||
return json.dumps(packaged_message)
|
||||
|
||||
|
||||
def get_login_event(last_login='Never (first login)', include_location=False, location_name='San Francisco, CA, USA'):
|
||||
|
||||
def get_login_event(last_login="Never (first login)", include_location=False, location_name="San Francisco, CA, USA"):
|
||||
# Package the message with time and location
|
||||
formatted_time = get_local_time()
|
||||
packaged_message = {
|
||||
"type": 'login',
|
||||
"type": "login",
|
||||
"last_login": last_login,
|
||||
"time": formatted_time,
|
||||
}
|
||||
@@ -88,12 +79,11 @@ def get_login_event(last_login='Never (first login)', include_location=False, lo
|
||||
return json.dumps(packaged_message)
|
||||
|
||||
|
||||
def package_user_message(user_message, time=None, include_location=False, location_name='San Francisco, CA, USA'):
|
||||
|
||||
def package_user_message(user_message, time=None, include_location=False, location_name="San Francisco, CA, USA"):
|
||||
# Package the message with time and location
|
||||
formatted_time = time if time else get_local_time()
|
||||
packaged_message = {
|
||||
"type": 'user_message',
|
||||
"type": "user_message",
|
||||
"message": user_message,
|
||||
"time": formatted_time,
|
||||
}
|
||||
@@ -103,11 +93,11 @@ def package_user_message(user_message, time=None, include_location=False, locati
|
||||
|
||||
return json.dumps(packaged_message)
|
||||
|
||||
def package_function_response(was_success, response_string, timestamp=None):
|
||||
|
||||
def package_function_response(was_success, response_string, timestamp=None):
|
||||
formatted_time = get_local_time() if timestamp is None else timestamp
|
||||
packaged_message = {
|
||||
"status": 'OK' if was_success else 'Failed',
|
||||
"status": "OK" if was_success else "Failed",
|
||||
"message": response_string,
|
||||
"time": formatted_time,
|
||||
}
|
||||
@@ -116,14 +106,14 @@ def package_function_response(was_success, response_string, timestamp=None):
|
||||
|
||||
|
||||
def package_summarize_message(summary, summary_length, hidden_message_count, total_message_count, timestamp=None):
|
||||
|
||||
context_message = \
|
||||
f"Note: prior messages ({hidden_message_count} of {total_message_count} total messages) have been hidden from view due to conversation memory constraints.\n" \
|
||||
context_message = (
|
||||
f"Note: prior messages ({hidden_message_count} of {total_message_count} total messages) have been hidden from view due to conversation memory constraints.\n"
|
||||
+ f"The following is a summary of the previous {summary_length} messages:\n {summary}"
|
||||
)
|
||||
|
||||
formatted_time = get_local_time() if timestamp is None else timestamp
|
||||
packaged_message = {
|
||||
"type": 'system_alert',
|
||||
"type": "system_alert",
|
||||
"message": context_message,
|
||||
"time": formatted_time,
|
||||
}
|
||||
@@ -136,10 +126,13 @@ def package_summarize_message_no_summary(hidden_message_count, timestamp=None, m
|
||||
|
||||
# Package the message with time and location
|
||||
formatted_time = get_local_time() if timestamp is None else timestamp
|
||||
context_message = message if message else \
|
||||
f"Note: {hidden_message_count} prior messages with the user have been hidden from view due to conversation memory constraints. Older messages are stored in Recall Memory and can be viewed using functions."
|
||||
context_message = (
|
||||
message
|
||||
if message
|
||||
else f"Note: {hidden_message_count} prior messages with the user have been hidden from view due to conversation memory constraints. Older messages are stored in Recall Memory and can be viewed using functions."
|
||||
)
|
||||
packaged_message = {
|
||||
"type": 'system_alert',
|
||||
"type": "system_alert",
|
||||
"message": context_message,
|
||||
"time": formatted_time,
|
||||
}
|
||||
@@ -148,12 +141,11 @@ def package_summarize_message_no_summary(hidden_message_count, timestamp=None, m
|
||||
|
||||
|
||||
def get_token_limit_warning():
|
||||
|
||||
formatted_time = get_local_time()
|
||||
packaged_message = {
|
||||
"type": 'system_alert',
|
||||
"type": "system_alert",
|
||||
"message": MESSAGE_SUMMARY_WARNING_STR,
|
||||
"time": formatted_time,
|
||||
}
|
||||
|
||||
return json.dumps(packaged_message)
|
||||
return json.dumps(packaged_message)
|
||||
|
||||
@@ -168,9 +168,7 @@ def chunk_file(file, tkns_per_chunk=300, model="gpt-4"):
|
||||
line_token_ct = len(encoding.encode(line))
|
||||
except Exception as e:
|
||||
line_token_ct = len(line.split(" ")) / 0.75
|
||||
print(
|
||||
f"Could not encode line {i}, estimating it to be {line_token_ct} tokens"
|
||||
)
|
||||
print(f"Could not encode line {i}, estimating it to be {line_token_ct} tokens")
|
||||
print(e)
|
||||
if line_token_ct > tkns_per_chunk:
|
||||
if len(curr_chunk) > 0:
|
||||
@@ -194,9 +192,7 @@ def chunk_files(files, tkns_per_chunk=300, model="gpt-4"):
|
||||
archival_database = []
|
||||
for file in files:
|
||||
timestamp = os.path.getmtime(file)
|
||||
formatted_time = datetime.fromtimestamp(timestamp).strftime(
|
||||
"%Y-%m-%d %I:%M:%S %p %Z%z"
|
||||
)
|
||||
formatted_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
file_stem = file.split("/")[-1]
|
||||
chunks = [c for c in chunk_file(file, tkns_per_chunk, model)]
|
||||
for i, chunk in enumerate(chunks):
|
||||
@@ -243,9 +239,7 @@ async def process_concurrently(archival_database, model, concurrency=10):
|
||||
|
||||
# Create a list of tasks for chunks
|
||||
embedding_data = [0 for _ in archival_database]
|
||||
tasks = [
|
||||
bounded_process_chunk(i, chunk) for i, chunk in enumerate(archival_database)
|
||||
]
|
||||
tasks = [bounded_process_chunk(i, chunk) for i, chunk in enumerate(archival_database)]
|
||||
|
||||
for future in tqdm(
|
||||
asyncio.as_completed(tasks),
|
||||
@@ -267,15 +261,12 @@ async def prepare_archival_index_from_files_compute_embeddings(
|
||||
files = sorted(glob.glob(glob_pattern))
|
||||
save_dir = os.path.join(
|
||||
MEMGPT_DIR,
|
||||
"archival_index_from_files_"
|
||||
+ get_local_time().replace(" ", "_").replace(":", "_"),
|
||||
"archival_index_from_files_" + get_local_time().replace(" ", "_").replace(":", "_"),
|
||||
)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
total_tokens = total_bytes(glob_pattern) / 3
|
||||
price_estimate = total_tokens / 1000 * 0.0001
|
||||
confirm = input(
|
||||
f"Computing embeddings over {len(files)} files. This will cost ~${price_estimate:.2f}. Continue? [y/n] "
|
||||
)
|
||||
confirm = input(f"Computing embeddings over {len(files)} files. This will cost ~${price_estimate:.2f}. Continue? [y/n] ")
|
||||
if confirm != "y":
|
||||
raise Exception("embeddings were not computed")
|
||||
|
||||
@@ -291,9 +282,7 @@ async def prepare_archival_index_from_files_compute_embeddings(
|
||||
archival_storage_file = os.path.join(save_dir, "all_docs.jsonl")
|
||||
chunks_by_file = chunk_files_for_jsonl(files, tkns_per_chunk, model)
|
||||
with open(archival_storage_file, "w") as f:
|
||||
print(
|
||||
f"Saving archival storage with preloaded files to {archival_storage_file}"
|
||||
)
|
||||
print(f"Saving archival storage with preloaded files to {archival_storage_file}")
|
||||
for c in chunks_by_file:
|
||||
json.dump(c, f)
|
||||
f.write("\n")
|
||||
|
||||
147
poetry.lock
generated
147
poetry.lock
generated
@@ -1,9 +1,10 @@
|
||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohttp"
|
||||
version = "3.8.6"
|
||||
description = "Async http client/server framework (asyncio)"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
@@ -112,6 +113,7 @@ speedups = ["Brotli", "aiodns", "cchardet"]
|
||||
name = "aiosignal"
|
||||
version = "1.3.1"
|
||||
description = "aiosignal: a list of registered asynchronous callbacks"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
@@ -126,6 +128,7 @@ frozenlist = ">=1.1.0"
|
||||
name = "async-timeout"
|
||||
version = "4.0.3"
|
||||
description = "Timeout context manager for asyncio programs"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
@@ -137,6 +140,7 @@ files = [
|
||||
name = "attrs"
|
||||
version = "23.1.0"
|
||||
description = "Classes Without Boilerplate"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
@@ -151,10 +155,54 @@ docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-
|
||||
tests = ["attrs[tests-no-zope]", "zope-interface"]
|
||||
tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
|
||||
|
||||
[[package]]
|
||||
name = "black"
|
||||
version = "23.10.1"
|
||||
description = "The uncompromising code formatter."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "black-23.10.1-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:ec3f8e6234c4e46ff9e16d9ae96f4ef69fa328bb4ad08198c8cee45bb1f08c69"},
|
||||
{file = "black-23.10.1-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:1b917a2aa020ca600483a7b340c165970b26e9029067f019e3755b56e8dd5916"},
|
||||
{file = "black-23.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c74de4c77b849e6359c6f01987e94873c707098322b91490d24296f66d067dc"},
|
||||
{file = "black-23.10.1-cp310-cp310-win_amd64.whl", hash = "sha256:7b4d10b0f016616a0d93d24a448100adf1699712fb7a4efd0e2c32bbb219b173"},
|
||||
{file = "black-23.10.1-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b15b75fc53a2fbcac8a87d3e20f69874d161beef13954747e053bca7a1ce53a0"},
|
||||
{file = "black-23.10.1-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:e293e4c2f4a992b980032bbd62df07c1bcff82d6964d6c9496f2cd726e246ace"},
|
||||
{file = "black-23.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d56124b7a61d092cb52cce34182a5280e160e6aff3137172a68c2c2c4b76bcb"},
|
||||
{file = "black-23.10.1-cp311-cp311-win_amd64.whl", hash = "sha256:3f157a8945a7b2d424da3335f7ace89c14a3b0625e6593d21139c2d8214d55ce"},
|
||||
{file = "black-23.10.1-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:cfcce6f0a384d0da692119f2d72d79ed07c7159879d0bb1bb32d2e443382bf3a"},
|
||||
{file = "black-23.10.1-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:33d40f5b06be80c1bbce17b173cda17994fbad096ce60eb22054da021bf933d1"},
|
||||
{file = "black-23.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:840015166dbdfbc47992871325799fd2dc0dcf9395e401ada6d88fe11498abad"},
|
||||
{file = "black-23.10.1-cp38-cp38-win_amd64.whl", hash = "sha256:037e9b4664cafda5f025a1728c50a9e9aedb99a759c89f760bd83730e76ba884"},
|
||||
{file = "black-23.10.1-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:7cb5936e686e782fddb1c73f8aa6f459e1ad38a6a7b0e54b403f1f05a1507ee9"},
|
||||
{file = "black-23.10.1-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:7670242e90dc129c539e9ca17665e39a146a761e681805c54fbd86015c7c84f7"},
|
||||
{file = "black-23.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ed45ac9a613fb52dad3b61c8dea2ec9510bf3108d4db88422bacc7d1ba1243d"},
|
||||
{file = "black-23.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:6d23d7822140e3fef190734216cefb262521789367fbdc0b3f22af6744058982"},
|
||||
{file = "black-23.10.1-py3-none-any.whl", hash = "sha256:d431e6739f727bb2e0495df64a6c7a5310758e87505f5f8cde9ff6c0f2d7e4fe"},
|
||||
{file = "black-23.10.1.tar.gz", hash = "sha256:1f8ce316753428ff68749c65a5f7844631aa18c8679dfd3ca9dc1a289979c258"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
click = ">=8.0.0"
|
||||
mypy-extensions = ">=0.4.3"
|
||||
packaging = ">=22.0"
|
||||
pathspec = ">=0.9.0"
|
||||
platformdirs = ">=2"
|
||||
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
|
||||
typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
colorama = ["colorama (>=0.4.3)"]
|
||||
d = ["aiohttp (>=3.7.4)"]
|
||||
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
|
||||
uvloop = ["uvloop (>=0.15.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2023.7.22"
|
||||
description = "Python package for providing Mozilla's CA Bundle."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
@@ -166,6 +214,7 @@ files = [
|
||||
name = "charset-normalizer"
|
||||
version = "3.3.1"
|
||||
description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7.0"
|
||||
files = [
|
||||
@@ -265,6 +314,7 @@ files = [
|
||||
name = "click"
|
||||
version = "8.1.7"
|
||||
description = "Composable command line interface toolkit"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
@@ -279,6 +329,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
||||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
description = "Cross-platform colored terminal text."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
files = [
|
||||
@@ -290,6 +341,7 @@ files = [
|
||||
name = "demjson3"
|
||||
version = "3.0.6"
|
||||
description = "encoder, decoder, and lint/validator for JSON (JavaScript Object Notation) compliant with RFC 7159"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
@@ -300,6 +352,7 @@ files = [
|
||||
name = "faiss-cpu"
|
||||
version = "1.7.4"
|
||||
description = "A library for efficient similarity search and clustering of dense vectors."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
@@ -334,6 +387,7 @@ files = [
|
||||
name = "frozenlist"
|
||||
version = "1.4.0"
|
||||
description = "A list-like structure which implements collections.abc.MutableSequence"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
@@ -404,6 +458,7 @@ files = [
|
||||
name = "idna"
|
||||
version = "3.4"
|
||||
description = "Internationalized Domain Names in Applications (IDNA)"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
files = [
|
||||
@@ -415,6 +470,7 @@ files = [
|
||||
name = "markdown-it-py"
|
||||
version = "3.0.0"
|
||||
description = "Python port of markdown-it. Markdown parsing, done right!"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
@@ -439,6 +495,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
|
||||
name = "mdurl"
|
||||
version = "0.1.2"
|
||||
description = "Markdown URL utilities"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
@@ -450,6 +507,7 @@ files = [
|
||||
name = "multidict"
|
||||
version = "6.0.4"
|
||||
description = "multidict implementation"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
@@ -529,10 +587,23 @@ files = [
|
||||
{file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mypy-extensions"
|
||||
version = "1.0.0"
|
||||
description = "Type system extensions for programs checked with the mypy type checker."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
files = [
|
||||
{file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"},
|
||||
{file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "1.26.1"
|
||||
description = "Fundamental package for array computing in Python"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "<3.13,>=3.9"
|
||||
files = [
|
||||
@@ -574,6 +645,7 @@ files = [
|
||||
name = "openai"
|
||||
version = "0.28.1"
|
||||
description = "Python client library for the OpenAI API"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7.1"
|
||||
files = [
|
||||
@@ -588,14 +660,55 @@ tqdm = "*"
|
||||
|
||||
[package.extras]
|
||||
datalib = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
|
||||
dev = ["black (>=21.6b0,<22.0)", "pytest (==6.*)", "pytest-asyncio", "pytest-mock"]
|
||||
dev = ["black (>=21.6b0,<22.0)", "pytest (>=6.0.0,<7.0.0)", "pytest-asyncio", "pytest-mock"]
|
||||
embeddings = ["matplotlib", "numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "plotly", "scikit-learn (>=1.0.2)", "scipy", "tenacity (>=8.0.1)"]
|
||||
wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "wandb"]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "23.2"
|
||||
description = "Core utilities for Python packages"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"},
|
||||
{file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pathspec"
|
||||
version = "0.11.2"
|
||||
description = "Utility library for gitignore style pattern matching of file paths."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pathspec-0.11.2-py3-none-any.whl", hash = "sha256:1d6ed233af05e679efb96b1851550ea95bbb64b7c490b0f5aa52996c11e92a20"},
|
||||
{file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "platformdirs"
|
||||
version = "3.11.0"
|
||||
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "platformdirs-3.11.0-py3-none-any.whl", hash = "sha256:e9d171d00af68be50e9202731309c4e658fd8bc76f55c11c7dd760d023bda68e"},
|
||||
{file = "platformdirs-3.11.0.tar.gz", hash = "sha256:cf8ee52a3afdb965072dcc652433e0c7e3e40cf5ea1477cd4b3b1d2eb75495b3"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-autodoc-typehints (>=1.24)"]
|
||||
test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "prompt-toolkit"
|
||||
version = "3.0.36"
|
||||
description = "Library for building powerful interactive command lines in Python"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.6.2"
|
||||
files = [
|
||||
@@ -610,6 +723,7 @@ wcwidth = "*"
|
||||
name = "pygments"
|
||||
version = "2.16.1"
|
||||
description = "Pygments is a syntax highlighting package written in Python."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
@@ -624,6 +738,7 @@ plugins = ["importlib-metadata"]
|
||||
name = "pymupdf"
|
||||
version = "1.23.5"
|
||||
description = "A high performance Python library for data extraction, analysis, conversion & manipulation of PDF (and other) documents."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
@@ -667,6 +782,7 @@ PyMuPDFb = "1.23.5"
|
||||
name = "pymupdfb"
|
||||
version = "1.23.5"
|
||||
description = "MuPDF shared libraries for PyMuPDF."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
@@ -682,6 +798,7 @@ files = [
|
||||
name = "pytz"
|
||||
version = "2023.3.post1"
|
||||
description = "World timezone definitions, modern and historical"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
@@ -693,6 +810,7 @@ files = [
|
||||
name = "questionary"
|
||||
version = "2.0.1"
|
||||
description = "Python library to build pretty command line user prompts ⭐️"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
@@ -707,6 +825,7 @@ prompt_toolkit = ">=2.0,<=3.0.36"
|
||||
name = "regex"
|
||||
version = "2023.10.3"
|
||||
description = "Alternative regular expression module, to replace re."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
@@ -804,6 +923,7 @@ files = [
|
||||
name = "requests"
|
||||
version = "2.31.0"
|
||||
description = "Python HTTP for Humans."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
@@ -825,6 +945,7 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||
name = "rich"
|
||||
version = "13.6.0"
|
||||
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7.0"
|
||||
files = [
|
||||
@@ -843,6 +964,7 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
||||
name = "shellingham"
|
||||
version = "1.5.3"
|
||||
description = "Tool to Detect Surrounding Shell"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
@@ -854,6 +976,7 @@ files = [
|
||||
name = "tiktoken"
|
||||
version = "0.5.1"
|
||||
description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
@@ -895,10 +1018,23 @@ requests = ">=2.26.0"
|
||||
[package.extras]
|
||||
blobfile = ["blobfile (>=2)"]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.0.1"
|
||||
description = "A lil' TOML parser"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
|
||||
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tqdm"
|
||||
version = "4.66.1"
|
||||
description = "Fast, Extensible Progress Meter"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
@@ -919,6 +1055,7 @@ telegram = ["requests"]
|
||||
name = "typer"
|
||||
version = "0.9.0"
|
||||
description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
@@ -943,6 +1080,7 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.
|
||||
name = "typing-extensions"
|
||||
version = "4.8.0"
|
||||
description = "Backported and Experimental Type Hints for Python 3.8+"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
@@ -954,6 +1092,7 @@ files = [
|
||||
name = "urllib3"
|
||||
version = "2.0.7"
|
||||
description = "HTTP library with thread-safe connection pooling, file post, and more."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
@@ -971,6 +1110,7 @@ zstd = ["zstandard (>=0.18.0)"]
|
||||
name = "wcwidth"
|
||||
version = "0.2.8"
|
||||
description = "Measures the displayed width of unicode strings in a terminal"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
@@ -982,6 +1122,7 @@ files = [
|
||||
name = "yarl"
|
||||
version = "1.9.2"
|
||||
description = "Yet another URL library"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
@@ -1068,4 +1209,4 @@ multidict = ">=4.0"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.13,>=3.9"
|
||||
content-hash = "d5c3f3a0c8149f3ba08cca638c1a5154f2d7d42ddeb4a7ed2d87bc9545841f23"
|
||||
content-hash = "9c19a9cd0487a85fa947ec3f53e765b47a03b2a1c6ae1a46de95b25a893690b2"
|
||||
|
||||
@@ -30,6 +30,7 @@ tiktoken = "^0.5.1"
|
||||
pymupdf = "^1.23.5"
|
||||
tqdm = "^4.66.1"
|
||||
openai = "^0.28.1"
|
||||
black = "^23.10.1"
|
||||
|
||||
|
||||
[build-system]
|
||||
|
||||
Reference in New Issue
Block a user