feat: refactor agent memory representation and modify routes for editing blocks (#2094)

Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
Sarah Wooders
2024-11-27 16:04:13 -08:00
committed by GitHub
parent 122faa78ea
commit 07bb536018
44 changed files with 1326 additions and 1219 deletions

View File

@@ -31,7 +31,7 @@ from letta.metadata import MetadataStore
from letta.orm import User
from letta.persistence_manager import LocalStateManager
from letta.schemas.agent import AgentState, AgentStepResponse
from letta.schemas.block import Block
from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.memory import ContextWindowOverview, Memory
@@ -235,11 +235,8 @@ class Agent(BaseAgent):
def __init__(
self,
interface: Optional[Union[AgentInterface, StreamingRefreshCLIInterface]],
# agents can be created from providing agent_state
agent_state: AgentState,
tools: List[Tool],
agent_state: AgentState, # in-memory representation of the agent state (read from multiple tables)
user: User,
# memory: Memory,
# extras
messages_total: Optional[int] = None, # TODO remove?
first_message_verify_mono: bool = True, # TODO move to config?
@@ -253,7 +250,7 @@ class Agent(BaseAgent):
self.user = user
# link tools
self.link_tools(tools)
self.link_tools(agent_state.tools)
# initialize a tool rules solver
if agent_state.tool_rules:
@@ -265,26 +262,14 @@ class Agent(BaseAgent):
# add default rule for having send_message be a terminal tool
if agent_state.tool_rules is None:
agent_state.tool_rules = []
# Define the rule to add
send_message_terminal_rule = TerminalToolRule(tool_name="send_message")
# Check if an equivalent rule is already present
if not any(
isinstance(rule, TerminalToolRule) and rule.tool_name == send_message_terminal_rule.tool_name for rule in agent_state.tool_rules
):
agent_state.tool_rules.append(send_message_terminal_rule)
self.tool_rules_solver = ToolRulesSolver(tool_rules=agent_state.tool_rules)
# gpt-4, gpt-3.5-turbo, ...
self.model = self.agent_state.llm_config.model
# Store the system instructions (used to rebuild memory)
self.system = self.agent_state.system
# Initialize the memory object
self.memory = self.agent_state.memory
assert isinstance(self.memory, Memory), f"Memory object is not of type Memory: {type(self.memory)}"
printd("Initialized memory object", self.memory.compile())
# state managers
self.block_manager = BlockManager()
# Interface must implement:
# - internal_monologue
@@ -322,8 +307,8 @@ class Agent(BaseAgent):
# Generate a sequence of initial messages to put in the buffer
init_messages = initialize_message_sequence(
model=self.model,
system=self.system,
memory=self.memory,
system=self.agent_state.system,
memory=self.agent_state.memory,
archival_memory=None,
recall_memory=None,
memory_edit_timestamp=get_utc_time(),
@@ -345,8 +330,8 @@ class Agent(BaseAgent):
# Basic "more human than human" initial message sequence
init_messages = initialize_message_sequence(
model=self.model,
system=self.system,
memory=self.memory,
system=self.agent_state.system,
memory=self.agent_state.memory,
archival_memory=None,
recall_memory=None,
memory_edit_timestamp=get_utc_time(),
@@ -380,6 +365,76 @@ class Agent(BaseAgent):
# Create the agent in the DB
self.update_state()
def update_memory_if_change(self, new_memory: Memory) -> bool:
"""
Update internal memory object and system prompt if there have been modifications.
Args:
new_memory (Memory): the new memory object to compare to the current memory object
Returns:
modified (bool): whether the memory was updated
"""
if self.agent_state.memory.compile() != new_memory.compile():
# update the blocks (LRW) in the DB
for label in self.agent_state.memory.list_block_labels():
updated_value = new_memory.get_block(label).value
if updated_value != self.agent_state.memory.get_block(label).value:
# update the block if it's changed
block_id = self.agent_state.memory.get_block(label).id
block = self.block_manager.update_block(
block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=self.user
)
# refresh memory from DB (using block ids)
self.agent_state.memory = Memory(
blocks=[self.block_manager.get_block_by_id(block.id, actor=self.user) for block in self.agent_state.memory.get_blocks()]
)
# NOTE: don't do this since re-buildin the memory is handled at the start of the step
# rebuild memory - this records the last edited timestamp of the memory
# TODO: pass in update timestamp from block edit time
self.rebuild_system_prompt()
return True
return False
def execute_tool_and_persist_state(self, function_name, function_to_call, function_args):
"""
Execute tool modifications and persist the state of the agent.
Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data
"""
# TODO: add agent manager here
orig_memory_str = self.agent_state.memory.compile()
# TODO: need to have an AgentState object that actually has full access to the block data
# this is because the sandbox tools need to be able to access block.value to edit this data
try:
if function_name in BASE_TOOLS:
# base tools are allowed to access the `Agent` object and run on the database
function_args["self"] = self # need to attach self to arg since it's dynamically linked
function_response = function_to_call(**function_args)
else:
# execute tool in a sandbox
# TODO: allow agent_state to specify which sandbox to execute tools in
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.agent_state.user_id).run(
agent_state=self.agent_state.__deepcopy__()
)
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
self.update_memory_if_change(updated_agent_state.memory)
except Exception as e:
# Need to catch error here, or else trunction wont happen
# TODO: modify to function execution error
from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT
error_msg = f"Error executing tool {function_name}: {e}"
if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]
raise ValueError(error_msg)
return function_response
@property
def messages(self) -> List[dict]:
"""Getter method that converts the internal Message list into OpenAI-style dicts"""
@@ -392,16 +447,6 @@ class Agent(BaseAgent):
def link_tools(self, tools: List[Tool]):
"""Bind a tool object (schema + python function) to the agent object"""
# tools
for tool in tools:
assert tool, f"Tool is None - must be error in querying tool from DB"
assert tool.name in self.agent_state.tools, f"Tool {tool} not found in agent_state.tools"
for tool_name in self.agent_state.tools:
assert tool_name in [tool.name for tool in tools], f"Tool name {tool_name} not included in agent tool list"
# Update tools
self.tools = tools
# Store the functions schemas (this is passed as an argument to ChatCompletion)
self.functions = []
self.functions_python = {}
@@ -416,9 +461,8 @@ class Agent(BaseAgent):
exec(tool.source_code, env)
self.functions_python[tool.json_schema["name"]] = env[tool.json_schema["name"]]
self.functions.append(tool.json_schema)
except Exception as e:
except Exception:
warnings.warn(f"WARNING: tool {tool.name} failed to link")
print(e)
assert all([callable(f) for k, f in self.functions_python.items()]), self.functions_python
def _load_messages_from_recall(self, message_ids: List[str]) -> List[Message]:
@@ -727,27 +771,10 @@ class Agent(BaseAgent):
if isinstance(function_args[name], dict):
function_args[name] = spec[name](**function_args[name])
# TODO: This needs to be rethought, how do we allow functions that modify agent state/db?
# TODO: There should probably be two types of tools: stateless/stateful
if function_name in BASE_TOOLS:
function_args["self"] = self # need to attach self to arg since it's dynamically linked
function_response = function_to_call(**function_args)
else:
# execute tool in a sandbox
# TODO: allow agent_state to specify which sandbox to execute tools in
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.agent_state.user_id).run(
agent_state=self.agent_state
)
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
# update agent state
if self.agent_state != updated_agent_state and updated_agent_state is not None:
self.agent_state = updated_agent_state
self.memory = self.agent_state.memory # TODO: don't duplicate
# rebuild memory
self.rebuild_memory()
# handle tool execution (sandbox) and state updates
function_response = self.execute_tool_and_persist_state(function_name, function_to_call, function_args)
# handle trunction
if function_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]:
# with certain functions we rely on the paging mechanism to handle overflow
truncate = False
@@ -820,7 +847,7 @@ class Agent(BaseAgent):
# rebuild memory
# TODO: @charles please check this
self.rebuild_memory()
self.rebuild_system_prompt()
# Update ToolRulesSolver state with last called function
self.tool_rules_solver.update_tool_usage(function_name)
@@ -936,17 +963,10 @@ class Agent(BaseAgent):
# Step 0: update core memory
# only pulling latest block data if shared memory is being used
# TODO: ensure we're passing in metadata store from all surfaces
if ms is not None:
should_update = False
for block in self.agent_state.memory.to_dict()["memory"].values():
if not block.get("template", False):
should_update = True
if should_update:
# TODO: the force=True can be optimized away
# once we ensure we're correctly comparing whether in-memory core
# data is different than persisted core data.
self.rebuild_memory(force=True, ms=ms)
current_persisted_memory = Memory(
blocks=[self.block_manager.get_block_by_id(block.id, actor=self.user) for block in self.agent_state.memory.get_blocks()]
) # read blocks from DB
self.update_memory_if_change(current_persisted_memory)
# Step 1: add user message
if isinstance(messages, Message):
@@ -1229,43 +1249,10 @@ class Agent(BaseAgent):
new_messages = [new_system_message_obj] + self._messages[1:] # swap index 0 (system)
self._messages = new_messages
def update_memory_blocks_from_db(self):
for block in self.memory.to_dict()["memory"].values():
if block.get("templates", False):
# we don't expect to update shared memory blocks that
# are templates. this is something we could update in the
# future if we expect templates to change often.
continue
block_id = block.get("id")
# TODO: This is really hacky and we should probably figure out how to
db_block = BlockManager().get_block_by_id(block_id=block_id, actor=self.user)
if db_block is None:
# this case covers if someone has deleted a shared block by interacting
# with some other agent.
# in that case we should remove this shared block from the agent currently being
# evaluated.
printd(f"removing block: {block_id=}")
continue
if not isinstance(db_block.value, str):
printd(f"skipping block update, unexpected value: {block_id=}")
continue
# TODO: we may want to update which columns we're updating from shared memory e.g. the limit
self.memory.update_block_value(label=block.get("label", ""), value=db_block.value)
def rebuild_memory(self, force=False, update_timestamp=True, ms: Optional[MetadataStore] = None):
def rebuild_system_prompt(self, force=False, update_timestamp=True):
"""Rebuilds the system message with the latest memory object and any shared memory block updates"""
curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt
# NOTE: This is a hacky way to check if the memory has changed
memory_repr = self.memory.compile()
if not force and memory_repr == curr_system_message["content"][-(len(memory_repr)) :]:
printd(f"Memory has not changed, not rebuilding system")
return
if ms:
self.update_memory_blocks_from_db()
# If the memory didn't update, we probably don't want to update the timestamp inside
# For example, if we're doing a system prompt swap, this should probably be False
if update_timestamp:
@@ -1276,8 +1263,8 @@ class Agent(BaseAgent):
# update memory (TODO: potentially update recall/archival stats seperately)
new_system_message_str = compile_system_message(
system_prompt=self.system,
in_context_memory=self.memory,
system_prompt=self.agent_state.system,
in_context_memory=self.agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
archival_memory=self.persistence_manager.archival_memory,
recall_memory=self.persistence_manager.recall_memory,
@@ -1304,13 +1291,13 @@ class Agent(BaseAgent):
"""Update the system prompt of the agent (requires rebuilding the memory block if there's a difference)"""
assert isinstance(new_system_prompt, str)
if new_system_prompt == self.system:
if new_system_prompt == self.agent_state.system:
return
self.system = new_system_prompt
self.agent_state.system = new_system_prompt
# updating the system prompt requires rebuilding the memory block inside the compiled system message
self.rebuild_memory(force=True, update_timestamp=False)
self.rebuild_system_prompt(force=True, update_timestamp=False)
# make sure to persist the change
_ = self.update_state()
@@ -1324,6 +1311,7 @@ class Agent(BaseAgent):
raise NotImplementedError
def update_state(self) -> AgentState:
# TODO: this should be removed and self._messages should be moved into self.agent_state.in_context_messages
message_ids = [msg.id for msg in self._messages]
# Assert that these are all strings
@@ -1331,12 +1319,8 @@ class Agent(BaseAgent):
warnings.warn(f"Non-string message IDs found in agent state: {message_ids}")
message_ids = [m_id for m_id in message_ids if isinstance(m_id, str)]
assert isinstance(self.memory, Memory), f"Memory is not a Memory object: {type(self.memory)}"
# override any fields that may have been updated
self.agent_state.message_ids = message_ids
self.agent_state.memory = self.memory
self.agent_state.system = self.system
return self.agent_state
@@ -1537,7 +1521,7 @@ class Agent(BaseAgent):
system_prompt = self.agent_state.system # TODO is this the current system or the initial system?
num_tokens_system = count_tokens(system_prompt)
core_memory = self.memory.compile()
core_memory = self.agent_state.memory.compile()
num_tokens_core_memory = count_tokens(core_memory)
# conversion of messages to OpenAI dict format, which is passed to the token counter
@@ -1629,37 +1613,15 @@ def save_agent(agent: Agent, ms: MetadataStore):
agent.update_state()
agent_state = agent.agent_state
agent_id = agent_state.id
assert isinstance(agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}"
# NOTE: we're saving agent memory before persisting the agent to ensure
# that allocated block_ids for each memory block are present in the agent model
save_agent_memory(agent=agent)
if ms.get_agent(agent_id=agent.agent_state.id):
ms.update_agent(agent_state)
# TODO: move this to agent manager
# convert to persisted model
persisted_agent_state = agent.agent_state.to_persisted_agent_state()
if ms.get_agent(agent_id=persisted_agent_state.id):
ms.update_agent(persisted_agent_state)
else:
ms.create_agent(agent_state)
agent.agent_state = ms.get_agent(agent_id=agent_id)
assert isinstance(agent.agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}"
def save_agent_memory(agent: Agent):
"""
Save agent memory to metadata store. Memory is a collection of blocks and each block is persisted to the block table.
NOTE: we are assuming agent.update_state has already been called.
"""
for block_dict in agent.memory.to_dict()["memory"].values():
# TODO: block creation should happen in one place to enforce these sort of constraints consistently.
block = Block(**block_dict)
# FIXME: should we expect for block values to be None? If not, we need to figure out why that is
# the case in some tests, if so we should relax the DB constraint.
if block.value is None:
block.value = ""
BlockManager().create_or_update_block(block, actor=agent.user)
ms.create_agent(persisted_agent_state)
def strip_name_field_from_user_message(user_message_text: str) -> Tuple[str, Optional[str]]: