Files
letta-server/memgpt/agent.py
2023-11-09 09:09:57 -08:00

1285 lines
59 KiB
Python

import asyncio
import inspect
import datetime
import glob
import pickle
import math
import os
import requests
import json
import threading
import traceback
import openai
from memgpt.persistence_manager import LocalStateManager
from memgpt.config import AgentConfig
from .system import get_heartbeat, get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
from .memory import CoreMemory as Memory, summarize_messages, a_summarize_messages
from .openai_tools import acompletions_with_backoff as acreate, completions_with_backoff as create
from .utils import get_local_time, parse_json, united_diff, printd, count_tokens
from .constants import (
MEMGPT_DIR,
FIRST_MESSAGE_ATTEMPTS,
MAX_PAUSE_HEARTBEATS,
MESSAGE_CHATGPT_FUNCTION_MODEL,
MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE,
MESSAGE_SUMMARY_WARNING_TOKENS,
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
CORE_MEMORY_HUMAN_CHAR_LIMIT,
CORE_MEMORY_PERSONA_CHAR_LIMIT,
)
from .errors import LLMError
def initialize_memory(ai_notes, human_notes):
if ai_notes is None:
raise ValueError(ai_notes)
if human_notes is None:
raise ValueError(human_notes)
memory = Memory(human_char_limit=CORE_MEMORY_HUMAN_CHAR_LIMIT, persona_char_limit=CORE_MEMORY_PERSONA_CHAR_LIMIT)
memory.edit_persona(ai_notes)
memory.edit_human(human_notes)
return memory
def construct_system_with_memory(system, memory, memory_edit_timestamp, archival_memory=None, recall_memory=None):
full_system_message = "\n".join(
[
system,
"\n",
f"### Memory [last modified: {memory_edit_timestamp}]",
f"{len(recall_memory) if recall_memory else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
f"{len(archival_memory) if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)",
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
"<persona>",
memory.persona,
"</persona>",
"<human>",
memory.human,
"</human>",
]
)
return full_system_message
def initialize_message_sequence(
model,
system,
memory,
archival_memory=None,
recall_memory=None,
memory_edit_timestamp=None,
include_initial_boot_message=True,
):
if memory_edit_timestamp is None:
memory_edit_timestamp = get_local_time()
full_system_message = construct_system_with_memory(
system, memory, memory_edit_timestamp, archival_memory=archival_memory, recall_memory=recall_memory
)
first_user_message = get_login_event() # event letting MemGPT know the user just logged in
if include_initial_boot_message:
if "gpt-3.5" in model:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35")
else:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message")
messages = (
[
{"role": "system", "content": full_system_message},
]
+ initial_boot_messages
+ [
{"role": "user", "content": first_user_message},
]
)
else:
messages = [
{"role": "system", "content": full_system_message},
{"role": "user", "content": first_user_message},
]
return messages
def get_ai_reply(
model,
message_sequence,
functions,
function_call="auto",
):
try:
response = create(
model=model,
messages=message_sequence,
functions=functions,
function_call=function_call,
)
# special case for 'length'
if response.choices[0].finish_reason == "length":
raise Exception("Finish reason was length (maximum context length)")
# catches for soft errors
if response.choices[0].finish_reason not in ["stop", "function_call"]:
raise Exception(f"API call finish with bad finish reason: {response}")
# unpack with response.choices[0].message.content
return response
except Exception as e:
raise e
async def get_ai_reply_async(
model,
message_sequence,
functions,
function_call="auto",
):
"""Base call to GPT API w/ functions"""
try:
response = await acreate(
model=model,
messages=message_sequence,
functions=functions,
function_call=function_call,
)
# special case for 'length'
if response.choices[0].finish_reason == "length":
raise Exception("Finish reason was length (maximum context length)")
# catches for soft errors
if response.choices[0].finish_reason not in ["stop", "function_call"]:
raise Exception(f"API call finish with bad finish reason: {response}")
# unpack with response.choices[0].message.content
return response
except Exception as e:
raise e
# Assuming function_to_call is either sync or async
async def call_function(function_to_call, **function_args):
if inspect.iscoroutinefunction(function_to_call):
return await function_to_call(**function_args)
else:
return function_to_call(**function_args)
class Agent(object):
def __init__(
self,
config,
model,
system,
functions,
interface,
persistence_manager,
persona_notes,
human_notes,
messages_total=None,
persistence_manager_init=True,
first_message_verify_mono=True,
):
# agent config
self.config = config
# gpt-4, gpt-3.5-turbo
self.model = model
# Store the system instructions (used to rebuild memory)
self.system = system
# Store the functions spec
self.functions = functions
# Initialize the memory object
self.memory = initialize_memory(persona_notes, human_notes)
# Once the memory object is initialize, use it to "bake" the system message
self._messages = initialize_message_sequence(
self.model,
self.system,
self.memory,
)
# Keep track of the total number of messages throughout all time
self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system)
# self.messages_total_init = self.messages_total
self.messages_total_init = len(self._messages) - 1
printd(f"AgentAsync initialized, self.messages_total={self.messages_total}")
# Interface must implement:
# - internal_monologue
# - assistant_message
# - function_message
# ...
# Different interfaces can handle events differently
# e.g., print in CLI vs send a discord message with a discord bot
self.interface = interface
# Persistence manager must implement:
# - set_messages
# - get_messages
# - append_to_messages
self.persistence_manager = persistence_manager
if persistence_manager_init:
# creates a new agent object in the database
self.persistence_manager.init(self)
# State needed for heartbeat pausing
self.pause_heartbeats_start = None
self.pause_heartbeats_minutes = 0
self.first_message_verify_mono = first_message_verify_mono
# Controls if the convo memory pressure warning is triggered
# When an alert is sent in the message queue, set this to True (to avoid repeat alerts)
# When the summarizer is run, set this back to False (to reset)
self.agent_alerted_about_memory_pressure = False
self.init_avail_functions()
def init_avail_functions(self):
"""
Allows subclasses to overwrite this dictionary with overriden methods.
"""
self.available_functions = {
# These functions aren't all visible to the LLM
# To see what functions the LLM sees, check self.functions
"send_message": self.send_ai_message,
"edit_memory": self.edit_memory,
"edit_memory_append": self.edit_memory_append,
"edit_memory_replace": self.edit_memory_replace,
"pause_heartbeats": self.pause_heartbeats,
"core_memory_append": self.edit_memory_append,
"core_memory_replace": self.edit_memory_replace,
"recall_memory_search": self.recall_memory_search,
"recall_memory_search_date": self.recall_memory_search_date,
"conversation_search": self.recall_memory_search,
"conversation_search_date": self.recall_memory_search_date,
"archival_memory_insert": self.archival_memory_insert,
"archival_memory_search": self.archival_memory_search,
# extras
"read_from_text_file": self.read_from_text_file,
"append_to_text_file": self.append_to_text_file,
"http_request": self.http_request,
}
@property
def messages(self):
return self._messages
@messages.setter
def messages(self, value):
raise Exception("Modifying message list directly not allowed")
def trim_messages(self, num):
"""Trim messages from the front, not including the system message"""
self.persistence_manager.trim_messages(num)
new_messages = [self.messages[0]] + self.messages[num:]
self._messages = new_messages
def prepend_to_messages(self, added_messages):
"""Wrapper around self.messages.prepend to allow additional calls to a state/persistence manager"""
self.persistence_manager.prepend_to_messages(added_messages)
new_messages = [self.messages[0]] + added_messages + self.messages[1:] # prepend (no system)
self._messages = new_messages
self.messages_total += len(added_messages) # still should increment the message counter (summaries are additions too)
def append_to_messages(self, added_messages):
"""Wrapper around self.messages.append to allow additional calls to a state/persistence manager"""
self.persistence_manager.append_to_messages(added_messages)
# strip extra metadata if it exists
for msg in added_messages:
msg.pop("api_response", None)
msg.pop("api_args", None)
new_messages = self.messages + added_messages # append
self._messages = new_messages
self.messages_total += len(added_messages)
def swap_system_message(self, new_system_message):
assert new_system_message["role"] == "system", new_system_message
assert self.messages[0]["role"] == "system", self.messages
self.persistence_manager.swap_system_message(new_system_message)
new_messages = [new_system_message] + self.messages[1:] # swap index 0 (system)
self._messages = new_messages
def rebuild_memory(self):
"""Rebuilds the system message with the latest memory object"""
curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt
new_system_message = initialize_message_sequence(
self.model,
self.system,
self.memory,
archival_memory=self.persistence_manager.archival_memory,
recall_memory=self.persistence_manager.recall_memory,
)[0]
diff = united_diff(curr_system_message["content"], new_system_message["content"])
printd(f"Rebuilding system with new memory...\nDiff:\n{diff}")
# Store the memory change (if stateful)
self.persistence_manager.update_memory(self.memory)
# Swap the system message out
self.swap_system_message(new_system_message)
### Local state management
def to_dict(self):
return {
"model": self.model,
"system": self.system,
"functions": self.functions,
"messages": self.messages,
"messages_total": self.messages_total,
"memory": self.memory.to_dict(),
}
def save_to_json_file(self, filename):
with open(filename, "w") as file:
json.dump(self.to_dict(), file)
def save(self):
"""Save agent state locally"""
timestamp = get_local_time().replace(" ", "_").replace(":", "_")
agent_name = self.config.name # TODO: fix
# save agent state
filename = f"{timestamp}.json"
os.makedirs(self.config.save_state_dir(), exist_ok=True)
self.save_to_json_file(os.path.join(self.config.save_state_dir(), filename))
# save the persistence manager too
filename = f"{timestamp}.persistence.pickle"
os.makedirs(self.config.save_persistence_manager_dir(), exist_ok=True)
self.persistence_manager.save(os.path.join(self.config.save_persistence_manager_dir(), filename))
@classmethod
def load_agent(cls, interface, agent_config: AgentConfig):
"""Load saved agent state"""
# TODO: support loading from specific file
agent_name = agent_config.name
# load state
directory = agent_config.save_state_dir()
json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory.
if not json_files:
print(f"/load error: no .json checkpoint files found")
raise ValueError(f"Cannot load {agent_name}")
# Sort files based on modified timestamp, with the latest file being the first.
filename = max(json_files, key=os.path.getmtime)
state = json.load(open(filename, "r"))
# load persistence manager
filename = os.path.basename(filename).replace(".json", ".persistence.pickle")
directory = agent_config.save_persistence_manager_dir()
printd(f"Loading persistence manager from {os.path.join(directory, filename)}")
persistence_manager = LocalStateManager.load(os.path.join(directory, filename), agent_config)
messages = state["messages"]
agent = cls(
config=agent_config,
model=state["model"],
system=state["system"],
functions=state["functions"],
interface=interface,
persistence_manager=persistence_manager,
persistence_manager_init=False,
persona_notes=state["memory"]["persona"],
human_notes=state["memory"]["human"],
messages_total=state["messages_total"] if "messages_total" in state else len(messages) - 1,
)
agent._messages = messages
agent.memory = initialize_memory(state["memory"]["persona"], state["memory"]["human"])
return agent
@classmethod
def load(cls, state, interface, persistence_manager):
model = state["model"]
system = state["system"]
functions = state["functions"]
messages = state["messages"]
try:
messages_total = state["messages_total"]
except KeyError:
messages_total = len(messages) - 1
# memory requires a nested load
memory_dict = state["memory"]
persona_notes = memory_dict["persona"]
human_notes = memory_dict["human"]
# Two-part load
new_agent = cls(
model=model,
system=system,
functions=functions,
interface=interface,
persistence_manager=persistence_manager,
persistence_manager_init=False,
persona_notes=persona_notes,
human_notes=human_notes,
messages_total=messages_total,
)
new_agent._messages = messages
return new_agent
def load_inplace(self, state):
self.model = state["model"]
self.system = state["system"]
self.functions = state["functions"]
# memory requires a nested load
memory_dict = state["memory"]
persona_notes = memory_dict["persona"]
human_notes = memory_dict["human"]
self.memory = initialize_memory(persona_notes, human_notes)
# messages also
self._messages = state["messages"]
try:
self.messages_total = state["messages_total"]
except KeyError:
self.messages_total = len(self.messages) - 1 # -system
@classmethod
def load_from_json(cls, json_state, interface, persistence_manager):
state = json.loads(json_state)
return cls.load(state, interface, persistence_manager)
@classmethod
def load_from_json_file(cls, json_file, interface, persistence_manager):
with open(json_file, "r") as file:
state = json.load(file)
return cls.load(state, interface, persistence_manager)
def load_from_json_file_inplace(self, json_file):
# Load in-place
# No interface arg needed, we can use the current one
with open(json_file, "r") as file:
state = json.load(file)
self.load_inplace(state)
def verify_first_message_correctness(self, response, require_send_message=True, require_monologue=False):
"""Can be used to enforce that the first message always uses send_message"""
response_message = response.choices[0].message
# First message should be a call to send_message with a non-empty content
if require_send_message and not response_message.get("function_call"):
printd(f"First message didn't include function call: {response_message}")
return False
function_name = response_message["function_call"]["name"]
if require_send_message and function_name != "send_message" and function_name != "archival_memory_search":
printd(f"First message function call wasn't send_message or archival_memory_search: {response_message}")
return False
if require_monologue and (
not response_message.get("content") or response_message["content"] is None or response_message["content"] == ""
):
printd(f"First message missing internal monologue: {response_message}")
return False
if response_message.get("content"):
### Extras
monologue = response_message.get("content")
def contains_special_characters(s):
special_characters = '(){}[]"'
return any(char in s for char in special_characters)
if contains_special_characters(monologue):
printd(f"First message internal monologue contained special characters: {response_message}")
return False
# if 'functions' in monologue or 'send_message' in monologue or 'inner thought' in monologue.lower():
if "functions" in monologue or "send_message" in monologue:
# Sometimes the syntax won't be correct and internal syntax will leak into message.context
printd(f"First message internal monologue contained reserved words: {response_message}")
return False
return True
def handle_ai_response(self, response_message):
"""Handles parsing and function execution"""
messages = [] # append these to the history when done
# Step 2: check if LLM wanted to call a function
if response_message.get("function_call"):
# The content if then internal monologue, not chat
self.interface.internal_monologue(response_message.content)
messages.append(response_message) # extend conversation with assistant's reply
# Step 3: call the function
# Note: the JSON response may not always be valid; be sure to handle errors
# Failure case 1: function name is wrong
function_name = response_message["function_call"]["name"]
try:
function_to_call = self.available_functions[function_name]
except KeyError as e:
error_msg = f"No function named {function_name}"
function_response = package_function_response(False, error_msg)
messages.append(
{
"role": "function",
"name": function_name,
"content": function_response,
}
) # extend conversation with function response
self.interface.function_message(f"Error: {error_msg}")
return messages, None, True # force a heartbeat to allow agent to handle error
# Failure case 2: function name is OK, but function args are bad JSON
try:
raw_function_args = response_message["function_call"]["arguments"]
function_args = parse_json(raw_function_args)
except Exception as e:
error_msg = f"Error parsing JSON for function '{function_name}' arguments: {raw_function_args}"
function_response = package_function_response(False, error_msg)
messages.append(
{
"role": "function",
"name": function_name,
"content": function_response,
}
) # extend conversation with function response
self.interface.function_message(f"Error: {error_msg}")
return messages, None, True # force a heartbeat to allow agent to handle error
# (Still parsing function args)
# Handle requests for immediate heartbeat
heartbeat_request = function_args.pop("request_heartbeat", None)
if not (isinstance(heartbeat_request, bool) or heartbeat_request is None):
printd(
f"Warning: 'request_heartbeat' arg parsed was not a bool or None, type={type(heartbeat_request)}, value={heartbeat_request}"
)
heartbeat_request = None
# Failure case 3: function failed during execution
self.interface.function_message(f"Running {function_name}({function_args})")
try:
function_response_string = function_to_call(**function_args)
function_response = package_function_response(True, function_response_string)
function_failed = False
except Exception as e:
error_msg = f"Error calling function {function_name} with args {function_args}: {str(e)}"
error_msg_user = f"{error_msg}\n{traceback.format_exc()}"
printd(error_msg_user)
function_response = package_function_response(False, error_msg)
messages.append(
{
"role": "function",
"name": function_name,
"content": function_response,
}
) # extend conversation with function response
self.interface.function_message(f"Error: {error_msg}")
return messages, None, True # force a heartbeat to allow agent to handle error
# If no failures happened along the way: ...
# Step 4: send the info on the function call and function response to GPT
self.interface.function_message(f"Success: {function_response_string}")
messages.append(
{
"role": "function",
"name": function_name,
"content": function_response,
}
) # extend conversation with function response
else:
# Standard non-function reply
self.interface.internal_monologue(response_message.content)
messages.append(response_message) # extend conversation with assistant's reply
heartbeat_request = None
function_failed = None
return messages, heartbeat_request, function_failed
def step(self, user_message, first_message=False, first_message_retry_limit=FIRST_MESSAGE_ATTEMPTS, skip_verify=False):
"""Top-level event message handler for the MemGPT agent"""
try:
# Step 0: add user message
if user_message is not None:
self.interface.user_message(user_message)
packed_user_message = {"role": "user", "content": user_message}
input_message_sequence = self.messages + [packed_user_message]
else:
input_message_sequence = self.messages
if len(input_message_sequence) > 1 and input_message_sequence[-1]["role"] != "user":
printd(f"WARNING: attempting to run ChatCompletion without user as the last message in the queue")
# Step 1: send the conversation and available functions to GPT
if not skip_verify and (first_message or self.messages_total == self.messages_total_init):
printd(f"This is the first message. Running extra verifier on AI response.")
counter = 0
while True:
response = get_ai_reply(model=self.model, message_sequence=input_message_sequence, functions=self.functions)
if self.verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono):
break
counter += 1
if counter > first_message_retry_limit:
raise Exception(f"Hit first message retry limit ({first_message_retry_limit})")
else:
response = get_ai_reply(model=self.model, message_sequence=input_message_sequence, functions=self.functions)
# Step 2: check if LLM wanted to call a function
# (if yes) Step 3: call the function
# (if yes) Step 4: send the info on the function call and function response to LLM
response_message = response.choices[0].message
response_message_copy = response_message.copy()
all_response_messages, heartbeat_request, function_failed = self.handle_ai_response(response_message)
# Add the extra metadata to the assistant response
# (e.g. enough metadata to enable recreating the API call)
assert "api_response" not in all_response_messages[0]
all_response_messages[0]["api_response"] = response_message_copy
assert "api_args" not in all_response_messages[0]
all_response_messages[0]["api_args"] = {
"model": self.model,
"messages": input_message_sequence,
"functions": self.functions,
}
# Step 4: extend the message history
if user_message is not None:
all_new_messages = [packed_user_message] + all_response_messages
else:
all_new_messages = all_response_messages
# Check the memory pressure and potentially issue a memory pressure warning
current_total_tokens = response["usage"]["total_tokens"]
active_memory_warning = False
if current_total_tokens > MESSAGE_SUMMARY_WARNING_TOKENS:
printd(f"WARNING: last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_TOKENS}")
# Only deliver the alert if we haven't already (this period)
if not self.agent_alerted_about_memory_pressure:
active_memory_warning = True
self.agent_alerted_about_memory_pressure = True # it's up to the outer loop to handle this
else:
printd(f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_TOKENS}")
self.append_to_messages(all_new_messages)
return all_new_messages, heartbeat_request, function_failed, active_memory_warning
except Exception as e:
printd(f"step() failed\nuser_message = {user_message}\nerror = {e}")
# If we got a context alert, try trimming the messages length, then try again
if "maximum context length" in str(e):
# A separate API call to run a summarizer
self.summarize_messages_inplace()
# Try step again
return self.step(user_message, first_message=first_message)
else:
printd(f"step() failed with openai.InvalidRequestError, but didn't recognize the error message: '{str(e)}'")
raise e
def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True):
assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})"
# Start at index 1 (past the system message),
# and collect messages for summarization until we reach the desired truncation token fraction (eg 50%)
# Do not allow truncation of the last N messages, since these are needed for in-context examples of function calling
token_counts = [count_tokens(str(msg)) for msg in self.messages]
message_buffer_token_count = sum(token_counts[1:]) # no system message
desired_token_count_to_summarize = int(message_buffer_token_count * MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC)
candidate_messages_to_summarize = self.messages[1:]
token_counts = token_counts[1:]
if preserve_last_N_messages:
candidate_messages_to_summarize = candidate_messages_to_summarize[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST]
token_counts = token_counts[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST]
printd(f"MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC={MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC}")
printd(f"MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}")
printd(f"token_counts={token_counts}")
printd(f"message_buffer_token_count={message_buffer_token_count}")
printd(f"desired_token_count_to_summarize={desired_token_count_to_summarize}")
printd(f"len(candidate_messages_to_summarize)={len(candidate_messages_to_summarize)}")
# If at this point there's nothing to summarize, throw an error
if len(candidate_messages_to_summarize) == 0:
raise LLMError(
f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(self.messages)}, preserve_N={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}]"
)
# Walk down the message buffer (front-to-back) until we hit the target token count
tokens_so_far = 0
cutoff = 0
for i, msg in enumerate(candidate_messages_to_summarize):
cutoff = i
tokens_so_far += token_counts[i]
if tokens_so_far > desired_token_count_to_summarize:
break
# Account for system message
cutoff += 1
# Try to make an assistant message come after the cutoff
try:
printd(f"Selected cutoff {cutoff} was a 'user', shifting one...")
if self.messages[cutoff]["role"] == "user":
new_cutoff = cutoff + 1
if self.messages[new_cutoff]["role"] == "user":
printd(f"Shifted cutoff {new_cutoff} is still a 'user', ignoring...")
cutoff = new_cutoff
except IndexError:
pass
message_sequence_to_summarize = self.messages[1:cutoff] # do NOT get rid of the system message
printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self.messages)}")
summary = summarize_messages(self.model, message_sequence_to_summarize)
printd(f"Got summary: {summary}")
# Metadata that's useful for the agent to see
all_time_message_count = self.messages_total
remaining_message_count = len(self.messages[cutoff:])
hidden_message_count = all_time_message_count - remaining_message_count
summary_message_count = len(message_sequence_to_summarize)
summary_message = package_summarize_message(summary, summary_message_count, hidden_message_count, all_time_message_count)
printd(f"Packaged into message: {summary_message}")
prior_len = len(self.messages)
self.trim_messages(cutoff)
packed_summary_message = {"role": "user", "content": summary_message}
self.prepend_to_messages([packed_summary_message])
# reset alert
self.agent_alerted_about_memory_pressure = False
printd(f"Ran summarizer, messages length {prior_len} -> {len(self.messages)}")
def send_ai_message(self, message):
"""AI wanted to send a message"""
self.interface.assistant_message(message)
return None
def edit_memory(self, name, content):
"""Edit memory.name <= content"""
new_len = self.memory.edit(name, content)
self.rebuild_memory()
return None
def edit_memory_append(self, name, content):
new_len = self.memory.edit_append(name, content)
self.rebuild_memory()
return None
def edit_memory_replace(self, name, old_content, new_content):
new_len = self.memory.edit_replace(name, old_content, new_content)
self.rebuild_memory()
return None
def recall_memory_search(self, query, count=5, page=0):
results, total = self.persistence_manager.recall_memory.text_search(query, count=count, start=page * count)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."
else:
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results]
results_str = f"{results_pref} {json.dumps(results_formatted)}"
return results_str
def recall_memory_search_date(self, start_date, end_date, count=5, page=0):
results, total = self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page * count)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."
else:
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results]
results_str = f"{results_pref} {json.dumps(results_formatted)}"
return results_str
def archival_memory_insert(self, content):
self.persistence_manager.archival_memory.insert(content)
return None
def archival_memory_search(self, query, count=5, page=0):
results, total = self.persistence_manager.archival_memory.search(query, count=count, start=page * count)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."
else:
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
results_formatted = [f"timestamp: {d['timestamp']}, memory: {d['content']}" for d in results]
results_str = f"{results_pref} {json.dumps(results_formatted)}"
return results_str
def message_chatgpt(self, message):
"""Base call to GPT API w/ functions"""
message_sequence = [
{"role": "system", "content": MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE},
{"role": "user", "content": str(message)},
]
response = create(
model=MESSAGE_CHATGPT_FUNCTION_MODEL,
messages=message_sequence,
# functions=functions,
# function_call=function_call,
)
reply = response.choices[0].message.content
return reply
def read_from_text_file(self, filename, line_start, num_lines=1, max_chars=500, trunc_message=True):
if not os.path.exists(filename):
raise FileNotFoundError(f"The file '{filename}' does not exist.")
if line_start < 1 or num_lines < 1:
raise ValueError("Both line_start and num_lines must be positive integers.")
lines = []
chars_read = 0
with open(filename, "r") as file:
for current_line_number, line in enumerate(file, start=1):
if line_start <= current_line_number < line_start + num_lines:
chars_to_add = len(line)
if max_chars is not None and chars_read + chars_to_add > max_chars:
# If adding this line exceeds MAX_CHARS, truncate the line if needed and stop reading further.
excess_chars = (chars_read + chars_to_add) - max_chars
lines.append(line[:-excess_chars].rstrip("\n"))
if trunc_message:
lines.append(f"[SYSTEM ALERT - max chars ({max_chars}) reached during file read]")
break
else:
lines.append(line.rstrip("\n"))
chars_read += chars_to_add
if current_line_number >= line_start + num_lines - 1:
break
return "\n".join(lines)
def append_to_text_file(self, filename, content):
if not os.path.exists(filename):
raise FileNotFoundError(f"The file '{filename}' does not exist.")
with open(filename, "a") as file:
file.write(content + "\n")
def http_request(self, method, url, payload_json=None):
"""
Makes an HTTP request based on the specified method, URL, and JSON payload.
Args:
method (str): The HTTP method (e.g., 'GET', 'POST').
url (str): The URL for the request.
payload_json (str): A JSON string representing the request payload.
Returns:
dict: The response from the HTTP request.
"""
try:
headers = {"Content-Type": "application/json"}
# For GET requests, ignore the payload
if method.upper() == "GET":
print(f"[HTTP] launching GET request to {url}")
response = requests.get(url, headers=headers)
else:
# Validate and convert the payload for other types of requests
if payload_json:
payload = json.loads(payload_json)
else:
payload = {}
print(f"[HTTP] launching {method} request to {url}, payload=\n{json.dumps(payload, indent=2)}")
response = requests.request(method, url, json=payload, headers=headers)
return {"status_code": response.status_code, "headers": dict(response.headers), "body": response.text}
except Exception as e:
return {"error": str(e)}
def pause_heartbeats(self, minutes, max_pause=MAX_PAUSE_HEARTBEATS):
"""Pause timed heartbeats for N minutes"""
minutes = min(max_pause, minutes)
# Record the current time
self.pause_heartbeats_start = datetime.datetime.now()
# And record how long the pause should go for
self.pause_heartbeats_minutes = int(minutes)
return f"Pausing timed heartbeats for {minutes} min"
def heartbeat_is_paused(self):
"""Check if there's a requested pause on timed heartbeats"""
# Check if the pause has been initiated
if self.pause_heartbeats_start is None:
return False
# Check if it's been more than pause_heartbeats_minutes since pause_heartbeats_start
elapsed_time = datetime.datetime.now() - self.pause_heartbeats_start
return elapsed_time.total_seconds() < self.pause_heartbeats_minutes * 60
class AgentAsync(Agent):
"""Core logic for an async MemGPT agent"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.init_avail_functions()
async def handle_ai_response(self, response_message):
"""Handles parsing and function execution"""
messages = [] # append these to the history when done
# Step 2: check if LLM wanted to call a function
if response_message.get("function_call"):
# The content if then internal monologue, not chat
await self.interface.internal_monologue(response_message.content)
messages.append(response_message) # extend conversation with assistant's reply
# Step 3: call the function
# Note: the JSON response may not always be valid; be sure to handle errors
# Failure case 1: function name is wrong
function_name = response_message["function_call"]["name"]
try:
function_to_call = self.available_functions[function_name]
except KeyError as e:
error_msg = f"No function named {function_name}"
function_response = package_function_response(False, error_msg)
messages.append(
{
"role": "function",
"name": function_name,
"content": function_response,
}
) # extend conversation with function response
await self.interface.function_message(f"Error: {error_msg}")
return messages, None, True # force a heartbeat to allow agent to handle error
# Failure case 2: function name is OK, but function args are bad JSON
try:
raw_function_args = response_message["function_call"]["arguments"]
function_args = parse_json(raw_function_args)
except Exception as e:
error_msg = f"Error parsing JSON for function '{function_name}' arguments: {raw_function_args}"
function_response = package_function_response(False, error_msg)
messages.append(
{
"role": "function",
"name": function_name,
"content": function_response,
}
) # extend conversation with function response
await self.interface.function_message(f"Error: {error_msg}")
return messages, None, True # force a heartbeat to allow agent to handle error
# (Still parsing function args)
# Handle requests for immediate heartbeat
heartbeat_request = function_args.pop("request_heartbeat", None)
if not (isinstance(heartbeat_request, bool) or heartbeat_request is None):
printd(
f"Warning: 'request_heartbeat' arg parsed was not a bool or None, type={type(heartbeat_request)}, value={heartbeat_request}"
)
heartbeat_request = None
# Failure case 3: function failed during execution
await self.interface.function_message(f"Running {function_name}({function_args})")
try:
function_response_string = await call_function(function_to_call, **function_args)
function_response = package_function_response(True, function_response_string)
function_failed = False
except Exception as e:
error_msg = f"Error calling function {function_name} with args {function_args}: {str(e)}"
error_msg_user = f"{error_msg}\n{traceback.format_exc()}"
printd(error_msg_user)
function_response = package_function_response(False, error_msg)
messages.append(
{
"role": "function",
"name": function_name,
"content": function_response,
}
) # extend conversation with function response
await self.interface.function_message(f"Error: {error_msg}")
return messages, None, True # force a heartbeat to allow agent to handle error
# If no failures happened along the way: ...
# Step 4: send the info on the function call and function response to GPT
if function_response_string:
await self.interface.function_message(f"Success: {function_response_string}")
else:
await self.interface.function_message(f"Success")
messages.append(
{
"role": "function",
"name": function_name,
"content": function_response,
}
) # extend conversation with function response
else:
# Standard non-function reply
await self.interface.internal_monologue(response_message.content)
messages.append(response_message) # extend conversation with assistant's reply
heartbeat_request = None
function_failed = None
return messages, heartbeat_request, function_failed
async def step(self, user_message, first_message=False, first_message_retry_limit=FIRST_MESSAGE_ATTEMPTS, skip_verify=False):
"""Top-level event message handler for the MemGPT agent"""
try:
# Step 0: add user message
if user_message is not None:
await self.interface.user_message(user_message)
packed_user_message = {"role": "user", "content": user_message}
input_message_sequence = self.messages + [packed_user_message]
else:
input_message_sequence = self.messages
if len(input_message_sequence) > 1 and input_message_sequence[-1]["role"] != "user":
printd(f"WARNING: attempting to run ChatCompletion without user as the last message in the queue")
from pprint import pprint
pprint(input_message_sequence[-1])
# Step 1: send the conversation and available functions to GPT
if not skip_verify and (first_message or self.messages_total == self.messages_total_init):
printd(f"This is the first message. Running extra verifier on AI response.")
counter = 0
while True:
response = await get_ai_reply_async(model=self.model, message_sequence=input_message_sequence, functions=self.functions)
if self.verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono):
break
counter += 1
if counter > first_message_retry_limit:
raise Exception(f"Hit first message retry limit ({first_message_retry_limit})")
else:
response = await get_ai_reply_async(model=self.model, message_sequence=input_message_sequence, functions=self.functions)
# Step 2: check if LLM wanted to call a function
# (if yes) Step 3: call the function
# (if yes) Step 4: send the info on the function call and function response to LLM
response_message = response.choices[0].message
response_message_copy = response_message.copy()
all_response_messages, heartbeat_request, function_failed = await self.handle_ai_response(response_message)
# Add the extra metadata to the assistant response
# (e.g. enough metadata to enable recreating the API call)
assert "api_response" not in all_response_messages[0], f"api_response already in {all_response_messages[0]}"
all_response_messages[0]["api_response"] = response_message_copy
assert "api_args" not in all_response_messages[0], f"api_args already in {all_response_messages[0]}"
all_response_messages[0]["api_args"] = {
"model": self.model,
"messages": input_message_sequence,
"functions": self.functions,
}
# Step 4: extend the message history
if user_message is not None:
all_new_messages = [packed_user_message] + all_response_messages
else:
all_new_messages = all_response_messages
# Check the memory pressure and potentially issue a memory pressure warning
current_total_tokens = response["usage"]["total_tokens"]
active_memory_warning = False
if current_total_tokens > MESSAGE_SUMMARY_WARNING_TOKENS:
printd(f"WARNING: last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_TOKENS}")
# Only deliver the alert if we haven't already (this period)
if not self.agent_alerted_about_memory_pressure:
active_memory_warning = True
self.agent_alerted_about_memory_pressure = True # it's up to the outer loop to handle this
else:
printd(f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_TOKENS}")
self.append_to_messages(all_new_messages)
return all_new_messages, heartbeat_request, function_failed, active_memory_warning
except Exception as e:
printd(f"step() failed\nuser_message = {user_message}\nerror = {e}")
print(f"step() failed\nuser_message = {user_message}\nerror = {e}")
# If we got a context alert, try trimming the messages length, then try again
if "maximum context length" in str(e):
# A separate API call to run a summarizer
await self.summarize_messages_inplace()
# Try step again
return await self.step(user_message, first_message=first_message)
else:
printd(f"step() failed with openai.InvalidRequestError, but didn't recognize the error message: '{str(e)}'")
print(e)
raise e
async def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True):
assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})"
# Start at index 1 (past the system message),
# and collect messages for summarization until we reach the desired truncation token fraction (eg 50%)
# Do not allow truncation of the last N messages, since these are needed for in-context examples of function calling
token_counts = [count_tokens(str(msg)) for msg in self.messages]
message_buffer_token_count = sum(token_counts[1:]) # no system message
token_counts = token_counts[1:]
desired_token_count_to_summarize = int(message_buffer_token_count * MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC)
candidate_messages_to_summarize = self.messages[1:]
if preserve_last_N_messages:
candidate_messages_to_summarize = candidate_messages_to_summarize[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST]
token_counts = token_counts[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST]
printd(f"MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC={MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC}")
printd(f"MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}")
printd(f"token_counts={token_counts}")
printd(f"message_buffer_token_count={message_buffer_token_count}")
printd(f"desired_token_count_to_summarize={desired_token_count_to_summarize}")
printd(f"len(candidate_messages_to_summarize)={len(candidate_messages_to_summarize)}")
# If at this point there's nothing to summarize, throw an error
if len(candidate_messages_to_summarize) == 0:
raise LLMError(
f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(self.messages)}, preserve_N={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}]"
)
# Walk down the message buffer (front-to-back) until we hit the target token count
tokens_so_far = 0
cutoff = 0
for i, msg in enumerate(candidate_messages_to_summarize):
cutoff = i
tokens_so_far += token_counts[i]
if tokens_so_far > desired_token_count_to_summarize:
break
# Account for system message
cutoff += 1
# Try to make an assistant message come after the cutoff
try:
printd(f"Selected cutoff {cutoff} was a 'user', shifting one...")
if self.messages[cutoff]["role"] == "user":
new_cutoff = cutoff + 1
if self.messages[new_cutoff]["role"] == "user":
printd(f"Shifted cutoff {new_cutoff} is still a 'user', ignoring...")
cutoff = new_cutoff
except IndexError:
pass
message_sequence_to_summarize = self.messages[1:cutoff] # do NOT get rid of the system message
if len(message_sequence_to_summarize) == 0:
printd(f"message_sequence_to_summarize is len 0, skipping summarize")
raise LLMError(
f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(self.messages)}, cutoff={cutoff}]"
)
printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self.messages)}")
summary = await a_summarize_messages(self.model, message_sequence_to_summarize)
printd(f"Got summary: {summary}")
# Metadata that's useful for the agent to see
all_time_message_count = self.messages_total
remaining_message_count = len(self.messages[cutoff:])
hidden_message_count = all_time_message_count - remaining_message_count
summary_message_count = len(message_sequence_to_summarize)
summary_message = package_summarize_message(summary, summary_message_count, hidden_message_count, all_time_message_count)
printd(f"Packaged into message: {summary_message}")
prior_len = len(self.messages)
self.trim_messages(cutoff)
packed_summary_message = {"role": "user", "content": summary_message}
self.prepend_to_messages([packed_summary_message])
# reset alert
self.agent_alerted_about_memory_pressure = False
printd(f"Ran summarizer, messages length {prior_len} -> {len(self.messages)}")
async def free_step(self, user_message, limit=None):
"""Allow agent to manage its own control flow (past a single LLM call).
Not currently used, instead this is handled in the CLI main.py logic
"""
new_messages, heartbeat_request, function_failed = self.step(user_message)
step_count = 1
while limit is None or step_count < limit:
if function_failed:
user_message = get_heartbeat("Function call failed")
new_messages, heartbeat_request, function_failed = await self.step(user_message)
step_count += 1
elif heartbeat_request:
user_message = get_heartbeat("AI requested")
new_messages, heartbeat_request, function_failed = await self.step(user_message)
step_count += 1
else:
break
return new_messages, heartbeat_request, function_failed
### Functions / tools the agent can use
# All functions should return a response string (or None)
# If the function fails, throw an exception
async def send_ai_message(self, message):
"""AI wanted to send a message"""
await self.interface.assistant_message(message)
return None
async def recall_memory_search(self, query, count=5, page=0):
results, total = await self.persistence_manager.recall_memory.a_text_search(query, count=count, start=page * count)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."
else:
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results]
results_str = f"{results_pref} {json.dumps(results_formatted)}"
return results_str
async def recall_memory_search_date(self, start_date, end_date, count=5, page=0):
results, total = await self.persistence_manager.recall_memory.a_date_search(start_date, end_date, count=count, start=page * count)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."
else:
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results]
results_str = f"{results_pref} {json.dumps(results_formatted)}"
return results_str
async def archival_memory_insert(self, content):
await self.persistence_manager.archival_memory.a_insert(content)
return None
async def archival_memory_search(self, query, count=5, page=0):
results, total = await self.persistence_manager.archival_memory.a_search(query, count=count, start=page * count)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."
else:
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
results_formatted = [f"timestamp: {d['timestamp']}, memory: {d['content']}" for d in results]
results_str = f"{results_pref} {json.dumps(results_formatted)}"
return results_str
async def message_chatgpt(self, message):
"""Base call to GPT API w/ functions"""
message_sequence = [
{"role": "system", "content": MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE},
{"role": "user", "content": str(message)},
]
response = await acreate(
model=MESSAGE_CHATGPT_FUNCTION_MODEL,
messages=message_sequence,
# functions=functions,
# function_call=function_call,
)
reply = response.choices[0].message.content
return reply