feat: Finish adding min/max buffer to voice sleeptime manager group config (#1954)
This commit is contained in:
@@ -0,0 +1,33 @@
|
||||
"""Add buffer length min max for voice sleeptime
|
||||
|
||||
Revision ID: c56081a05371
|
||||
Revises: 28b8765bdd0a
|
||||
Create Date: 2025-04-30 16:03:41.213750
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c56081a05371"
|
||||
down_revision: Union[str, None] = "28b8765bdd0a"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("groups", sa.Column("max_message_buffer_length", sa.Integer(), nullable=True))
|
||||
op.add_column("groups", sa.Column("min_message_buffer_length", sa.Integer(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("groups", "min_message_buffer_length")
|
||||
op.drop_column("groups", "max_message_buffer_length")
|
||||
# ### end Alembic commands ###
|
||||
@@ -69,8 +69,6 @@ class VoiceAgent(BaseAgent):
|
||||
block_manager: BlockManager,
|
||||
passage_manager: PassageManager,
|
||||
actor: User,
|
||||
message_buffer_limit: int,
|
||||
message_buffer_min: int,
|
||||
):
|
||||
super().__init__(
|
||||
agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor
|
||||
@@ -81,8 +79,6 @@ class VoiceAgent(BaseAgent):
|
||||
self.passage_manager = passage_manager
|
||||
# TODO: This is not guaranteed to exist!
|
||||
self.summary_block_label = "human"
|
||||
self.message_buffer_limit = message_buffer_limit
|
||||
self.message_buffer_min = message_buffer_min
|
||||
|
||||
# Cached archival memory/message size
|
||||
self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_id)
|
||||
@@ -109,8 +105,8 @@ class VoiceAgent(BaseAgent):
|
||||
target_block_label=self.summary_block_label,
|
||||
message_transcripts=[],
|
||||
),
|
||||
message_buffer_limit=self.message_buffer_limit,
|
||||
message_buffer_min=self.message_buffer_min,
|
||||
message_buffer_limit=agent_state.multi_agent_group.max_message_buffer_length,
|
||||
message_buffer_min=agent_state.multi_agent_group.min_message_buffer_length,
|
||||
)
|
||||
|
||||
return summarizer
|
||||
|
||||
@@ -21,6 +21,8 @@ class Group(SqlalchemyBase, OrganizationMixin):
|
||||
termination_token: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")
|
||||
max_turns: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
|
||||
sleeptime_agent_frequency: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
|
||||
max_message_buffer_length: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
|
||||
min_message_buffer_length: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
|
||||
turns_counter: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
|
||||
last_processed_message_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")
|
||||
|
||||
|
||||
@@ -32,6 +32,14 @@ class Group(GroupBase):
|
||||
sleeptime_agent_frequency: Optional[int] = Field(None, description="")
|
||||
turns_counter: Optional[int] = Field(None, description="")
|
||||
last_processed_message_id: Optional[str] = Field(None, description="")
|
||||
max_message_buffer_length: Optional[int] = Field(
|
||||
None,
|
||||
description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
|
||||
)
|
||||
min_message_buffer_length: Optional[int] = Field(
|
||||
None,
|
||||
description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
|
||||
)
|
||||
|
||||
|
||||
class ManagerConfig(BaseModel):
|
||||
@@ -87,11 +95,27 @@ class SleeptimeManagerUpdate(ManagerConfig):
|
||||
class VoiceSleeptimeManager(ManagerConfig):
|
||||
manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="")
|
||||
manager_agent_id: str = Field(..., description="")
|
||||
max_message_buffer_length: Optional[int] = Field(
|
||||
None,
|
||||
description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
|
||||
)
|
||||
min_message_buffer_length: Optional[int] = Field(
|
||||
None,
|
||||
description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
|
||||
)
|
||||
|
||||
|
||||
class VoiceSleeptimeManagerUpdate(ManagerConfig):
|
||||
manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="")
|
||||
manager_agent_id: Optional[str] = Field(None, description="")
|
||||
max_message_buffer_length: Optional[int] = Field(
|
||||
None,
|
||||
description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
|
||||
)
|
||||
min_message_buffer_length: Optional[int] = Field(
|
||||
None,
|
||||
description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
|
||||
)
|
||||
|
||||
|
||||
# class SwarmGroup(ManagerConfig):
|
||||
|
||||
@@ -56,8 +56,6 @@ async def create_voice_chat_completions(
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
message_buffer_limit=8,
|
||||
message_buffer_min=4,
|
||||
)
|
||||
|
||||
# Return the streaming generator
|
||||
|
||||
@@ -862,6 +862,8 @@ class SyncServer(Server):
|
||||
agent_ids=[voice_sleeptime_agent.id],
|
||||
manager_config=VoiceSleeptimeManager(
|
||||
manager_agent_id=main_agent.id,
|
||||
max_message_buffer_length=30,
|
||||
min_message_buffer_length=15,
|
||||
),
|
||||
),
|
||||
actor=actor,
|
||||
|
||||
@@ -80,6 +80,12 @@ class GroupManager:
|
||||
case ManagerType.voice_sleeptime:
|
||||
new_group.manager_type = ManagerType.voice_sleeptime
|
||||
new_group.manager_agent_id = group.manager_config.manager_agent_id
|
||||
max_message_buffer_length = group.manager_config.max_message_buffer_length
|
||||
min_message_buffer_length = group.manager_config.min_message_buffer_length
|
||||
# Safety check for buffer length range
|
||||
self.ensure_buffer_length_range_valid(max_value=max_message_buffer_length, min_value=min_message_buffer_length)
|
||||
new_group.max_message_buffer_length = max_message_buffer_length
|
||||
new_group.min_message_buffer_length = min_message_buffer_length
|
||||
case _:
|
||||
raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}")
|
||||
|
||||
@@ -97,6 +103,8 @@ class GroupManager:
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
|
||||
sleeptime_agent_frequency = None
|
||||
max_message_buffer_length = None
|
||||
min_message_buffer_length = None
|
||||
max_turns = None
|
||||
termination_token = None
|
||||
manager_agent_id = None
|
||||
@@ -117,11 +125,24 @@ class GroupManager:
|
||||
sleeptime_agent_frequency = group_update.manager_config.sleeptime_agent_frequency
|
||||
if sleeptime_agent_frequency and group.turns_counter is None:
|
||||
group.turns_counter = -1
|
||||
case ManagerType.sleeptime:
|
||||
manager_agent_id = group_update.manager_config.manager_agent_id
|
||||
max_message_buffer_length = group_update.manager_config.max_message_buffer_length
|
||||
min_message_buffer_length = group_update.manager_config.min_message_buffer_length
|
||||
if sleeptime_agent_frequency and group.turns_counter is None:
|
||||
group.turns_counter = -1
|
||||
case _:
|
||||
raise ValueError(f"Unsupported manager type: {group_update.manager_config.manager_type}")
|
||||
|
||||
# Safety check for buffer length range
|
||||
self.ensure_buffer_length_range_valid(max_value=max_message_buffer_length, min_value=min_message_buffer_length)
|
||||
|
||||
if sleeptime_agent_frequency:
|
||||
group.sleeptime_agent_frequency = sleeptime_agent_frequency
|
||||
if max_message_buffer_length:
|
||||
group.max_message_buffer_length = max_message_buffer_length
|
||||
if min_message_buffer_length:
|
||||
group.min_message_buffer_length = min_message_buffer_length
|
||||
if max_turns:
|
||||
group.max_turns = max_turns
|
||||
if termination_token:
|
||||
@@ -274,3 +295,24 @@ class GroupManager:
|
||||
if manager_agent:
|
||||
for block in blocks:
|
||||
session.add(BlocksAgents(agent_id=manager_agent.id, block_id=block.id, block_label=block.label))
|
||||
|
||||
@staticmethod
|
||||
def ensure_buffer_length_range_valid(
|
||||
max_value: Optional[int],
|
||||
min_value: Optional[int],
|
||||
max_name: str = "max_message_buffer_length",
|
||||
min_name: str = "min_message_buffer_length",
|
||||
) -> None:
|
||||
"""
|
||||
1) If one of max_value/min_value is set, the other must also be set.
|
||||
2) If both are set, max_value must be greater than min_value.
|
||||
"""
|
||||
# 1) require both-or-none
|
||||
if (max_value is None) != (min_value is None):
|
||||
raise ValueError(
|
||||
f"Both '{max_name}' and '{min_name}' must be provided together " f"(got {max_name}={max_value}, {min_name}={min_value})"
|
||||
)
|
||||
|
||||
# 2) valid range
|
||||
if max_value is not None and min_value is not None and max_value <= min_value:
|
||||
raise ValueError(f"'{max_name}' must be greater than '{min_name}' " f"(got {max_name}={max_value} <= {min_name}={min_value})")
|
||||
|
||||
Reference in New Issue
Block a user