Swap the order of @trace_method and @raise_on_invalid_id decorators across all service managers so that @trace_method is always the first wrapper applied to the function (positioned directly above the method). This ensures the ID validation happens before tracing begins, which is the intended execution order. Files modified: - agent_manager.py (23 occurrences) - archive_manager.py (11 occurrences) - block_manager.py (7 occurrences) - file_manager.py (6 occurrences) - group_manager.py (9 occurrences) - identity_manager.py (10 occurrences) - job_manager.py (7 occurrences) - message_manager.py (2 occurrences) - provider_manager.py (3 occurrences) - sandbox_config_manager.py (7 occurrences) - source_manager.py (5 occurrences) - step_manager.py (13 occurrences)
810 lines
36 KiB
Python
810 lines
36 KiB
Python
import asyncio
|
||
from datetime import datetime
|
||
from typing import Dict, List, Optional
|
||
|
||
from sqlalchemy import and_, delete, func, or_, select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from letta.errors import LettaInvalidArgumentError
|
||
from letta.log import get_logger
|
||
from letta.orm.agent import Agent as AgentModel
|
||
from letta.orm.block import Block as BlockModel
|
||
from letta.orm.block_history import BlockHistory
|
||
from letta.orm.blocks_agents import BlocksAgents
|
||
from letta.orm.errors import NoResultFound
|
||
from letta.otel.tracing import trace_method
|
||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||
from letta.schemas.block import Block as PydanticBlock, BlockUpdate
|
||
from letta.schemas.enums import ActorType, PrimitiveType
|
||
from letta.schemas.user import User as PydanticUser
|
||
from letta.server.db import db_registry
|
||
from letta.settings import DatabaseChoice, settings
|
||
from letta.utils import enforce_types
|
||
from letta.validators import raise_on_invalid_id
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
def validate_block_limit_constraint(update_data: dict, existing_block: BlockModel) -> None:
|
||
"""
|
||
Validates that block limit constraints are satisfied when updating a block.
|
||
|
||
Rules:
|
||
- If limit is being updated, it must be >= the length of the value (existing or new)
|
||
- If value is being updated, its length must not exceed the limit (existing or new)
|
||
|
||
Args:
|
||
update_data: Dictionary of fields to update
|
||
existing_block: The current block being updated
|
||
|
||
Raises:
|
||
LettaInvalidArgumentError: If validation fails
|
||
"""
|
||
# If limit is being updated, ensure it's >= current value length
|
||
if "limit" in update_data:
|
||
# Get the value that will be used (either from update_data or existing)
|
||
value_to_check = update_data.get("value", existing_block.value)
|
||
limit_to_check = update_data["limit"]
|
||
if value_to_check and limit_to_check < len(value_to_check):
|
||
raise LettaInvalidArgumentError(
|
||
f"Limit ({limit_to_check}) cannot be less than current value length ({len(value_to_check)} characters)",
|
||
argument_name="limit",
|
||
)
|
||
# If value is being updated and there's an existing limit, ensure value doesn't exceed limit
|
||
elif "value" in update_data and existing_block.limit:
|
||
if len(update_data["value"]) > existing_block.limit:
|
||
raise LettaInvalidArgumentError(
|
||
f"Value length ({len(update_data['value'])} characters) exceeds block limit ({existing_block.limit} characters)",
|
||
argument_name="value",
|
||
)
|
||
|
||
|
||
def validate_block_creation(block_data: dict) -> None:
|
||
"""
|
||
Validates that block limit constraints are satisfied when creating a block.
|
||
|
||
Rules:
|
||
- If both value and limit are provided, limit must be >= value length
|
||
|
||
Args:
|
||
block_data: Dictionary of block fields for creation
|
||
|
||
Raises:
|
||
LettaInvalidArgumentError: If validation fails
|
||
"""
|
||
value = block_data.get("value")
|
||
limit = block_data.get("limit")
|
||
|
||
if value and limit and len(value) > limit:
|
||
raise LettaInvalidArgumentError(
|
||
f"Block limit ({limit}) must be greater than or equal to value length ({len(value)} characters)", argument_name="limit"
|
||
)
|
||
|
||
|
||
class BlockManager:
|
||
"""Manager class to handle business logic related to Blocks."""
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def create_or_update_block_async(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock:
|
||
"""Create a new block based on the Block schema."""
|
||
db_block = await self.get_block_by_id_async(block.id, actor)
|
||
if db_block:
|
||
update_data = BlockUpdate(**block.model_dump(to_orm=True, exclude_none=True))
|
||
return await self.update_block_async(block.id, update_data, actor)
|
||
else:
|
||
async with db_registry.async_session() as session:
|
||
data = block.model_dump(to_orm=True, exclude_none=True)
|
||
# Validate block creation constraints
|
||
validate_block_creation(data)
|
||
block = BlockModel(**data, organization_id=actor.organization_id)
|
||
await block.create_async(session, actor=actor, no_commit=True, no_refresh=True)
|
||
pydantic_block = block.to_pydantic()
|
||
await session.commit()
|
||
return pydantic_block
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def batch_create_blocks_async(self, blocks: List[PydanticBlock], actor: PydanticUser) -> List[PydanticBlock]:
|
||
"""
|
||
Batch-create multiple Blocks in one transaction for better performance.
|
||
Args:
|
||
blocks: List of PydanticBlock schemas to create
|
||
actor: The user performing the operation
|
||
Returns:
|
||
List of created PydanticBlock instances (with IDs, timestamps, etc.)
|
||
"""
|
||
if not blocks:
|
||
return []
|
||
|
||
async with db_registry.async_session() as session:
|
||
# Validate all blocks before creating any
|
||
for block in blocks:
|
||
block_data = block.model_dump(to_orm=True, exclude_none=True)
|
||
validate_block_creation(block_data)
|
||
|
||
block_models = [
|
||
BlockModel(**block.model_dump(to_orm=True, exclude_none=True), organization_id=actor.organization_id) for block in blocks
|
||
]
|
||
created_models = await BlockModel.batch_create_async(
|
||
items=block_models, db_session=session, actor=actor, no_commit=True, no_refresh=True
|
||
)
|
||
result = [m.to_pydantic() for m in created_models]
|
||
await session.commit()
|
||
return result
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||
@trace_method
|
||
async def update_block_async(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
|
||
"""Update a block by its ID with the given BlockUpdate object."""
|
||
async with db_registry.async_session() as session:
|
||
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
|
||
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||
|
||
# Validate limit constraints before updating
|
||
validate_block_limit_constraint(update_data, block)
|
||
|
||
for key, value in update_data.items():
|
||
setattr(block, key, value)
|
||
|
||
await block.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||
pydantic_block = block.to_pydantic()
|
||
await session.commit()
|
||
return pydantic_block
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||
@trace_method
|
||
async def delete_block_async(self, block_id: str, actor: PydanticUser) -> None:
|
||
"""Delete a block by its ID."""
|
||
async with db_registry.async_session() as session:
|
||
# First, delete all references in blocks_agents table
|
||
await session.execute(delete(BlocksAgents).where(BlocksAgents.block_id == block_id))
|
||
await session.flush()
|
||
|
||
# Then delete the block itself
|
||
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
|
||
await block.hard_delete_async(db_session=session, actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def get_blocks_async(
|
||
self,
|
||
actor: PydanticUser,
|
||
label: Optional[str] = None,
|
||
is_template: Optional[bool] = None,
|
||
template_name: Optional[str] = None,
|
||
identity_id: Optional[str] = None,
|
||
identifier_keys: Optional[List[str]] = None,
|
||
project_id: Optional[str] = None,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
limit: Optional[int] = 50,
|
||
label_search: Optional[str] = None,
|
||
description_search: Optional[str] = None,
|
||
value_search: Optional[str] = None,
|
||
connected_to_agents_count_gt: Optional[int] = None,
|
||
connected_to_agents_count_lt: Optional[int] = None,
|
||
connected_to_agents_count_eq: Optional[List[int]] = None,
|
||
ascending: bool = True,
|
||
show_hidden_blocks: Optional[bool] = None,
|
||
) -> List[PydanticBlock]:
|
||
"""Async version of get_blocks method. Retrieve blocks based on various optional filters."""
|
||
from sqlalchemy import select
|
||
from sqlalchemy.orm import noload
|
||
|
||
from letta.orm.sqlalchemy_base import AccessType
|
||
|
||
async with db_registry.async_session() as session:
|
||
# Start with a basic query
|
||
query = select(BlockModel)
|
||
|
||
# Explicitly avoid loading relationships
|
||
query = query.options(noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups))
|
||
|
||
# Apply access control
|
||
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||
|
||
# Add filters
|
||
query = query.where(BlockModel.organization_id == actor.organization_id)
|
||
if label:
|
||
query = query.where(BlockModel.label == label)
|
||
|
||
if is_template is not None:
|
||
query = query.where(BlockModel.is_template == is_template)
|
||
|
||
if template_name:
|
||
query = query.where(BlockModel.template_name == template_name)
|
||
|
||
if project_id:
|
||
query = query.where(BlockModel.project_id == project_id)
|
||
|
||
if label_search and not label:
|
||
query = query.where(BlockModel.label.ilike(f"%{label_search}%"))
|
||
|
||
if description_search:
|
||
query = query.where(BlockModel.description.ilike(f"%{description_search}%"))
|
||
|
||
if value_search:
|
||
query = query.where(BlockModel.value.ilike(f"%{value_search}%"))
|
||
|
||
# Apply hidden filter
|
||
if not show_hidden_blocks:
|
||
query = query.where((BlockModel.hidden.is_(None)) | (BlockModel.hidden == False))
|
||
|
||
needs_distinct = False
|
||
|
||
needs_agent_count_join = any(
|
||
condition is not None
|
||
for condition in [connected_to_agents_count_gt, connected_to_agents_count_lt, connected_to_agents_count_eq]
|
||
)
|
||
|
||
# If any agent count filters are specified, create a single subquery and apply all filters
|
||
if needs_agent_count_join:
|
||
# Create a subquery to count agents per block
|
||
agent_count_subquery = (
|
||
select(BlocksAgents.block_id, func.count(BlocksAgents.agent_id).label("agent_count"))
|
||
.group_by(BlocksAgents.block_id)
|
||
.subquery()
|
||
)
|
||
|
||
# Determine if we need a left join (for cases involving 0 counts)
|
||
needs_left_join = (connected_to_agents_count_lt is not None) or (
|
||
connected_to_agents_count_eq is not None and 0 in connected_to_agents_count_eq
|
||
)
|
||
|
||
if needs_left_join:
|
||
# Left join to include blocks with no agents
|
||
query = query.outerjoin(agent_count_subquery, BlockModel.id == agent_count_subquery.c.block_id)
|
||
# Use coalesce to treat NULL as 0 for blocks with no agents
|
||
agent_count_expr = func.coalesce(agent_count_subquery.c.agent_count, 0)
|
||
else:
|
||
# Inner join since we don't need blocks with no agents
|
||
query = query.join(agent_count_subquery, BlockModel.id == agent_count_subquery.c.block_id)
|
||
agent_count_expr = agent_count_subquery.c.agent_count
|
||
|
||
# Build the combined filter conditions
|
||
conditions = []
|
||
|
||
if connected_to_agents_count_gt is not None:
|
||
conditions.append(agent_count_expr > connected_to_agents_count_gt)
|
||
|
||
if connected_to_agents_count_lt is not None:
|
||
conditions.append(agent_count_expr < connected_to_agents_count_lt)
|
||
|
||
if connected_to_agents_count_eq is not None:
|
||
conditions.append(agent_count_expr.in_(connected_to_agents_count_eq))
|
||
|
||
# Apply all conditions with AND logic
|
||
if conditions:
|
||
query = query.where(and_(*conditions))
|
||
|
||
needs_distinct = True
|
||
|
||
if identifier_keys:
|
||
query = query.join(BlockModel.identities).filter(
|
||
BlockModel.identities.property.mapper.class_.identifier_key.in_(identifier_keys)
|
||
)
|
||
needs_distinct = True
|
||
|
||
if identity_id:
|
||
query = query.join(BlockModel.identities).filter(BlockModel.identities.property.mapper.class_.id == identity_id)
|
||
needs_distinct = True
|
||
|
||
if after:
|
||
result = (await session.execute(select(BlockModel.created_at, BlockModel.id).where(BlockModel.id == after))).first()
|
||
if result:
|
||
after_sort_value, after_id = result
|
||
# SQLite does not support as granular timestamping, so we need to round the timestamp
|
||
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(after_sort_value, datetime):
|
||
after_sort_value = after_sort_value.strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
if ascending:
|
||
query = query.where(
|
||
BlockModel.created_at > after_sort_value,
|
||
or_(BlockModel.created_at == after_sort_value, BlockModel.id > after_id),
|
||
)
|
||
else:
|
||
query = query.where(
|
||
BlockModel.created_at < after_sort_value,
|
||
or_(BlockModel.created_at == after_sort_value, BlockModel.id < after_id),
|
||
)
|
||
|
||
if before:
|
||
result = (await session.execute(select(BlockModel.created_at, BlockModel.id).where(BlockModel.id == before))).first()
|
||
if result:
|
||
before_sort_value, before_id = result
|
||
# SQLite does not support as granular timestamping, so we need to round the timestamp
|
||
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(before_sort_value, datetime):
|
||
before_sort_value = before_sort_value.strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
if ascending:
|
||
query = query.where(
|
||
BlockModel.created_at < before_sort_value,
|
||
or_(BlockModel.created_at == before_sort_value, BlockModel.id < before_id),
|
||
)
|
||
else:
|
||
query = query.where(
|
||
BlockModel.created_at > before_sort_value,
|
||
or_(BlockModel.created_at == before_sort_value, BlockModel.id > before_id),
|
||
)
|
||
|
||
# Apply ordering and handle distinct if needed
|
||
if needs_distinct:
|
||
if ascending:
|
||
query = query.distinct(BlockModel.id).order_by(BlockModel.id.asc(), BlockModel.created_at.asc())
|
||
else:
|
||
query = query.distinct(BlockModel.id).order_by(BlockModel.id.desc(), BlockModel.created_at.desc())
|
||
else:
|
||
if ascending:
|
||
query = query.order_by(BlockModel.created_at.asc(), BlockModel.id.asc())
|
||
else:
|
||
query = query.order_by(BlockModel.created_at.desc(), BlockModel.id.desc())
|
||
|
||
# Add limit
|
||
if limit:
|
||
query = query.limit(limit)
|
||
|
||
# Execute the query
|
||
result = await session.execute(query)
|
||
blocks = result.scalars().all()
|
||
|
||
return [block.to_pydantic() for block in blocks]
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||
@trace_method
|
||
async def get_block_by_id_async(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
|
||
"""Retrieve a block by its name."""
|
||
async with db_registry.async_session() as session:
|
||
try:
|
||
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
|
||
return block.to_pydantic()
|
||
except NoResultFound:
|
||
return None
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
|
||
"""Retrieve blocks by their ids without loading unnecessary relationships. Async implementation."""
|
||
from sqlalchemy import select
|
||
from sqlalchemy.orm import noload
|
||
|
||
from letta.orm.sqlalchemy_base import AccessType
|
||
|
||
if not block_ids:
|
||
return []
|
||
|
||
async with db_registry.async_session() as session:
|
||
# Start with a basic query
|
||
query = select(BlockModel)
|
||
|
||
# Add ID filter
|
||
query = query.where(BlockModel.id.in_(block_ids))
|
||
|
||
# Explicitly avoid loading relationships
|
||
query = query.options(noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups))
|
||
|
||
# Apply access control if actor is provided
|
||
if actor:
|
||
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||
|
||
# TODO: Add soft delete filter if applicable
|
||
# if hasattr(BlockModel, "is_deleted"):
|
||
# query = query.where(BlockModel.is_deleted == False)
|
||
|
||
# Execute the query
|
||
result = await session.execute(query)
|
||
blocks = result.scalars().all()
|
||
|
||
# Convert to Pydantic models
|
||
pydantic_blocks = [block.to_pydantic() for block in blocks]
|
||
|
||
# For backward compatibility, add None for missing blocks
|
||
if len(pydantic_blocks) < len(block_ids):
|
||
{block.id for block in pydantic_blocks}
|
||
result_blocks = []
|
||
for block_id in block_ids:
|
||
block = next((b for b in pydantic_blocks if b.id == block_id), None)
|
||
result_blocks.append(block)
|
||
return result_blocks
|
||
|
||
return pydantic_blocks
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||
@trace_method
|
||
async def get_agents_for_block_async(
|
||
self,
|
||
block_id: str,
|
||
actor: PydanticUser,
|
||
include_relationships: Optional[List[str]] = None,
|
||
include: List[str] = [],
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
limit: Optional[int] = 50,
|
||
ascending: bool = True,
|
||
) -> List[PydanticAgentState]:
|
||
"""
|
||
Retrieve all agents associated with a given block with pagination support.
|
||
|
||
Args:
|
||
block_id: ID of the block to get agents for
|
||
actor: User performing the operation
|
||
include_relationships: List of relationships to include in the response
|
||
before: Cursor for pagination (get items before this ID)
|
||
after: Cursor for pagination (get items after this ID)
|
||
limit: Maximum number of items to return
|
||
ascending: Sort order (True for ascending, False for descending)
|
||
|
||
Returns:
|
||
List of agent states associated with the block
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# Start with a basic query
|
||
query = (
|
||
select(AgentModel)
|
||
.where(AgentModel.id.in_(select(BlocksAgents.agent_id).where(BlocksAgents.block_id == block_id)))
|
||
.where(AgentModel.organization_id == actor.organization_id)
|
||
)
|
||
|
||
# Apply pagination using cursor-based approach
|
||
if after:
|
||
result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after))).first()
|
||
if result:
|
||
after_sort_value, after_id = result
|
||
# SQLite does not support as granular timestamping, so we need to round the timestamp
|
||
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(after_sort_value, datetime):
|
||
after_sort_value = after_sort_value.strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
if ascending:
|
||
query = query.where(
|
||
AgentModel.created_at > after_sort_value,
|
||
or_(AgentModel.created_at == after_sort_value, AgentModel.id > after_id),
|
||
)
|
||
else:
|
||
query = query.where(
|
||
AgentModel.created_at < after_sort_value,
|
||
or_(AgentModel.created_at == after_sort_value, AgentModel.id < after_id),
|
||
)
|
||
|
||
if before:
|
||
result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before))).first()
|
||
if result:
|
||
before_sort_value, before_id = result
|
||
# SQLite does not support as granular timestamping, so we need to round the timestamp
|
||
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(before_sort_value, datetime):
|
||
before_sort_value = before_sort_value.strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
if ascending:
|
||
query = query.where(
|
||
AgentModel.created_at < before_sort_value,
|
||
or_(AgentModel.created_at == before_sort_value, AgentModel.id < before_id),
|
||
)
|
||
else:
|
||
query = query.where(
|
||
AgentModel.created_at > before_sort_value,
|
||
or_(AgentModel.created_at == before_sort_value, AgentModel.id > before_id),
|
||
)
|
||
|
||
# Apply sorting
|
||
if ascending:
|
||
query = query.order_by(AgentModel.created_at.asc(), AgentModel.id.asc())
|
||
else:
|
||
query = query.order_by(AgentModel.created_at.desc(), AgentModel.id.desc())
|
||
|
||
# Apply limit
|
||
if limit:
|
||
query = query.limit(limit)
|
||
|
||
# Execute the query
|
||
result = await session.execute(query)
|
||
agents_orm = result.scalars().all()
|
||
|
||
agents = await asyncio.gather(
|
||
*[agent.to_pydantic_async(include_relationships=include_relationships, include=include) for agent in agents_orm]
|
||
)
|
||
return agents
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def size_async(self, actor: PydanticUser) -> int:
|
||
"""
|
||
Get the total count of blocks for the given user.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
return await BlockModel.size_async(db_session=session, actor=actor)
|
||
|
||
# Block History Functions
|
||
|
||
@enforce_types
|
||
async def _move_block_to_sequence(self, session: AsyncSession, block: BlockModel, target_seq: int, actor: PydanticUser) -> BlockModel:
|
||
"""
|
||
Internal helper that moves the 'block' to the specified 'target_seq' within BlockHistory.
|
||
1) Find the BlockHistory row at sequence_number=target_seq
|
||
2) Copy fields into the block
|
||
3) Update and flush (no_commit=True) - the caller is responsible for final commit
|
||
|
||
Raises:
|
||
NoResultFound: if no BlockHistory row for (block_id, target_seq)
|
||
"""
|
||
if not block.id:
|
||
raise ValueError("Block is missing an ID. Cannot move sequence.")
|
||
|
||
stmt = select(BlockHistory).filter(
|
||
BlockHistory.block_id == block.id,
|
||
BlockHistory.sequence_number == target_seq,
|
||
)
|
||
result = await session.execute(stmt)
|
||
target_entry = result.scalar_one_or_none()
|
||
if not target_entry:
|
||
raise NoResultFound(f"No BlockHistory row found for block_id={block.id} at sequence={target_seq}")
|
||
|
||
# Copy fields from target_entry to block
|
||
block.description = target_entry.description # type: ignore
|
||
block.label = target_entry.label # type: ignore
|
||
block.value = target_entry.value # type: ignore
|
||
block.limit = target_entry.limit # type: ignore
|
||
block.metadata_ = target_entry.metadata_ # type: ignore
|
||
block.current_history_entry_id = target_entry.id # type: ignore
|
||
|
||
# Update in DB (optimistic locking).
|
||
# We'll do a flush now; the caller does final commit.
|
||
updated_block = await block.update_async(db_session=session, actor=actor, no_commit=True)
|
||
return updated_block
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def bulk_update_block_values_async(
|
||
self, updates: Dict[str, str], actor: PydanticUser, return_hydrated: bool = False
|
||
) -> Optional[List[PydanticBlock]]:
|
||
"""
|
||
Bulk-update the `value` field for multiple blocks in one transaction.
|
||
|
||
Args:
|
||
updates: mapping of block_id -> new value
|
||
actor: the user performing the update (for org scoping, permissions, audit)
|
||
return_hydrated: whether to return the pydantic Block objects that were updated
|
||
|
||
Returns:
|
||
the updated Block objects as Pydantic schemas
|
||
|
||
Raises:
|
||
NoResultFound if any block_id doesn't exist or isn't visible to this actor
|
||
ValueError if any new value exceeds its block's limit
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
query = select(BlockModel).where(BlockModel.id.in_(updates.keys()), BlockModel.organization_id == actor.organization_id)
|
||
result = await session.execute(query)
|
||
blocks = result.scalars().all()
|
||
|
||
found_ids = {b.id for b in blocks}
|
||
missing = set(updates.keys()) - found_ids
|
||
if missing:
|
||
logger.warning(f"Block IDs not found or inaccessible, skipping during bulk update: {missing!r}")
|
||
|
||
for block in blocks:
|
||
new_val = updates[block.id]
|
||
if len(new_val) > block.limit:
|
||
logger.warning(f"Value length ({len(new_val)}) exceeds limit ({block.limit}) for block {block.id!r}, truncating...")
|
||
new_val = new_val[: block.limit]
|
||
block.value = new_val
|
||
|
||
await session.commit()
|
||
|
||
if return_hydrated:
|
||
# TODO: implement for async
|
||
pass
|
||
|
||
return None
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@trace_method
|
||
async def checkpoint_block_async(
|
||
self,
|
||
block_id: str,
|
||
actor: PydanticUser,
|
||
agent_id: Optional[str] = None,
|
||
use_preloaded_block: Optional[BlockModel] = None, # For concurrency tests
|
||
) -> PydanticBlock:
|
||
"""
|
||
Create a new checkpoint for the given Block by copying its
|
||
current state into BlockHistory, using SQLAlchemy's built-in
|
||
version_id_col for concurrency checks.
|
||
|
||
- If the block was undone to an earlier checkpoint, we remove
|
||
any "future" checkpoints beyond the current state to keep a
|
||
strictly linear history.
|
||
- A single commit at the end ensures atomicity.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# 1) Load the Block
|
||
if use_preloaded_block is not None:
|
||
block = await session.merge(use_preloaded_block)
|
||
else:
|
||
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
|
||
|
||
# 2) Identify the block's current checkpoint (if any)
|
||
current_entry = None
|
||
if block.current_history_entry_id:
|
||
current_entry = await session.get(BlockHistory, block.current_history_entry_id)
|
||
|
||
# The current sequence, or 0 if no checkpoints exist
|
||
current_seq = current_entry.sequence_number if current_entry else 0
|
||
|
||
# 3) Truncate any future checkpoints
|
||
# If we are at seq=2, but there's a seq=3 or higher from a prior "redo chain",
|
||
# remove those, so we maintain a strictly linear undo/redo stack.
|
||
stmt = select(BlockHistory).filter(BlockHistory.block_id == block.id, BlockHistory.sequence_number > current_seq)
|
||
result = await session.execute(stmt)
|
||
for entry in result.scalars():
|
||
session.delete(entry)
|
||
|
||
# Flush the deletes to ensure they're executed before we create a new entry
|
||
await session.flush()
|
||
|
||
# 4) Determine the next sequence number
|
||
next_seq = current_seq + 1
|
||
|
||
# 5) Create a new BlockHistory row reflecting the block's current state
|
||
history_entry = BlockHistory(
|
||
organization_id=actor.organization_id,
|
||
block_id=block.id,
|
||
sequence_number=next_seq,
|
||
description=block.description,
|
||
label=block.label,
|
||
value=block.value,
|
||
limit=block.limit,
|
||
metadata_=block.metadata_,
|
||
actor_type=ActorType.LETTA_AGENT if agent_id else ActorType.LETTA_USER,
|
||
actor_id=agent_id if agent_id else actor.id,
|
||
)
|
||
await history_entry.create_async(session, actor=actor, no_commit=True)
|
||
|
||
# 6) Update the block’s pointer to the new checkpoint
|
||
block.current_history_entry_id = history_entry.id
|
||
|
||
# 7) Flush changes, then commit once
|
||
block = await block.update_async(db_session=session, actor=actor, no_commit=True)
|
||
await session.commit()
|
||
|
||
return block.to_pydantic()
|
||
|
||
@enforce_types
|
||
async def _move_block_to_sequence(self, session: AsyncSession, block: BlockModel, target_seq: int, actor: PydanticUser) -> BlockModel:
|
||
"""
|
||
Internal helper that moves the 'block' to the specified 'target_seq' within BlockHistory.
|
||
1) Find the BlockHistory row at sequence_number=target_seq
|
||
2) Copy fields into the block
|
||
3) Update and flush (no_commit=True) - the caller is responsible for final commit
|
||
|
||
Raises:
|
||
NoResultFound: if no BlockHistory row for (block_id, target_seq)
|
||
"""
|
||
if not block.id:
|
||
raise ValueError("Block is missing an ID. Cannot move sequence.")
|
||
|
||
stmt = select(BlockHistory).filter(
|
||
BlockHistory.block_id == block.id,
|
||
BlockHistory.sequence_number == target_seq,
|
||
)
|
||
result = await session.execute(stmt)
|
||
target_entry = result.scalar_one_or_none()
|
||
if not target_entry:
|
||
raise NoResultFound(f"No BlockHistory row found for block_id={block.id} at sequence={target_seq}")
|
||
|
||
# Copy fields from target_entry to block
|
||
block.description = target_entry.description # type: ignore
|
||
block.label = target_entry.label # type: ignore
|
||
block.value = target_entry.value # type: ignore
|
||
block.limit = target_entry.limit # type: ignore
|
||
block.metadata_ = target_entry.metadata_ # type: ignore
|
||
block.current_history_entry_id = target_entry.id # type: ignore
|
||
|
||
# Update in DB (optimistic locking).
|
||
# We'll do a flush now; the caller does final commit.
|
||
updated_block = await block.update_async(db_session=session, actor=actor, no_commit=True)
|
||
return updated_block
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||
@trace_method
|
||
async def undo_checkpoint_block(
|
||
self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None
|
||
) -> PydanticBlock:
|
||
"""
|
||
Move the block to the immediately previous checkpoint in BlockHistory.
|
||
If older sequences have been pruned, we jump to the largest sequence
|
||
number that is still < current_seq.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# 1) Load the current block
|
||
block = (
|
||
await session.merge(use_preloaded_block)
|
||
if use_preloaded_block
|
||
else await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
|
||
)
|
||
|
||
if not block.current_history_entry_id:
|
||
raise LettaInvalidArgumentError(f"Block {block_id} has no history entry - cannot undo.", argument_name="block_id")
|
||
|
||
current_entry = await session.get(BlockHistory, block.current_history_entry_id)
|
||
if not current_entry:
|
||
raise NoResultFound(f"BlockHistory row not found for id={block.current_history_entry_id}")
|
||
|
||
current_seq = current_entry.sequence_number
|
||
|
||
# 2) Find the largest sequence < current_seq
|
||
stmt = (
|
||
select(BlockHistory)
|
||
.filter(BlockHistory.block_id == block.id, BlockHistory.sequence_number < current_seq)
|
||
.order_by(BlockHistory.sequence_number.desc())
|
||
.limit(1)
|
||
)
|
||
result = await session.execute(stmt)
|
||
previous_entry = result.scalar_one_or_none()
|
||
if not previous_entry:
|
||
# No earlier checkpoint available
|
||
raise LettaInvalidArgumentError(
|
||
f"Block {block_id} is already at the earliest checkpoint (seq={current_seq}). Cannot undo further.",
|
||
argument_name="block_id",
|
||
)
|
||
|
||
# 3) Move to that sequence
|
||
block = await self._move_block_to_sequence(session, block, previous_entry.sequence_number, actor)
|
||
|
||
# 4) Commit
|
||
await session.commit()
|
||
return block.to_pydantic()
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||
@trace_method
|
||
async def redo_checkpoint_block(
|
||
self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None
|
||
) -> PydanticBlock:
|
||
"""
|
||
Move the block to the next checkpoint if it exists.
|
||
If some middle checkpoints have been pruned, we jump to the smallest
|
||
sequence > current_seq that remains.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
block = (
|
||
await session.merge(use_preloaded_block)
|
||
if use_preloaded_block
|
||
else await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
|
||
)
|
||
|
||
if not block.current_history_entry_id:
|
||
raise LettaInvalidArgumentError(f"Block {block_id} has no history entry - cannot redo.", argument_name="block_id")
|
||
|
||
current_entry = await session.get(BlockHistory, block.current_history_entry_id)
|
||
if not current_entry:
|
||
raise LettaInvalidArgumentError(
|
||
f"BlockHistory row not found for id={block.current_history_entry_id}", argument_name="block_id"
|
||
)
|
||
|
||
current_seq = current_entry.sequence_number
|
||
|
||
# Find the smallest sequence that is > current_seq
|
||
stmt = (
|
||
select(BlockHistory)
|
||
.filter(BlockHistory.block_id == block.id, BlockHistory.sequence_number > current_seq)
|
||
.order_by(BlockHistory.sequence_number.asc())
|
||
.limit(1)
|
||
)
|
||
result = await session.execute(stmt)
|
||
next_entry = result.scalar_one_or_none()
|
||
if not next_entry:
|
||
raise LettaInvalidArgumentError(
|
||
f"Block {block_id} is at the highest checkpoint (seq={current_seq}). Cannot redo further.", argument_name="block_id"
|
||
)
|
||
|
||
block = await self._move_block_to_sequence(session, block, next_entry.sequence_number, actor)
|
||
|
||
await session.commit()
|
||
return block.to_pydantic()
|