feat: Finish adding min/max buffer to voice sleeptime manager group config (#1954)

This commit is contained in:
Matthew Zhou
2025-04-30 17:20:53 -07:00
committed by GitHub
parent 84f66aedd3
commit 57218d2b8f
7 changed files with 105 additions and 8 deletions

View File

@@ -0,0 +1,33 @@
"""Add buffer length min max for voice sleeptime
Revision ID: c56081a05371
Revises: 28b8765bdd0a
Create Date: 2025-04-30 16:03:41.213750
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "c56081a05371"
down_revision: Union[str, None] = "28b8765bdd0a"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("groups", sa.Column("max_message_buffer_length", sa.Integer(), nullable=True))
op.add_column("groups", sa.Column("min_message_buffer_length", sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("groups", "min_message_buffer_length")
op.drop_column("groups", "max_message_buffer_length")
# ### end Alembic commands ###

View File

@@ -69,8 +69,6 @@ class VoiceAgent(BaseAgent):
block_manager: BlockManager,
passage_manager: PassageManager,
actor: User,
message_buffer_limit: int,
message_buffer_min: int,
):
super().__init__(
agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor
@@ -81,8 +79,6 @@ class VoiceAgent(BaseAgent):
self.passage_manager = passage_manager
# TODO: This is not guaranteed to exist!
self.summary_block_label = "human"
self.message_buffer_limit = message_buffer_limit
self.message_buffer_min = message_buffer_min
# Cached archival memory/message size
self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_id)
@@ -109,8 +105,8 @@ class VoiceAgent(BaseAgent):
target_block_label=self.summary_block_label,
message_transcripts=[],
),
message_buffer_limit=self.message_buffer_limit,
message_buffer_min=self.message_buffer_min,
message_buffer_limit=agent_state.multi_agent_group.max_message_buffer_length,
message_buffer_min=agent_state.multi_agent_group.min_message_buffer_length,
)
return summarizer

View File

@@ -21,6 +21,8 @@ class Group(SqlalchemyBase, OrganizationMixin):
termination_token: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")
max_turns: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
sleeptime_agent_frequency: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
max_message_buffer_length: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
min_message_buffer_length: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
turns_counter: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
last_processed_message_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")

View File

@@ -32,6 +32,14 @@ class Group(GroupBase):
sleeptime_agent_frequency: Optional[int] = Field(None, description="")
turns_counter: Optional[int] = Field(None, description="")
last_processed_message_id: Optional[str] = Field(None, description="")
max_message_buffer_length: Optional[int] = Field(
None,
description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
)
min_message_buffer_length: Optional[int] = Field(
None,
description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
)
class ManagerConfig(BaseModel):
@@ -87,11 +95,27 @@ class SleeptimeManagerUpdate(ManagerConfig):
class VoiceSleeptimeManager(ManagerConfig):
manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="")
manager_agent_id: str = Field(..., description="")
max_message_buffer_length: Optional[int] = Field(
None,
description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
)
min_message_buffer_length: Optional[int] = Field(
None,
description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
)
class VoiceSleeptimeManagerUpdate(ManagerConfig):
manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="")
manager_agent_id: Optional[str] = Field(None, description="")
max_message_buffer_length: Optional[int] = Field(
None,
description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
)
min_message_buffer_length: Optional[int] = Field(
None,
description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
)
# class SwarmGroup(ManagerConfig):

View File

@@ -56,8 +56,6 @@ async def create_voice_chat_completions(
block_manager=server.block_manager,
passage_manager=server.passage_manager,
actor=actor,
message_buffer_limit=8,
message_buffer_min=4,
)
# Return the streaming generator

View File

@@ -862,6 +862,8 @@ class SyncServer(Server):
agent_ids=[voice_sleeptime_agent.id],
manager_config=VoiceSleeptimeManager(
manager_agent_id=main_agent.id,
max_message_buffer_length=30,
min_message_buffer_length=15,
),
),
actor=actor,

View File

@@ -80,6 +80,12 @@ class GroupManager:
case ManagerType.voice_sleeptime:
new_group.manager_type = ManagerType.voice_sleeptime
new_group.manager_agent_id = group.manager_config.manager_agent_id
max_message_buffer_length = group.manager_config.max_message_buffer_length
min_message_buffer_length = group.manager_config.min_message_buffer_length
# Safety check for buffer length range
self.ensure_buffer_length_range_valid(max_value=max_message_buffer_length, min_value=min_message_buffer_length)
new_group.max_message_buffer_length = max_message_buffer_length
new_group.min_message_buffer_length = min_message_buffer_length
case _:
raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}")
@@ -97,6 +103,8 @@ class GroupManager:
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
sleeptime_agent_frequency = None
max_message_buffer_length = None
min_message_buffer_length = None
max_turns = None
termination_token = None
manager_agent_id = None
@@ -117,11 +125,24 @@ class GroupManager:
sleeptime_agent_frequency = group_update.manager_config.sleeptime_agent_frequency
if sleeptime_agent_frequency and group.turns_counter is None:
group.turns_counter = -1
case ManagerType.sleeptime:
manager_agent_id = group_update.manager_config.manager_agent_id
max_message_buffer_length = group_update.manager_config.max_message_buffer_length
min_message_buffer_length = group_update.manager_config.min_message_buffer_length
if sleeptime_agent_frequency and group.turns_counter is None:
group.turns_counter = -1
case _:
raise ValueError(f"Unsupported manager type: {group_update.manager_config.manager_type}")
# Safety check for buffer length range
self.ensure_buffer_length_range_valid(max_value=max_message_buffer_length, min_value=min_message_buffer_length)
if sleeptime_agent_frequency:
group.sleeptime_agent_frequency = sleeptime_agent_frequency
if max_message_buffer_length:
group.max_message_buffer_length = max_message_buffer_length
if min_message_buffer_length:
group.min_message_buffer_length = min_message_buffer_length
if max_turns:
group.max_turns = max_turns
if termination_token:
@@ -274,3 +295,24 @@ class GroupManager:
if manager_agent:
for block in blocks:
session.add(BlocksAgents(agent_id=manager_agent.id, block_id=block.id, block_label=block.label))
@staticmethod
def ensure_buffer_length_range_valid(
max_value: Optional[int],
min_value: Optional[int],
max_name: str = "max_message_buffer_length",
min_name: str = "min_message_buffer_length",
) -> None:
"""
1) If one of max_value/min_value is set, the other must also be set.
2) If both are set, max_value must be greater than min_value.
"""
# 1) require both-or-none
if (max_value is None) != (min_value is None):
raise ValueError(
f"Both '{max_name}' and '{min_name}' must be provided together " f"(got {max_name}={max_value}, {min_name}={min_value})"
)
# 2) valid range
if max_value is not None and min_value is not None and max_value <= min_value:
raise ValueError(f"'{max_name}' must be greater than '{min_name}' " f"(got {max_name}={max_value} <= {min_name}={min_value})")