feat: add identities to blocks (#1219)
This commit is contained in:
38
alembic/versions/167491cfb7a8_add_identities_for_blocks.py
Normal file
38
alembic/versions/167491cfb7a8_add_identities_for_blocks.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""add identities for blocks
|
||||
|
||||
Revision ID: 167491cfb7a8
|
||||
Revises: d211df879a5f
|
||||
Create Date: 2025-03-07 17:51:24.843275
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "167491cfb7a8"
|
||||
down_revision: Union[str, None] = "d211df879a5f"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"identities_blocks",
|
||||
sa.Column("identity_id", sa.String(), nullable=False),
|
||||
sa.Column("block_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["block_id"], ["block.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["identity_id"], ["identities.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("identity_id", "block_id"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("identities_blocks")
|
||||
# ### end Alembic commands ###
|
||||
@@ -5,6 +5,7 @@ from letta.orm.block import Block
|
||||
from letta.orm.blocks_agents import BlocksAgents
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.identities_agents import IdentitiesAgents
|
||||
from letta.orm.identities_blocks import IdentitiesBlocks
|
||||
from letta.orm.identity import Identity
|
||||
from letta.orm.job import Job
|
||||
from letta.orm.job_messages import JobMessage
|
||||
|
||||
@@ -12,6 +12,7 @@ from letta.schemas.block import Human, Persona
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm import Organization
|
||||
from letta.orm.identity import Identity
|
||||
|
||||
|
||||
class Block(OrganizationMixin, SqlalchemyBase):
|
||||
@@ -47,6 +48,13 @@ class Block(OrganizationMixin, SqlalchemyBase):
|
||||
back_populates="core_memory",
|
||||
doc="Agents associated with this block.",
|
||||
)
|
||||
identities: Mapped[List["Identity"]] = relationship(
|
||||
"Identity",
|
||||
secondary="identities_blocks",
|
||||
lazy="selectin",
|
||||
back_populates="blocks",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
def to_pydantic(self) -> Type:
|
||||
match self.label:
|
||||
|
||||
13
letta/orm/identities_blocks.py
Normal file
13
letta/orm/identities_blocks.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.orm.base import Base
|
||||
|
||||
|
||||
class IdentitiesBlocks(Base):
|
||||
"""Identities may have one or many blocks associated with them."""
|
||||
|
||||
__tablename__ = "identities_blocks"
|
||||
|
||||
identity_id: Mapped[str] = mapped_column(String, ForeignKey("identities.id", ondelete="CASCADE"), primary_key=True)
|
||||
block_id: Mapped[str] = mapped_column(String, ForeignKey("block.id", ondelete="CASCADE"), primary_key=True)
|
||||
@@ -40,12 +40,20 @@ class Identity(SqlalchemyBase, OrganizationMixin):
|
||||
agents: Mapped[List["Agent"]] = relationship(
|
||||
"Agent", secondary="identities_agents", lazy="selectin", passive_deletes=True, back_populates="identities"
|
||||
)
|
||||
blocks: Mapped[List["Block"]] = relationship(
|
||||
"Block", secondary="identities_blocks", lazy="selectin", passive_deletes=True, back_populates="identities"
|
||||
)
|
||||
|
||||
@property
|
||||
def agent_ids(self) -> List[str]:
|
||||
"""Get just the agent IDs without loading the full agent objects"""
|
||||
return [agent.id for agent in self.agents]
|
||||
|
||||
@property
|
||||
def block_ids(self) -> List[str]:
|
||||
"""Get just the block IDs without loading the full agent objects"""
|
||||
return [block.id for block in self.blocks]
|
||||
|
||||
def to_pydantic(self) -> PydanticIdentity:
|
||||
state = {
|
||||
"id": self.id,
|
||||
@@ -54,6 +62,7 @@ class Identity(SqlalchemyBase, OrganizationMixin):
|
||||
"identity_type": self.identity_type,
|
||||
"project_id": self.project_id,
|
||||
"agent_ids": self.agent_ids,
|
||||
"block_ids": self.block_ids,
|
||||
"organization_id": self.organization_id,
|
||||
"properties": self.properties,
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
join_model: Optional[Base] = None,
|
||||
join_conditions: Optional[Union[Tuple, List]] = None,
|
||||
identifier_keys: Optional[List[str]] = None,
|
||||
identifier_id: Optional[str] = None,
|
||||
identity_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> List["SqlalchemyBase"]:
|
||||
"""
|
||||
@@ -148,9 +148,9 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
if identifier_keys and hasattr(cls, "identities"):
|
||||
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
|
||||
|
||||
# given the identifier_id, we can find within the agents table any agents that have the identifier_id in their identity_ids
|
||||
if identifier_id and hasattr(cls, "identities"):
|
||||
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.id == identifier_id)
|
||||
# given the identity_id, we can find within the agents table any agents that have the identity_id in their identity_ids
|
||||
if identity_id and hasattr(cls, "identities"):
|
||||
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.id == identity_id)
|
||||
|
||||
# Apply filtering logic from kwargs
|
||||
for key, value in kwargs.items():
|
||||
|
||||
@@ -46,6 +46,7 @@ class Identity(IdentityBase):
|
||||
identity_type: IdentityType = Field(..., description="The type of the identity.")
|
||||
project_id: Optional[str] = Field(None, description="The project id of the identity, if applicable.")
|
||||
agent_ids: List[str] = Field(..., description="The IDs of the agents associated with the identity.")
|
||||
block_ids: List[str] = Field(..., description="The IDs of the blocks associated with the identity.")
|
||||
organization_id: Optional[str] = Field(None, description="The organization id of the user")
|
||||
properties: List[IdentityProperty] = Field(default_factory=list, description="List of properties associated with the identity")
|
||||
|
||||
@@ -56,6 +57,7 @@ class IdentityCreate(LettaBase):
|
||||
identity_type: IdentityType = Field(..., description="The type of the identity.")
|
||||
project_id: Optional[str] = Field(None, description="The project id of the identity, if applicable.")
|
||||
agent_ids: Optional[List[str]] = Field(None, description="The agent ids that are associated with the identity.")
|
||||
block_ids: Optional[List[str]] = Field(None, description="The IDs of the blocks associated with the identity.")
|
||||
properties: Optional[List[IdentityProperty]] = Field(None, description="List of properties associated with the identity.")
|
||||
|
||||
|
||||
@@ -64,4 +66,5 @@ class IdentityUpdate(LettaBase):
|
||||
name: Optional[str] = Field(None, description="The name of the identity.")
|
||||
identity_type: Optional[IdentityType] = Field(None, description="The type of the identity.")
|
||||
agent_ids: Optional[List[str]] = Field(None, description="The agent ids that are associated with the identity.")
|
||||
block_ids: Optional[List[str]] = Field(None, description="The IDs of the blocks associated with the identity.")
|
||||
properties: Optional[List[IdentityProperty]] = Field(None, description="List of properties associated with the identity.")
|
||||
|
||||
@@ -53,7 +53,7 @@ def list_agents(
|
||||
project_id: Optional[str] = Query(None, description="Search agents by project id"),
|
||||
template_id: Optional[str] = Query(None, description="Search agents by template id"),
|
||||
base_template_id: Optional[str] = Query(None, description="Search agents by base template id"),
|
||||
identifier_id: Optional[str] = Query(None, description="Search agents by identifier id"),
|
||||
identity_id: Optional[str] = Query(None, description="Search agents by identifier id"),
|
||||
identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"),
|
||||
):
|
||||
"""
|
||||
@@ -84,7 +84,7 @@ def list_agents(
|
||||
tags=tags,
|
||||
match_all_tags=match_all_tags,
|
||||
identifier_keys=identifier_keys,
|
||||
identifier_id=identifier_id,
|
||||
identity_id=identity_id,
|
||||
**kwargs,
|
||||
)
|
||||
return agents
|
||||
|
||||
@@ -20,11 +20,15 @@ def list_blocks(
|
||||
label: Optional[str] = Query(None, description="Labels to include (e.g. human, persona)"),
|
||||
templates_only: bool = Query(True, description="Whether to include only templates"),
|
||||
name: Optional[str] = Query(None, description="Name of the block"),
|
||||
identity_id: Optional[str] = Query(None, description="Search agents by identifier id"),
|
||||
identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.block_manager.get_blocks(actor=actor, label=label, is_template=templates_only, template_name=name)
|
||||
return server.block_manager.get_blocks(
|
||||
actor=actor, label=label, is_template=templates_only, template_name=name, identity_id=identity_id, identifier_keys=identifier_keys
|
||||
)
|
||||
|
||||
|
||||
@router.post("/", response_model=Block, operation_id="create_block")
|
||||
|
||||
@@ -337,6 +337,7 @@ class AgentManager:
|
||||
match_all_tags: bool = False,
|
||||
query_text: Optional[str] = None,
|
||||
identifier_keys: Optional[List[str]] = None,
|
||||
identity_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> List[PydanticAgentState]:
|
||||
"""
|
||||
@@ -353,6 +354,7 @@ class AgentManager:
|
||||
organization_id=actor.organization_id if actor else None,
|
||||
query_text=query_text,
|
||||
identifier_keys=identifier_keys,
|
||||
identity_id=identity_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -64,6 +64,8 @@ class BlockManager:
|
||||
label: Optional[str] = None,
|
||||
is_template: Optional[bool] = None,
|
||||
template_name: Optional[str] = None,
|
||||
identifier_keys: Optional[List[str]] = None,
|
||||
identity_id: Optional[str] = None,
|
||||
id: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
@@ -81,7 +83,14 @@ class BlockManager:
|
||||
if id:
|
||||
filters["id"] = id
|
||||
|
||||
blocks = BlockModel.list(db_session=session, after=after, limit=limit, **filters)
|
||||
blocks = BlockModel.list(
|
||||
db_session=session,
|
||||
after=after,
|
||||
limit=limit,
|
||||
identifier_keys=identifier_keys,
|
||||
identity_id=identity_id,
|
||||
**filters,
|
||||
)
|
||||
|
||||
return [block.to_pydantic() for block in blocks]
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
from letta.orm.block import Block as BlockModel
|
||||
from letta.orm.identity import Identity as IdentityModel
|
||||
from letta.schemas.identity import Identity as PydanticIdentity
|
||||
from letta.schemas.identity import IdentityCreate, IdentityType, IdentityUpdate
|
||||
@@ -58,9 +59,24 @@ class IdentityManager:
|
||||
@enforce_types
|
||||
def create_identity(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity:
|
||||
with self.session_maker() as session:
|
||||
new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids"}, exclude_unset=True))
|
||||
new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids", "block_ids"}, exclude_unset=True))
|
||||
new_identity.organization_id = actor.organization_id
|
||||
self._process_agent_relationship(session=session, identity=new_identity, agent_ids=identity.agent_ids, allow_partial=False)
|
||||
self._process_relationship(
|
||||
session=session,
|
||||
identity=new_identity,
|
||||
relationship_name="agents",
|
||||
model_class=AgentModel,
|
||||
item_ids=identity.agent_ids,
|
||||
allow_partial=False,
|
||||
)
|
||||
self._process_relationship(
|
||||
session=session,
|
||||
identity=new_identity,
|
||||
relationship_name="blocks",
|
||||
model_class=BlockModel,
|
||||
item_ids=identity.block_ids,
|
||||
allow_partial=False,
|
||||
)
|
||||
new_identity.create(session, actor=actor)
|
||||
return new_identity.to_pydantic()
|
||||
|
||||
@@ -124,9 +140,26 @@ class IdentityManager:
|
||||
new_properties = existing_identity.properties + [prop.model_dump() for prop in identity.properties]
|
||||
existing_identity.properties = new_properties
|
||||
|
||||
self._process_agent_relationship(
|
||||
session=session, identity=existing_identity, agent_ids=identity.agent_ids, allow_partial=False, replace=replace
|
||||
)
|
||||
if identity.agent_ids is not None:
|
||||
self._process_relationship(
|
||||
session=session,
|
||||
identity=existing_identity,
|
||||
relationship_name="agents",
|
||||
model_class=AgentModel,
|
||||
item_ids=identity.agent_ids,
|
||||
allow_partial=False,
|
||||
replace=replace,
|
||||
)
|
||||
if identity.block_ids is not None:
|
||||
self._process_relationship(
|
||||
session=session,
|
||||
identity=existing_identity,
|
||||
relationship_name="blocks",
|
||||
model_class=BlockModel,
|
||||
item_ids=identity.block_ids,
|
||||
allow_partial=False,
|
||||
replace=replace,
|
||||
)
|
||||
existing_identity.update(session, actor=actor)
|
||||
return existing_identity.to_pydantic()
|
||||
|
||||
@@ -141,26 +174,33 @@ class IdentityManager:
|
||||
session.delete(identity)
|
||||
session.commit()
|
||||
|
||||
def _process_agent_relationship(
|
||||
self, session: Session, identity: IdentityModel, agent_ids: List[str], allow_partial=False, replace=True
|
||||
def _process_relationship(
|
||||
self,
|
||||
session: Session,
|
||||
identity: PydanticIdentity,
|
||||
relationship_name: str,
|
||||
model_class,
|
||||
item_ids: List[str],
|
||||
allow_partial=False,
|
||||
replace=True,
|
||||
):
|
||||
current_relationship = getattr(identity, "agents", [])
|
||||
if not agent_ids:
|
||||
current_relationship = getattr(identity, relationship_name, [])
|
||||
if not item_ids:
|
||||
if replace:
|
||||
setattr(identity, "agents", [])
|
||||
setattr(identity, relationship_name, [])
|
||||
return
|
||||
|
||||
# Retrieve models for the provided IDs
|
||||
found_items = session.query(AgentModel).filter(AgentModel.id.in_(agent_ids)).all()
|
||||
found_items = session.query(model_class).filter(model_class.id.in_(item_ids)).all()
|
||||
|
||||
# Validate all items are found if allow_partial is False
|
||||
if not allow_partial and len(found_items) != len(agent_ids):
|
||||
missing = set(agent_ids) - {item.id for item in found_items}
|
||||
if not allow_partial and len(found_items) != len(item_ids):
|
||||
missing = set(item_ids) - {item.id for item in found_items}
|
||||
raise NoResultFound(f"Items not found in agents: {missing}")
|
||||
|
||||
if replace:
|
||||
# Replace the relationship
|
||||
setattr(identity, "agents", found_items)
|
||||
setattr(identity, relationship_name, found_items)
|
||||
else:
|
||||
# Extend the relationship (only add new items)
|
||||
current_ids = {item.id for item in current_relationship}
|
||||
|
||||
@@ -42,7 +42,6 @@ from letta.schemas.user import User as PydanticUser
|
||||
from letta.schemas.user import UserUpdate
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.identity_manager import IdentityManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.settings import tool_settings
|
||||
from tests.helpers.utils import comprehensive_agent_checks
|
||||
@@ -2243,7 +2242,6 @@ def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, de
|
||||
|
||||
|
||||
def test_create_and_upsert_identity(server: SyncServer, default_user):
|
||||
identity_manager = IdentityManager()
|
||||
identity_create = IdentityCreate(
|
||||
identifier_key="1234",
|
||||
name="caren",
|
||||
@@ -2254,7 +2252,7 @@ def test_create_and_upsert_identity(server: SyncServer, default_user):
|
||||
],
|
||||
)
|
||||
|
||||
identity = identity_manager.create_identity(identity_create, actor=default_user)
|
||||
identity = server.identity_manager.create_identity(identity_create, actor=default_user)
|
||||
|
||||
# Assertions to ensure the created identity matches the expected values
|
||||
assert identity.identifier_key == identity_create.identifier_key
|
||||
@@ -2265,48 +2263,46 @@ def test_create_and_upsert_identity(server: SyncServer, default_user):
|
||||
assert identity.project_id == None
|
||||
|
||||
with pytest.raises(UniqueConstraintViolationError):
|
||||
identity_manager.create_identity(
|
||||
server.identity_manager.create_identity(
|
||||
IdentityCreate(identifier_key="1234", name="sarah", identity_type=IdentityType.user),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
identity_create.properties = [(IdentityProperty(key="age", value=29, type=IdentityPropertyType.number))]
|
||||
|
||||
identity = identity_manager.upsert_identity(identity_create, actor=default_user)
|
||||
identity = server.identity_manager.upsert_identity(identity_create, actor=default_user)
|
||||
|
||||
identity = identity_manager.get_identity(identity_id=identity.id, actor=default_user)
|
||||
identity = server.identity_manager.get_identity(identity_id=identity.id, actor=default_user)
|
||||
assert len(identity.properties) == 1
|
||||
assert identity.properties[0].key == "age"
|
||||
assert identity.properties[0].value == 29
|
||||
|
||||
identity_manager.delete_identity(identity.id, actor=default_user)
|
||||
server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user)
|
||||
|
||||
|
||||
def test_get_identities(server, default_user):
|
||||
identity_manager = IdentityManager()
|
||||
|
||||
# Create identities to retrieve later
|
||||
user = identity_manager.create_identity(
|
||||
user = server.identity_manager.create_identity(
|
||||
IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user), actor=default_user
|
||||
)
|
||||
org = identity_manager.create_identity(
|
||||
org = server.identity_manager.create_identity(
|
||||
IdentityCreate(name="letta", identifier_key="0001", identity_type=IdentityType.org), actor=default_user
|
||||
)
|
||||
|
||||
# Retrieve identities by different filters
|
||||
all_identities = identity_manager.list_identities(actor=default_user)
|
||||
all_identities = server.identity_manager.list_identities(actor=default_user)
|
||||
assert len(all_identities) == 2
|
||||
|
||||
user_identities = identity_manager.list_identities(actor=default_user, identity_type=IdentityType.user)
|
||||
user_identities = server.identity_manager.list_identities(actor=default_user, identity_type=IdentityType.user)
|
||||
assert len(user_identities) == 1
|
||||
assert user_identities[0].name == user.name
|
||||
|
||||
org_identities = identity_manager.list_identities(actor=default_user, identity_type=IdentityType.org)
|
||||
org_identities = server.identity_manager.list_identities(actor=default_user, identity_type=IdentityType.org)
|
||||
assert len(org_identities) == 1
|
||||
assert org_identities[0].name == org.name
|
||||
|
||||
identity_manager.delete_identity(user.id, actor=default_user)
|
||||
identity_manager.delete_identity(org.id, actor=default_user)
|
||||
server.identity_manager.delete_identity(identity_id=user.id, actor=default_user)
|
||||
server.identity_manager.delete_identity(identity_id=org.id, actor=default_user)
|
||||
|
||||
|
||||
def test_update_identity(server: SyncServer, sarah_agent, charles_agent, default_user):
|
||||
@@ -2333,7 +2329,7 @@ def test_update_identity(server: SyncServer, sarah_agent, charles_agent, default
|
||||
agent_state = server.agent_manager.get_agent_by_id(agent_id=charles_agent.id, actor=default_user)
|
||||
assert identity.id in agent_state.identity_ids
|
||||
|
||||
server.identity_manager.delete_identity(identity.id, actor=default_user)
|
||||
server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user)
|
||||
|
||||
|
||||
def test_attach_detach_identity_from_agent(server: SyncServer, sarah_agent, default_user):
|
||||
@@ -2360,29 +2356,137 @@ def test_attach_detach_identity_from_agent(server: SyncServer, sarah_agent, defa
|
||||
assert not identity.id in agent_state.identity_ids
|
||||
|
||||
|
||||
def test_get_agents_for_identities(server: SyncServer, sarah_agent, charles_agent, default_user):
|
||||
def test_get_set_agents_for_identities(server: SyncServer, sarah_agent, charles_agent, default_user):
|
||||
identity = server.identity_manager.create_identity(
|
||||
IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user, agent_ids=[sarah_agent.id, charles_agent.id]),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Get the agents for identity id
|
||||
agent_states = server.agent_manager.list_agents(identifier_id=identity.id, actor=default_user)
|
||||
assert len(agent_states) == 2
|
||||
agent_with_identity = server.create_agent(
|
||||
CreateAgent(
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
identity_ids=[identity.id],
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
agent_without_identity = server.create_agent(
|
||||
CreateAgent(
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Check both agents are in the list
|
||||
# Get the agents for identity id
|
||||
agent_states = server.agent_manager.list_agents(identity_id=identity.id, actor=default_user)
|
||||
assert len(agent_states) == 3
|
||||
|
||||
# Check all agents are in the list
|
||||
agent_state_ids = [a.id for a in agent_states]
|
||||
assert sarah_agent.id in agent_state_ids
|
||||
assert charles_agent.id in agent_state_ids
|
||||
assert agent_with_identity.id in agent_state_ids
|
||||
assert not agent_without_identity.id in agent_state_ids
|
||||
|
||||
# Get the agents for identifier key
|
||||
agent_states = server.agent_manager.list_agents(identifier_keys=[identity.identifier_key], actor=default_user)
|
||||
assert len(agent_states) == 2
|
||||
assert len(agent_states) == 3
|
||||
|
||||
# Check both agents are in the list
|
||||
# Check all agents are in the list
|
||||
agent_state_ids = [a.id for a in agent_states]
|
||||
assert sarah_agent.id in agent_state_ids
|
||||
assert charles_agent.id in agent_state_ids
|
||||
assert agent_with_identity.id in agent_state_ids
|
||||
assert not agent_without_identity.id in agent_state_ids
|
||||
|
||||
# Delete new agents
|
||||
server.agent_manager.delete_agent(agent_id=agent_with_identity.id, actor=default_user)
|
||||
server.agent_manager.delete_agent(agent_id=agent_without_identity.id, actor=default_user)
|
||||
|
||||
# Get the agents for identity id
|
||||
agent_states = server.agent_manager.list_agents(identity_id=identity.id, actor=default_user)
|
||||
assert len(agent_states) == 2
|
||||
|
||||
# Check only initial agents are in the list
|
||||
agent_state_ids = [a.id for a in agent_states]
|
||||
assert sarah_agent.id in agent_state_ids
|
||||
assert charles_agent.id in agent_state_ids
|
||||
|
||||
server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user)
|
||||
|
||||
|
||||
def test_attach_detach_identity_from_block(server: SyncServer, default_block, default_user):
|
||||
# Create an identity
|
||||
identity = server.identity_manager.create_identity(
|
||||
IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user, block_ids=[default_block.id]),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Check that identity has been attached
|
||||
blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user)
|
||||
assert len(blocks) == 1 and blocks[0].id == default_block.id
|
||||
|
||||
# Now attempt to delete the identity
|
||||
server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user)
|
||||
|
||||
# Verify that the identity was deleted
|
||||
identities = server.identity_manager.list_identities(actor=default_user)
|
||||
assert len(identities) == 0
|
||||
|
||||
# Check that block has been detached too
|
||||
blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user)
|
||||
assert len(blocks) == 0
|
||||
|
||||
|
||||
def test_get_set_blocks_for_identities(server: SyncServer, default_block, default_user):
|
||||
block_manager = BlockManager()
|
||||
block_with_identity = block_manager.create_or_update_block(PydanticBlock(label="persona", value="Original Content"), actor=default_user)
|
||||
block_without_identity = block_manager.create_or_update_block(PydanticBlock(label="user", value="Original Content"), actor=default_user)
|
||||
identity = server.identity_manager.create_identity(
|
||||
IdentityCreate(
|
||||
name="caren", identifier_key="1234", identity_type=IdentityType.user, block_ids=[default_block.id, block_with_identity.id]
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Get the blocks for identity id
|
||||
blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user)
|
||||
assert len(blocks) == 2
|
||||
|
||||
# Check blocks are in the list
|
||||
block_ids = [b.id for b in blocks]
|
||||
assert default_block.id in block_ids
|
||||
assert block_with_identity.id in block_ids
|
||||
assert not block_without_identity.id in block_ids
|
||||
|
||||
# Get the blocks for identifier key
|
||||
blocks = server.block_manager.get_blocks(identifier_keys=[identity.identifier_key], actor=default_user)
|
||||
assert len(blocks) == 2
|
||||
|
||||
# Check blocks are in the list
|
||||
block_ids = [b.id for b in blocks]
|
||||
assert default_block.id in block_ids
|
||||
assert block_with_identity.id in block_ids
|
||||
assert not block_without_identity.id in block_ids
|
||||
|
||||
# Delete new agents
|
||||
server.block_manager.delete_block(block_id=block_with_identity.id, actor=default_user)
|
||||
server.block_manager.delete_block(block_id=block_without_identity.id, actor=default_user)
|
||||
|
||||
# Get the blocks for identity id
|
||||
blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user)
|
||||
assert len(blocks) == 1
|
||||
|
||||
# Check only initial block in the list
|
||||
block_ids = [b.id for b in blocks]
|
||||
assert default_block.id in block_ids
|
||||
assert not block_with_identity.id in block_ids
|
||||
assert not block_without_identity.id in block_ids
|
||||
|
||||
server.identity_manager.delete_identity(identity.id, actor=default_user)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
|
||||
@@ -380,6 +380,8 @@ def test_list_blocks(client, mock_sync_server):
|
||||
label=None,
|
||||
is_template=True,
|
||||
template_name=None,
|
||||
identity_id=None,
|
||||
identifier_keys=None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user