From d15e49808dae86bcfbf9d676fd7e0679d02eed77 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Tue, 13 May 2025 20:54:52 -0700 Subject: [PATCH] chore: add a concept (and enforcement of) read-only blocks (#2161) --- .../220856bbf43b_add_read_only_column.py | 35 +++++++++++++++++ letta/agent.py | 12 ++++++ letta/constants.py | 3 ++ letta/orm/block.py | 3 ++ letta/schemas/agent.py | 12 +++++- letta/schemas/block.py | 3 ++ letta/schemas/memory.py | 9 ++++- letta/services/tool_executor/tool_executor.py | 20 +++++++++- tests/test_memory.py | 17 -------- tests/test_sdk_client.py | 39 +++++++++++++++++++ 10 files changed, 131 insertions(+), 22 deletions(-) create mode 100644 alembic/versions/220856bbf43b_add_read_only_column.py diff --git a/alembic/versions/220856bbf43b_add_read_only_column.py b/alembic/versions/220856bbf43b_add_read_only_column.py new file mode 100644 index 00000000..a8a962de --- /dev/null +++ b/alembic/versions/220856bbf43b_add_read_only_column.py @@ -0,0 +1,35 @@ +"""add read-only column + +Revision ID: 220856bbf43b +Revises: 1dc0fee72dea +Create Date: 2025-05-13 14:42:17.353614 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "220856bbf43b" +down_revision: Union[str, None] = "1dc0fee72dea" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # add default value of `False` + op.add_column("block", sa.Column("read_only", sa.Boolean(), nullable=True)) + op.execute( + f""" + UPDATE block + SET read_only = False + """ + ) + op.alter_column("block", "read_only", nullable=False) + + +def downgrade() -> None: + op.drop_column("block", "read_only") diff --git a/letta/agent.py b/letta/agent.py index 50a8ed20..df755c23 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -179,6 +179,15 @@ class Agent(BaseAgent): raise ValueError(f"Invalid JSON format in message: {text_content}") return None + def ensure_read_only_block_not_modified(self, new_memory: Memory) -> None: + """ + Throw an error if a read-only block has been modified + """ + for label in self.agent_state.memory.list_block_labels(): + if self.agent_state.memory.get_block(label).read_only: + if new_memory.get_block(label).value != self.agent_state.memory.get_block(label).value: + raise ValueError(READ_ONLY_BLOCK_EDIT_ERROR) + def update_memory_if_changed(self, new_memory: Memory) -> bool: """ Update internal memory object and system prompt if there have been modifications. @@ -1277,6 +1286,9 @@ class Agent(BaseAgent): agent_state_copy = self.agent_state.__deepcopy__() function_args["agent_state"] = agent_state_copy # need to attach self to arg since it's dynamically linked function_response = callable_func(**function_args) + self.ensure_read_only_block_not_modified( + new_memory=agent_state_copy.memory + ) # memory editing tools cannot edit read-only blocks self.update_memory_if_changed(agent_state_copy.memory) elif target_letta_tool.tool_type == ToolType.EXTERNAL_COMPOSIO: action_name = generate_composio_action_from_func_name(target_letta_tool.name) diff --git a/letta/constants.py b/letta/constants.py index 448277f8..1068c614 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -195,6 +195,9 @@ DATA_SOURCE_ATTACH_ALERT = ( "[ALERT] New data was just uploaded to archival memory. You can view this data by calling the archival_memory_search tool." ) +# Throw an error message when a read-only block is edited +READ_ONLY_BLOCK_EDIT_ERROR = f"{ERROR_MESSAGE_PREFIX} This block is read-only and cannot be edited." + # The ackknowledgement message used in the summarize sequence MESSAGE_SUMMARY_REQUEST_ACK = "Understood, I will respond with a summary of the message (and only the summary, nothing else) once I receive the conversation history. I'm ready." diff --git a/letta/orm/block.py b/letta/orm/block.py index 30b2f1ab..271a9baa 100644 --- a/letta/orm/block.py +++ b/letta/orm/block.py @@ -39,6 +39,9 @@ class Block(OrganizationMixin, SqlalchemyBase): limit: Mapped[BigInteger] = mapped_column(Integer, default=CORE_MEMORY_BLOCK_CHAR_LIMIT, doc="Character limit of the block.") metadata_: Mapped[Optional[dict]] = mapped_column(JSON, default={}, doc="arbitrary information related to the block.") + # permissions of the agent + read_only: Mapped[bool] = mapped_column(doc="whether the agent has read-only access to the block", default=False) + # history pointers / locking mechanisms current_history_entry_id: Mapped[Optional[str]] = mapped_column( String, ForeignKey("block_history.id", name="fk_block_current_history_entry", use_alter=True), nullable=True, index=True diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 13f74d82..cccd048f 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -312,9 +312,17 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None): ) return ( "{% for block in blocks %}" - '<{{ block.label }} characters="{{ block.value|length }}/{{ block.limit }}">\n' + "<{{ block.label }}>\n" + "\n" + "{{ block.description }}\n" + "\n" + "\n" + '{% if block.read_only %}read_only="true" {% endif %}chars_current="{{ block.value|length }}" chars_limit="{{ block.limit }}"\n' + "\n" + "\n" "{{ block.value }}\n" - "" + "\n" + "\n" "{% if not loop.last %}\n{% endif %}" "{% endfor %}" ) diff --git a/letta/schemas/block.py b/letta/schemas/block.py index 3e2fbb7e..babcd803 100644 --- a/letta/schemas/block.py +++ b/letta/schemas/block.py @@ -25,6 +25,9 @@ class BaseBlock(LettaBase, validate_assignment=True): # context window label label: Optional[str] = Field(None, description="Label of the block (e.g. 'human', 'persona') in the context window.") + # permissions of the agent + read_only: bool = Field(False, description="Whether the agent has read-only access to the block.") + # metadata description: Optional[str] = Field(None, description="Description of the block.") metadata: Optional[dict] = Field({}, description="Metadata of the block.") diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index 1f60a09a..e64533be 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -69,9 +69,14 @@ class Memory(BaseModel, validate_assignment=True): # Memory.template is a Jinja2 template for compiling memory module into a prompt string. prompt_template: str = Field( default="{% for block in blocks %}" - '<{{ block.label }} characters="{{ block.value|length }}/{{ block.limit }}">\n' + "<{{ block.label }}>\n" + "" + 'read_only="{{ block.read_only}}" chars_current="{{ block.value|length }}" chars_limit="{{ block.limit }}"' + "" + "" "{{ block.value }}\n" - "" + "" + "\n" "{% if not loop.last %}\n{% endif %}" "{% endfor %}", description="Jinja2 template for compiling memory blocks into a prompt string", diff --git a/letta/services/tool_executor/tool_executor.py b/letta/services/tool_executor/tool_executor.py index 50879e57..9424520c 100644 --- a/letta/services/tool_executor/tool_executor.py +++ b/letta/services/tool_executor/tool_executor.py @@ -3,7 +3,12 @@ import traceback from abc import ABC, abstractmethod from typing import Any, Dict, Optional -from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, CORE_MEMORY_LINE_NUMBER_WARNING, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE +from letta.constants import ( + COMPOSIO_ENTITY_ENV_VAR_KEY, + CORE_MEMORY_LINE_NUMBER_WARNING, + READ_ONLY_BLOCK_EDIT_ERROR, + RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE, +) from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source from letta.functions.composio_helpers import execute_composio_action_async, generate_composio_action_from_func_name from letta.helpers.composio_helpers import get_composio_api_key @@ -203,6 +208,8 @@ class LettaCoreToolExecutor(ToolExecutor): Returns: Optional[str]: None is always returned as this function does not produce a response. """ + if agent_state.memory.get_block(label).read_only: + raise ValueError(f"{READ_ONLY_BLOCK_EDIT_ERROR}") current_value = str(agent_state.memory.get_block(label).value) new_value = current_value + "\n" + str(content) agent_state.memory.update_block_value(label=label, value=new_value) @@ -228,6 +235,8 @@ class LettaCoreToolExecutor(ToolExecutor): Returns: Optional[str]: None is always returned as this function does not produce a response. """ + if agent_state.memory.get_block(label).read_only: + raise ValueError(f"{READ_ONLY_BLOCK_EDIT_ERROR}") current_value = str(agent_state.memory.get_block(label).value) if old_content not in current_value: raise ValueError(f"Old content '{old_content}' not found in memory block '{label}'") @@ -260,6 +269,9 @@ class LettaCoreToolExecutor(ToolExecutor): """ import re + if agent_state.memory.get_block(label).read_only: + raise ValueError(f"{READ_ONLY_BLOCK_EDIT_ERROR}") + if bool(re.search(r"\nLine \d+: ", old_str)): raise ValueError( "old_str contains a line number prefix, which is not allowed. " @@ -349,6 +361,9 @@ class LettaCoreToolExecutor(ToolExecutor): """ import re + if agent_state.memory.get_block(label).read_only: + raise ValueError(f"{READ_ONLY_BLOCK_EDIT_ERROR}") + if bool(re.search(r"\nLine \d+: ", new_str)): raise ValueError( "new_str contains a line number prefix, which is not allowed. Do not " @@ -426,6 +441,9 @@ class LettaCoreToolExecutor(ToolExecutor): """ import re + if agent_state.memory.get_block(label).read_only: + raise ValueError(f"{READ_ONLY_BLOCK_EDIT_ERROR}") + if bool(re.search(r"\nLine \d+: ", new_memory)): raise ValueError( "new_memory contains a line number prefix, which is not allowed. Do not " diff --git a/tests/test_memory.py b/tests/test_memory.py index 85e12e80..87e02a0e 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -25,23 +25,6 @@ def test_memory_limit_validation(sample_memory: Memory): sample_memory.get_block("persona").value = "x " * 10000 -def test_memory_jinja2_template(sample_memory: Memory): - """Test to make sure the jinja2 template string is equivalent to the old __repr__ method""" - - def old_repr(self: Memory) -> str: - """Generate a string representation of the memory in-context""" - section_strs = [] - for block in sample_memory.get_blocks(): - section = block.label - module = block - section_strs.append(f'<{section} characters="{len(module.value)}/{module.limit}">\n{module.value}\n') - return "\n".join(section_strs) - - old_repr_str = old_repr(sample_memory) - new_repr_str = sample_memory.compile() - assert new_repr_str == old_repr_str, f"Expected '{old_repr_str}' to be '{new_repr_str}'" - - def test_memory_jinja2_set_template(sample_memory: Memory): """Test setting the template for the memory""" diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index d51ad028..9482b5a4 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -122,6 +122,45 @@ def test_shared_blocks(client: LettaSDKClient): client.agents.delete(agent_state2.id) +def test_read_only_block(client: LettaSDKClient): + block_value = "username: sarah" + agent = client.agents.create( + memory_blocks=[ + CreateBlock( + label="human", + value=block_value, + read_only=True, + ), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-ada-002", + ) + + # make sure agent cannot update read-only block + client.agents.messages.create( + agent_id=agent.id, + messages=[ + MessageCreate( + role="user", + content="my name is actually charles", + ) + ], + ) + + # make sure block value is still the same + block = client.agents.blocks.retrieve(agent_id=agent.id, block_label="human") + assert block.value == block_value + + # make sure can update from client + new_value = "hello" + client.agents.blocks.modify(agent_id=agent.id, block_label="human", value=new_value) + block = client.agents.blocks.retrieve(agent_id=agent.id, block_label="human") + assert block.value == new_value + + # cleanup + client.agents.delete(agent.id) + + def test_add_and_manage_tags_for_agent(client: LettaSDKClient): """ Comprehensive happy path test for adding, retrieving, and managing tags on an agent.