import datetime import uuid import inspect import json from pathlib import Path import traceback from typing import List, Tuple, Optional, cast, Union from tqdm import tqdm from memgpt.metadata import MetadataStore from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.data_types import AgentState, Message, LLMConfig, EmbeddingConfig, Passage, Preset from memgpt.models import chat_completion_response from memgpt.interface import AgentInterface from memgpt.persistence_manager import LocalStateManager from memgpt.system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages from memgpt.memory import CoreMemory as InContextMemory, summarize_messages, ArchivalMemory, RecallMemory from memgpt.llm_api_tools import create, is_context_overflow_error from memgpt.utils import ( create_random_username, get_tool_call_id, get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff, validate_function_response, verify_first_message_correctness, create_uuid_from_string, ) from memgpt.constants import ( FIRST_MESSAGE_ATTEMPTS, JSON_LOADS_STRICT, MESSAGE_SUMMARY_WARNING_FRAC, MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC, MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST, CORE_MEMORY_HUMAN_CHAR_LIMIT, CORE_MEMORY_PERSONA_CHAR_LIMIT, LLM_MAX_TOKENS, CLI_WARNING_PREFIX, JSON_ENSURE_ASCII, ) from .errors import LLMError from .functions.functions import USER_FUNCTIONS_DIR, load_all_function_sets def link_functions(function_schemas: list): """Link function definitions to list of function schemas""" # need to dynamically link the functions # the saved agent.functions will just have the schemas, but we need to # go through the functions library and pull the respective python functions # Available functions is a mapping from: # function_name -> { # json_schema: schema # python_function: function # } # agent.functions is a list of schemas (OpenAI kwarg functions style, see: https://platform.openai.com/docs/api-reference/chat/create) # [{'name': ..., 'description': ...}, {...}] available_functions = load_all_function_sets() linked_function_set = {} for f_schema in function_schemas: # Attempt to find the function in the existing function library f_name = f_schema.get("name") if f_name is None: raise ValueError(f"While loading agent.state.functions encountered a bad function schema object with no name:\n{f_schema}") linked_function = available_functions.get(f_name) if linked_function is None: raise ValueError( f"Function '{f_name}' was specified in agent.state.functions, but is not in function library:\n{available_functions.keys()}" ) # Once we find a matching function, make sure the schema is identical if json.dumps(f_schema, ensure_ascii=JSON_ENSURE_ASCII) != json.dumps( linked_function["json_schema"], ensure_ascii=JSON_ENSURE_ASCII ): # error_message = ( # f"Found matching function '{f_name}' from agent.state.functions inside function library, but schemas are different." # + f"\n>>>agent.state.functions\n{json.dumps(f_schema, indent=2, ensure_ascii=JSON_ENSURE_ASCII)}" # + f"\n>>>function library\n{json.dumps(linked_function['json_schema'], indent=2, ensure_ascii=JSON_ENSURE_ASCII)}" # ) schema_diff = get_schema_diff(f_schema, linked_function["json_schema"]) error_message = ( f"Found matching function '{f_name}' from agent.state.functions inside function library, but schemas are different.\n" + "".join(schema_diff) ) # NOTE to handle old configs, instead of erroring here let's just warn # raise ValueError(error_message) printd(error_message) linked_function_set[f_name] = linked_function return linked_function_set def initialize_memory(ai_notes: Union[str, None], human_notes: Union[str, None]): if ai_notes is None: raise ValueError(ai_notes) if human_notes is None: raise ValueError(human_notes) memory = InContextMemory(human_char_limit=CORE_MEMORY_HUMAN_CHAR_LIMIT, persona_char_limit=CORE_MEMORY_PERSONA_CHAR_LIMIT) memory.edit_persona(ai_notes) memory.edit_human(human_notes) return memory def construct_system_with_memory( system: str, memory: InContextMemory, memory_edit_timestamp: str, archival_memory: ArchivalMemory = None, recall_memory: RecallMemory = None, include_char_count: bool = True, ): full_system_message = "\n".join( [ system, "\n", f"### Memory [last modified: {memory_edit_timestamp.strip()}]", f"{len(recall_memory) if recall_memory else 0} previous messages between you and the user are stored in recall memory (use functions to access them)", f"{len(archival_memory) if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)", "\nCore memory shown below (limited in size, additional information stored in archival / recall memory):", f'' if include_char_count else "", memory.persona, "", f'' if include_char_count else "", memory.human, "", ] ) return full_system_message def initialize_message_sequence( model: str, system: str, memory: InContextMemory, archival_memory: ArchivalMemory = None, recall_memory: RecallMemory = None, memory_edit_timestamp: str = None, include_initial_boot_message: bool = True, ): if memory_edit_timestamp is None: memory_edit_timestamp = get_local_time() full_system_message = construct_system_with_memory( system, memory, memory_edit_timestamp, archival_memory=archival_memory, recall_memory=recall_memory ) first_user_message = get_login_event() # event letting MemGPT know the user just logged in if include_initial_boot_message: if model is not None and "gpt-3.5" in model: initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35") else: initial_boot_messages = get_initial_boot_messages("startup_with_send_message") messages = ( [ {"role": "system", "content": full_system_message}, ] + initial_boot_messages + [ {"role": "user", "content": first_user_message}, ] ) else: messages = [ {"role": "system", "content": full_system_message}, {"role": "user", "content": first_user_message}, ] return messages class Agent(object): def __init__( self, interface: AgentInterface, # agents can be created from providing agent_state agent_state: Optional[AgentState] = None, # or from providing a preset (requires preset + extra fields) preset: Optional[Preset] = None, created_by: Optional[uuid.UUID] = None, name: Optional[str] = None, llm_config: Optional[LLMConfig] = None, embedding_config: Optional[EmbeddingConfig] = None, # extras messages_total: Optional[int] = None, # TODO remove? first_message_verify_mono: bool = True, # TODO move to config? ): # An agent can be created from a Preset object if preset is not None: assert agent_state is None, "Can create an agent from a Preset or AgentState (but both were provided)" assert created_by is not None, "Must provide created_by field when creating an Agent from a Preset" assert llm_config is not None, "Must provide llm_config field when creating an Agent from a Preset" assert embedding_config is not None, "Must provide embedding_config field when creating an Agent from a Preset" # if agent_state is also provided, override any preset values init_agent_state = AgentState( name=name if name else create_random_username(), user_id=created_by, persona=preset.persona, human=preset.human, llm_config=llm_config, embedding_config=embedding_config, preset=preset.name, # TODO link via preset.id instead of name? state={ "persona": preset.persona, "human": preset.human, "system": preset.system, "functions": preset.functions_schema, "messages": None, }, ) # An agent can also be created directly from AgentState elif agent_state is not None: assert preset is None, "Can create an agent from a Preset or AgentState (but both were provided)" assert agent_state.state is not None and agent_state.state != {}, "AgentState.state cannot be empty" # Assume the agent_state passed in is formatted correctly init_agent_state = agent_state else: raise ValueError("Both Preset and AgentState were null (must provide one or the other)") # Hold a copy of the state that was used to init the agent self.agent_state = init_agent_state # gpt-4, gpt-3.5-turbo, ... self.model = self.agent_state.llm_config.model # Store the system instructions (used to rebuild memory) if "system" not in self.agent_state.state: raise ValueError(f"'system' not found in provided AgentState") self.system = self.agent_state.state["system"] if "functions" not in self.agent_state.state: raise ValueError(f"'functions' not found in provided AgentState") # Store the functions schemas (this is passed as an argument to ChatCompletion) self.functions = self.agent_state.state["functions"] # these are the schema # Link the actual python functions corresponding to the schemas self.functions_python = {k: v["python_function"] for k, v in link_functions(function_schemas=self.functions).items()} assert all([callable(f) for k, f in self.functions_python.items()]), self.functions_python # Initialize the memory object if "persona" not in self.agent_state.state: raise ValueError(f"'persona' not found in provided AgentState") if "human" not in self.agent_state.state: raise ValueError(f"'human' not found in provided AgentState") self.memory = initialize_memory(ai_notes=self.agent_state.state["persona"], human_notes=self.agent_state.state["human"]) # Interface must implement: # - internal_monologue # - assistant_message # - function_message # ... # Different interfaces can handle events differently # e.g., print in CLI vs send a discord message with a discord bot self.interface = interface # Create the persistence manager object based on the AgentState info # TODO self.persistence_manager = LocalStateManager(agent_state=self.agent_state) # State needed for heartbeat pausing self.pause_heartbeats_start = None self.pause_heartbeats_minutes = 0 self.first_message_verify_mono = first_message_verify_mono # Controls if the convo memory pressure warning is triggered # When an alert is sent in the message queue, set this to True (to avoid repeat alerts) # When the summarizer is run, set this back to False (to reset) self.agent_alerted_about_memory_pressure = False self._messages: List[Message] = [] # Once the memory object is initialized, use it to "bake" the system message if "messages" in self.agent_state.state and self.agent_state.state["messages"] is not None: # print(f"Agent.__init__ :: loading, state={agent_state.state['messages']}") if not isinstance(self.agent_state.state["messages"], list): raise ValueError(f"'messages' in AgentState was bad type: {type(self.agent_state.state['messages'])}") assert all([isinstance(msg, str) for msg in self.agent_state.state["messages"]]) # Convert to IDs, and pull from the database raw_messages = [ self.persistence_manager.recall_memory.storage.get(id=uuid.UUID(msg_id)) for msg_id in self.agent_state.state["messages"] ] assert all([isinstance(msg, Message) for msg in raw_messages]), (raw_messages, self.agent_state.state["messages"]) self._messages.extend([cast(Message, msg) for msg in raw_messages if msg is not None]) else: # print(f"Agent.__init__ :: creating, state={agent_state.state['messages']}") init_messages = initialize_message_sequence( self.model, self.system, self.memory, ) init_messages_objs = [] for msg in init_messages: init_messages_objs.append( Message.dict_to_message( agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=msg ) ) assert all([isinstance(msg, Message) for msg in init_messages_objs]), (init_messages_objs, init_messages) self.messages_total = 0 self._append_to_messages(added_messages=[cast(Message, msg) for msg in init_messages_objs if msg is not None]) # Keep track of the total number of messages throughout all time self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system) # self.messages_total_init = self.messages_total self.messages_total_init = len(self._messages) - 1 printd(f"Agent initialized, self.messages_total={self.messages_total}") # Create the agent in the DB # self.save() self.update_state() @property def messages(self) -> List[dict]: """Getter method that converts the internal Message list into OpenAI-style dicts""" return [msg.to_openai_dict() for msg in self._messages] @messages.setter def messages(self, value): raise Exception("Modifying message list directly not allowed") def _trim_messages(self, num): """Trim messages from the front, not including the system message""" self.persistence_manager.trim_messages(num) new_messages = [self._messages[0]] + self._messages[num:] self._messages = new_messages def _prepend_to_messages(self, added_messages: List[Message]): """Wrapper around self.messages.prepend to allow additional calls to a state/persistence manager""" assert all([isinstance(msg, Message) for msg in added_messages]) self.persistence_manager.prepend_to_messages(added_messages) new_messages = [self._messages[0]] + added_messages + self._messages[1:] # prepend (no system) self._messages = new_messages self.messages_total += len(added_messages) # still should increment the message counter (summaries are additions too) def _append_to_messages(self, added_messages: List[Message]): """Wrapper around self.messages.append to allow additional calls to a state/persistence manager""" assert all([isinstance(msg, Message) for msg in added_messages]) self.persistence_manager.append_to_messages(added_messages) # strip extra metadata if it exists # for msg in added_messages: # msg.pop("api_response", None) # msg.pop("api_args", None) new_messages = self._messages + added_messages # append self._messages = new_messages self.messages_total += len(added_messages) def append_to_messages(self, added_messages: List[dict]): """An external-facing message append, where dict-like messages are first converted to Message objects""" added_messages_objs = [ Message.dict_to_message( agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=msg, ) for msg in added_messages ] self._append_to_messages(added_messages_objs) def _swap_system_message(self, new_system_message: Message): assert isinstance(new_system_message, Message) assert new_system_message.role == "system", new_system_message assert self._messages[0].role == "system", self._messages self.persistence_manager.swap_system_message(new_system_message) new_messages = [new_system_message] + self._messages[1:] # swap index 0 (system) self._messages = new_messages def _get_ai_reply( self, message_sequence: List[dict], function_call: str = "auto", first_message: bool = False, # hint ) -> chat_completion_response.ChatCompletionResponse: """Get response from LLM API""" try: response = create( agent_state=self.agent_state, messages=message_sequence, functions=self.functions, functions_python=self.functions_python, function_call=function_call, # hint first_message=first_message, ) # special case for 'length' if response.choices[0].finish_reason == "length": raise Exception("Finish reason was length (maximum context length)") # catches for soft errors if response.choices[0].finish_reason not in ["stop", "function_call", "tool_calls"]: raise Exception(f"API call finish with bad finish reason: {response}") # unpack with response.choices[0].message.content return response except Exception as e: raise e def _handle_ai_response( self, response_message: chat_completion_response.Message, override_tool_call_id: bool = True ) -> Tuple[List[Message], bool, bool]: """Handles parsing and function execution""" messages = [] # append these to the history when done # Step 2: check if LLM wanted to call a function if response_message.function_call or (response_message.tool_calls is not None and len(response_message.tool_calls) > 0): if response_message.function_call: raise DeprecationWarning(response_message) if response_message.tool_calls is not None and len(response_message.tool_calls) > 1: # raise NotImplementedError(f">1 tool call not supported") # TODO eventually support sequential tool calling printd(f">1 tool call not supported, using index=0 only\n{response_message.tool_calls}") response_message.tool_calls = [response_message.tool_calls[0]] assert response_message.tool_calls is not None and len(response_message.tool_calls) > 0 # The content if then internal monologue, not chat self.interface.internal_monologue(response_message.content) # generate UUID for tool call if override_tool_call_id or response_message.function_call: tool_call_id = get_tool_call_id() # needs to be a string for JSON response_message.tool_calls[0].id = tool_call_id else: tool_call_id = response_message.tool_calls[0].id assert tool_call_id is not None # should be defined # only necessary to add the tool_cal_id to a function call (antipattern) # response_message_dict = response_message.model_dump() # response_message_dict["tool_call_id"] = tool_call_id # role: assistant (requesting tool call, set tool call ID) messages.append( Message.dict_to_message( agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=response_message.model_dump(), ) ) # extend conversation with assistant's reply printd(f"Function call message: {messages[-1]}") # Step 3: call the function # Note: the JSON response may not always be valid; be sure to handle errors # Failure case 1: function name is wrong function_call = ( response_message.function_call if response_message.function_call is not None else response_message.tool_calls[0].function ) function_name = function_call.name printd(f"Request to call function {function_name} with tool_call_id: {tool_call_id}") try: function_to_call = self.functions_python[function_name] except KeyError as e: error_msg = f"No function named {function_name}" function_response = package_function_response(False, error_msg) messages.append( Message.dict_to_message( agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict={ "role": "tool", "name": function_name, "content": function_response, "tool_call_id": tool_call_id, }, ) ) # extend conversation with function response self.interface.function_message(f"Error: {error_msg}") return messages, False, True # force a heartbeat to allow agent to handle error # Failure case 2: function name is OK, but function args are bad JSON try: raw_function_args = function_call.arguments function_args = parse_json(raw_function_args) except Exception as e: error_msg = f"Error parsing JSON for function '{function_name}' arguments: {function_call.arguments}" function_response = package_function_response(False, error_msg) messages.append( Message.dict_to_message( agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict={ "role": "tool", "name": function_name, "content": function_response, "tool_call_id": tool_call_id, }, ) ) # extend conversation with function response self.interface.function_message(f"Error: {error_msg}") return messages, False, True # force a heartbeat to allow agent to handle error # (Still parsing function args) # Handle requests for immediate heartbeat heartbeat_request = function_args.pop("request_heartbeat", None) if not (isinstance(heartbeat_request, bool) or heartbeat_request is None): printd( f"{CLI_WARNING_PREFIX}'request_heartbeat' arg parsed was not a bool or None, type={type(heartbeat_request)}, value={heartbeat_request}" ) heartbeat_request = False # Failure case 3: function failed during execution self.interface.function_message(f"Running {function_name}({function_args})") try: spec = inspect.getfullargspec(function_to_call).annotations for name, arg in function_args.items(): if isinstance(function_args[name], dict): function_args[name] = spec[name](**function_args[name]) function_args["self"] = self # need to attach self to arg since it's dynamically linked function_response = function_to_call(**function_args) if function_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]: # with certain functions we rely on the paging mechanism to handle overflow truncate = False else: # but by default, we add a truncation safeguard to prevent bad functions from # overflow the agent context window truncate = True function_response_string = validate_function_response(function_response, truncate=truncate) function_args.pop("self", None) function_response = package_function_response(True, function_response_string) function_failed = False except Exception as e: function_args.pop("self", None) # error_msg = f"Error calling function {function_name} with args {function_args}: {str(e)}" # Less detailed - don't provide full args, idea is that it should be in recent context so no need (just adds noise) error_msg = f"Error calling function {function_name}: {str(e)}" error_msg_user = f"{error_msg}\n{traceback.format_exc()}" printd(error_msg_user) function_response = package_function_response(False, error_msg) messages.append( Message.dict_to_message( agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict={ "role": "tool", "name": function_name, "content": function_response, "tool_call_id": tool_call_id, }, ) ) # extend conversation with function response self.interface.function_message(f"Error: {error_msg}") return messages, False, True # force a heartbeat to allow agent to handle error # If no failures happened along the way: ... # Step 4: send the info on the function call and function response to GPT self.interface.function_message(f"Success: {function_response_string}") messages.append( Message.dict_to_message( agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict={ "role": "tool", "name": function_name, "content": function_response, "tool_call_id": tool_call_id, }, ) ) # extend conversation with function response else: # Standard non-function reply self.interface.internal_monologue(response_message.content) messages.append( Message.dict_to_message( agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=response_message.model_dump(), ) ) # extend conversation with assistant's reply heartbeat_request = False function_failed = False return messages, heartbeat_request, function_failed def step( self, user_message: Union[Message, str], # NOTE: should be json.dump(dict) first_message: bool = False, first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS, skip_verify: bool = False, ) -> Tuple[List[dict], bool, bool, bool]: """Top-level event message handler for the MemGPT agent""" try: # Step 0: add user message if user_message is not None: if isinstance(user_message, Message): user_message_text = user_message.text elif isinstance(user_message, str): user_message_text = user_message else: raise ValueError(f"Bad type for user_message: {type(user_message)}") self.interface.user_message(user_message_text) packed_user_message = {"role": "user", "content": user_message_text} # Special handling for AutoGen messages with 'name' field try: user_message_json = json.loads(user_message_text, strict=JSON_LOADS_STRICT) # Special handling for AutoGen messages with 'name' field # Treat 'name' as a special field # If it exists in the input message, elevate it to the 'message' level if "name" in user_message_json: packed_user_message["name"] = user_message_json["name"] user_message_json.pop("name", None) packed_user_message["content"] = json.dumps(user_message_json, ensure_ascii=JSON_ENSURE_ASCII) except Exception as e: print(f"{CLI_WARNING_PREFIX}handling of 'name' field failed with: {e}") # Create the associated Message object (in the database) packed_user_message_obj = Message.dict_to_message( agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=packed_user_message, ) input_message_sequence = self.messages + [packed_user_message] # Alternatively, the requestor can send an empty user message else: input_message_sequence = self.messages packed_user_message = None if len(input_message_sequence) > 1 and input_message_sequence[-1]["role"] != "user": printd(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue") # Step 1: send the conversation and available functions to GPT if not skip_verify and (first_message or self.messages_total == self.messages_total_init): printd(f"This is the first message. Running extra verifier on AI response.") counter = 0 while True: response = self._get_ai_reply( message_sequence=input_message_sequence, first_message=True, # passed through to the prompt formatter ) if verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono): break counter += 1 if counter > first_message_retry_limit: raise Exception(f"Hit first message retry limit ({first_message_retry_limit})") else: response = self._get_ai_reply( message_sequence=input_message_sequence, ) # Step 2: check if LLM wanted to call a function # (if yes) Step 3: call the function # (if yes) Step 4: send the info on the function call and function response to LLM response_message = response.choices[0].message response_message_copy = response_message.copy() all_response_messages, heartbeat_request, function_failed = self._handle_ai_response(response_message) # Add the extra metadata to the assistant response # (e.g. enough metadata to enable recreating the API call) # assert "api_response" not in all_response_messages[0] # all_response_messages[0]["api_response"] = response_message_copy # assert "api_args" not in all_response_messages[0] # all_response_messages[0]["api_args"] = { # "model": self.model, # "messages": input_message_sequence, # "functions": self.functions, # } # Step 4: extend the message history if user_message is not None: if isinstance(user_message, Message): all_new_messages = [user_message] + all_response_messages else: all_new_messages = [ Message.dict_to_message( agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=packed_user_message, ) ] + all_response_messages else: all_new_messages = all_response_messages # Check the memory pressure and potentially issue a memory pressure warning current_total_tokens = response.usage.total_tokens active_memory_warning = False # We can't do summarize logic properly if context_window is undefined if self.agent_state.llm_config.context_window is None: # Fallback if for some reason context_window is missing, just set to the default print(f"{CLI_WARNING_PREFIX}could not find context_window in config, setting to default {LLM_MAX_TOKENS['DEFAULT']}") print(f"{self.agent_state}") self.agent_state.llm_config.context_window = ( LLM_MAX_TOKENS[self.model] if (self.model is not None and self.model in LLM_MAX_TOKENS) else LLM_MAX_TOKENS["DEFAULT"] ) if current_total_tokens > MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window): printd( f"{CLI_WARNING_PREFIX}last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}" ) # Only deliver the alert if we haven't already (this period) if not self.agent_alerted_about_memory_pressure: active_memory_warning = True self.agent_alerted_about_memory_pressure = True # it's up to the outer loop to handle this else: printd( f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}" ) self._append_to_messages(all_new_messages) all_new_messages_dicts = [msg.to_openai_dict() for msg in all_new_messages] return all_new_messages_dicts, heartbeat_request, function_failed, active_memory_warning, response.usage.completion_tokens except Exception as e: printd(f"step() failed\nuser_message = {user_message}\nerror = {e}") # If we got a context alert, try trimming the messages length, then try again if is_context_overflow_error(e): # A separate API call to run a summarizer self.summarize_messages_inplace() # Try step again return self.step(user_message, first_message=first_message) else: printd(f"step() failed with an unrecognized exception: '{str(e)}'") raise e def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, disallow_tool_as_first=True): assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})" # Start at index 1 (past the system message), # and collect messages for summarization until we reach the desired truncation token fraction (eg 50%) # Do not allow truncation of the last N messages, since these are needed for in-context examples of function calling token_counts = [count_tokens(str(msg)) for msg in self.messages] message_buffer_token_count = sum(token_counts[1:]) # no system message desired_token_count_to_summarize = int(message_buffer_token_count * MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC) candidate_messages_to_summarize = self.messages[1:] token_counts = token_counts[1:] if preserve_last_N_messages: candidate_messages_to_summarize = candidate_messages_to_summarize[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST] token_counts = token_counts[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST] # if disallow_tool_as_first: # # We have to make sure that a "tool" call is not sitting at the front (after system message), # # otherwise we'll get an error from OpenAI (if using the OpenAI API) # while len(candidate_messages_to_summarize) > 0: # if candidate_messages_to_summarize[0]["role"] in ["tool", "function"]: # candidate_messages_to_summarize.pop(0) # else: # break printd(f"MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC={MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC}") printd(f"MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}") printd(f"token_counts={token_counts}") printd(f"message_buffer_token_count={message_buffer_token_count}") printd(f"desired_token_count_to_summarize={desired_token_count_to_summarize}") printd(f"len(candidate_messages_to_summarize)={len(candidate_messages_to_summarize)}") # If at this point there's nothing to summarize, throw an error if len(candidate_messages_to_summarize) == 0: raise LLMError( f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(self.messages)}, preserve_N={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}]" ) # Walk down the message buffer (front-to-back) until we hit the target token count tokens_so_far = 0 cutoff = 0 for i, msg in enumerate(candidate_messages_to_summarize): cutoff = i tokens_so_far += token_counts[i] if tokens_so_far > desired_token_count_to_summarize: break # Account for system message cutoff += 1 # Try to make an assistant message come after the cutoff try: printd(f"Selected cutoff {cutoff} was a 'user', shifting one...") if self.messages[cutoff]["role"] == "user": new_cutoff = cutoff + 1 if self.messages[new_cutoff]["role"] == "user": printd(f"Shifted cutoff {new_cutoff} is still a 'user', ignoring...") cutoff = new_cutoff except IndexError: pass # Make sure the cutoff isn't on a 'tool' or 'function' if disallow_tool_as_first: while self.messages[cutoff]["role"] in ["tool", "function"] and cutoff < len(self.messages): printd(f"Selected cutoff {cutoff} was a 'tool', shifting one...") cutoff += 1 message_sequence_to_summarize = self.messages[1:cutoff] # do NOT get rid of the system message if len(message_sequence_to_summarize) <= 1: # This prevents a potential infinite loop of summarizing the same message over and over raise LLMError( f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(message_sequence_to_summarize)} <= 1]" ) else: printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self.messages)}") # We can't do summarize logic properly if context_window is undefined if self.agent_state.llm_config.context_window is None: # Fallback if for some reason context_window is missing, just set to the default print(f"{CLI_WARNING_PREFIX}could not find context_window in config, setting to default {LLM_MAX_TOKENS['DEFAULT']}") print(f"{self.agent_state}") self.agent_state.llm_config.context_window = ( LLM_MAX_TOKENS[self.model] if (self.model is not None and self.model in LLM_MAX_TOKENS) else LLM_MAX_TOKENS["DEFAULT"] ) summary = summarize_messages(agent_state=self.agent_state, message_sequence_to_summarize=message_sequence_to_summarize) printd(f"Got summary: {summary}") # Metadata that's useful for the agent to see all_time_message_count = self.messages_total remaining_message_count = len(self.messages[cutoff:]) hidden_message_count = all_time_message_count - remaining_message_count summary_message_count = len(message_sequence_to_summarize) summary_message = package_summarize_message(summary, summary_message_count, hidden_message_count, all_time_message_count) printd(f"Packaged into message: {summary_message}") prior_len = len(self.messages) self._trim_messages(cutoff) packed_summary_message = {"role": "user", "content": summary_message} self._prepend_to_messages( [ Message.dict_to_message( agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=packed_summary_message, ) ] ) # reset alert self.agent_alerted_about_memory_pressure = False printd(f"Ran summarizer, messages length {prior_len} -> {len(self.messages)}") def heartbeat_is_paused(self): """Check if there's a requested pause on timed heartbeats""" # Check if the pause has been initiated if self.pause_heartbeats_start is None: return False # Check if it's been more than pause_heartbeats_minutes since pause_heartbeats_start elapsed_time = datetime.datetime.now() - self.pause_heartbeats_start return elapsed_time.total_seconds() < self.pause_heartbeats_minutes * 60 def rebuild_memory(self): """Rebuilds the system message with the latest memory object""" curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt new_system_message = initialize_message_sequence( self.model, self.system, self.memory, archival_memory=self.persistence_manager.archival_memory, recall_memory=self.persistence_manager.recall_memory, )[0] diff = united_diff(curr_system_message["content"], new_system_message["content"]) printd(f"Rebuilding system with new memory...\nDiff:\n{diff}") # Swap the system message out self._swap_system_message( Message.dict_to_message( agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=new_system_message ) ) # def to_agent_state(self) -> AgentState: # # The state may have change since the last time we wrote it # updated_state = { # "persona": self.memory.persona, # "human": self.memory.human, # "system": self.system, # "functions": self.functions, # "messages": [str(msg.id) for msg in self._messages], # } # agent_state = AgentState( # name=self.agent_state.name, # user_id=self.agent_state.user_id, # persona=self.agent_state.persona, # human=self.agent_state.human, # llm_config=self.agent_state.llm_config, # embedding_config=self.agent_state.embedding_config, # preset=self.agent_state.preset, # id=self.agent_state.id, # created_at=self.agent_state.created_at, # state=updated_state, # ) # return agent_state def add_function(self, function_name: str) -> str: if function_name in self.functions_python.keys(): msg = f"Function {function_name} already loaded" printd(msg) return msg available_functions = load_all_function_sets() if function_name not in available_functions.keys(): raise ValueError(f"Function {function_name} not found in function library") self.functions.append(available_functions[function_name]["json_schema"]) self.functions_python[function_name] = available_functions[function_name]["python_function"] msg = f"Added function {function_name}" # self.save() self.update_state() printd(msg) return msg def remove_function(self, function_name: str) -> str: if function_name not in self.functions_python.keys(): msg = f"Function {function_name} not loaded, ignoring" printd(msg) return msg # only allow removal of user defined functions user_func_path = Path(USER_FUNCTIONS_DIR) func_path = Path(inspect.getfile(self.functions_python[function_name])) is_subpath = func_path.resolve().parts[: len(user_func_path.resolve().parts)] == user_func_path.resolve().parts if not is_subpath: raise ValueError(f"Function {function_name} is not user defined and cannot be removed") self.functions = [f_schema for f_schema in self.functions if f_schema["name"] != function_name] self.functions_python.pop(function_name) msg = f"Removed function {function_name}" # self.save() self.update_state() printd(msg) return msg # def save(self): # """Save agent state locally""" # new_agent_state = self.to_agent_state() # # without this, even after Agent.__init__, agent.config.state["messages"] will be None # self.agent_state = new_agent_state # # Check if we need to create the agent # if not self.ms.get_agent(agent_id=new_agent_state.id, user_id=new_agent_state.user_id, agent_name=new_agent_state.name): # # print(f"Agent.save {new_agent_state.id} :: agent does not exist, creating...") # self.ms.create_agent(agent=new_agent_state) # # Otherwise, we should update the agent # else: # # print(f"Agent.save {new_agent_state.id} :: agent already exists, updating...") # print(f"Agent.save {new_agent_state.id} :: preupdate:\n\tmessages={new_agent_state.state['messages']}") # self.ms.update_agent(agent=new_agent_state) def update_state(self) -> AgentState: updated_state = { "persona": self.memory.persona, "human": self.memory.human, "system": self.system, "functions": self.functions, "messages": [str(msg.id) for msg in self._messages], } self.agent_state = AgentState( name=self.agent_state.name, user_id=self.agent_state.user_id, persona=self.agent_state.persona, human=self.agent_state.human, llm_config=self.agent_state.llm_config, embedding_config=self.agent_state.embedding_config, preset=self.agent_state.preset, id=self.agent_state.id, created_at=self.agent_state.created_at, state=updated_state, ) return self.agent_state def migrate_embedding(self, embedding_config: EmbeddingConfig): """Migrate the agent to a new embedding""" # TODO: archival memory # TODO: recall memory raise NotImplementedError() def attach_source(self, source_name, source_connector: StorageConnector, ms: MetadataStore): """Attach data with name `source_name` to the agent from source_connector.""" # TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory filters = {"user_id": self.agent_state.user_id, "data_source": source_name} size = source_connector.size(filters) # typer.secho(f"Ingesting {size} passages into {agent.name}", fg=typer.colors.GREEN) page_size = 100 generator = source_connector.get_all_paginated(filters=filters, page_size=page_size) # yields List[Passage] all_passages = [] for i in tqdm(range(0, size, page_size)): passages = next(generator) # need to associated passage with agent (for filtering) for passage in passages: assert isinstance(passage, Passage), f"Generate yielded bad non-Passage type: {type(passage)}" passage.agent_id = self.agent_state.id # regenerate passage ID (avoid duplicates) passage.id = create_uuid_from_string(f"{source_name}_{str(passage.agent_id)}_{passage.text}") # insert into agent archival memory self.persistence_manager.archival_memory.storage.insert_many(passages) all_passages += passages assert size == len(all_passages), f"Expected {size} passages, but only got {len(all_passages)}" # save destination storage self.persistence_manager.archival_memory.storage.save() # attach to agent source = ms.get_source(source_name=source_name, user_id=self.agent_state.user_id) assert source is not None, f"source does not exist for source_name={source_name}, user_id={self.agent_state.user_id}" source_id = source.id ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id) total_agent_passages = self.persistence_manager.archival_memory.storage.size() printd( f"Attached data source {source_name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.", ) def save_agent(agent: Agent, ms: MetadataStore): """Save agent to metadata store""" agent.update_state() agent_state = agent.agent_state if ms.get_agent(agent_id=agent_state.id): ms.update_agent(agent_state) else: ms.create_agent(agent_state)