feat: refactor agent memory representation and modify routes for editing blocks (#2094)
Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
242
letta/agent.py
242
letta/agent.py
@@ -31,7 +31,7 @@ from letta.metadata import MetadataStore
|
||||
from letta.orm import User
|
||||
from letta.persistence_manager import LocalStateManager
|
||||
from letta.schemas.agent import AgentState, AgentStepResponse
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.memory import ContextWindowOverview, Memory
|
||||
@@ -235,11 +235,8 @@ class Agent(BaseAgent):
|
||||
def __init__(
|
||||
self,
|
||||
interface: Optional[Union[AgentInterface, StreamingRefreshCLIInterface]],
|
||||
# agents can be created from providing agent_state
|
||||
agent_state: AgentState,
|
||||
tools: List[Tool],
|
||||
agent_state: AgentState, # in-memory representation of the agent state (read from multiple tables)
|
||||
user: User,
|
||||
# memory: Memory,
|
||||
# extras
|
||||
messages_total: Optional[int] = None, # TODO remove?
|
||||
first_message_verify_mono: bool = True, # TODO move to config?
|
||||
@@ -253,7 +250,7 @@ class Agent(BaseAgent):
|
||||
self.user = user
|
||||
|
||||
# link tools
|
||||
self.link_tools(tools)
|
||||
self.link_tools(agent_state.tools)
|
||||
|
||||
# initialize a tool rules solver
|
||||
if agent_state.tool_rules:
|
||||
@@ -265,26 +262,14 @@ class Agent(BaseAgent):
|
||||
# add default rule for having send_message be a terminal tool
|
||||
if agent_state.tool_rules is None:
|
||||
agent_state.tool_rules = []
|
||||
# Define the rule to add
|
||||
send_message_terminal_rule = TerminalToolRule(tool_name="send_message")
|
||||
# Check if an equivalent rule is already present
|
||||
if not any(
|
||||
isinstance(rule, TerminalToolRule) and rule.tool_name == send_message_terminal_rule.tool_name for rule in agent_state.tool_rules
|
||||
):
|
||||
agent_state.tool_rules.append(send_message_terminal_rule)
|
||||
|
||||
self.tool_rules_solver = ToolRulesSolver(tool_rules=agent_state.tool_rules)
|
||||
|
||||
# gpt-4, gpt-3.5-turbo, ...
|
||||
self.model = self.agent_state.llm_config.model
|
||||
|
||||
# Store the system instructions (used to rebuild memory)
|
||||
self.system = self.agent_state.system
|
||||
|
||||
# Initialize the memory object
|
||||
self.memory = self.agent_state.memory
|
||||
assert isinstance(self.memory, Memory), f"Memory object is not of type Memory: {type(self.memory)}"
|
||||
printd("Initialized memory object", self.memory.compile())
|
||||
# state managers
|
||||
self.block_manager = BlockManager()
|
||||
|
||||
# Interface must implement:
|
||||
# - internal_monologue
|
||||
@@ -322,8 +307,8 @@ class Agent(BaseAgent):
|
||||
# Generate a sequence of initial messages to put in the buffer
|
||||
init_messages = initialize_message_sequence(
|
||||
model=self.model,
|
||||
system=self.system,
|
||||
memory=self.memory,
|
||||
system=self.agent_state.system,
|
||||
memory=self.agent_state.memory,
|
||||
archival_memory=None,
|
||||
recall_memory=None,
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
@@ -345,8 +330,8 @@ class Agent(BaseAgent):
|
||||
# Basic "more human than human" initial message sequence
|
||||
init_messages = initialize_message_sequence(
|
||||
model=self.model,
|
||||
system=self.system,
|
||||
memory=self.memory,
|
||||
system=self.agent_state.system,
|
||||
memory=self.agent_state.memory,
|
||||
archival_memory=None,
|
||||
recall_memory=None,
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
@@ -380,6 +365,76 @@ class Agent(BaseAgent):
|
||||
# Create the agent in the DB
|
||||
self.update_state()
|
||||
|
||||
def update_memory_if_change(self, new_memory: Memory) -> bool:
|
||||
"""
|
||||
Update internal memory object and system prompt if there have been modifications.
|
||||
|
||||
Args:
|
||||
new_memory (Memory): the new memory object to compare to the current memory object
|
||||
|
||||
Returns:
|
||||
modified (bool): whether the memory was updated
|
||||
"""
|
||||
if self.agent_state.memory.compile() != new_memory.compile():
|
||||
# update the blocks (LRW) in the DB
|
||||
for label in self.agent_state.memory.list_block_labels():
|
||||
updated_value = new_memory.get_block(label).value
|
||||
if updated_value != self.agent_state.memory.get_block(label).value:
|
||||
# update the block if it's changed
|
||||
block_id = self.agent_state.memory.get_block(label).id
|
||||
block = self.block_manager.update_block(
|
||||
block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=self.user
|
||||
)
|
||||
|
||||
# refresh memory from DB (using block ids)
|
||||
self.agent_state.memory = Memory(
|
||||
blocks=[self.block_manager.get_block_by_id(block.id, actor=self.user) for block in self.agent_state.memory.get_blocks()]
|
||||
)
|
||||
|
||||
# NOTE: don't do this since re-buildin the memory is handled at the start of the step
|
||||
# rebuild memory - this records the last edited timestamp of the memory
|
||||
# TODO: pass in update timestamp from block edit time
|
||||
self.rebuild_system_prompt()
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
def execute_tool_and_persist_state(self, function_name, function_to_call, function_args):
|
||||
"""
|
||||
Execute tool modifications and persist the state of the agent.
|
||||
Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data
|
||||
"""
|
||||
# TODO: add agent manager here
|
||||
orig_memory_str = self.agent_state.memory.compile()
|
||||
|
||||
# TODO: need to have an AgentState object that actually has full access to the block data
|
||||
# this is because the sandbox tools need to be able to access block.value to edit this data
|
||||
try:
|
||||
if function_name in BASE_TOOLS:
|
||||
# base tools are allowed to access the `Agent` object and run on the database
|
||||
function_args["self"] = self # need to attach self to arg since it's dynamically linked
|
||||
function_response = function_to_call(**function_args)
|
||||
else:
|
||||
# execute tool in a sandbox
|
||||
# TODO: allow agent_state to specify which sandbox to execute tools in
|
||||
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.agent_state.user_id).run(
|
||||
agent_state=self.agent_state.__deepcopy__()
|
||||
)
|
||||
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
|
||||
assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
|
||||
self.update_memory_if_change(updated_agent_state.memory)
|
||||
except Exception as e:
|
||||
# Need to catch error here, or else trunction wont happen
|
||||
# TODO: modify to function execution error
|
||||
from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT
|
||||
|
||||
error_msg = f"Error executing tool {function_name}: {e}"
|
||||
if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
|
||||
error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]
|
||||
raise ValueError(error_msg)
|
||||
|
||||
return function_response
|
||||
|
||||
@property
|
||||
def messages(self) -> List[dict]:
|
||||
"""Getter method that converts the internal Message list into OpenAI-style dicts"""
|
||||
@@ -392,16 +447,6 @@ class Agent(BaseAgent):
|
||||
def link_tools(self, tools: List[Tool]):
|
||||
"""Bind a tool object (schema + python function) to the agent object"""
|
||||
|
||||
# tools
|
||||
for tool in tools:
|
||||
assert tool, f"Tool is None - must be error in querying tool from DB"
|
||||
assert tool.name in self.agent_state.tools, f"Tool {tool} not found in agent_state.tools"
|
||||
for tool_name in self.agent_state.tools:
|
||||
assert tool_name in [tool.name for tool in tools], f"Tool name {tool_name} not included in agent tool list"
|
||||
|
||||
# Update tools
|
||||
self.tools = tools
|
||||
|
||||
# Store the functions schemas (this is passed as an argument to ChatCompletion)
|
||||
self.functions = []
|
||||
self.functions_python = {}
|
||||
@@ -416,9 +461,8 @@ class Agent(BaseAgent):
|
||||
exec(tool.source_code, env)
|
||||
self.functions_python[tool.json_schema["name"]] = env[tool.json_schema["name"]]
|
||||
self.functions.append(tool.json_schema)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
warnings.warn(f"WARNING: tool {tool.name} failed to link")
|
||||
print(e)
|
||||
assert all([callable(f) for k, f in self.functions_python.items()]), self.functions_python
|
||||
|
||||
def _load_messages_from_recall(self, message_ids: List[str]) -> List[Message]:
|
||||
@@ -727,27 +771,10 @@ class Agent(BaseAgent):
|
||||
if isinstance(function_args[name], dict):
|
||||
function_args[name] = spec[name](**function_args[name])
|
||||
|
||||
# TODO: This needs to be rethought, how do we allow functions that modify agent state/db?
|
||||
# TODO: There should probably be two types of tools: stateless/stateful
|
||||
|
||||
if function_name in BASE_TOOLS:
|
||||
function_args["self"] = self # need to attach self to arg since it's dynamically linked
|
||||
function_response = function_to_call(**function_args)
|
||||
else:
|
||||
# execute tool in a sandbox
|
||||
# TODO: allow agent_state to specify which sandbox to execute tools in
|
||||
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.agent_state.user_id).run(
|
||||
agent_state=self.agent_state
|
||||
)
|
||||
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
|
||||
# update agent state
|
||||
if self.agent_state != updated_agent_state and updated_agent_state is not None:
|
||||
self.agent_state = updated_agent_state
|
||||
self.memory = self.agent_state.memory # TODO: don't duplicate
|
||||
|
||||
# rebuild memory
|
||||
self.rebuild_memory()
|
||||
# handle tool execution (sandbox) and state updates
|
||||
function_response = self.execute_tool_and_persist_state(function_name, function_to_call, function_args)
|
||||
|
||||
# handle trunction
|
||||
if function_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]:
|
||||
# with certain functions we rely on the paging mechanism to handle overflow
|
||||
truncate = False
|
||||
@@ -820,7 +847,7 @@ class Agent(BaseAgent):
|
||||
|
||||
# rebuild memory
|
||||
# TODO: @charles please check this
|
||||
self.rebuild_memory()
|
||||
self.rebuild_system_prompt()
|
||||
|
||||
# Update ToolRulesSolver state with last called function
|
||||
self.tool_rules_solver.update_tool_usage(function_name)
|
||||
@@ -936,17 +963,10 @@ class Agent(BaseAgent):
|
||||
|
||||
# Step 0: update core memory
|
||||
# only pulling latest block data if shared memory is being used
|
||||
# TODO: ensure we're passing in metadata store from all surfaces
|
||||
if ms is not None:
|
||||
should_update = False
|
||||
for block in self.agent_state.memory.to_dict()["memory"].values():
|
||||
if not block.get("template", False):
|
||||
should_update = True
|
||||
if should_update:
|
||||
# TODO: the force=True can be optimized away
|
||||
# once we ensure we're correctly comparing whether in-memory core
|
||||
# data is different than persisted core data.
|
||||
self.rebuild_memory(force=True, ms=ms)
|
||||
current_persisted_memory = Memory(
|
||||
blocks=[self.block_manager.get_block_by_id(block.id, actor=self.user) for block in self.agent_state.memory.get_blocks()]
|
||||
) # read blocks from DB
|
||||
self.update_memory_if_change(current_persisted_memory)
|
||||
|
||||
# Step 1: add user message
|
||||
if isinstance(messages, Message):
|
||||
@@ -1229,43 +1249,10 @@ class Agent(BaseAgent):
|
||||
new_messages = [new_system_message_obj] + self._messages[1:] # swap index 0 (system)
|
||||
self._messages = new_messages
|
||||
|
||||
def update_memory_blocks_from_db(self):
|
||||
for block in self.memory.to_dict()["memory"].values():
|
||||
if block.get("templates", False):
|
||||
# we don't expect to update shared memory blocks that
|
||||
# are templates. this is something we could update in the
|
||||
# future if we expect templates to change often.
|
||||
continue
|
||||
block_id = block.get("id")
|
||||
|
||||
# TODO: This is really hacky and we should probably figure out how to
|
||||
db_block = BlockManager().get_block_by_id(block_id=block_id, actor=self.user)
|
||||
if db_block is None:
|
||||
# this case covers if someone has deleted a shared block by interacting
|
||||
# with some other agent.
|
||||
# in that case we should remove this shared block from the agent currently being
|
||||
# evaluated.
|
||||
printd(f"removing block: {block_id=}")
|
||||
continue
|
||||
if not isinstance(db_block.value, str):
|
||||
printd(f"skipping block update, unexpected value: {block_id=}")
|
||||
continue
|
||||
# TODO: we may want to update which columns we're updating from shared memory e.g. the limit
|
||||
self.memory.update_block_value(label=block.get("label", ""), value=db_block.value)
|
||||
|
||||
def rebuild_memory(self, force=False, update_timestamp=True, ms: Optional[MetadataStore] = None):
|
||||
def rebuild_system_prompt(self, force=False, update_timestamp=True):
|
||||
"""Rebuilds the system message with the latest memory object and any shared memory block updates"""
|
||||
curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt
|
||||
|
||||
# NOTE: This is a hacky way to check if the memory has changed
|
||||
memory_repr = self.memory.compile()
|
||||
if not force and memory_repr == curr_system_message["content"][-(len(memory_repr)) :]:
|
||||
printd(f"Memory has not changed, not rebuilding system")
|
||||
return
|
||||
|
||||
if ms:
|
||||
self.update_memory_blocks_from_db()
|
||||
|
||||
# If the memory didn't update, we probably don't want to update the timestamp inside
|
||||
# For example, if we're doing a system prompt swap, this should probably be False
|
||||
if update_timestamp:
|
||||
@@ -1276,8 +1263,8 @@ class Agent(BaseAgent):
|
||||
|
||||
# update memory (TODO: potentially update recall/archival stats seperately)
|
||||
new_system_message_str = compile_system_message(
|
||||
system_prompt=self.system,
|
||||
in_context_memory=self.memory,
|
||||
system_prompt=self.agent_state.system,
|
||||
in_context_memory=self.agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
archival_memory=self.persistence_manager.archival_memory,
|
||||
recall_memory=self.persistence_manager.recall_memory,
|
||||
@@ -1304,13 +1291,13 @@ class Agent(BaseAgent):
|
||||
"""Update the system prompt of the agent (requires rebuilding the memory block if there's a difference)"""
|
||||
assert isinstance(new_system_prompt, str)
|
||||
|
||||
if new_system_prompt == self.system:
|
||||
if new_system_prompt == self.agent_state.system:
|
||||
return
|
||||
|
||||
self.system = new_system_prompt
|
||||
self.agent_state.system = new_system_prompt
|
||||
|
||||
# updating the system prompt requires rebuilding the memory block inside the compiled system message
|
||||
self.rebuild_memory(force=True, update_timestamp=False)
|
||||
self.rebuild_system_prompt(force=True, update_timestamp=False)
|
||||
|
||||
# make sure to persist the change
|
||||
_ = self.update_state()
|
||||
@@ -1324,6 +1311,7 @@ class Agent(BaseAgent):
|
||||
raise NotImplementedError
|
||||
|
||||
def update_state(self) -> AgentState:
|
||||
# TODO: this should be removed and self._messages should be moved into self.agent_state.in_context_messages
|
||||
message_ids = [msg.id for msg in self._messages]
|
||||
|
||||
# Assert that these are all strings
|
||||
@@ -1331,12 +1319,8 @@ class Agent(BaseAgent):
|
||||
warnings.warn(f"Non-string message IDs found in agent state: {message_ids}")
|
||||
message_ids = [m_id for m_id in message_ids if isinstance(m_id, str)]
|
||||
|
||||
assert isinstance(self.memory, Memory), f"Memory is not a Memory object: {type(self.memory)}"
|
||||
|
||||
# override any fields that may have been updated
|
||||
self.agent_state.message_ids = message_ids
|
||||
self.agent_state.memory = self.memory
|
||||
self.agent_state.system = self.system
|
||||
|
||||
return self.agent_state
|
||||
|
||||
@@ -1537,7 +1521,7 @@ class Agent(BaseAgent):
|
||||
|
||||
system_prompt = self.agent_state.system # TODO is this the current system or the initial system?
|
||||
num_tokens_system = count_tokens(system_prompt)
|
||||
core_memory = self.memory.compile()
|
||||
core_memory = self.agent_state.memory.compile()
|
||||
num_tokens_core_memory = count_tokens(core_memory)
|
||||
|
||||
# conversion of messages to OpenAI dict format, which is passed to the token counter
|
||||
@@ -1629,37 +1613,15 @@ def save_agent(agent: Agent, ms: MetadataStore):
|
||||
|
||||
agent.update_state()
|
||||
agent_state = agent.agent_state
|
||||
agent_id = agent_state.id
|
||||
assert isinstance(agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}"
|
||||
|
||||
# NOTE: we're saving agent memory before persisting the agent to ensure
|
||||
# that allocated block_ids for each memory block are present in the agent model
|
||||
save_agent_memory(agent=agent)
|
||||
|
||||
if ms.get_agent(agent_id=agent.agent_state.id):
|
||||
ms.update_agent(agent_state)
|
||||
# TODO: move this to agent manager
|
||||
# convert to persisted model
|
||||
persisted_agent_state = agent.agent_state.to_persisted_agent_state()
|
||||
if ms.get_agent(agent_id=persisted_agent_state.id):
|
||||
ms.update_agent(persisted_agent_state)
|
||||
else:
|
||||
ms.create_agent(agent_state)
|
||||
|
||||
agent.agent_state = ms.get_agent(agent_id=agent_id)
|
||||
assert isinstance(agent.agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}"
|
||||
|
||||
|
||||
def save_agent_memory(agent: Agent):
|
||||
"""
|
||||
Save agent memory to metadata store. Memory is a collection of blocks and each block is persisted to the block table.
|
||||
|
||||
NOTE: we are assuming agent.update_state has already been called.
|
||||
"""
|
||||
|
||||
for block_dict in agent.memory.to_dict()["memory"].values():
|
||||
# TODO: block creation should happen in one place to enforce these sort of constraints consistently.
|
||||
block = Block(**block_dict)
|
||||
# FIXME: should we expect for block values to be None? If not, we need to figure out why that is
|
||||
# the case in some tests, if so we should relax the DB constraint.
|
||||
if block.value is None:
|
||||
block.value = ""
|
||||
BlockManager().create_or_update_block(block, actor=agent.user)
|
||||
ms.create_agent(persisted_agent_state)
|
||||
|
||||
|
||||
def strip_name_field_from_user_message(user_message_text: str) -> Tuple[str, Optional[str]]:
|
||||
|
||||
Reference in New Issue
Block a user