From 6b4533e7cbcc6619b0c9ffa031db01cb394cdbcd Mon Sep 17 00:00:00 2001 From: cthomas Date: Wed, 12 Mar 2025 12:09:31 -0700 Subject: [PATCH] feat: add identities to blocks (#1219) --- .../167491cfb7a8_add_identities_for_blocks.py | 38 +++++ letta/orm/__init__.py | 1 + letta/orm/block.py | 8 + letta/orm/identities_blocks.py | 13 ++ letta/orm/identity.py | 9 ++ letta/orm/sqlalchemy_base.py | 8 +- letta/schemas/identity.py | 3 + letta/server/rest_api/routers/v1/agents.py | 4 +- letta/server/rest_api/routers/v1/blocks.py | 6 +- letta/services/agent_manager.py | 2 + letta/services/block_manager.py | 11 +- letta/services/identity_manager.py | 68 ++++++-- tests/test_managers.py | 152 +++++++++++++++--- tests/test_v1_routes.py | 2 + 14 files changed, 279 insertions(+), 46 deletions(-) create mode 100644 alembic/versions/167491cfb7a8_add_identities_for_blocks.py create mode 100644 letta/orm/identities_blocks.py diff --git a/alembic/versions/167491cfb7a8_add_identities_for_blocks.py b/alembic/versions/167491cfb7a8_add_identities_for_blocks.py new file mode 100644 index 00000000..8e0b8a17 --- /dev/null +++ b/alembic/versions/167491cfb7a8_add_identities_for_blocks.py @@ -0,0 +1,38 @@ +"""add identities for blocks + +Revision ID: 167491cfb7a8 +Revises: d211df879a5f +Create Date: 2025-03-07 17:51:24.843275 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "167491cfb7a8" +down_revision: Union[str, None] = "d211df879a5f" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "identities_blocks", + sa.Column("identity_id", sa.String(), nullable=False), + sa.Column("block_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["block_id"], ["block.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["identity_id"], ["identities.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("identity_id", "block_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("identities_blocks") + # ### end Alembic commands ### diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index 10c25253..5963d36c 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -5,6 +5,7 @@ from letta.orm.block import Block from letta.orm.blocks_agents import BlocksAgents from letta.orm.file import FileMetadata from letta.orm.identities_agents import IdentitiesAgents +from letta.orm.identities_blocks import IdentitiesBlocks from letta.orm.identity import Identity from letta.orm.job import Job from letta.orm.job_messages import JobMessage diff --git a/letta/orm/block.py b/letta/orm/block.py index 3e8c8006..940a8ec9 100644 --- a/letta/orm/block.py +++ b/letta/orm/block.py @@ -12,6 +12,7 @@ from letta.schemas.block import Human, Persona if TYPE_CHECKING: from letta.orm import Organization + from letta.orm.identity import Identity class Block(OrganizationMixin, SqlalchemyBase): @@ -47,6 +48,13 @@ class Block(OrganizationMixin, SqlalchemyBase): back_populates="core_memory", doc="Agents associated with this block.", ) + identities: Mapped[List["Identity"]] = relationship( + "Identity", + secondary="identities_blocks", + lazy="selectin", + back_populates="blocks", + passive_deletes=True, + ) def to_pydantic(self) -> Type: match self.label: diff --git a/letta/orm/identities_blocks.py b/letta/orm/identities_blocks.py new file mode 100644 index 00000000..2c5a8ef0 --- /dev/null +++ b/letta/orm/identities_blocks.py @@ -0,0 +1,13 @@ +from sqlalchemy import ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column + +from letta.orm.base import Base + + +class IdentitiesBlocks(Base): + """Identities may have one or many blocks associated with them.""" + + __tablename__ = "identities_blocks" + + identity_id: Mapped[str] = mapped_column(String, ForeignKey("identities.id", ondelete="CASCADE"), primary_key=True) + block_id: Mapped[str] = mapped_column(String, ForeignKey("block.id", ondelete="CASCADE"), primary_key=True) diff --git a/letta/orm/identity.py b/letta/orm/identity.py index f92b053b..5f6e5606 100644 --- a/letta/orm/identity.py +++ b/letta/orm/identity.py @@ -40,12 +40,20 @@ class Identity(SqlalchemyBase, OrganizationMixin): agents: Mapped[List["Agent"]] = relationship( "Agent", secondary="identities_agents", lazy="selectin", passive_deletes=True, back_populates="identities" ) + blocks: Mapped[List["Block"]] = relationship( + "Block", secondary="identities_blocks", lazy="selectin", passive_deletes=True, back_populates="identities" + ) @property def agent_ids(self) -> List[str]: """Get just the agent IDs without loading the full agent objects""" return [agent.id for agent in self.agents] + @property + def block_ids(self) -> List[str]: + """Get just the block IDs without loading the full agent objects""" + return [block.id for block in self.blocks] + def to_pydantic(self) -> PydanticIdentity: state = { "id": self.id, @@ -54,6 +62,7 @@ class Identity(SqlalchemyBase, OrganizationMixin): "identity_type": self.identity_type, "project_id": self.project_id, "agent_ids": self.agent_ids, + "block_ids": self.block_ids, "organization_id": self.organization_id, "properties": self.properties, } diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 5652e7a4..11ac3070 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -69,7 +69,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): join_model: Optional[Base] = None, join_conditions: Optional[Union[Tuple, List]] = None, identifier_keys: Optional[List[str]] = None, - identifier_id: Optional[str] = None, + identity_id: Optional[str] = None, **kwargs, ) -> List["SqlalchemyBase"]: """ @@ -148,9 +148,9 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): if identifier_keys and hasattr(cls, "identities"): query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys)) - # given the identifier_id, we can find within the agents table any agents that have the identifier_id in their identity_ids - if identifier_id and hasattr(cls, "identities"): - query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.id == identifier_id) + # given the identity_id, we can find within the agents table any agents that have the identity_id in their identity_ids + if identity_id and hasattr(cls, "identities"): + query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.id == identity_id) # Apply filtering logic from kwargs for key, value in kwargs.items(): diff --git a/letta/schemas/identity.py b/letta/schemas/identity.py index 5b44fa67..017b9e30 100644 --- a/letta/schemas/identity.py +++ b/letta/schemas/identity.py @@ -46,6 +46,7 @@ class Identity(IdentityBase): identity_type: IdentityType = Field(..., description="The type of the identity.") project_id: Optional[str] = Field(None, description="The project id of the identity, if applicable.") agent_ids: List[str] = Field(..., description="The IDs of the agents associated with the identity.") + block_ids: List[str] = Field(..., description="The IDs of the blocks associated with the identity.") organization_id: Optional[str] = Field(None, description="The organization id of the user") properties: List[IdentityProperty] = Field(default_factory=list, description="List of properties associated with the identity") @@ -56,6 +57,7 @@ class IdentityCreate(LettaBase): identity_type: IdentityType = Field(..., description="The type of the identity.") project_id: Optional[str] = Field(None, description="The project id of the identity, if applicable.") agent_ids: Optional[List[str]] = Field(None, description="The agent ids that are associated with the identity.") + block_ids: Optional[List[str]] = Field(None, description="The IDs of the blocks associated with the identity.") properties: Optional[List[IdentityProperty]] = Field(None, description="List of properties associated with the identity.") @@ -64,4 +66,5 @@ class IdentityUpdate(LettaBase): name: Optional[str] = Field(None, description="The name of the identity.") identity_type: Optional[IdentityType] = Field(None, description="The type of the identity.") agent_ids: Optional[List[str]] = Field(None, description="The agent ids that are associated with the identity.") + block_ids: Optional[List[str]] = Field(None, description="The IDs of the blocks associated with the identity.") properties: Optional[List[IdentityProperty]] = Field(None, description="List of properties associated with the identity.") diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 5de351a2..0c0afa51 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -53,7 +53,7 @@ def list_agents( project_id: Optional[str] = Query(None, description="Search agents by project id"), template_id: Optional[str] = Query(None, description="Search agents by template id"), base_template_id: Optional[str] = Query(None, description="Search agents by base template id"), - identifier_id: Optional[str] = Query(None, description="Search agents by identifier id"), + identity_id: Optional[str] = Query(None, description="Search agents by identifier id"), identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"), ): """ @@ -84,7 +84,7 @@ def list_agents( tags=tags, match_all_tags=match_all_tags, identifier_keys=identifier_keys, - identifier_id=identifier_id, + identity_id=identity_id, **kwargs, ) return agents diff --git a/letta/server/rest_api/routers/v1/blocks.py b/letta/server/rest_api/routers/v1/blocks.py index 322f323d..22ff8e35 100644 --- a/letta/server/rest_api/routers/v1/blocks.py +++ b/letta/server/rest_api/routers/v1/blocks.py @@ -20,11 +20,15 @@ def list_blocks( label: Optional[str] = Query(None, description="Labels to include (e.g. human, persona)"), templates_only: bool = Query(True, description="Whether to include only templates"), name: Optional[str] = Query(None, description="Name of the block"), + identity_id: Optional[str] = Query(None, description="Search agents by identifier id"), + identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"), server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.block_manager.get_blocks(actor=actor, label=label, is_template=templates_only, template_name=name) + return server.block_manager.get_blocks( + actor=actor, label=label, is_template=templates_only, template_name=name, identity_id=identity_id, identifier_keys=identifier_keys + ) @router.post("/", response_model=Block, operation_id="create_block") diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 12ef7790..8ea31fa7 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -337,6 +337,7 @@ class AgentManager: match_all_tags: bool = False, query_text: Optional[str] = None, identifier_keys: Optional[List[str]] = None, + identity_id: Optional[str] = None, **kwargs, ) -> List[PydanticAgentState]: """ @@ -353,6 +354,7 @@ class AgentManager: organization_id=actor.organization_id if actor else None, query_text=query_text, identifier_keys=identifier_keys, + identity_id=identity_id, **kwargs, ) diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index fe10671d..ff9b8507 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -64,6 +64,8 @@ class BlockManager: label: Optional[str] = None, is_template: Optional[bool] = None, template_name: Optional[str] = None, + identifier_keys: Optional[List[str]] = None, + identity_id: Optional[str] = None, id: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 50, @@ -81,7 +83,14 @@ class BlockManager: if id: filters["id"] = id - blocks = BlockModel.list(db_session=session, after=after, limit=limit, **filters) + blocks = BlockModel.list( + db_session=session, + after=after, + limit=limit, + identifier_keys=identifier_keys, + identity_id=identity_id, + **filters, + ) return [block.to_pydantic() for block in blocks] diff --git a/letta/services/identity_manager.py b/letta/services/identity_manager.py index 42efa191..8e965399 100644 --- a/letta/services/identity_manager.py +++ b/letta/services/identity_manager.py @@ -5,6 +5,7 @@ from sqlalchemy.exc import NoResultFound from sqlalchemy.orm import Session from letta.orm.agent import Agent as AgentModel +from letta.orm.block import Block as BlockModel from letta.orm.identity import Identity as IdentityModel from letta.schemas.identity import Identity as PydanticIdentity from letta.schemas.identity import IdentityCreate, IdentityType, IdentityUpdate @@ -58,9 +59,24 @@ class IdentityManager: @enforce_types def create_identity(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity: with self.session_maker() as session: - new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids"}, exclude_unset=True)) + new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids", "block_ids"}, exclude_unset=True)) new_identity.organization_id = actor.organization_id - self._process_agent_relationship(session=session, identity=new_identity, agent_ids=identity.agent_ids, allow_partial=False) + self._process_relationship( + session=session, + identity=new_identity, + relationship_name="agents", + model_class=AgentModel, + item_ids=identity.agent_ids, + allow_partial=False, + ) + self._process_relationship( + session=session, + identity=new_identity, + relationship_name="blocks", + model_class=BlockModel, + item_ids=identity.block_ids, + allow_partial=False, + ) new_identity.create(session, actor=actor) return new_identity.to_pydantic() @@ -124,9 +140,26 @@ class IdentityManager: new_properties = existing_identity.properties + [prop.model_dump() for prop in identity.properties] existing_identity.properties = new_properties - self._process_agent_relationship( - session=session, identity=existing_identity, agent_ids=identity.agent_ids, allow_partial=False, replace=replace - ) + if identity.agent_ids is not None: + self._process_relationship( + session=session, + identity=existing_identity, + relationship_name="agents", + model_class=AgentModel, + item_ids=identity.agent_ids, + allow_partial=False, + replace=replace, + ) + if identity.block_ids is not None: + self._process_relationship( + session=session, + identity=existing_identity, + relationship_name="blocks", + model_class=BlockModel, + item_ids=identity.block_ids, + allow_partial=False, + replace=replace, + ) existing_identity.update(session, actor=actor) return existing_identity.to_pydantic() @@ -141,26 +174,33 @@ class IdentityManager: session.delete(identity) session.commit() - def _process_agent_relationship( - self, session: Session, identity: IdentityModel, agent_ids: List[str], allow_partial=False, replace=True + def _process_relationship( + self, + session: Session, + identity: PydanticIdentity, + relationship_name: str, + model_class, + item_ids: List[str], + allow_partial=False, + replace=True, ): - current_relationship = getattr(identity, "agents", []) - if not agent_ids: + current_relationship = getattr(identity, relationship_name, []) + if not item_ids: if replace: - setattr(identity, "agents", []) + setattr(identity, relationship_name, []) return # Retrieve models for the provided IDs - found_items = session.query(AgentModel).filter(AgentModel.id.in_(agent_ids)).all() + found_items = session.query(model_class).filter(model_class.id.in_(item_ids)).all() # Validate all items are found if allow_partial is False - if not allow_partial and len(found_items) != len(agent_ids): - missing = set(agent_ids) - {item.id for item in found_items} + if not allow_partial and len(found_items) != len(item_ids): + missing = set(item_ids) - {item.id for item in found_items} raise NoResultFound(f"Items not found in agents: {missing}") if replace: # Replace the relationship - setattr(identity, "agents", found_items) + setattr(identity, relationship_name, found_items) else: # Extend the relationship (only add new items) current_ids = {item.id for item in current_relationship} diff --git a/tests/test_managers.py b/tests/test_managers.py index ed86aa01..0b43e924 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -42,7 +42,6 @@ from letta.schemas.user import User as PydanticUser from letta.schemas.user import UserUpdate from letta.server.server import SyncServer from letta.services.block_manager import BlockManager -from letta.services.identity_manager import IdentityManager from letta.services.organization_manager import OrganizationManager from letta.settings import tool_settings from tests.helpers.utils import comprehensive_agent_checks @@ -2243,7 +2242,6 @@ def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, de def test_create_and_upsert_identity(server: SyncServer, default_user): - identity_manager = IdentityManager() identity_create = IdentityCreate( identifier_key="1234", name="caren", @@ -2254,7 +2252,7 @@ def test_create_and_upsert_identity(server: SyncServer, default_user): ], ) - identity = identity_manager.create_identity(identity_create, actor=default_user) + identity = server.identity_manager.create_identity(identity_create, actor=default_user) # Assertions to ensure the created identity matches the expected values assert identity.identifier_key == identity_create.identifier_key @@ -2265,48 +2263,46 @@ def test_create_and_upsert_identity(server: SyncServer, default_user): assert identity.project_id == None with pytest.raises(UniqueConstraintViolationError): - identity_manager.create_identity( + server.identity_manager.create_identity( IdentityCreate(identifier_key="1234", name="sarah", identity_type=IdentityType.user), actor=default_user, ) identity_create.properties = [(IdentityProperty(key="age", value=29, type=IdentityPropertyType.number))] - identity = identity_manager.upsert_identity(identity_create, actor=default_user) + identity = server.identity_manager.upsert_identity(identity_create, actor=default_user) - identity = identity_manager.get_identity(identity_id=identity.id, actor=default_user) + identity = server.identity_manager.get_identity(identity_id=identity.id, actor=default_user) assert len(identity.properties) == 1 assert identity.properties[0].key == "age" assert identity.properties[0].value == 29 - identity_manager.delete_identity(identity.id, actor=default_user) + server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user) def test_get_identities(server, default_user): - identity_manager = IdentityManager() - # Create identities to retrieve later - user = identity_manager.create_identity( + user = server.identity_manager.create_identity( IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user), actor=default_user ) - org = identity_manager.create_identity( + org = server.identity_manager.create_identity( IdentityCreate(name="letta", identifier_key="0001", identity_type=IdentityType.org), actor=default_user ) # Retrieve identities by different filters - all_identities = identity_manager.list_identities(actor=default_user) + all_identities = server.identity_manager.list_identities(actor=default_user) assert len(all_identities) == 2 - user_identities = identity_manager.list_identities(actor=default_user, identity_type=IdentityType.user) + user_identities = server.identity_manager.list_identities(actor=default_user, identity_type=IdentityType.user) assert len(user_identities) == 1 assert user_identities[0].name == user.name - org_identities = identity_manager.list_identities(actor=default_user, identity_type=IdentityType.org) + org_identities = server.identity_manager.list_identities(actor=default_user, identity_type=IdentityType.org) assert len(org_identities) == 1 assert org_identities[0].name == org.name - identity_manager.delete_identity(user.id, actor=default_user) - identity_manager.delete_identity(org.id, actor=default_user) + server.identity_manager.delete_identity(identity_id=user.id, actor=default_user) + server.identity_manager.delete_identity(identity_id=org.id, actor=default_user) def test_update_identity(server: SyncServer, sarah_agent, charles_agent, default_user): @@ -2333,7 +2329,7 @@ def test_update_identity(server: SyncServer, sarah_agent, charles_agent, default agent_state = server.agent_manager.get_agent_by_id(agent_id=charles_agent.id, actor=default_user) assert identity.id in agent_state.identity_ids - server.identity_manager.delete_identity(identity.id, actor=default_user) + server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user) def test_attach_detach_identity_from_agent(server: SyncServer, sarah_agent, default_user): @@ -2360,29 +2356,137 @@ def test_attach_detach_identity_from_agent(server: SyncServer, sarah_agent, defa assert not identity.id in agent_state.identity_ids -def test_get_agents_for_identities(server: SyncServer, sarah_agent, charles_agent, default_user): +def test_get_set_agents_for_identities(server: SyncServer, sarah_agent, charles_agent, default_user): identity = server.identity_manager.create_identity( IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user, agent_ids=[sarah_agent.id, charles_agent.id]), actor=default_user, ) - # Get the agents for identity id - agent_states = server.agent_manager.list_agents(identifier_id=identity.id, actor=default_user) - assert len(agent_states) == 2 + agent_with_identity = server.create_agent( + CreateAgent( + memory_blocks=[], + llm_config=LLMConfig.default_config("gpt-4"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + identity_ids=[identity.id], + ), + actor=default_user, + ) + agent_without_identity = server.create_agent( + CreateAgent( + memory_blocks=[], + llm_config=LLMConfig.default_config("gpt-4"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ), + actor=default_user, + ) - # Check both agents are in the list + # Get the agents for identity id + agent_states = server.agent_manager.list_agents(identity_id=identity.id, actor=default_user) + assert len(agent_states) == 3 + + # Check all agents are in the list agent_state_ids = [a.id for a in agent_states] assert sarah_agent.id in agent_state_ids assert charles_agent.id in agent_state_ids + assert agent_with_identity.id in agent_state_ids + assert not agent_without_identity.id in agent_state_ids # Get the agents for identifier key agent_states = server.agent_manager.list_agents(identifier_keys=[identity.identifier_key], actor=default_user) - assert len(agent_states) == 2 + assert len(agent_states) == 3 - # Check both agents are in the list + # Check all agents are in the list agent_state_ids = [a.id for a in agent_states] assert sarah_agent.id in agent_state_ids assert charles_agent.id in agent_state_ids + assert agent_with_identity.id in agent_state_ids + assert not agent_without_identity.id in agent_state_ids + + # Delete new agents + server.agent_manager.delete_agent(agent_id=agent_with_identity.id, actor=default_user) + server.agent_manager.delete_agent(agent_id=agent_without_identity.id, actor=default_user) + + # Get the agents for identity id + agent_states = server.agent_manager.list_agents(identity_id=identity.id, actor=default_user) + assert len(agent_states) == 2 + + # Check only initial agents are in the list + agent_state_ids = [a.id for a in agent_states] + assert sarah_agent.id in agent_state_ids + assert charles_agent.id in agent_state_ids + + server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user) + + +def test_attach_detach_identity_from_block(server: SyncServer, default_block, default_user): + # Create an identity + identity = server.identity_manager.create_identity( + IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user, block_ids=[default_block.id]), + actor=default_user, + ) + + # Check that identity has been attached + blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user) + assert len(blocks) == 1 and blocks[0].id == default_block.id + + # Now attempt to delete the identity + server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user) + + # Verify that the identity was deleted + identities = server.identity_manager.list_identities(actor=default_user) + assert len(identities) == 0 + + # Check that block has been detached too + blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user) + assert len(blocks) == 0 + + +def test_get_set_blocks_for_identities(server: SyncServer, default_block, default_user): + block_manager = BlockManager() + block_with_identity = block_manager.create_or_update_block(PydanticBlock(label="persona", value="Original Content"), actor=default_user) + block_without_identity = block_manager.create_or_update_block(PydanticBlock(label="user", value="Original Content"), actor=default_user) + identity = server.identity_manager.create_identity( + IdentityCreate( + name="caren", identifier_key="1234", identity_type=IdentityType.user, block_ids=[default_block.id, block_with_identity.id] + ), + actor=default_user, + ) + + # Get the blocks for identity id + blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user) + assert len(blocks) == 2 + + # Check blocks are in the list + block_ids = [b.id for b in blocks] + assert default_block.id in block_ids + assert block_with_identity.id in block_ids + assert not block_without_identity.id in block_ids + + # Get the blocks for identifier key + blocks = server.block_manager.get_blocks(identifier_keys=[identity.identifier_key], actor=default_user) + assert len(blocks) == 2 + + # Check blocks are in the list + block_ids = [b.id for b in blocks] + assert default_block.id in block_ids + assert block_with_identity.id in block_ids + assert not block_without_identity.id in block_ids + + # Delete new agents + server.block_manager.delete_block(block_id=block_with_identity.id, actor=default_user) + server.block_manager.delete_block(block_id=block_without_identity.id, actor=default_user) + + # Get the blocks for identity id + blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user) + assert len(blocks) == 1 + + # Check only initial block in the list + block_ids = [b.id for b in blocks] + assert default_block.id in block_ids + assert not block_with_identity.id in block_ids + assert not block_without_identity.id in block_ids + + server.identity_manager.delete_identity(identity.id, actor=default_user) # ====================================================================================================================== diff --git a/tests/test_v1_routes.py b/tests/test_v1_routes.py index 989c3775..0f12ac4b 100644 --- a/tests/test_v1_routes.py +++ b/tests/test_v1_routes.py @@ -380,6 +380,8 @@ def test_list_blocks(client, mock_sync_server): label=None, is_template=True, template_name=None, + identity_id=None, + identifier_keys=None, )