diff --git a/letta/agent.py b/letta/agent.py index 5e3fd2df..d80fff04 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -1208,6 +1208,30 @@ class Agent(BaseAgent): new_messages = [new_system_message_obj] + self._messages[1:] # swap index 0 (system) self._messages = new_messages + def update_memory_blocks_from_db(self): + for block in self.memory.to_dict()["memory"].values(): + if block.get("templates", False): + # we don't expect to update shared memory blocks that + # are templates. this is something we could update in the + # future if we expect templates to change often. + continue + block_id = block.get("id") + + # TODO: This is really hacky and we should probably figure out how to + db_block = BlockManager().get_block_by_id(block_id=block_id, actor=self.user) + if db_block is None: + # this case covers if someone has deleted a shared block by interacting + # with some other agent. + # in that case we should remove this shared block from the agent currently being + # evaluated. + printd(f"removing block: {block_id=}") + continue + if not isinstance(db_block.value, str): + printd(f"skipping block update, unexpected value: {block_id=}") + continue + # TODO: we may want to update which columns we're updating from shared memory e.g. the limit + self.memory.update_block_value(label=block.get("label", ""), value=db_block.value) + def rebuild_memory(self, force=False, update_timestamp=True, ms: Optional[MetadataStore] = None): """Rebuilds the system message with the latest memory object and any shared memory block updates""" curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt @@ -1219,28 +1243,7 @@ class Agent(BaseAgent): return if ms: - for block in self.memory.to_dict()["memory"].values(): - if block.get("templates", False): - # we don't expect to update shared memory blocks that - # are templates. this is something we could update in the - # future if we expect templates to change often. - continue - block_id = block.get("id") - - # TODO: This is really hacky and we should probably figure out how to - db_block = BlockManager().get_block_by_id(block_id=block_id, actor=self.user) - if db_block is None: - # this case covers if someone has deleted a shared block by interacting - # with some other agent. - # in that case we should remove this shared block from the agent currently being - # evaluated. - printd(f"removing block: {block_id=}") - continue - if not isinstance(db_block.value, str): - printd(f"skipping block update, unexpected value: {block_id=}") - continue - # TODO: we may want to update which columns we're updating from shared memory e.g. the limit - self.memory.update_block_value(label=block.get("label", ""), value=db_block.value) + self.update_memory_blocks_from_db() # If the memory didn't update, we probably don't want to update the timestamp inside # For example, if we're doing a system prompt swap, this should probably be False diff --git a/letta/server/server.py b/letta/server/server.py index 183b33f2..2296c1e3 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1385,8 +1385,9 @@ class SyncServer(Server): # Get the agent object (loaded in memory) letta_agent = self._get_or_load_agent(agent_id=agent_id) assert isinstance(letta_agent.memory, Memory) - agent_state = letta_agent.agent_state.model_copy(deep=True) + letta_agent.update_memory_blocks_from_db() + agent_state = letta_agent.agent_state.model_copy(deep=True) # Load the tags in for the agent_state agent_state.tags = self.agents_tags_manager.get_tags_for_agent(agent_id=agent_id, actor=user) return agent_state diff --git a/tests/test_local_client.py b/tests/test_local_client.py index d8518a35..f38045f2 100644 --- a/tests/test_local_client.py +++ b/tests/test_local_client.py @@ -432,3 +432,40 @@ def test_tool_creation_langchain_missing_imports(client: LocalClient): # Intentionally missing {"langchain_community.utilities": "WikipediaAPIWrapper"} with pytest.raises(RuntimeError): ToolCreate.from_langchain(langchain_tool) + + +def test_shared_blocks_without_send_message(client: LocalClient): + from letta import BasicBlockMemory + from letta.client.client import Block, create_client + from letta.schemas.agent import AgentType + from letta.schemas.embedding_config import EmbeddingConfig + from letta.schemas.llm_config import LLMConfig + + client = create_client() + shared_memory_block = Block(name="shared_memory", label="shared_memory", value="[empty]", limit=2000) + memory = BasicBlockMemory(blocks=[shared_memory_block]) + + agent_1 = client.create_agent( + agent_type=AgentType.memgpt_agent, + llm_config=LLMConfig.default_config("gpt-4"), + embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"), + memory=memory, + ) + + agent_2 = client.create_agent( + agent_type=AgentType.memgpt_agent, + llm_config=LLMConfig.default_config("gpt-4"), + embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"), + memory=memory, + ) + + agent_1.memory.update_block_value(label="shared_memory", value="I am no longer an [empty] memory") + + block_id = agent_1.memory.get_block("shared_memory").id + client.update_block(block_id, text="I am no longer an [empty] memory") + client.update_agent(agent_id=agent_1.id, memory=agent_1.memory) + agent_1 = client.get_agent(agent_1.id) + agent_2 = client.get_agent(agent_2.id) + client.update_agent(agent_id=agent_2.id, memory=agent_2.memory) + assert agent_1.memory.get_block("shared_memory").value == "I am no longer an [empty] memory" + assert agent_2.memory.get_block("shared_memory").value == "I am no longer an [empty] memory"