fix: shared memory without requiring send message (#2068)

This commit is contained in:
Kevin Lin
2024-11-21 13:58:03 -08:00
committed by GitHub
parent 6f1964c575
commit 8395c86f78
3 changed files with 64 additions and 23 deletions

View File

@@ -1208,6 +1208,30 @@ class Agent(BaseAgent):
new_messages = [new_system_message_obj] + self._messages[1:] # swap index 0 (system)
self._messages = new_messages
def update_memory_blocks_from_db(self):
for block in self.memory.to_dict()["memory"].values():
if block.get("templates", False):
# we don't expect to update shared memory blocks that
# are templates. this is something we could update in the
# future if we expect templates to change often.
continue
block_id = block.get("id")
# TODO: This is really hacky and we should probably figure out how to
db_block = BlockManager().get_block_by_id(block_id=block_id, actor=self.user)
if db_block is None:
# this case covers if someone has deleted a shared block by interacting
# with some other agent.
# in that case we should remove this shared block from the agent currently being
# evaluated.
printd(f"removing block: {block_id=}")
continue
if not isinstance(db_block.value, str):
printd(f"skipping block update, unexpected value: {block_id=}")
continue
# TODO: we may want to update which columns we're updating from shared memory e.g. the limit
self.memory.update_block_value(label=block.get("label", ""), value=db_block.value)
def rebuild_memory(self, force=False, update_timestamp=True, ms: Optional[MetadataStore] = None):
"""Rebuilds the system message with the latest memory object and any shared memory block updates"""
curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt
@@ -1219,28 +1243,7 @@ class Agent(BaseAgent):
return
if ms:
for block in self.memory.to_dict()["memory"].values():
if block.get("templates", False):
# we don't expect to update shared memory blocks that
# are templates. this is something we could update in the
# future if we expect templates to change often.
continue
block_id = block.get("id")
# TODO: This is really hacky and we should probably figure out how to
db_block = BlockManager().get_block_by_id(block_id=block_id, actor=self.user)
if db_block is None:
# this case covers if someone has deleted a shared block by interacting
# with some other agent.
# in that case we should remove this shared block from the agent currently being
# evaluated.
printd(f"removing block: {block_id=}")
continue
if not isinstance(db_block.value, str):
printd(f"skipping block update, unexpected value: {block_id=}")
continue
# TODO: we may want to update which columns we're updating from shared memory e.g. the limit
self.memory.update_block_value(label=block.get("label", ""), value=db_block.value)
self.update_memory_blocks_from_db()
# If the memory didn't update, we probably don't want to update the timestamp inside
# For example, if we're doing a system prompt swap, this should probably be False

View File

@@ -1385,8 +1385,9 @@ class SyncServer(Server):
# Get the agent object (loaded in memory)
letta_agent = self._get_or_load_agent(agent_id=agent_id)
assert isinstance(letta_agent.memory, Memory)
agent_state = letta_agent.agent_state.model_copy(deep=True)
letta_agent.update_memory_blocks_from_db()
agent_state = letta_agent.agent_state.model_copy(deep=True)
# Load the tags in for the agent_state
agent_state.tags = self.agents_tags_manager.get_tags_for_agent(agent_id=agent_id, actor=user)
return agent_state

View File

@@ -432,3 +432,40 @@ def test_tool_creation_langchain_missing_imports(client: LocalClient):
# Intentionally missing {"langchain_community.utilities": "WikipediaAPIWrapper"}
with pytest.raises(RuntimeError):
ToolCreate.from_langchain(langchain_tool)
def test_shared_blocks_without_send_message(client: LocalClient):
from letta import BasicBlockMemory
from letta.client.client import Block, create_client
from letta.schemas.agent import AgentType
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
client = create_client()
shared_memory_block = Block(name="shared_memory", label="shared_memory", value="[empty]", limit=2000)
memory = BasicBlockMemory(blocks=[shared_memory_block])
agent_1 = client.create_agent(
agent_type=AgentType.memgpt_agent,
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
memory=memory,
)
agent_2 = client.create_agent(
agent_type=AgentType.memgpt_agent,
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
memory=memory,
)
agent_1.memory.update_block_value(label="shared_memory", value="I am no longer an [empty] memory")
block_id = agent_1.memory.get_block("shared_memory").id
client.update_block(block_id, text="I am no longer an [empty] memory")
client.update_agent(agent_id=agent_1.id, memory=agent_1.memory)
agent_1 = client.get_agent(agent_1.id)
agent_2 = client.get_agent(agent_2.id)
client.update_agent(agent_id=agent_2.id, memory=agent_2.memory)
assert agent_1.memory.get_block("shared_memory").value == "I am no longer an [empty] memory"
assert agent_2.memory.get_block("shared_memory").value == "I am no longer an [empty] memory"