feat: add hidden property to group and blocks [PRO-1145] (#4415)
* feat: add hidden property to group and blocks * feat: add hidden property to group and blocks * chore: bup * chore: add hidden property * chore: next --------- Co-authored-by: Shubham Naik <shub@memgpt.ai>
This commit is contained in:
@@ -0,0 +1,35 @@
|
||||
"""add_hidden_property_to_groups_and_blocks
|
||||
|
||||
Revision ID: 5b804970e6a0
|
||||
Revises: ddb69be34a72
|
||||
Create Date: 2025-09-03 22:19:03.825077
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "5b804970e6a0"
|
||||
down_revision: Union[str, None] = "ddb69be34a72"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add hidden column to groups table
|
||||
op.add_column("groups", sa.Column("hidden", sa.Boolean(), nullable=True))
|
||||
|
||||
# Add hidden column to block table
|
||||
op.add_column("block", sa.Column("hidden", sa.Boolean(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove hidden column from block table
|
||||
op.drop_column("block", "hidden")
|
||||
|
||||
# Remove hidden column from groups table
|
||||
op.drop_column("groups", "hidden")
|
||||
@@ -41,6 +41,7 @@ class Block(OrganizationMixin, SqlalchemyBase, ProjectMixin, TemplateEntityMixin
|
||||
|
||||
# permissions of the agent
|
||||
read_only: Mapped[bool] = mapped_column(doc="whether the agent has read-only access to the block", default=False)
|
||||
hidden: Mapped[Optional[bool]] = mapped_column(nullable=True, doc="If set to True, the block will be hidden.")
|
||||
|
||||
# history pointers / locking mechanisms
|
||||
current_history_entry_id: Mapped[Optional[str]] = mapped_column(
|
||||
|
||||
@@ -24,6 +24,7 @@ class Group(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateMixin):
|
||||
min_message_buffer_length: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
|
||||
turns_counter: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
|
||||
last_processed_message_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")
|
||||
hidden: Mapped[Optional[bool]] = mapped_column(nullable=True, doc="If set to True, the group will be hidden.")
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="groups")
|
||||
|
||||
@@ -38,6 +38,10 @@ class BaseBlock(LettaBase, validate_assignment=True):
|
||||
# metadata
|
||||
description: Optional[str] = Field(None, description="Description of the block.")
|
||||
metadata: Optional[dict] = Field({}, description="Metadata of the block.")
|
||||
hidden: Optional[bool] = Field(
|
||||
None,
|
||||
description="If set to True, the block will be hidden.",
|
||||
)
|
||||
|
||||
# def __len__(self):
|
||||
# return len(self.value)
|
||||
|
||||
@@ -49,6 +49,10 @@ class Group(GroupBase):
|
||||
None,
|
||||
description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
|
||||
)
|
||||
hidden: Optional[bool] = Field(
|
||||
None,
|
||||
description="If set to True, the group will be hidden.",
|
||||
)
|
||||
|
||||
@property
|
||||
def manager_config(self) -> ManagerConfig:
|
||||
@@ -170,6 +174,10 @@ class GroupCreate(BaseModel):
|
||||
manager_config: ManagerConfigUnion = Field(RoundRobinManager(), description="")
|
||||
project_id: Optional[str] = Field(None, description="The associated project id.")
|
||||
shared_block_ids: List[str] = Field([], description="")
|
||||
hidden: Optional[bool] = Field(
|
||||
None,
|
||||
description="If set to True, the group will be hidden.",
|
||||
)
|
||||
|
||||
|
||||
class InternalTemplateGroupCreate(GroupCreate):
|
||||
|
||||
@@ -68,6 +68,11 @@ async def list_blocks(
|
||||
"If provided, returns blocks that have exactly this number of connected agents."
|
||||
),
|
||||
),
|
||||
show_hidden_blocks: bool | None = Query(
|
||||
False,
|
||||
include_in_schema=False,
|
||||
description="If set to True, include blocks marked as hidden in the results.",
|
||||
),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
@@ -89,6 +94,7 @@ async def list_blocks(
|
||||
connected_to_agents_count_eq=connected_to_agents_count_eq,
|
||||
limit=limit,
|
||||
after=after,
|
||||
show_hidden_blocks=show_hidden_blocks,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,11 @@ async def list_groups(
|
||||
after: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
limit: Optional[int] = Query(None, description="Limit for pagination"),
|
||||
project_id: Optional[str] = Query(None, description="Search groups by project id"),
|
||||
show_hidden_groups: bool | None = Query(
|
||||
False,
|
||||
include_in_schema=False,
|
||||
description="If set to True, include groups marked as hidden in the results.",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Fetch all multi-agent groups matching query.
|
||||
@@ -37,6 +42,7 @@ async def list_groups(
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
show_hidden_groups=show_hidden_groups,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -188,6 +188,7 @@ class BlockManager:
|
||||
connected_to_agents_count_lt: Optional[int] = None,
|
||||
connected_to_agents_count_eq: Optional[List[int]] = None,
|
||||
ascending: bool = True,
|
||||
show_hidden_blocks: Optional[bool] = None,
|
||||
) -> List[PydanticBlock]:
|
||||
"""Async version of get_blocks method. Retrieve blocks based on various optional filters."""
|
||||
from sqlalchemy import select
|
||||
@@ -228,6 +229,10 @@ class BlockManager:
|
||||
if value_search:
|
||||
query = query.where(BlockModel.value.ilike(f"%{value_search}%"))
|
||||
|
||||
# Apply hidden filter
|
||||
if not show_hidden_blocks:
|
||||
query = query.where((BlockModel.hidden.is_(None)) | (BlockModel.hidden == False))
|
||||
|
||||
needs_distinct = False
|
||||
|
||||
needs_agent_count_join = any(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy import and_, asc, delete, desc, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
@@ -13,6 +14,7 @@ from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
@@ -27,20 +29,34 @@ class GroupManager:
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
show_hidden_groups: Optional[bool] = None,
|
||||
) -> list[PydanticGroup]:
|
||||
async with db_registry.async_session() as session:
|
||||
filters = {"organization_id": actor.organization_id}
|
||||
from sqlalchemy import select
|
||||
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
|
||||
query = select(GroupModel)
|
||||
query = GroupModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||||
|
||||
# Apply filters
|
||||
if project_id:
|
||||
filters["project_id"] = project_id
|
||||
query = query.where(GroupModel.project_id == project_id)
|
||||
if manager_type:
|
||||
filters["manager_type"] = manager_type
|
||||
groups = await GroupModel.list_async(
|
||||
db_session=session,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
**filters,
|
||||
)
|
||||
query = query.where(GroupModel.manager_type == manager_type)
|
||||
|
||||
# Apply hidden filter
|
||||
if not show_hidden_groups:
|
||||
query = query.where((GroupModel.hidden.is_(None)) | (GroupModel.hidden == False))
|
||||
|
||||
# Apply pagination
|
||||
query = await _apply_group_pagination_async(query, before, after, session, ascending=True)
|
||||
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
result = await session.execute(query)
|
||||
groups = result.scalars().all()
|
||||
return [group.to_pydantic() for group in groups]
|
||||
|
||||
@enforce_types
|
||||
@@ -561,3 +577,50 @@ class GroupManager:
|
||||
# 3) ordering
|
||||
if max_value <= min_value:
|
||||
raise ValueError(f"'{max_name}' must be greater than '{min_name}' (got {max_name}={max_value} <= {min_name}={min_value})")
|
||||
|
||||
|
||||
def _cursor_filter(sort_col, id_col, ref_sort_col, ref_id, forward: bool):
|
||||
"""
|
||||
Returns a SQLAlchemy filter expression for cursor-based pagination for groups.
|
||||
|
||||
If `forward` is True, returns records after the reference.
|
||||
If `forward` is False, returns records before the reference.
|
||||
"""
|
||||
if forward:
|
||||
return or_(
|
||||
sort_col > ref_sort_col,
|
||||
and_(sort_col == ref_sort_col, id_col > ref_id),
|
||||
)
|
||||
else:
|
||||
return or_(
|
||||
sort_col < ref_sort_col,
|
||||
and_(sort_col == ref_sort_col, id_col < ref_id),
|
||||
)
|
||||
|
||||
|
||||
async def _apply_group_pagination_async(query, before: Optional[str], after: Optional[str], session, ascending: bool = True) -> any:
|
||||
"""Apply cursor-based pagination to group queries."""
|
||||
sort_column = GroupModel.created_at
|
||||
|
||||
if after:
|
||||
result = (await session.execute(select(sort_column, GroupModel.id).where(GroupModel.id == after))).first()
|
||||
if result:
|
||||
after_sort_value, after_id = result
|
||||
# SQLite does not support as granular timestamping, so we need to round the timestamp
|
||||
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(after_sort_value, datetime):
|
||||
after_sort_value = after_sort_value.strftime("%Y-%m-%d %H:%M:%S")
|
||||
query = query.where(_cursor_filter(sort_column, GroupModel.id, after_sort_value, after_id, forward=ascending))
|
||||
|
||||
if before:
|
||||
result = (await session.execute(select(sort_column, GroupModel.id).where(GroupModel.id == before))).first()
|
||||
if result:
|
||||
before_sort_value, before_id = result
|
||||
# SQLite does not support as granular timestamping, so we need to round the timestamp
|
||||
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(before_sort_value, datetime):
|
||||
before_sort_value = before_sort_value.strftime("%Y-%m-%d %H:%M:%S")
|
||||
query = query.where(_cursor_filter(sort_column, GroupModel.id, before_sort_value, before_id, forward=not ascending))
|
||||
|
||||
# Apply ordering
|
||||
order_fn = asc if ascending else desc
|
||||
query = query.order_by(order_fn(sort_column), order_fn(GroupModel.id))
|
||||
return query
|
||||
|
||||
Reference in New Issue
Block a user