fix: Update block label also updates the BlocksAgents table (#2106)

This commit is contained in:
Matthew Zhou
2024-11-26 10:21:30 -08:00
committed by GitHub
parent 8711e1dc00
commit 4d9b4eef9d
7 changed files with 63 additions and 6 deletions

View File

@@ -10,7 +10,7 @@ from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import Human, Persona
if TYPE_CHECKING:
from letta.orm.organization import Organization
from letta.orm import BlocksAgents, Organization
class Block(OrganizationMixin, SqlalchemyBase):
@@ -35,6 +35,7 @@ class Block(OrganizationMixin, SqlalchemyBase):
# relationships
organization: Mapped[Optional["Organization"]] = relationship("Organization")
blocks_agents: Mapped[list["BlocksAgents"]] = relationship("BlocksAgents", back_populates="block", cascade="all, delete")
def to_pydantic(self) -> Type:
match self.label:

View File

@@ -1,5 +1,5 @@
from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.blocks_agents import BlocksAgents as PydanticBlocksAgents
@@ -27,3 +27,6 @@ class BlocksAgents(SqlalchemyBase):
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
block_id: Mapped[str] = mapped_column(String, primary_key=True)
block_label: Mapped[str] = mapped_column(String, primary_key=True)
# relationships
block: Mapped["Block"] = relationship("Block", back_populates="blocks_agents")

View File

@@ -180,6 +180,19 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
"""Handle database errors and raise appropriate custom exceptions."""
orig = e.orig # Extract the original error from the DBAPIError
error_code = None
error_message = str(orig) if orig else str(e)
logger.info(f"Handling DBAPIError: {error_message}")
# Handle SQLite-specific errors
if "UNIQUE constraint failed" in error_message:
raise UniqueConstraintViolationError(
f"A unique constraint was violated for {cls.__name__}. Check your input for duplicates: {e}"
) from e
if "FOREIGN KEY constraint failed" in error_message:
raise ForeignKeyConstraintViolationError(
f"A foreign key constraint was violated for {cls.__name__}. Check your input for missing or invalid references: {e}"
) from e
# For psycopg2
if hasattr(orig, "pgcode"):

View File

@@ -30,7 +30,7 @@ class BaseBlock(LettaBase, validate_assignment=True):
@model_validator(mode="after")
def verify_char_limit(self) -> Self:
if len(self.value) > self.limit:
if self.value and len(self.value) > self.limit:
error_msg = f"Edit failed: Exceeds {self.limit} character limit (requested {len(self.value)}) - {str(self)}."
raise ValueError(error_msg)

View File

@@ -7,6 +7,7 @@ from letta.schemas.block import Block
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import BlockUpdate, Human, Persona
from letta.schemas.user import User as PydanticUser
from letta.services.blocks_agents_manager import BlocksAgentsManager
from letta.utils import enforce_types, list_human_files, list_persona_files
@@ -38,13 +39,28 @@ class BlockManager:
@enforce_types
def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
"""Update a block by its ID with the given BlockUpdate object."""
# TODO: REMOVE THIS ONCE AGENT IS ON ORM -> Update blocks_agents
blocks_agents_manager = BlocksAgentsManager()
agent_ids = []
if block_update.label:
agent_ids = blocks_agents_manager.list_agent_ids_with_block(block_id=block_id)
for agent_id in agent_ids:
blocks_agents_manager.remove_block_with_id_from_agent(agent_id=agent_id, block_id=block_id)
with self.session_maker() as session:
# Update block
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
update_data = block_update.model_dump(exclude_unset=True, exclude_none=True)
for key, value in update_data.items():
setattr(block, key, value)
block.update(db_session=session, actor=actor)
return block.to_pydantic()
# TODO: REMOVE THIS ONCE AGENT IS ON ORM -> Update blocks_agents
if block_update.label:
for agent_id in agent_ids:
blocks_agents_manager.add_block_to_agent(agent_id=agent_id, block_id=block_id, block_label=block_update.label)
return block.to_pydantic()
@enforce_types
def delete_block(self, block_id: str, actor: PydanticUser) -> PydanticBlock:

View File

@@ -71,11 +71,18 @@ class BlocksAgentsManager:
@enforce_types
def list_block_ids_for_agent(self, agent_id: str) -> List[str]:
"""List all blocks associated with a specific agent."""
"""List all block ids associated with a specific agent."""
with self.session_maker() as session:
blocks_agents_record = BlocksAgentsModel.list(db_session=session, agent_id=agent_id)
return [record.block_id for record in blocks_agents_record]
@enforce_types
def list_block_labels_for_agent(self, agent_id: str) -> List[str]:
"""List all block labels associated with a specific agent."""
with self.session_maker() as session:
blocks_agents_record = BlocksAgentsModel.list(db_session=session, agent_id=agent_id)
return [record.block_label for record in blocks_agents_record]
@enforce_types
def list_agent_ids_with_block(self, block_id: str) -> List[str]:
"""List all agents associated with a specific block."""

View File

@@ -925,7 +925,6 @@ def test_default_e2b_settings_sandbox_config(server: SyncServer, default_user):
# Assertions
assert e2b_config.timeout == 5 * 60
assert e2b_config.template
assert e2b_config.template == tool_settings.e2b_sandbox_template_id
@@ -1063,6 +1062,24 @@ def test_add_block_to_agent(server, sarah_agent, default_user, default_block):
assert block_association.block_label == default_block.label
def test_change_label_on_block_reflects_in_block_agents_table(server, sarah_agent, default_user, default_block):
# Add the block
block_association = server.blocks_agents_manager.add_block_to_agent(
agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label
)
assert block_association.block_label == default_block.label
# Change the block label
new_label = "banana"
block = server.block_manager.update_block(block_id=default_block.id, block_update=BlockUpdate(label=new_label), actor=default_user)
assert block.label == new_label
# Get the association
labels = server.blocks_agents_manager.list_block_labels_for_agent(agent_id=sarah_agent.id)
assert new_label in labels
assert default_block.label not in labels
def test_add_block_to_agent_nonexistent_block(server, sarah_agent, default_user):
with pytest.raises(ForeignKeyConstraintViolationError):
server.blocks_agents_manager.add_block_to_agent(