fix: shared memory without requiring send message (#2068)
This commit is contained in:
@@ -1208,6 +1208,30 @@ class Agent(BaseAgent):
|
||||
new_messages = [new_system_message_obj] + self._messages[1:] # swap index 0 (system)
|
||||
self._messages = new_messages
|
||||
|
||||
def update_memory_blocks_from_db(self):
|
||||
for block in self.memory.to_dict()["memory"].values():
|
||||
if block.get("templates", False):
|
||||
# we don't expect to update shared memory blocks that
|
||||
# are templates. this is something we could update in the
|
||||
# future if we expect templates to change often.
|
||||
continue
|
||||
block_id = block.get("id")
|
||||
|
||||
# TODO: This is really hacky and we should probably figure out how to
|
||||
db_block = BlockManager().get_block_by_id(block_id=block_id, actor=self.user)
|
||||
if db_block is None:
|
||||
# this case covers if someone has deleted a shared block by interacting
|
||||
# with some other agent.
|
||||
# in that case we should remove this shared block from the agent currently being
|
||||
# evaluated.
|
||||
printd(f"removing block: {block_id=}")
|
||||
continue
|
||||
if not isinstance(db_block.value, str):
|
||||
printd(f"skipping block update, unexpected value: {block_id=}")
|
||||
continue
|
||||
# TODO: we may want to update which columns we're updating from shared memory e.g. the limit
|
||||
self.memory.update_block_value(label=block.get("label", ""), value=db_block.value)
|
||||
|
||||
def rebuild_memory(self, force=False, update_timestamp=True, ms: Optional[MetadataStore] = None):
|
||||
"""Rebuilds the system message with the latest memory object and any shared memory block updates"""
|
||||
curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt
|
||||
@@ -1219,28 +1243,7 @@ class Agent(BaseAgent):
|
||||
return
|
||||
|
||||
if ms:
|
||||
for block in self.memory.to_dict()["memory"].values():
|
||||
if block.get("templates", False):
|
||||
# we don't expect to update shared memory blocks that
|
||||
# are templates. this is something we could update in the
|
||||
# future if we expect templates to change often.
|
||||
continue
|
||||
block_id = block.get("id")
|
||||
|
||||
# TODO: This is really hacky and we should probably figure out how to
|
||||
db_block = BlockManager().get_block_by_id(block_id=block_id, actor=self.user)
|
||||
if db_block is None:
|
||||
# this case covers if someone has deleted a shared block by interacting
|
||||
# with some other agent.
|
||||
# in that case we should remove this shared block from the agent currently being
|
||||
# evaluated.
|
||||
printd(f"removing block: {block_id=}")
|
||||
continue
|
||||
if not isinstance(db_block.value, str):
|
||||
printd(f"skipping block update, unexpected value: {block_id=}")
|
||||
continue
|
||||
# TODO: we may want to update which columns we're updating from shared memory e.g. the limit
|
||||
self.memory.update_block_value(label=block.get("label", ""), value=db_block.value)
|
||||
self.update_memory_blocks_from_db()
|
||||
|
||||
# If the memory didn't update, we probably don't want to update the timestamp inside
|
||||
# For example, if we're doing a system prompt swap, this should probably be False
|
||||
|
||||
@@ -1385,8 +1385,9 @@ class SyncServer(Server):
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
||||
assert isinstance(letta_agent.memory, Memory)
|
||||
agent_state = letta_agent.agent_state.model_copy(deep=True)
|
||||
|
||||
letta_agent.update_memory_blocks_from_db()
|
||||
agent_state = letta_agent.agent_state.model_copy(deep=True)
|
||||
# Load the tags in for the agent_state
|
||||
agent_state.tags = self.agents_tags_manager.get_tags_for_agent(agent_id=agent_id, actor=user)
|
||||
return agent_state
|
||||
|
||||
@@ -432,3 +432,40 @@ def test_tool_creation_langchain_missing_imports(client: LocalClient):
|
||||
# Intentionally missing {"langchain_community.utilities": "WikipediaAPIWrapper"}
|
||||
with pytest.raises(RuntimeError):
|
||||
ToolCreate.from_langchain(langchain_tool)
|
||||
|
||||
|
||||
def test_shared_blocks_without_send_message(client: LocalClient):
|
||||
from letta import BasicBlockMemory
|
||||
from letta.client.client import Block, create_client
|
||||
from letta.schemas.agent import AgentType
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
client = create_client()
|
||||
shared_memory_block = Block(name="shared_memory", label="shared_memory", value="[empty]", limit=2000)
|
||||
memory = BasicBlockMemory(blocks=[shared_memory_block])
|
||||
|
||||
agent_1 = client.create_agent(
|
||||
agent_type=AgentType.memgpt_agent,
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
agent_2 = client.create_agent(
|
||||
agent_type=AgentType.memgpt_agent,
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
agent_1.memory.update_block_value(label="shared_memory", value="I am no longer an [empty] memory")
|
||||
|
||||
block_id = agent_1.memory.get_block("shared_memory").id
|
||||
client.update_block(block_id, text="I am no longer an [empty] memory")
|
||||
client.update_agent(agent_id=agent_1.id, memory=agent_1.memory)
|
||||
agent_1 = client.get_agent(agent_1.id)
|
||||
agent_2 = client.get_agent(agent_2.id)
|
||||
client.update_agent(agent_id=agent_2.id, memory=agent_2.memory)
|
||||
assert agent_1.memory.get_block("shared_memory").value == "I am no longer an [empty] memory"
|
||||
assert agent_2.memory.get_block("shared_memory").value == "I am no longer an [empty] memory"
|
||||
|
||||
Reference in New Issue
Block a user