From e51c21070e6347159f393c69b5f08a96409ecf5a Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Thu, 5 Dec 2024 11:52:51 -0800 Subject: [PATCH] fix: patch system message creation spam (#2169) Co-authored-by: Sarah Wooders --- .github/workflows/check_for_new_prints.yml | 6 +- letta/agent.py | 16 ++++- letta/server/server.py | 8 ++- tests/test_server.py | 70 +++++++++++++++++++++- 4 files changed, 94 insertions(+), 6 deletions(-) diff --git a/.github/workflows/check_for_new_prints.yml b/.github/workflows/check_for_new_prints.yml index 2ed7ff64..75ef2e27 100644 --- a/.github/workflows/check_for_new_prints.yml +++ b/.github/workflows/check_for_new_prints.yml @@ -1,4 +1,4 @@ -name: Check for new print statements +name: Check for Print Statements on: pull_request: paths: @@ -23,8 +23,8 @@ jobs: # Get the files changed in this PR git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} > changed_files.txt - # Filter for only Python files - grep "\.py$" changed_files.txt > python_files.txt || true + # Filter for only Python files, excluding tests directory + grep "\.py$" changed_files.txt | grep -v "^tests/" > python_files.txt || true # Initialize error flag ERROR=0 diff --git a/letta/agent.py b/letta/agent.py index 7c4ff97e..6aea2829 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -1253,9 +1253,23 @@ class Agent(BaseAgent): self._messages = new_messages def rebuild_system_prompt(self, force=False, update_timestamp=True): - """Rebuilds the system message with the latest memory object and any shared memory block updates""" + """Rebuilds the system message with the latest memory object and any shared memory block updates + + Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object + + Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages + """ + curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt + # note: we only update the system prompt if the core memory is changed + # this means that the archival/recall memory statistics may be someout out of date + curr_memory_str = self.agent_state.memory.compile() + if curr_memory_str in curr_system_message["content"] and not force: + # NOTE: could this cause issues if a block is removed? (substring match would still work) + printd(f"Memory hasn't changed, skipping system prompt rebuild") + return + # If the memory didn't update, we probably don't want to update the timestamp inside # For example, if we're doing a system prompt swap, this should probably be False if update_timestamp: diff --git a/letta/server/server.py b/letta/server/server.py index 60a7a9bc..d5c5f8d4 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -395,6 +395,10 @@ class SyncServer(Server): agent_lock = self.per_agent_lock_manager.get_lock(agent_id) with agent_lock: agent_state = self.get_agent(agent_id=agent_id) + if agent_state is None: + raise ValueError(f"Agent (agent_id={agent_id}) does not exist") + elif agent_state.user_id is None: + raise ValueError(f"Agent (agent_id={agent_id}) does not have a user_id") actor = self.user_manager.get_user_by_id(user_id=agent_state.user_id) interface = interface or self.default_interface_factory() @@ -882,7 +886,7 @@ class SyncServer(Server): in_memory_agent_state = self.get_agent(agent_state.id) return in_memory_agent_state - def get_agent(self, agent_id: str) -> AgentState: + def get_agent(self, agent_id: str) -> Optional[AgentState]: """ Retrieve the full agent state from the DB. This gathers data accross multiple tables to provide the full state of an agent, which is passed into the `Agent` object for creation. @@ -893,6 +897,8 @@ class SyncServer(Server): if agent_state is None: # agent does not exist return None + if agent_state.user_id is None: + raise ValueError(f"Agent {agent_id} does not have a user_id") user = self.user_manager.get_user_by_id(user_id=agent_state.user_id) # construct the in-memory, full agent state - this gather data stored in different tables but that needs to be passed to `Agent` diff --git a/tests/test_server.py b/tests/test_server.py index 2811efed..44c08c5c 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,11 +1,13 @@ import json import uuid import warnings +from typing import List, Tuple import pytest import letta.utils as utils -from letta.constants import BASE_TOOLS +from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS +from letta.schemas.block import CreateBlock from letta.schemas.enums import MessageRole from letta.schemas.letta_message import ( FunctionCallMessage, @@ -668,3 +670,69 @@ def test_composio_client_simple(server): # Assert there's some amount of actions assert len(actions) > 0 + + +def test_memory_rebuild_count(server, user_id): + """Test that the memory rebuild is generating the correct number of role=system messages""" + + # create agent + agent_state = server.create_agent( + request=CreateAgent( + name="memory_rebuild_test_agent", + tools=BASE_TOOLS + BASE_MEMORY_TOOLS, + memory_blocks=[ + CreateBlock(label="human", value="The human's name is Bob."), + CreateBlock(label="persona", value="My name is Alice."), + ], + llm_config=LLMConfig.default_config("gpt-4"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ), + actor=server.get_user_or_default(user_id), + ) + print(f"Created agent\n{agent_state}") + + def count_system_messages_in_recall() -> Tuple[int, List[LettaMessage]]: + + # At this stage, there should only be 1 system message inside of recall storage + letta_messages = server.get_agent_recall_cursor( + user_id=user_id, + agent_id=agent_state.id, + limit=1000, + # reverse=reverse, + return_message_object=False, + ) + assert all(isinstance(m, LettaMessage) for m in letta_messages) + + print("LETTA_MESSAGES:") + for i, m in enumerate(letta_messages): + print(f"{i}: {type(m)} ...{str(m)[-50:]}") + + # Collect system messages and their texts + system_messages = [m for m in letta_messages if m.message_type == "system_message"] + return len(system_messages), letta_messages + + try: + + # At this stage, there should only be 1 system message inside of recall storage + num_system_messages, all_messages = count_system_messages_in_recall() + # assert num_system_messages == 1, (num_system_messages, all_messages) + assert num_system_messages == 2, (num_system_messages, all_messages) + + # Assuming core memory append actually ran correctly, at this point there should be 2 messages + server.user_message(user_id=user_id, agent_id=agent_state.id, message="Append 'banana' to your core memory") + + # At this stage, there should only be 1 system message inside of recall storage + num_system_messages, all_messages = count_system_messages_in_recall() + # assert num_system_messages == 2, (num_system_messages, all_messages) + assert num_system_messages == 3, (num_system_messages, all_messages) + + # Run server.load_agent, and make sure that the number of system messages is still 2 + server.load_agent(agent_id=agent_state.id) + + num_system_messages, all_messages = count_system_messages_in_recall() + # assert num_system_messages == 2, (num_system_messages, all_messages) + assert num_system_messages == 3, (num_system_messages, all_messages) + + finally: + # cleanup + server.delete_agent(user_id, agent_state.id)