Remove AsyncAgent and async from cli (#400)
* Remove AsyncAgent and async from cli Refactor agent.py memory.py Refactor interface.py Refactor main.py Refactor openai_tools.py Refactor cli/cli.py stray asyncs save make legacy embeddings not use async Refactor presets Remove deleted function from import * remove stray prints * typo * another stray print * patch test --------- Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
406
memgpt/agent.py
406
memgpt/agent.py
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import datetime
|
||||
import glob
|
||||
@@ -14,8 +13,8 @@ import openai
|
||||
from memgpt.persistence_manager import LocalStateManager
|
||||
from memgpt.config import AgentConfig
|
||||
from .system import get_heartbeat, get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
|
||||
from .memory import CoreMemory as Memory, summarize_messages, a_summarize_messages
|
||||
from .openai_tools import acompletions_with_backoff as acreate, completions_with_backoff as create
|
||||
from .memory import CoreMemory as Memory, summarize_messages
|
||||
from .openai_tools import completions_with_backoff as create
|
||||
from .utils import get_local_time, parse_json, united_diff, printd, count_tokens
|
||||
from .constants import (
|
||||
MEMGPT_DIR,
|
||||
@@ -133,45 +132,6 @@ def get_ai_reply(
|
||||
raise e
|
||||
|
||||
|
||||
async def get_ai_reply_async(
|
||||
model,
|
||||
message_sequence,
|
||||
functions,
|
||||
function_call="auto",
|
||||
):
|
||||
"""Base call to GPT API w/ functions"""
|
||||
|
||||
try:
|
||||
response = await acreate(
|
||||
model=model,
|
||||
messages=message_sequence,
|
||||
functions=functions,
|
||||
function_call=function_call,
|
||||
)
|
||||
|
||||
# special case for 'length'
|
||||
if response.choices[0].finish_reason == "length":
|
||||
raise Exception("Finish reason was length (maximum context length)")
|
||||
|
||||
# catches for soft errors
|
||||
if response.choices[0].finish_reason not in ["stop", "function_call"]:
|
||||
raise Exception(f"API call finish with bad finish reason: {response}")
|
||||
|
||||
# unpack with response.choices[0].message.content
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
# Assuming function_to_call is either sync or async
|
||||
async def call_function(function_to_call, **function_args):
|
||||
if inspect.iscoroutinefunction(function_to_call):
|
||||
return await function_to_call(**function_args)
|
||||
else:
|
||||
return function_to_call(**function_args)
|
||||
|
||||
|
||||
class Agent(object):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -207,7 +167,7 @@ class Agent(object):
|
||||
self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system)
|
||||
# self.messages_total_init = self.messages_total
|
||||
self.messages_total_init = len(self._messages) - 1
|
||||
printd(f"AgentAsync initialized, self.messages_total={self.messages_total}")
|
||||
printd(f"Agent initialized, self.messages_total={self.messages_total}")
|
||||
|
||||
# Interface must implement:
|
||||
# - internal_monologue
|
||||
@@ -922,363 +882,3 @@ class Agent(object):
|
||||
# Check if it's been more than pause_heartbeats_minutes since pause_heartbeats_start
|
||||
elapsed_time = datetime.datetime.now() - self.pause_heartbeats_start
|
||||
return elapsed_time.total_seconds() < self.pause_heartbeats_minutes * 60
|
||||
|
||||
|
||||
class AgentAsync(Agent):
|
||||
"""Core logic for an async MemGPT agent"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.init_avail_functions()
|
||||
|
||||
async def handle_ai_response(self, response_message):
|
||||
"""Handles parsing and function execution"""
|
||||
messages = [] # append these to the history when done
|
||||
|
||||
# Step 2: check if LLM wanted to call a function
|
||||
if response_message.get("function_call"):
|
||||
# The content if then internal monologue, not chat
|
||||
await self.interface.internal_monologue(response_message.content)
|
||||
messages.append(response_message) # extend conversation with assistant's reply
|
||||
|
||||
# Step 3: call the function
|
||||
# Note: the JSON response may not always be valid; be sure to handle errors
|
||||
|
||||
# Failure case 1: function name is wrong
|
||||
function_name = response_message["function_call"]["name"]
|
||||
try:
|
||||
function_to_call = self.available_functions[function_name]
|
||||
except KeyError as e:
|
||||
error_msg = f"No function named {function_name}"
|
||||
function_response = package_function_response(False, error_msg)
|
||||
messages.append(
|
||||
{
|
||||
"role": "function",
|
||||
"name": function_name,
|
||||
"content": function_response,
|
||||
}
|
||||
) # extend conversation with function response
|
||||
await self.interface.function_message(f"Error: {error_msg}")
|
||||
return messages, None, True # force a heartbeat to allow agent to handle error
|
||||
|
||||
# Failure case 2: function name is OK, but function args are bad JSON
|
||||
try:
|
||||
raw_function_args = response_message["function_call"]["arguments"]
|
||||
function_args = parse_json(raw_function_args)
|
||||
except Exception as e:
|
||||
error_msg = f"Error parsing JSON for function '{function_name}' arguments: {raw_function_args}"
|
||||
function_response = package_function_response(False, error_msg)
|
||||
messages.append(
|
||||
{
|
||||
"role": "function",
|
||||
"name": function_name,
|
||||
"content": function_response,
|
||||
}
|
||||
) # extend conversation with function response
|
||||
await self.interface.function_message(f"Error: {error_msg}")
|
||||
return messages, None, True # force a heartbeat to allow agent to handle error
|
||||
|
||||
# (Still parsing function args)
|
||||
# Handle requests for immediate heartbeat
|
||||
heartbeat_request = function_args.pop("request_heartbeat", None)
|
||||
if not (isinstance(heartbeat_request, bool) or heartbeat_request is None):
|
||||
printd(
|
||||
f"Warning: 'request_heartbeat' arg parsed was not a bool or None, type={type(heartbeat_request)}, value={heartbeat_request}"
|
||||
)
|
||||
heartbeat_request = None
|
||||
|
||||
# Failure case 3: function failed during execution
|
||||
await self.interface.function_message(f"Running {function_name}({function_args})")
|
||||
try:
|
||||
function_response_string = await call_function(function_to_call, **function_args)
|
||||
function_response = package_function_response(True, function_response_string)
|
||||
function_failed = False
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling function {function_name} with args {function_args}: {str(e)}"
|
||||
error_msg_user = f"{error_msg}\n{traceback.format_exc()}"
|
||||
printd(error_msg_user)
|
||||
function_response = package_function_response(False, error_msg)
|
||||
messages.append(
|
||||
{
|
||||
"role": "function",
|
||||
"name": function_name,
|
||||
"content": function_response,
|
||||
}
|
||||
) # extend conversation with function response
|
||||
await self.interface.function_message(f"Error: {error_msg}")
|
||||
return messages, None, True # force a heartbeat to allow agent to handle error
|
||||
|
||||
# If no failures happened along the way: ...
|
||||
# Step 4: send the info on the function call and function response to GPT
|
||||
if function_response_string:
|
||||
await self.interface.function_message(f"Success: {function_response_string}")
|
||||
else:
|
||||
await self.interface.function_message(f"Success")
|
||||
messages.append(
|
||||
{
|
||||
"role": "function",
|
||||
"name": function_name,
|
||||
"content": function_response,
|
||||
}
|
||||
) # extend conversation with function response
|
||||
|
||||
else:
|
||||
# Standard non-function reply
|
||||
await self.interface.internal_monologue(response_message.content)
|
||||
messages.append(response_message) # extend conversation with assistant's reply
|
||||
heartbeat_request = None
|
||||
function_failed = None
|
||||
|
||||
return messages, heartbeat_request, function_failed
|
||||
|
||||
async def step(self, user_message, first_message=False, first_message_retry_limit=FIRST_MESSAGE_ATTEMPTS, skip_verify=False):
|
||||
"""Top-level event message handler for the MemGPT agent"""
|
||||
|
||||
try:
|
||||
# Step 0: add user message
|
||||
if user_message is not None:
|
||||
await self.interface.user_message(user_message)
|
||||
packed_user_message = {"role": "user", "content": user_message}
|
||||
input_message_sequence = self.messages + [packed_user_message]
|
||||
else:
|
||||
input_message_sequence = self.messages
|
||||
|
||||
if len(input_message_sequence) > 1 and input_message_sequence[-1]["role"] != "user":
|
||||
printd(f"WARNING: attempting to run ChatCompletion without user as the last message in the queue")
|
||||
from pprint import pprint
|
||||
|
||||
pprint(input_message_sequence[-1])
|
||||
|
||||
# Step 1: send the conversation and available functions to GPT
|
||||
if not skip_verify and (first_message or self.messages_total == self.messages_total_init):
|
||||
printd(f"This is the first message. Running extra verifier on AI response.")
|
||||
counter = 0
|
||||
while True:
|
||||
response = await get_ai_reply_async(model=self.model, message_sequence=input_message_sequence, functions=self.functions)
|
||||
if self.verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono):
|
||||
break
|
||||
|
||||
counter += 1
|
||||
if counter > first_message_retry_limit:
|
||||
raise Exception(f"Hit first message retry limit ({first_message_retry_limit})")
|
||||
|
||||
else:
|
||||
response = await get_ai_reply_async(model=self.model, message_sequence=input_message_sequence, functions=self.functions)
|
||||
|
||||
# Step 2: check if LLM wanted to call a function
|
||||
# (if yes) Step 3: call the function
|
||||
# (if yes) Step 4: send the info on the function call and function response to LLM
|
||||
response_message = response.choices[0].message
|
||||
response_message_copy = response_message.copy()
|
||||
all_response_messages, heartbeat_request, function_failed = await self.handle_ai_response(response_message)
|
||||
|
||||
# Add the extra metadata to the assistant response
|
||||
# (e.g. enough metadata to enable recreating the API call)
|
||||
assert "api_response" not in all_response_messages[0], f"api_response already in {all_response_messages[0]}"
|
||||
all_response_messages[0]["api_response"] = response_message_copy
|
||||
assert "api_args" not in all_response_messages[0], f"api_args already in {all_response_messages[0]}"
|
||||
all_response_messages[0]["api_args"] = {
|
||||
"model": self.model,
|
||||
"messages": input_message_sequence,
|
||||
"functions": self.functions,
|
||||
}
|
||||
|
||||
# Step 4: extend the message history
|
||||
if user_message is not None:
|
||||
all_new_messages = [packed_user_message] + all_response_messages
|
||||
else:
|
||||
all_new_messages = all_response_messages
|
||||
|
||||
# Check the memory pressure and potentially issue a memory pressure warning
|
||||
current_total_tokens = response["usage"]["total_tokens"]
|
||||
active_memory_warning = False
|
||||
if current_total_tokens > MESSAGE_SUMMARY_WARNING_TOKENS:
|
||||
printd(f"WARNING: last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_TOKENS}")
|
||||
# Only deliver the alert if we haven't already (this period)
|
||||
if not self.agent_alerted_about_memory_pressure:
|
||||
active_memory_warning = True
|
||||
self.agent_alerted_about_memory_pressure = True # it's up to the outer loop to handle this
|
||||
else:
|
||||
printd(f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_TOKENS}")
|
||||
|
||||
self.append_to_messages(all_new_messages)
|
||||
return all_new_messages, heartbeat_request, function_failed, active_memory_warning
|
||||
|
||||
except Exception as e:
|
||||
printd(f"step() failed\nuser_message = {user_message}\nerror = {e}")
|
||||
print(f"step() failed\nuser_message = {user_message}\nerror = {e}")
|
||||
|
||||
# If we got a context alert, try trimming the messages length, then try again
|
||||
if "maximum context length" in str(e):
|
||||
# A separate API call to run a summarizer
|
||||
await self.summarize_messages_inplace()
|
||||
|
||||
# Try step again
|
||||
return await self.step(user_message, first_message=first_message)
|
||||
else:
|
||||
printd(f"step() failed with openai.InvalidRequestError, but didn't recognize the error message: '{str(e)}'")
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
async def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True):
|
||||
assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})"
|
||||
|
||||
# Start at index 1 (past the system message),
|
||||
# and collect messages for summarization until we reach the desired truncation token fraction (eg 50%)
|
||||
# Do not allow truncation of the last N messages, since these are needed for in-context examples of function calling
|
||||
token_counts = [count_tokens(str(msg)) for msg in self.messages]
|
||||
message_buffer_token_count = sum(token_counts[1:]) # no system message
|
||||
token_counts = token_counts[1:]
|
||||
desired_token_count_to_summarize = int(message_buffer_token_count * MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC)
|
||||
candidate_messages_to_summarize = self.messages[1:]
|
||||
if preserve_last_N_messages:
|
||||
candidate_messages_to_summarize = candidate_messages_to_summarize[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST]
|
||||
token_counts = token_counts[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST]
|
||||
printd(f"MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC={MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC}")
|
||||
printd(f"MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}")
|
||||
printd(f"token_counts={token_counts}")
|
||||
printd(f"message_buffer_token_count={message_buffer_token_count}")
|
||||
printd(f"desired_token_count_to_summarize={desired_token_count_to_summarize}")
|
||||
printd(f"len(candidate_messages_to_summarize)={len(candidate_messages_to_summarize)}")
|
||||
|
||||
# If at this point there's nothing to summarize, throw an error
|
||||
if len(candidate_messages_to_summarize) == 0:
|
||||
raise LLMError(
|
||||
f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(self.messages)}, preserve_N={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}]"
|
||||
)
|
||||
|
||||
# Walk down the message buffer (front-to-back) until we hit the target token count
|
||||
tokens_so_far = 0
|
||||
cutoff = 0
|
||||
for i, msg in enumerate(candidate_messages_to_summarize):
|
||||
cutoff = i
|
||||
tokens_so_far += token_counts[i]
|
||||
if tokens_so_far > desired_token_count_to_summarize:
|
||||
break
|
||||
# Account for system message
|
||||
cutoff += 1
|
||||
|
||||
# Try to make an assistant message come after the cutoff
|
||||
try:
|
||||
printd(f"Selected cutoff {cutoff} was a 'user', shifting one...")
|
||||
if self.messages[cutoff]["role"] == "user":
|
||||
new_cutoff = cutoff + 1
|
||||
if self.messages[new_cutoff]["role"] == "user":
|
||||
printd(f"Shifted cutoff {new_cutoff} is still a 'user', ignoring...")
|
||||
cutoff = new_cutoff
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
message_sequence_to_summarize = self.messages[1:cutoff] # do NOT get rid of the system message
|
||||
if len(message_sequence_to_summarize) == 0:
|
||||
printd(f"message_sequence_to_summarize is len 0, skipping summarize")
|
||||
raise LLMError(
|
||||
f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(self.messages)}, cutoff={cutoff}]"
|
||||
)
|
||||
|
||||
printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self.messages)}")
|
||||
summary = await a_summarize_messages(self.model, message_sequence_to_summarize)
|
||||
printd(f"Got summary: {summary}")
|
||||
|
||||
# Metadata that's useful for the agent to see
|
||||
all_time_message_count = self.messages_total
|
||||
remaining_message_count = len(self.messages[cutoff:])
|
||||
hidden_message_count = all_time_message_count - remaining_message_count
|
||||
summary_message_count = len(message_sequence_to_summarize)
|
||||
summary_message = package_summarize_message(summary, summary_message_count, hidden_message_count, all_time_message_count)
|
||||
printd(f"Packaged into message: {summary_message}")
|
||||
|
||||
prior_len = len(self.messages)
|
||||
self.trim_messages(cutoff)
|
||||
packed_summary_message = {"role": "user", "content": summary_message}
|
||||
self.prepend_to_messages([packed_summary_message])
|
||||
|
||||
# reset alert
|
||||
self.agent_alerted_about_memory_pressure = False
|
||||
|
||||
printd(f"Ran summarizer, messages length {prior_len} -> {len(self.messages)}")
|
||||
|
||||
async def free_step(self, user_message, limit=None):
|
||||
"""Allow agent to manage its own control flow (past a single LLM call).
|
||||
Not currently used, instead this is handled in the CLI main.py logic
|
||||
"""
|
||||
|
||||
new_messages, heartbeat_request, function_failed = self.step(user_message)
|
||||
step_count = 1
|
||||
|
||||
while limit is None or step_count < limit:
|
||||
if function_failed:
|
||||
user_message = get_heartbeat("Function call failed")
|
||||
new_messages, heartbeat_request, function_failed = await self.step(user_message)
|
||||
step_count += 1
|
||||
elif heartbeat_request:
|
||||
user_message = get_heartbeat("AI requested")
|
||||
new_messages, heartbeat_request, function_failed = await self.step(user_message)
|
||||
step_count += 1
|
||||
else:
|
||||
break
|
||||
|
||||
return new_messages, heartbeat_request, function_failed
|
||||
|
||||
### Functions / tools the agent can use
|
||||
# All functions should return a response string (or None)
|
||||
# If the function fails, throw an exception
|
||||
|
||||
async def send_ai_message(self, message):
|
||||
"""AI wanted to send a message"""
|
||||
await self.interface.assistant_message(message)
|
||||
return None
|
||||
|
||||
async def recall_memory_search(self, query, count=5, page=0):
|
||||
results, total = await self.persistence_manager.recall_memory.a_text_search(query, count=count, start=page * count)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
else:
|
||||
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
|
||||
results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results]
|
||||
results_str = f"{results_pref} {json.dumps(results_formatted)}"
|
||||
return results_str
|
||||
|
||||
async def recall_memory_search_date(self, start_date, end_date, count=5, page=0):
|
||||
results, total = await self.persistence_manager.recall_memory.a_date_search(start_date, end_date, count=count, start=page * count)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
else:
|
||||
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
|
||||
results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results]
|
||||
results_str = f"{results_pref} {json.dumps(results_formatted)}"
|
||||
return results_str
|
||||
|
||||
async def archival_memory_insert(self, content):
|
||||
await self.persistence_manager.archival_memory.a_insert(content)
|
||||
return None
|
||||
|
||||
async def archival_memory_search(self, query, count=5, page=0):
|
||||
results, total = await self.persistence_manager.archival_memory.a_search(query, count=count, start=page * count)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
else:
|
||||
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
|
||||
results_formatted = [f"timestamp: {d['timestamp']}, memory: {d['content']}" for d in results]
|
||||
results_str = f"{results_pref} {json.dumps(results_formatted)}"
|
||||
return results_str
|
||||
|
||||
async def message_chatgpt(self, message):
|
||||
"""Base call to GPT API w/ functions"""
|
||||
|
||||
message_sequence = [
|
||||
{"role": "system", "content": MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE},
|
||||
{"role": "user", "content": str(message)},
|
||||
]
|
||||
response = await acreate(
|
||||
model=MESSAGE_CHATGPT_FUNCTION_MODEL,
|
||||
messages=message_sequence,
|
||||
# functions=functions,
|
||||
# function_call=function_call,
|
||||
)
|
||||
|
||||
reply = response.choices[0].message.content
|
||||
return reply
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class AgentAsyncBase(ABC):
|
||||
@abstractmethod
|
||||
async def step(self, user_message):
|
||||
pass
|
||||
@@ -46,7 +46,7 @@ def create_memgpt_autogen_agent_from_config(
|
||||
|
||||
autogen_memgpt_agent = create_autogen_memgpt_agent(
|
||||
name,
|
||||
preset=presets.SYNC_CHAT,
|
||||
preset=presets.DEFAULT_PRESET,
|
||||
model=model,
|
||||
persona_description=persona_desc,
|
||||
user_description=user_desc,
|
||||
@@ -57,7 +57,7 @@ def create_memgpt_autogen_agent_from_config(
|
||||
if human_input_mode != "ALWAYS":
|
||||
coop_agent1 = create_autogen_memgpt_agent(
|
||||
name,
|
||||
preset=presets.SYNC_CHAT,
|
||||
preset=presets.DEFAULT_PRESET,
|
||||
model=model,
|
||||
persona_description=persona_desc,
|
||||
user_description=user_desc,
|
||||
@@ -73,7 +73,7 @@ def create_memgpt_autogen_agent_from_config(
|
||||
else:
|
||||
coop_agent2 = create_autogen_memgpt_agent(
|
||||
name,
|
||||
preset=presets.SYNC_CHAT,
|
||||
preset=presets.DEFAULT_PRESET,
|
||||
model=model,
|
||||
persona_description=persona_desc,
|
||||
user_description=user_desc,
|
||||
@@ -95,7 +95,7 @@ def create_memgpt_autogen_agent_from_config(
|
||||
|
||||
def create_autogen_memgpt_agent(
|
||||
autogen_name,
|
||||
preset=presets.SYNC_CHAT,
|
||||
preset=presets.DEFAULT_PRESET,
|
||||
model=constants.DEFAULT_MEMGPT_MODEL,
|
||||
persona_description=personas.DEFAULT,
|
||||
user_description=humans.DEFAULT,
|
||||
@@ -126,7 +126,7 @@ def create_autogen_memgpt_agent(
|
||||
persona=persona_description,
|
||||
human=user_description,
|
||||
model=model,
|
||||
preset=presets.SYNC_CHAT,
|
||||
preset=presets.DEFAULT_PRESET,
|
||||
)
|
||||
|
||||
interface = AutoGenInterface(**interface_kwargs) if interface is None else interface
|
||||
|
||||
@@ -2,7 +2,6 @@ import typer
|
||||
import sys
|
||||
import io
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
from prettytable import PrettyTable
|
||||
import questionary
|
||||
@@ -24,7 +23,7 @@ from memgpt.utils import printd
|
||||
from memgpt.persistence_manager import LocalStateManager
|
||||
from memgpt.config import MemGPTConfig, AgentConfig
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
from memgpt.agent import AgentAsync
|
||||
from memgpt.agent import Agent
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.openai_tools import (
|
||||
configure_azure_support,
|
||||
@@ -121,7 +120,7 @@ def run(
|
||||
agent_config.save()
|
||||
|
||||
# load existing agent
|
||||
memgpt_agent = AgentAsync.load_agent(memgpt.interface, agent_config)
|
||||
memgpt_agent = Agent.load_agent(memgpt.interface, agent_config)
|
||||
else: # create new agent
|
||||
# create new agent config: override defaults with args if provided
|
||||
typer.secho("Creating new agent...", fg=typer.colors.GREEN)
|
||||
@@ -162,8 +161,7 @@ def run(
|
||||
if config.model_endpoint == "azure":
|
||||
configure_azure_support()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(run_agent_loop(memgpt_agent, first, no_verify, config)) # TODO: add back no_verify
|
||||
run_agent_loop(memgpt_agent, first, no_verify, config) # TODO: add back no_verify
|
||||
|
||||
|
||||
def attach(
|
||||
|
||||
@@ -30,7 +30,7 @@ def configure():
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# openai credentials
|
||||
use_openai = questionary.confirm("Do you want to enable MemGPT with Open AI?", default=True).ask()
|
||||
use_openai = questionary.confirm("Do you want to enable MemGPT with OpenAI?", default=True).ask()
|
||||
if use_openai:
|
||||
# search for key in enviornment
|
||||
openai_key = os.getenv("OPENAI_API_KEY")
|
||||
@@ -119,10 +119,10 @@ def configure():
|
||||
|
||||
# defaults
|
||||
personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()]
|
||||
print(personas)
|
||||
# print(personas)
|
||||
default_persona = questionary.select("Select default persona:", personas, default=config.default_persona).ask()
|
||||
humans = [os.path.basename(f).replace(".txt", "") for f in utils.list_human_files()]
|
||||
print(humans)
|
||||
# print(humans)
|
||||
default_human = questionary.select("Select default human:", humans, default=config.default_human).ask()
|
||||
|
||||
# TODO: figure out if we should set a default agent or not
|
||||
|
||||
@@ -203,7 +203,7 @@ class MemGPTConfig:
|
||||
|
||||
# archival storage
|
||||
config.add_section("archival_storage")
|
||||
print("archival storage", self.archival_storage_type)
|
||||
# print("archival storage", self.archival_storage_type)
|
||||
config.set("archival_storage", "type", self.archival_storage_type)
|
||||
if self.archival_storage_path:
|
||||
config.set("archival_storage", "path", self.archival_storage_path)
|
||||
@@ -350,7 +350,7 @@ class Config:
|
||||
self.preload_archival = False
|
||||
|
||||
@classmethod
|
||||
async def legacy_flags_init(
|
||||
def legacy_flags_init(
|
||||
cls: Type["Config"],
|
||||
model: str,
|
||||
memgpt_persona: str,
|
||||
@@ -372,11 +372,11 @@ class Config:
|
||||
if self.archival_storage_index:
|
||||
recompute_embeddings = False # TODO Legacy support -- can't recompute embeddings on a path that's not specified.
|
||||
if self.archival_storage_files:
|
||||
await self.configure_archival_storage(recompute_embeddings)
|
||||
self.configure_archival_storage(recompute_embeddings)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
async def config_init(cls: Type["Config"], config_file: str = None):
|
||||
def config_init(cls: Type["Config"], config_file: str = None):
|
||||
self = cls()
|
||||
self.config_file = config_file
|
||||
if self.config_file is None:
|
||||
@@ -384,7 +384,7 @@ class Config:
|
||||
use_cfg = False
|
||||
if cfg:
|
||||
print(f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Found saved config file.{Style.RESET_ALL}")
|
||||
use_cfg = await questionary.confirm(f"Use most recent config file '{cfg}'?").ask_async()
|
||||
use_cfg = questionary.confirm(f"Use most recent config file '{cfg}'?").ask()
|
||||
if use_cfg:
|
||||
self.config_file = cfg
|
||||
|
||||
@@ -393,74 +393,74 @@ class Config:
|
||||
recompute_embeddings = False
|
||||
if self.compute_embeddings:
|
||||
if self.archival_storage_index:
|
||||
recompute_embeddings = await questionary.confirm(
|
||||
recompute_embeddings = questionary.confirm(
|
||||
f"Would you like to recompute embeddings? Do this if your files have changed.\n Files: {self.archival_storage_files}",
|
||||
default=False,
|
||||
).ask_async()
|
||||
).ask()
|
||||
else:
|
||||
recompute_embeddings = True
|
||||
if self.load_type:
|
||||
await self.configure_archival_storage(recompute_embeddings)
|
||||
self.configure_archival_storage(recompute_embeddings)
|
||||
self.write_config()
|
||||
return self
|
||||
|
||||
# print("No settings file found, configuring MemGPT...")
|
||||
print(f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ No settings file found, configuring MemGPT...{Style.RESET_ALL}")
|
||||
|
||||
self.model = await questionary.select(
|
||||
self.model = questionary.select(
|
||||
"Which model would you like to use?",
|
||||
model_choices,
|
||||
default=model_choices[0],
|
||||
).ask_async()
|
||||
).ask()
|
||||
|
||||
self.memgpt_persona = await questionary.select(
|
||||
self.memgpt_persona = questionary.select(
|
||||
"Which persona would you like MemGPT to use?",
|
||||
Config.get_memgpt_personas(),
|
||||
).ask_async()
|
||||
).ask()
|
||||
print(self.memgpt_persona)
|
||||
|
||||
self.human_persona = await questionary.select(
|
||||
self.human_persona = questionary.select(
|
||||
"Which user would you like to use?",
|
||||
Config.get_user_personas(),
|
||||
).ask_async()
|
||||
).ask()
|
||||
|
||||
self.archival_storage_index = None
|
||||
self.preload_archival = await questionary.confirm(
|
||||
self.preload_archival = questionary.confirm(
|
||||
"Would you like to preload anything into MemGPT's archival memory?", default=False
|
||||
).ask_async()
|
||||
).ask()
|
||||
if self.preload_archival:
|
||||
self.load_type = await questionary.select(
|
||||
self.load_type = questionary.select(
|
||||
"What would you like to load?",
|
||||
choices=[
|
||||
questionary.Choice("A folder or file", value="folder"),
|
||||
questionary.Choice("A SQL database", value="sql"),
|
||||
questionary.Choice("A glob pattern", value="glob"),
|
||||
],
|
||||
).ask_async()
|
||||
).ask()
|
||||
if self.load_type == "folder" or self.load_type == "sql":
|
||||
archival_storage_path = await questionary.path("Please enter the folder or file (tab for autocomplete):").ask_async()
|
||||
archival_storage_path = questionary.path("Please enter the folder or file (tab for autocomplete):").ask()
|
||||
if os.path.isdir(archival_storage_path):
|
||||
self.archival_storage_files = os.path.join(archival_storage_path, "*")
|
||||
else:
|
||||
self.archival_storage_files = archival_storage_path
|
||||
else:
|
||||
self.archival_storage_files = await questionary.path("Please enter the glob pattern (tab for autocomplete):").ask_async()
|
||||
self.compute_embeddings = await questionary.confirm(
|
||||
self.archival_storage_files = questionary.path("Please enter the glob pattern (tab for autocomplete):").ask()
|
||||
self.compute_embeddings = questionary.confirm(
|
||||
"Would you like to compute embeddings over these files to enable embeddings search?"
|
||||
).ask_async()
|
||||
await self.configure_archival_storage(self.compute_embeddings)
|
||||
).ask()
|
||||
self.configure_archival_storage(self.compute_embeddings)
|
||||
|
||||
self.write_config()
|
||||
return self
|
||||
|
||||
async def configure_archival_storage(self, recompute_embeddings):
|
||||
def configure_archival_storage(self, recompute_embeddings):
|
||||
if recompute_embeddings:
|
||||
if self.host:
|
||||
interface.warning_message(
|
||||
"⛔️ Embeddings on a non-OpenAI endpoint are not yet supported, falling back to substring matching search."
|
||||
)
|
||||
else:
|
||||
self.archival_storage_index = await utils.prepare_archival_index_from_files_compute_embeddings(self.archival_storage_files)
|
||||
self.archival_storage_index = utils.prepare_archival_index_from_files_compute_embeddings(self.archival_storage_files)
|
||||
if self.compute_embeddings and self.archival_storage_index:
|
||||
self.index, self.archival_database = utils.prepare_archival_index(self.archival_storage_index)
|
||||
else:
|
||||
|
||||
@@ -28,7 +28,7 @@ def warning_message(msg):
|
||||
print(fstr.format(msg=msg))
|
||||
|
||||
|
||||
async def internal_monologue(msg):
|
||||
def internal_monologue(msg):
|
||||
# ANSI escape code for italic is '\x1B[3m'
|
||||
fstr = f"\x1B[3m{Fore.LIGHTBLACK_EX}💭 {{msg}}{Style.RESET_ALL}"
|
||||
if STRIP_UI:
|
||||
@@ -36,28 +36,28 @@ async def internal_monologue(msg):
|
||||
print(fstr.format(msg=msg))
|
||||
|
||||
|
||||
async def assistant_message(msg):
|
||||
def assistant_message(msg):
|
||||
fstr = f"{Fore.YELLOW}{Style.BRIGHT}🤖 {Fore.YELLOW}{{msg}}{Style.RESET_ALL}"
|
||||
if STRIP_UI:
|
||||
fstr = "{msg}"
|
||||
print(fstr.format(msg=msg))
|
||||
|
||||
|
||||
async def memory_message(msg):
|
||||
def memory_message(msg):
|
||||
fstr = f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{{msg}}{Style.RESET_ALL}"
|
||||
if STRIP_UI:
|
||||
fstr = "{msg}"
|
||||
print(fstr.format(msg=msg))
|
||||
|
||||
|
||||
async def system_message(msg):
|
||||
def system_message(msg):
|
||||
fstr = f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}"
|
||||
if STRIP_UI:
|
||||
fstr = "{msg}"
|
||||
print(fstr.format(msg=msg))
|
||||
|
||||
|
||||
async def user_message(msg, raw=False, dump=False, debug=DEBUG):
|
||||
def user_message(msg, raw=False, dump=False, debug=DEBUG):
|
||||
def print_user_message(icon, msg, printf=print):
|
||||
if STRIP_UI:
|
||||
printf(f"{icon} {msg}")
|
||||
@@ -103,7 +103,7 @@ async def user_message(msg, raw=False, dump=False, debug=DEBUG):
|
||||
printd_user_message("🧑", msg_json)
|
||||
|
||||
|
||||
async def function_message(msg, debug=DEBUG):
|
||||
def function_message(msg, debug=DEBUG):
|
||||
def print_function_message(icon, msg, color=Fore.RED, printf=print):
|
||||
if STRIP_UI:
|
||||
printf(f"⚡{icon} [function] {msg}")
|
||||
@@ -171,7 +171,7 @@ async def function_message(msg, debug=DEBUG):
|
||||
printd_function_message("", msg)
|
||||
|
||||
|
||||
async def print_messages(message_sequence, dump=False):
|
||||
def print_messages(message_sequence, dump=False):
|
||||
idx = len(message_sequence)
|
||||
for msg in message_sequence:
|
||||
if dump:
|
||||
@@ -181,42 +181,42 @@ async def print_messages(message_sequence, dump=False):
|
||||
content = msg["content"]
|
||||
|
||||
if role == "system":
|
||||
await system_message(content)
|
||||
system_message(content)
|
||||
elif role == "assistant":
|
||||
# Differentiate between internal monologue, function calls, and messages
|
||||
if msg.get("function_call"):
|
||||
if content is not None:
|
||||
await internal_monologue(content)
|
||||
internal_monologue(content)
|
||||
# I think the next one is not up to date
|
||||
# await function_message(msg["function_call"])
|
||||
# function_message(msg["function_call"])
|
||||
args = json.loads(msg["function_call"].get("arguments"))
|
||||
await assistant_message(args.get("message"))
|
||||
assistant_message(args.get("message"))
|
||||
# assistant_message(content)
|
||||
else:
|
||||
await internal_monologue(content)
|
||||
internal_monologue(content)
|
||||
elif role == "user":
|
||||
await user_message(content, dump=dump)
|
||||
user_message(content, dump=dump)
|
||||
elif role == "function":
|
||||
await function_message(content, debug=dump)
|
||||
function_message(content, debug=dump)
|
||||
else:
|
||||
print(f"Unknown role: {content}")
|
||||
|
||||
|
||||
async def print_messages_simple(message_sequence):
|
||||
def print_messages_simple(message_sequence):
|
||||
for msg in message_sequence:
|
||||
role = msg["role"]
|
||||
content = msg["content"]
|
||||
|
||||
if role == "system":
|
||||
await system_message(content)
|
||||
system_message(content)
|
||||
elif role == "assistant":
|
||||
await assistant_message(content)
|
||||
assistant_message(content)
|
||||
elif role == "user":
|
||||
await user_message(content, raw=True)
|
||||
user_message(content, raw=True)
|
||||
else:
|
||||
print(f"Unknown role: {content}")
|
||||
|
||||
|
||||
async def print_messages_raw(message_sequence):
|
||||
def print_messages_raw(message_sequence):
|
||||
for msg in message_sequence:
|
||||
print(msg)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import shutil
|
||||
import configparser
|
||||
import uuid
|
||||
@@ -38,14 +37,13 @@ from memgpt.cli.cli_config import configure, list, add
|
||||
from memgpt.cli.cli_load import app as load_app
|
||||
from memgpt.config import Config, MemGPTConfig, AgentConfig
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
from memgpt.agent import AgentAsync
|
||||
from memgpt.agent import Agent
|
||||
from memgpt.openai_tools import (
|
||||
configure_azure_support,
|
||||
check_azure_embeddings,
|
||||
get_set_azure_env_vars,
|
||||
)
|
||||
from memgpt.connectors.storage import StorageConnector
|
||||
import asyncio
|
||||
|
||||
app = typer.Typer(pretty_exceptions_enable=False)
|
||||
app.command(name="run")(run)
|
||||
@@ -180,26 +178,23 @@ def legacy_run(
|
||||
if not questionary.confirm("Continue with legacy CLI?", default=False).ask():
|
||||
return
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(
|
||||
main(
|
||||
persona,
|
||||
human,
|
||||
model,
|
||||
first,
|
||||
debug,
|
||||
no_verify,
|
||||
archival_storage_faiss_path,
|
||||
archival_storage_files,
|
||||
archival_storage_files_compute_embeddings,
|
||||
archival_storage_sqldb,
|
||||
use_azure_openai,
|
||||
strip_ui,
|
||||
)
|
||||
main(
|
||||
persona,
|
||||
human,
|
||||
model,
|
||||
first,
|
||||
debug,
|
||||
no_verify,
|
||||
archival_storage_faiss_path,
|
||||
archival_storage_files,
|
||||
archival_storage_files_compute_embeddings,
|
||||
archival_storage_sqldb,
|
||||
use_azure_openai,
|
||||
strip_ui,
|
||||
)
|
||||
|
||||
|
||||
async def main(
|
||||
def main(
|
||||
persona,
|
||||
human,
|
||||
model,
|
||||
@@ -271,7 +266,7 @@ async def main(
|
||||
|
||||
print(persona, model, memgpt_persona)
|
||||
if archival_storage_files:
|
||||
cfg = await Config.legacy_flags_init(
|
||||
cfg = Config.legacy_flags_init(
|
||||
model,
|
||||
memgpt_persona,
|
||||
human_persona,
|
||||
@@ -280,7 +275,7 @@ async def main(
|
||||
compute_embeddings=False,
|
||||
)
|
||||
elif archival_storage_faiss_path:
|
||||
cfg = await Config.legacy_flags_init(
|
||||
cfg = Config.legacy_flags_init(
|
||||
model,
|
||||
memgpt_persona,
|
||||
human_persona,
|
||||
@@ -293,7 +288,7 @@ async def main(
|
||||
print(model)
|
||||
print(memgpt_persona)
|
||||
print(human_persona)
|
||||
cfg = await Config.legacy_flags_init(
|
||||
cfg = Config.legacy_flags_init(
|
||||
model,
|
||||
memgpt_persona,
|
||||
human_persona,
|
||||
@@ -302,7 +297,7 @@ async def main(
|
||||
compute_embeddings=True,
|
||||
)
|
||||
elif archival_storage_sqldb:
|
||||
cfg = await Config.legacy_flags_init(
|
||||
cfg = Config.legacy_flags_init(
|
||||
model,
|
||||
memgpt_persona,
|
||||
human_persona,
|
||||
@@ -311,13 +306,13 @@ async def main(
|
||||
compute_embeddings=False,
|
||||
)
|
||||
else:
|
||||
cfg = await Config.legacy_flags_init(
|
||||
cfg = Config.legacy_flags_init(
|
||||
model,
|
||||
memgpt_persona,
|
||||
human_persona,
|
||||
)
|
||||
else:
|
||||
cfg = await Config.config_init()
|
||||
cfg = Config.config_init()
|
||||
|
||||
memgpt.interface.important_message("Running... [exit by typing '/exit', list available commands with '/help']")
|
||||
if cfg.model != constants.DEFAULT_MEMGPT_MODEL:
|
||||
@@ -352,7 +347,7 @@ async def main(
|
||||
persistence_manager,
|
||||
)
|
||||
print_messages = memgpt.interface.print_messages
|
||||
await print_messages(memgpt_agent.messages)
|
||||
print_messages(memgpt_agent.messages)
|
||||
|
||||
if cfg.load_type == "sql": # TODO: move this into config.py in a clean manner
|
||||
if not os.path.exists(cfg.archival_storage_files):
|
||||
@@ -364,19 +359,19 @@ async def main(
|
||||
data_list = utils.read_database_as_list(cfg.archival_storage_files)
|
||||
user_message = f"Your archival memory has been loaded with a SQL database called {data_list[0]}, which contains schema {data_list[1]}. Remember to refer to this first while answering any user questions!"
|
||||
for row in data_list:
|
||||
await memgpt_agent.persistence_manager.archival_memory.insert(row)
|
||||
memgpt_agent.persistence_manager.archival_memory.insert(row)
|
||||
print(f"Database loaded into archival memory.")
|
||||
|
||||
if cfg.agent_save_file:
|
||||
load_save_file = await questionary.confirm(f"Load in saved agent '{cfg.agent_save_file}'?").ask_async()
|
||||
load_save_file = questionary.confirm(f"Load in saved agent '{cfg.agent_save_file}'?").ask()
|
||||
if load_save_file:
|
||||
load(memgpt_agent, cfg.agent_save_file)
|
||||
|
||||
# run agent loop
|
||||
await run_agent_loop(memgpt_agent, first, no_verify, cfg, strip_ui, legacy=True)
|
||||
run_agent_loop(memgpt_agent, first, no_verify, cfg, strip_ui, legacy=True)
|
||||
|
||||
|
||||
async def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_ui=False, legacy=False):
|
||||
def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_ui=False, legacy=False):
|
||||
counter = 0
|
||||
user_input = None
|
||||
skip_next_user_input = False
|
||||
@@ -392,11 +387,11 @@ async def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_u
|
||||
while True:
|
||||
if not skip_next_user_input and (counter > 0 or USER_GOES_FIRST):
|
||||
# Ask for user input
|
||||
user_input = await questionary.text(
|
||||
user_input = questionary.text(
|
||||
"Enter your message:",
|
||||
multiline=multiline_input,
|
||||
qmark=">",
|
||||
).ask_async()
|
||||
).ask()
|
||||
clear_line(strip_ui)
|
||||
|
||||
# Gracefully exit on Ctrl-C/D
|
||||
@@ -462,7 +457,7 @@ async def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_u
|
||||
|
||||
# TODO: check if agent already has it
|
||||
data_source_options = StorageConnector.list_loaded_data()
|
||||
data_source = await questionary.select("Select data source", choices=data_source_options).ask_async()
|
||||
data_source = questionary.select("Select data source", choices=data_source_options).ask()
|
||||
|
||||
# attach new data
|
||||
attach(memgpt_agent.config.name, data_source)
|
||||
@@ -482,13 +477,13 @@ async def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_u
|
||||
command = user_input.strip().split()
|
||||
amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0
|
||||
if amount == 0:
|
||||
await memgpt.interface.print_messages(memgpt_agent.messages, dump=True)
|
||||
memgpt.interface.print_messages(memgpt_agent.messages, dump=True)
|
||||
else:
|
||||
await memgpt.interface.print_messages(memgpt_agent.messages[-min(amount, len(memgpt_agent.messages)) :], dump=True)
|
||||
memgpt.interface.print_messages(memgpt_agent.messages[-min(amount, len(memgpt_agent.messages)) :], dump=True)
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/dumpraw":
|
||||
await memgpt.interface.print_messages_raw(memgpt_agent.messages)
|
||||
memgpt.interface.print_messages_raw(memgpt_agent.messages)
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/memory":
|
||||
@@ -554,7 +549,7 @@ async def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_u
|
||||
|
||||
# No skip options
|
||||
elif user_input.lower() == "/wipe":
|
||||
memgpt_agent = agent.AgentAsync(memgpt.interface)
|
||||
memgpt_agent = agent.Agent(memgpt.interface)
|
||||
user_message = None
|
||||
|
||||
elif user_input.lower() == "/heartbeat":
|
||||
@@ -585,8 +580,8 @@ async def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_u
|
||||
|
||||
skip_next_user_input = False
|
||||
|
||||
async def process_agent_step(user_message, no_verify):
|
||||
new_messages, heartbeat_request, function_failed, token_warning = await memgpt_agent.step(
|
||||
def process_agent_step(user_message, no_verify):
|
||||
new_messages, heartbeat_request, function_failed, token_warning = memgpt_agent.step(
|
||||
user_message, first_message=False, skip_verify=no_verify
|
||||
)
|
||||
|
||||
@@ -606,16 +601,16 @@ async def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_u
|
||||
while True:
|
||||
try:
|
||||
if strip_ui:
|
||||
new_messages, user_message, skip_next_user_input = await process_agent_step(user_message, no_verify)
|
||||
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
|
||||
break
|
||||
else:
|
||||
with console.status("[bold cyan]Thinking...") as status:
|
||||
new_messages, user_message, skip_next_user_input = await process_agent_step(user_message, no_verify)
|
||||
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
|
||||
break
|
||||
except Exception as e:
|
||||
print("An exception ocurred when running agent.step(): ")
|
||||
traceback.print_exc()
|
||||
retry = await questionary.confirm("Retry agent.step()?").ask_async()
|
||||
retry = questionary.confirm("Retry agent.step()?").ask()
|
||||
if not retry:
|
||||
break
|
||||
|
||||
@@ -639,13 +634,3 @@ USER_COMMANDS = [
|
||||
("/memorywarning", "send a memory warning system message to the agent"),
|
||||
("/attach", "attach data source to agent"),
|
||||
]
|
||||
# if __name__ == "__main__":
|
||||
#
|
||||
# app()
|
||||
# #typer.run(run)
|
||||
#
|
||||
# #def run(argv):
|
||||
# # loop = asyncio.get_event_loop()
|
||||
# # loop.run_until_complete(main())
|
||||
#
|
||||
# #app.run(run)
|
||||
|
||||
116
memgpt/memory.py
116
memgpt/memory.py
@@ -9,8 +9,6 @@ from .utils import cosine_similarity, get_local_time, printd, count_tokens
|
||||
from .prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
|
||||
from memgpt import utils
|
||||
from .openai_tools import (
|
||||
acompletions_with_backoff as acreate,
|
||||
async_get_embedding_with_backoff,
|
||||
get_embedding_with_backoff,
|
||||
completions_with_backoff as create,
|
||||
)
|
||||
@@ -148,36 +146,6 @@ def summarize_messages(
|
||||
return reply
|
||||
|
||||
|
||||
async def a_summarize_messages(
|
||||
model,
|
||||
message_sequence_to_summarize,
|
||||
):
|
||||
"""Summarize a message sequence using GPT"""
|
||||
|
||||
summary_prompt = SUMMARY_PROMPT_SYSTEM
|
||||
summary_input = str(message_sequence_to_summarize)
|
||||
summary_input_tkns = count_tokens(summary_input)
|
||||
if summary_input_tkns > MESSAGE_SUMMARY_WARNING_TOKENS:
|
||||
trunc_ratio = (MESSAGE_SUMMARY_WARNING_TOKENS / summary_input_tkns) * 0.8 # For good measure...
|
||||
cutoff = int(len(message_sequence_to_summarize) * trunc_ratio)
|
||||
summary_input = str(
|
||||
[await a_summarize_messages(model, message_sequence_to_summarize[:cutoff])] + message_sequence_to_summarize[cutoff:]
|
||||
)
|
||||
message_sequence = [
|
||||
{"role": "system", "content": summary_prompt},
|
||||
{"role": "user", "content": summary_input},
|
||||
]
|
||||
|
||||
response = await acreate(
|
||||
model=model,
|
||||
messages=message_sequence,
|
||||
)
|
||||
|
||||
printd(f"summarize_messages gpt reply: {response.choices[0]}")
|
||||
reply = response.choices[0].message.content
|
||||
return reply
|
||||
|
||||
|
||||
class ArchivalMemory(ABC):
|
||||
@abstractmethod
|
||||
def insert(self, memory_string):
|
||||
@@ -238,9 +206,6 @@ class DummyArchivalMemory(ArchivalMemory):
|
||||
}
|
||||
)
|
||||
|
||||
async def a_insert(self, memory_string):
|
||||
return self.insert(memory_string)
|
||||
|
||||
def search(self, query_string, count=None, start=None):
|
||||
"""Simple text-based search"""
|
||||
# in the dummy version, run an (inefficient) case-insensitive match search
|
||||
@@ -261,9 +226,6 @@ class DummyArchivalMemory(ArchivalMemory):
|
||||
else:
|
||||
return matches, len(matches)
|
||||
|
||||
async def a_search(self, query_string, count=None, start=None):
|
||||
return self.search(query_string, count=None, start=None)
|
||||
|
||||
|
||||
class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory):
|
||||
"""Same as dummy in-memory archival memory, but with bare-bones embedding support"""
|
||||
@@ -293,13 +255,10 @@ class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory):
|
||||
embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model)
|
||||
return self._insert(memory_string, embedding)
|
||||
|
||||
async def a_insert(self, memory_string):
|
||||
embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
|
||||
return self._insert(memory_string, embedding)
|
||||
|
||||
def _search(self, query_embedding, query_string, count, start):
|
||||
def search(self, query_string, count, start):
|
||||
"""Simple embedding-based search (inefficient, no caching)"""
|
||||
# see: https://github.com/openai/openai-cookbook/blob/main/examples/Semantic_text_search_using_embeddings.ipynb
|
||||
query_embedding = get_embedding_with_backoff(query_string, model=self.embedding_model)
|
||||
|
||||
# query_embedding = get_embedding(query_string, model=self.embedding_model)
|
||||
# our wrapped version supports backoff/rate-limits
|
||||
@@ -328,14 +287,6 @@ class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory):
|
||||
else:
|
||||
return matches, len(matches)
|
||||
|
||||
def search(self, query_string, count=None, start=None):
|
||||
query_embedding = get_embedding_with_backoff(query_string, model=self.embedding_model)
|
||||
return self._search(self, query_embedding, query_string, count, start)
|
||||
|
||||
async def a_search(self, query_string, count=None, start=None):
|
||||
query_embedding = await async_get_embedding_with_backoff(query_string, model=self.embedding_model)
|
||||
return await self._search(self, query_embedding, query_string, count, start)
|
||||
|
||||
|
||||
class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
|
||||
"""Dummy in-memory version of an archival memory database, using a FAISS
|
||||
@@ -365,9 +316,12 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
|
||||
def __len__(self):
|
||||
return len(self._archive)
|
||||
|
||||
def _insert(self, memory_string, embedding):
|
||||
def insert(self, memory_string):
|
||||
import numpy as np
|
||||
|
||||
# Get the embedding
|
||||
embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model)
|
||||
|
||||
print(f"Got an embedding, type {type(embedding)}, len {len(embedding)}")
|
||||
|
||||
self._archive.append(
|
||||
@@ -380,17 +334,7 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
|
||||
embedding = np.array([embedding]).astype("float32")
|
||||
self.index.add(embedding)
|
||||
|
||||
def insert(self, memory_string):
|
||||
# Get the embedding
|
||||
embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model)
|
||||
return self._insert(memory_string, embedding)
|
||||
|
||||
async def a_insert(self, memory_string):
|
||||
# Get the embedding
|
||||
embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
|
||||
return self._insert(memory_string, embedding)
|
||||
|
||||
def _search(self, query_embedding, query_string, count=None, start=None):
|
||||
def search(self, query_string, count=None, start=None):
|
||||
"""Simple embedding-based search (inefficient, no caching)"""
|
||||
# see: https://github.com/openai/openai-cookbook/blob/main/examples/Semantic_text_search_using_embeddings.ipynb
|
||||
|
||||
@@ -401,6 +345,7 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
|
||||
if query_string in self.embeddings_dict:
|
||||
search_result = self.search_results[query_string]
|
||||
else:
|
||||
query_embedding = get_embedding_with_backoff(query_string, model=self.embedding_model)
|
||||
_, indices = self.index.search(np.array([np.array(query_embedding, dtype=np.float32)]), self.k)
|
||||
search_result = [self._archive[idx] if idx < len(self._archive) else "" for idx in indices[0]]
|
||||
self.embeddings_dict[query_string] = query_embedding
|
||||
@@ -430,38 +375,16 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
|
||||
else:
|
||||
return matches, len(matches)
|
||||
|
||||
def search(self, query_string, count=None, start=None):
|
||||
if query_string in self.embeddings_dict:
|
||||
query_embedding = self.embeddings_dict[query_string]
|
||||
else:
|
||||
query_embedding = get_embedding_with_backoff(query_string, model=self.embedding_model)
|
||||
return self._search(query_embedding, query_string, count, start)
|
||||
|
||||
async def a_search(self, query_string, count=None, start=None):
|
||||
if query_string in self.embeddings_dict:
|
||||
query_embedding = self.embeddings_dict[query_string]
|
||||
else:
|
||||
query_embedding = await async_get_embedding_with_backoff(query_string, model=self.embedding_model)
|
||||
return self._search(query_embedding, query_string, count, start)
|
||||
|
||||
|
||||
class RecallMemory(ABC):
|
||||
@abstractmethod
|
||||
def text_search(self, query_string, count=None, start=None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def a_text_search(self, query_string, count=None, start=None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def date_search(self, query_string, count=None, start=None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def a_date_search(self, query_string, count=None, start=None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
pass
|
||||
@@ -513,7 +436,7 @@ class DummyRecallMemory(RecallMemory):
|
||||
)
|
||||
return f"\n### RECALL MEMORY ###" + f"\n{memory_str}"
|
||||
|
||||
async def insert(self, message):
|
||||
def insert(self, message):
|
||||
raise NotImplementedError("This should be handled by the PersistenceManager, recall memory is just a search layer on top")
|
||||
|
||||
def text_search(self, query_string, count=None, start=None):
|
||||
@@ -538,9 +461,6 @@ class DummyRecallMemory(RecallMemory):
|
||||
else:
|
||||
return matches, len(matches)
|
||||
|
||||
async def a_text_search(self, query_string, count=None, start=None):
|
||||
return self.text_search(query_string, count, start)
|
||||
|
||||
def _validate_date_format(self, date_str):
|
||||
"""Validate the given date string in the format 'YYYY-MM-DD'."""
|
||||
try:
|
||||
@@ -583,9 +503,6 @@ class DummyRecallMemory(RecallMemory):
|
||||
else:
|
||||
return matches, len(matches)
|
||||
|
||||
async def a_date_search(self, start_date, end_date, count=None, start=None):
|
||||
return self.date_search(start_date, end_date, count, start)
|
||||
|
||||
|
||||
class DummyRecallMemoryWithEmbeddings(DummyRecallMemory):
|
||||
"""Lazily manage embeddings by keeping a string->embed dict"""
|
||||
@@ -641,9 +558,6 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory):
|
||||
else:
|
||||
return matches, len(matches)
|
||||
|
||||
async def a_text_search(self, query_string, count=None, start=None):
|
||||
return self.text_search(query_string, count, start)
|
||||
|
||||
|
||||
class LocalArchivalMemory(ArchivalMemory):
|
||||
"""Archival memory built on top of Llama Index"""
|
||||
@@ -707,9 +621,6 @@ class LocalArchivalMemory(ArchivalMemory):
|
||||
similarity_top_k=self.top_k,
|
||||
)
|
||||
|
||||
async def a_insert(self, memory_string):
|
||||
return self.insert(memory_string)
|
||||
|
||||
def search(self, query_string, count=None, start=None):
|
||||
print("searching with local")
|
||||
if self.retriever is None:
|
||||
@@ -729,9 +640,6 @@ class LocalArchivalMemory(ArchivalMemory):
|
||||
# pprint(results)
|
||||
return results, len(results)
|
||||
|
||||
async def a_search(self, query_string, count=None, start=None):
|
||||
return self.search(query_string, count, start)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if isinstance(self.index, EmptyIndex):
|
||||
memory_str = "<empty>"
|
||||
@@ -809,12 +717,6 @@ class EmbeddingArchivalMemory(ArchivalMemory):
|
||||
print("Archival search error", e)
|
||||
raise e
|
||||
|
||||
async def a_search(self, query_string, count=None, start=None):
|
||||
return self.search(query_string, count, start)
|
||||
|
||||
async def a_insert(self, memory_string):
|
||||
return self.insert(memory_string)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
limit = 10
|
||||
passages = []
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import random
|
||||
import os
|
||||
import time
|
||||
@@ -74,89 +73,6 @@ def completions_with_backoff(**kwargs):
|
||||
return openai.ChatCompletion.create(**kwargs)
|
||||
|
||||
|
||||
def aretry_with_exponential_backoff(
|
||||
func,
|
||||
initial_delay: float = 1,
|
||||
exponential_base: float = 2,
|
||||
jitter: bool = True,
|
||||
max_retries: int = 20,
|
||||
errors: tuple = (openai.error.RateLimitError,),
|
||||
):
|
||||
"""Retry a function with exponential backoff."""
|
||||
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Initialize variables
|
||||
num_retries = 0
|
||||
delay = initial_delay
|
||||
|
||||
# Loop until a successful response or max_retries is hit or an exception is raised
|
||||
while True:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
# Retry on specified errors
|
||||
except errors as e:
|
||||
print(f"acreate (backoff): caught error: {e}")
|
||||
# Increment retries
|
||||
num_retries += 1
|
||||
|
||||
# Check if max retries has been reached
|
||||
if num_retries > max_retries:
|
||||
raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")
|
||||
|
||||
# Increment the delay
|
||||
delay *= exponential_base * (1 + jitter * random.random())
|
||||
|
||||
# Sleep for the delay
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Raise exceptions for any errors not specified
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@aretry_with_exponential_backoff
|
||||
async def acompletions_with_backoff(**kwargs):
|
||||
# Local model
|
||||
if HOST_TYPE is not None:
|
||||
return get_chat_completion(**kwargs)
|
||||
|
||||
# OpenAI / Azure model
|
||||
else:
|
||||
if using_azure():
|
||||
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
||||
if azure_openai_deployment is not None:
|
||||
kwargs["deployment_id"] = azure_openai_deployment
|
||||
else:
|
||||
kwargs["engine"] = MODEL_TO_AZURE_ENGINE[kwargs["model"]]
|
||||
kwargs.pop("model")
|
||||
return await openai.ChatCompletion.acreate(**kwargs)
|
||||
|
||||
|
||||
@aretry_with_exponential_backoff
|
||||
async def acreate_embedding_with_backoff(**kwargs):
|
||||
"""Wrapper around Embedding.acreate w/ backoff"""
|
||||
if using_azure():
|
||||
azure_openai_deployment = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT")
|
||||
if azure_openai_deployment is not None:
|
||||
kwargs["deployment_id"] = azure_openai_deployment
|
||||
else:
|
||||
kwargs["engine"] = kwargs["model"]
|
||||
kwargs.pop("model")
|
||||
return await openai.Embedding.acreate(**kwargs)
|
||||
|
||||
|
||||
async def async_get_embedding_with_backoff(text, model="text-embedding-ada-002"):
|
||||
"""To get text embeddings, import/call this function
|
||||
It specifies defaults + handles rate-limiting + is async"""
|
||||
text = text.replace("\n", " ")
|
||||
response = await acreate_embedding_with_backoff(input=[text], model=model)
|
||||
embedding = response["data"][0]["embedding"]
|
||||
return embedding
|
||||
|
||||
|
||||
@retry_with_exponential_backoff
|
||||
def create_embedding_with_backoff(**kwargs):
|
||||
if using_azure():
|
||||
|
||||
@@ -4,13 +4,11 @@ from .prompts import gpt_system
|
||||
DEFAULT_PRESET = "memgpt_chat"
|
||||
preset_options = [DEFAULT_PRESET]
|
||||
|
||||
SYNC_CHAT = "memgpt_chat_sync" # TODO: remove me after we move the CLI to AgentSync
|
||||
|
||||
|
||||
def use_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager):
|
||||
"""Storing combinations of SYSTEM + FUNCTION prompts"""
|
||||
|
||||
from memgpt.agent import AgentAsync, Agent
|
||||
from memgpt.agent import Agent
|
||||
from memgpt.utils import printd
|
||||
|
||||
if preset_name == DEFAULT_PRESET:
|
||||
@@ -28,38 +26,6 @@ def use_preset(preset_name, agent_config, model, persona, human, interface, pers
|
||||
printd(f"Available functions:\n", [x["name"] for x in available_functions])
|
||||
assert len(functions) == len(available_functions)
|
||||
|
||||
if "gpt-3.5" in model:
|
||||
# use a different system message for gpt-3.5
|
||||
preset_name = "memgpt_gpt35_extralong"
|
||||
|
||||
return AgentAsync(
|
||||
config=agent_config,
|
||||
model=model,
|
||||
system=gpt_system.get_system_text(preset_name),
|
||||
functions=available_functions,
|
||||
interface=interface,
|
||||
persistence_manager=persistence_manager,
|
||||
persona_notes=persona,
|
||||
human_notes=human,
|
||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||
first_message_verify_mono=True if "gpt-4" in model else False,
|
||||
)
|
||||
|
||||
elif preset_name == "memgpt_chat_sync": # TODO: remove me after we move the CLI to AgentSync
|
||||
functions = [
|
||||
"send_message",
|
||||
"pause_heartbeats",
|
||||
"core_memory_append",
|
||||
"core_memory_replace",
|
||||
"conversation_search",
|
||||
"conversation_search_date",
|
||||
"archival_memory_insert",
|
||||
"archival_memory_search",
|
||||
]
|
||||
available_functions = [v for k, v in gpt_functions.FUNCTIONS_CHAINING.items() if k in functions]
|
||||
printd(f"Available functions:\n", [x["name"] for x in available_functions])
|
||||
assert len(functions) == len(available_functions)
|
||||
|
||||
if "gpt-3.5" in model:
|
||||
# use a different system message for gpt-3.5
|
||||
preset_name = "memgpt_gpt35_extralong"
|
||||
@@ -67,7 +33,7 @@ def use_preset(preset_name, agent_config, model, persona, human, interface, pers
|
||||
return Agent(
|
||||
config=agent_config,
|
||||
model=model,
|
||||
system=gpt_system.get_system_text(DEFAULT_PRESET),
|
||||
system=gpt_system.get_system_text(preset_name),
|
||||
functions=available_functions,
|
||||
interface=interface,
|
||||
persistence_manager=persistence_manager,
|
||||
@@ -101,7 +67,7 @@ def use_preset(preset_name, agent_config, model, persona, human, interface, pers
|
||||
# use a different system message for gpt-3.5
|
||||
preset_name = "memgpt_gpt35_extralong"
|
||||
|
||||
return AgentAsync(
|
||||
return Agent(
|
||||
model=model,
|
||||
system=gpt_system.get_system_text("memgpt_chat"),
|
||||
functions=available_functions,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
import csv
|
||||
import difflib
|
||||
import demjson3 as demjson
|
||||
@@ -14,11 +13,13 @@ import fitz
|
||||
from tqdm import tqdm
|
||||
import typer
|
||||
import memgpt
|
||||
from memgpt.openai_tools import async_get_embedding_with_backoff
|
||||
from memgpt.openai_tools import get_embedding_with_backoff
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
from llama_index import set_global_service_context, ServiceContext, VectorStoreIndex, load_index_from_storage, StorageContext
|
||||
from llama_index.embeddings import OpenAIEmbedding
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
|
||||
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
@@ -242,38 +243,28 @@ def chunk_files_for_jsonl(files, tkns_per_chunk=300, model="gpt-4"):
|
||||
return ret
|
||||
|
||||
|
||||
async def process_chunk(i, chunk, model):
|
||||
def process_chunk(i, chunk, model):
|
||||
try:
|
||||
return i, await async_get_embedding_with_backoff(chunk["content"], model=model)
|
||||
return i, get_embedding_with_backoff(chunk["content"], model=model)
|
||||
except Exception as e:
|
||||
print(chunk)
|
||||
raise e
|
||||
|
||||
|
||||
async def process_concurrently(archival_database, model, concurrency=10):
|
||||
# Create a semaphore to limit the number of concurrent tasks
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
|
||||
async def bounded_process_chunk(i, chunk):
|
||||
async with semaphore:
|
||||
return await process_chunk(i, chunk, model)
|
||||
|
||||
# Create a list of tasks for chunks
|
||||
def process_concurrently(archival_database, model, concurrency=10):
|
||||
embedding_data = [0 for _ in archival_database]
|
||||
tasks = [bounded_process_chunk(i, chunk) for i, chunk in enumerate(archival_database)]
|
||||
|
||||
for future in tqdm(
|
||||
asyncio.as_completed(tasks),
|
||||
total=len(archival_database),
|
||||
desc="Processing file chunks",
|
||||
):
|
||||
i, result = await future
|
||||
embedding_data[i] = result
|
||||
with ThreadPoolExecutor(max_workers=concurrency) as executor:
|
||||
# Submit tasks to the executor
|
||||
future_to_chunk = {executor.submit(process_chunk, i, chunk, model): i for i, chunk in enumerate(archival_database)}
|
||||
|
||||
# As each task completes, process the results
|
||||
for future in tqdm(as_completed(future_to_chunk), total=len(archival_database), desc="Processing file chunks"):
|
||||
i, result = future.result()
|
||||
embedding_data[i] = result
|
||||
return embedding_data
|
||||
|
||||
|
||||
async def prepare_archival_index_from_files_compute_embeddings(
|
||||
def prepare_archival_index_from_files_compute_embeddings(
|
||||
glob_pattern,
|
||||
tkns_per_chunk=300,
|
||||
model="gpt-4",
|
||||
@@ -293,7 +284,7 @@ async def prepare_archival_index_from_files_compute_embeddings(
|
||||
|
||||
# chunk the files, make embeddings
|
||||
archival_database = chunk_files(files, tkns_per_chunk, model)
|
||||
embedding_data = await process_concurrently(archival_database, embeddings_model)
|
||||
embedding_data = process_concurrently(archival_database, embeddings_model)
|
||||
embeddings_file = os.path.join(save_dir, "embeddings.json")
|
||||
with open(embeddings_file, "w") as f:
|
||||
print(f"Saving embeddings to {embeddings_file}")
|
||||
|
||||
@@ -6,7 +6,7 @@ from .constants import TIMEOUT
|
||||
def configure_memgpt(enable_openai=True, enable_azure=False):
|
||||
child = pexpect.spawn("memgpt configure")
|
||||
|
||||
child.expect("Do you want to enable MemGPT with Open AI?", timeout=TIMEOUT)
|
||||
child.expect("Do you want to enable MemGPT with OpenAI?", timeout=TIMEOUT)
|
||||
if enable_openai:
|
||||
child.sendline("y")
|
||||
else:
|
||||
@@ -27,6 +27,9 @@ def configure_memgpt(enable_openai=True, enable_azure=False):
|
||||
child.expect("Select default preset:", timeout=TIMEOUT)
|
||||
child.sendline()
|
||||
|
||||
child.expect("Select default model", timeout=TIMEOUT)
|
||||
child.sendline()
|
||||
|
||||
child.expect("Select default persona:", timeout=TIMEOUT)
|
||||
child.sendline()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user