From ca43ffb474ae89a3d95da5785667f45a4eaae883 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 12 Feb 2025 15:06:56 -0800 Subject: [PATCH] feat: Add `message_buffer_autoclear` field to Agent (#978) --- ...a08_add_stateless_option_for_agentstate.py | 36 +++++++++++++++++++ letta/orm/agent.py | 8 ++++- letta/schemas/agent.py | 14 ++++++++ letta/services/agent_manager.py | 4 +++ tests/helpers/utils.py | 4 +++ tests/test_managers.py | 2 ++ 6 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 alembic/versions/7980d239ea08_add_stateless_option_for_agentstate.py diff --git a/alembic/versions/7980d239ea08_add_stateless_option_for_agentstate.py b/alembic/versions/7980d239ea08_add_stateless_option_for_agentstate.py new file mode 100644 index 00000000..9693940d --- /dev/null +++ b/alembic/versions/7980d239ea08_add_stateless_option_for_agentstate.py @@ -0,0 +1,36 @@ +"""Add message_buffer_autoclear option for AgentState + +Revision ID: 7980d239ea08 +Revises: dfafcf8210ca +Create Date: 2025-02-12 14:02:00.918226 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "7980d239ea08" +down_revision: Union[str, None] = "dfafcf8210ca" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add the column with a temporary nullable=True so we can backfill + op.add_column("agents", sa.Column("message_buffer_autoclear", sa.Boolean(), nullable=True)) + + # Backfill existing rows to set message_buffer_autoclear to False where it's NULL + op.execute("UPDATE agents SET message_buffer_autoclear = false WHERE message_buffer_autoclear IS NULL") + + # Now, enforce nullable=False after backfilling + op.alter_column("agents", "message_buffer_autoclear", nullable=False) + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("agents", "message_buffer_autoclear") + # ### end Alembic commands ### diff --git a/letta/orm/agent.py b/letta/orm/agent.py index a4d08f71..07b3917b 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -1,7 +1,7 @@ import uuid from typing import TYPE_CHECKING, List, Optional -from sqlalchemy import JSON, Index, String +from sqlalchemy import JSON, Boolean, Index, String from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.block import Block @@ -62,6 +62,11 @@ class Agent(SqlalchemyBase, OrganizationMixin): # Tool rules tool_rules: Mapped[Optional[List[ToolRule]]] = mapped_column(ToolRulesColumn, doc="the tool rules for this agent.") + # Stateless + message_buffer_autoclear: Mapped[bool] = mapped_column( + Boolean, doc="If set to True, the agent will not remember previous messages. Not recommended unless you have an advanced use case." + ) + # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="agents") tool_exec_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship( @@ -146,6 +151,7 @@ class Agent(SqlalchemyBase, OrganizationMixin): "project_id": self.project_id, "template_id": self.template_id, "base_template_id": self.base_template_id, + "message_buffer_autoclear": self.message_buffer_autoclear, } return self.__pydantic_model__(**state) diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 9269742d..d1c9d8b8 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -85,6 +85,12 @@ class AgentState(OrmMetadataBase, validate_assignment=True): template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.") base_template_id: Optional[str] = Field(None, description="The base template id of the agent.") + # An advanced configuration that makes it so this agent does not remember any previous messages + message_buffer_autoclear: bool = Field( + False, + description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.", + ) + def get_agent_env_vars_as_dict(self) -> Dict[str, str]: # Get environment variables for this agent specifically per_agent_env_vars = {} @@ -146,6 +152,10 @@ class CreateAgent(BaseModel, validate_assignment=True): # project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.") template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.") base_template_id: Optional[str] = Field(None, description="The base template id of the agent.") + message_buffer_autoclear: bool = Field( + False, + description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.", + ) @field_validator("name") @classmethod @@ -216,6 +226,10 @@ class UpdateAgent(BaseModel): project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.") template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.") base_template_id: Optional[str] = Field(None, description="The base template id of the agent.") + message_buffer_autoclear: Optional[bool] = Field( + None, + description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.", + ) class Config: extra = "ignore" # Ignores extra fields diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 3c965386..3024123d 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -123,6 +123,7 @@ class AgentManager: project_id=agent_create.project_id, template_id=agent_create.template_id, base_template_id=agent_create.base_template_id, + message_buffer_autoclear=agent_create.message_buffer_autoclear, ) # If there are provided environment variables, add them in @@ -185,6 +186,7 @@ class AgentManager: project_id: Optional[str] = None, template_id: Optional[str] = None, base_template_id: Optional[str] = None, + message_buffer_autoclear: bool = False, ) -> PydanticAgentState: """Create a new agent.""" with self.session_maker() as session: @@ -202,6 +204,7 @@ class AgentManager: "project_id": project_id, "template_id": template_id, "base_template_id": base_template_id, + "message_buffer_autoclear": message_buffer_autoclear, } # Create the new agent using SqlalchemyBase.create @@ -263,6 +266,7 @@ class AgentManager: "project_id", "template_id", "base_template_id", + "message_buffer_autoclear", } for field in scalar_fields: value = getattr(agent_update, field, None) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 167a39ee..f4868fda 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -151,3 +151,7 @@ def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, Up assert all( any(rule.tool_name == req_rule.tool_name for rule in agent.tool_rules) for req_rule in request.tool_rules ), f"Tool rules mismatch: {agent.tool_rules} != {request.tool_rules}" + + # Assert message_buffer_autoclear + if not request.message_buffer_autoclear is None: + assert agent.message_buffer_autoclear == request.message_buffer_autoclear diff --git a/tests/test_managers.py b/tests/test_managers.py index 540717dc..cce3e449 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -447,6 +447,7 @@ def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_too tool_rules=[InitToolRule(tool_name=print_tool.name)], initial_message_sequence=[MessageCreate(role=MessageRole.user, content="hello world")], tool_exec_environment_variables={"test_env_var_key_a": "test_env_var_value_a", "test_env_var_key_b": "test_env_var_value_b"}, + message_buffer_autoclear=True, ) created_agent = server.agent_manager.create_agent( create_agent_request, @@ -601,6 +602,7 @@ def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, othe message_ids=["10", "20"], metadata={"train_key": "train_value"}, tool_exec_environment_variables={"test_env_var_key_a": "a", "new_tool_exec_key": "n"}, + message_buffer_autoclear=False, ) last_updated_timestamp = agent.updated_at