From 57218d2b8fb3dcd379118cedf9921b180822e948 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 30 Apr 2025 17:20:53 -0700 Subject: [PATCH] feat: Finish adding min/max buffer to voice sleeptime manager group config (#1954) --- ...71_add_buffer_length_min_max_for_voice_.py | 33 +++++++++++++++ letta/agents/voice_agent.py | 8 +--- letta/orm/group.py | 2 + letta/schemas/group.py | 24 +++++++++++ letta/server/rest_api/routers/v1/voice.py | 2 - letta/server/server.py | 2 + letta/services/group_manager.py | 42 +++++++++++++++++++ 7 files changed, 105 insertions(+), 8 deletions(-) create mode 100644 alembic/versions/c56081a05371_add_buffer_length_min_max_for_voice_.py diff --git a/alembic/versions/c56081a05371_add_buffer_length_min_max_for_voice_.py b/alembic/versions/c56081a05371_add_buffer_length_min_max_for_voice_.py new file mode 100644 index 00000000..44f9a87f --- /dev/null +++ b/alembic/versions/c56081a05371_add_buffer_length_min_max_for_voice_.py @@ -0,0 +1,33 @@ +"""Add buffer length min max for voice sleeptime + +Revision ID: c56081a05371 +Revises: 28b8765bdd0a +Create Date: 2025-04-30 16:03:41.213750 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "c56081a05371" +down_revision: Union[str, None] = "28b8765bdd0a" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("groups", sa.Column("max_message_buffer_length", sa.Integer(), nullable=True)) + op.add_column("groups", sa.Column("min_message_buffer_length", sa.Integer(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("groups", "min_message_buffer_length") + op.drop_column("groups", "max_message_buffer_length") + # ### end Alembic commands ### diff --git a/letta/agents/voice_agent.py b/letta/agents/voice_agent.py index b2f48df5..39096460 100644 --- a/letta/agents/voice_agent.py +++ b/letta/agents/voice_agent.py @@ -69,8 +69,6 @@ class VoiceAgent(BaseAgent): block_manager: BlockManager, passage_manager: PassageManager, actor: User, - message_buffer_limit: int, - message_buffer_min: int, ): super().__init__( agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor @@ -81,8 +79,6 @@ class VoiceAgent(BaseAgent): self.passage_manager = passage_manager # TODO: This is not guaranteed to exist! self.summary_block_label = "human" - self.message_buffer_limit = message_buffer_limit - self.message_buffer_min = message_buffer_min # Cached archival memory/message size self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_id) @@ -109,8 +105,8 @@ class VoiceAgent(BaseAgent): target_block_label=self.summary_block_label, message_transcripts=[], ), - message_buffer_limit=self.message_buffer_limit, - message_buffer_min=self.message_buffer_min, + message_buffer_limit=agent_state.multi_agent_group.max_message_buffer_length, + message_buffer_min=agent_state.multi_agent_group.min_message_buffer_length, ) return summarizer diff --git a/letta/orm/group.py b/letta/orm/group.py index 48c3b65b..489e563f 100644 --- a/letta/orm/group.py +++ b/letta/orm/group.py @@ -21,6 +21,8 @@ class Group(SqlalchemyBase, OrganizationMixin): termination_token: Mapped[Optional[str]] = mapped_column(nullable=True, doc="") max_turns: Mapped[Optional[int]] = mapped_column(nullable=True, doc="") sleeptime_agent_frequency: Mapped[Optional[int]] = mapped_column(nullable=True, doc="") + max_message_buffer_length: Mapped[Optional[int]] = mapped_column(nullable=True, doc="") + min_message_buffer_length: Mapped[Optional[int]] = mapped_column(nullable=True, doc="") turns_counter: Mapped[Optional[int]] = mapped_column(nullable=True, doc="") last_processed_message_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="") diff --git a/letta/schemas/group.py b/letta/schemas/group.py index dce4a9e5..de40ba5d 100644 --- a/letta/schemas/group.py +++ b/letta/schemas/group.py @@ -32,6 +32,14 @@ class Group(GroupBase): sleeptime_agent_frequency: Optional[int] = Field(None, description="") turns_counter: Optional[int] = Field(None, description="") last_processed_message_id: Optional[str] = Field(None, description="") + max_message_buffer_length: Optional[int] = Field( + None, + description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.", + ) + min_message_buffer_length: Optional[int] = Field( + None, + description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.", + ) class ManagerConfig(BaseModel): @@ -87,11 +95,27 @@ class SleeptimeManagerUpdate(ManagerConfig): class VoiceSleeptimeManager(ManagerConfig): manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="") manager_agent_id: str = Field(..., description="") + max_message_buffer_length: Optional[int] = Field( + None, + description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.", + ) + min_message_buffer_length: Optional[int] = Field( + None, + description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.", + ) class VoiceSleeptimeManagerUpdate(ManagerConfig): manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="") manager_agent_id: Optional[str] = Field(None, description="") + max_message_buffer_length: Optional[int] = Field( + None, + description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.", + ) + min_message_buffer_length: Optional[int] = Field( + None, + description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.", + ) # class SwarmGroup(ManagerConfig): diff --git a/letta/server/rest_api/routers/v1/voice.py b/letta/server/rest_api/routers/v1/voice.py index 561081c9..7b3d7efd 100644 --- a/letta/server/rest_api/routers/v1/voice.py +++ b/letta/server/rest_api/routers/v1/voice.py @@ -56,8 +56,6 @@ async def create_voice_chat_completions( block_manager=server.block_manager, passage_manager=server.passage_manager, actor=actor, - message_buffer_limit=8, - message_buffer_min=4, ) # Return the streaming generator diff --git a/letta/server/server.py b/letta/server/server.py index 5a7deff7..2fa95034 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -862,6 +862,8 @@ class SyncServer(Server): agent_ids=[voice_sleeptime_agent.id], manager_config=VoiceSleeptimeManager( manager_agent_id=main_agent.id, + max_message_buffer_length=30, + min_message_buffer_length=15, ), ), actor=actor, diff --git a/letta/services/group_manager.py b/letta/services/group_manager.py index e24d508d..3e1ee023 100644 --- a/letta/services/group_manager.py +++ b/letta/services/group_manager.py @@ -80,6 +80,12 @@ class GroupManager: case ManagerType.voice_sleeptime: new_group.manager_type = ManagerType.voice_sleeptime new_group.manager_agent_id = group.manager_config.manager_agent_id + max_message_buffer_length = group.manager_config.max_message_buffer_length + min_message_buffer_length = group.manager_config.min_message_buffer_length + # Safety check for buffer length range + self.ensure_buffer_length_range_valid(max_value=max_message_buffer_length, min_value=min_message_buffer_length) + new_group.max_message_buffer_length = max_message_buffer_length + new_group.min_message_buffer_length = min_message_buffer_length case _: raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}") @@ -97,6 +103,8 @@ class GroupManager: group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) sleeptime_agent_frequency = None + max_message_buffer_length = None + min_message_buffer_length = None max_turns = None termination_token = None manager_agent_id = None @@ -117,11 +125,24 @@ class GroupManager: sleeptime_agent_frequency = group_update.manager_config.sleeptime_agent_frequency if sleeptime_agent_frequency and group.turns_counter is None: group.turns_counter = -1 + case ManagerType.sleeptime: + manager_agent_id = group_update.manager_config.manager_agent_id + max_message_buffer_length = group_update.manager_config.max_message_buffer_length + min_message_buffer_length = group_update.manager_config.min_message_buffer_length + if sleeptime_agent_frequency and group.turns_counter is None: + group.turns_counter = -1 case _: raise ValueError(f"Unsupported manager type: {group_update.manager_config.manager_type}") + # Safety check for buffer length range + self.ensure_buffer_length_range_valid(max_value=max_message_buffer_length, min_value=min_message_buffer_length) + if sleeptime_agent_frequency: group.sleeptime_agent_frequency = sleeptime_agent_frequency + if max_message_buffer_length: + group.max_message_buffer_length = max_message_buffer_length + if min_message_buffer_length: + group.min_message_buffer_length = min_message_buffer_length if max_turns: group.max_turns = max_turns if termination_token: @@ -274,3 +295,24 @@ class GroupManager: if manager_agent: for block in blocks: session.add(BlocksAgents(agent_id=manager_agent.id, block_id=block.id, block_label=block.label)) + + @staticmethod + def ensure_buffer_length_range_valid( + max_value: Optional[int], + min_value: Optional[int], + max_name: str = "max_message_buffer_length", + min_name: str = "min_message_buffer_length", + ) -> None: + """ + 1) If one of max_value/min_value is set, the other must also be set. + 2) If both are set, max_value must be greater than min_value. + """ + # 1) require both-or-none + if (max_value is None) != (min_value is None): + raise ValueError( + f"Both '{max_name}' and '{min_name}' must be provided together " f"(got {max_name}={max_value}, {min_name}={min_value})" + ) + + # 2) valid range + if max_value is not None and min_value is not None and max_value <= min_value: + raise ValueError(f"'{max_name}' must be greater than '{min_name}' " f"(got {max_name}={max_value} <= {min_name}={min_value})")