Files
letta-server/letta/services/block_manager.py
2025-03-31 16:39:23 -07:00

202 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
from typing import List, Optional
from sqlalchemy import func
from letta.orm.block import Block as BlockModel
from letta.orm.block_history import BlockHistory
from letta.orm.enums import ActorType
from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import BlockUpdate, Human, Persona
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types, list_human_files, list_persona_files
class BlockManager:
"""Manager class to handle business logic related to Blocks."""
def __init__(self):
# Fetching the db_context similarly as in ToolManager
from letta.server.db import db_context
self.session_maker = db_context
@enforce_types
def create_or_update_block(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock:
"""Create a new block based on the Block schema."""
db_block = self.get_block_by_id(block.id, actor)
if db_block:
update_data = BlockUpdate(**block.model_dump(to_orm=True, exclude_none=True))
self.update_block(block.id, update_data, actor)
else:
with self.session_maker() as session:
data = block.model_dump(to_orm=True, exclude_none=True)
block = BlockModel(**data, organization_id=actor.organization_id)
block.create(session, actor=actor)
return block.to_pydantic()
@enforce_types
def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
"""Update a block by its ID with the given BlockUpdate object."""
# Safety check for block
with self.session_maker() as session:
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
for key, value in update_data.items():
setattr(block, key, value)
block.update(db_session=session, actor=actor)
return block.to_pydantic()
@enforce_types
def delete_block(self, block_id: str, actor: PydanticUser) -> PydanticBlock:
"""Delete a block by its ID."""
with self.session_maker() as session:
block = BlockModel.read(db_session=session, identifier=block_id)
block.hard_delete(db_session=session, actor=actor)
return block.to_pydantic()
@enforce_types
def get_blocks(
self,
actor: PydanticUser,
label: Optional[str] = None,
is_template: Optional[bool] = None,
template_name: Optional[str] = None,
identifier_keys: Optional[List[str]] = None,
identity_id: Optional[str] = None,
id: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
) -> List[PydanticBlock]:
"""Retrieve blocks based on various optional filters."""
with self.session_maker() as session:
# Prepare filters
filters = {"organization_id": actor.organization_id}
if label:
filters["label"] = label
if is_template is not None:
filters["is_template"] = is_template
if template_name:
filters["template_name"] = template_name
if id:
filters["id"] = id
blocks = BlockModel.list(
db_session=session,
after=after,
limit=limit,
identifier_keys=identifier_keys,
identity_id=identity_id,
**filters,
)
return [block.to_pydantic() for block in blocks]
@enforce_types
def get_block_by_id(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
"""Retrieve a block by its name."""
with self.session_maker() as session:
try:
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
return block.to_pydantic()
except NoResultFound:
return None
@enforce_types
def get_all_blocks_by_ids(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
"""Retrieve blocks by their names."""
with self.session_maker() as session:
blocks = list(
map(lambda obj: obj.to_pydantic(), BlockModel.read_multiple(db_session=session, identifiers=block_ids, actor=actor))
)
# backwards compatibility. previous implementation added None for every block not found.
blocks.extend([None for _ in range(len(block_ids) - len(blocks))])
return blocks
@enforce_types
def add_default_blocks(self, actor: PydanticUser):
for persona_file in list_persona_files():
with open(persona_file, "r", encoding="utf-8") as f:
text = f.read()
name = os.path.basename(persona_file).replace(".txt", "")
self.create_or_update_block(Persona(template_name=name, value=text, is_template=True), actor=actor)
for human_file in list_human_files():
with open(human_file, "r", encoding="utf-8") as f:
text = f.read()
name = os.path.basename(human_file).replace(".txt", "")
self.create_or_update_block(Human(template_name=name, value=text, is_template=True), actor=actor)
@enforce_types
def get_agents_for_block(self, block_id: str, actor: PydanticUser) -> List[PydanticAgentState]:
"""
Retrieve all agents associated with a given block.
"""
with self.session_maker() as session:
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
agents_orm = block.agents
agents_pydantic = [agent.to_pydantic() for agent in agents_orm]
return agents_pydantic
# Block History Functions
@enforce_types
def checkpoint_block(
self,
block_id: str,
actor: PydanticUser,
agent_id: Optional[str] = None,
use_preloaded_block: Optional[BlockModel] = None, # TODO: Useful for testing concurrency
) -> PydanticBlock:
"""
Create a new checkpoint for the given Block by copying its
current state into BlockHistory, using SQLAlchemy's built-in
version_id_col for concurrency checks.
Note: We only have a single commit at the end, to avoid weird intermediate states.
e.g. created a BlockHistory, but the block update failed
"""
"""If `use_preloaded_block` is given, skip re-reading from DB."""
with self.session_maker() as session:
# 1) Load the block via the ORM
if use_preloaded_block is not None:
block = session.merge(use_preloaded_block)
else:
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
# 2) Create a new sequence number for BlockHistory
current_max_seq = session.query(func.max(BlockHistory.sequence_number)).filter(BlockHistory.block_id == block_id).scalar()
next_seq = (current_max_seq or 0) + 1
# 3) Create a snapshot in BlockHistory
history_entry = BlockHistory(
organization_id=actor.organization_id,
block_id=block.id,
sequence_number=next_seq,
description=block.description,
label=block.label,
value=block.value,
limit=block.limit,
metadata_=block.metadata_,
actor_type=ActorType.LETTA_AGENT if agent_id else ActorType.LETTA_USER,
actor_id=agent_id if agent_id else actor.id,
)
history_entry.create(session, actor=actor, no_commit=True)
# 4) Update the blocks pointer
block.current_history_entry_id = history_entry.id
# 5) Now just flush; SQLAlchemy will:
block = block.update(db_session=session, actor=actor, no_commit=True)
session.commit()
# Return the blocks new state
return block.to_pydantic()