feat: add identities to blocks (#1219)

This commit is contained in:
cthomas
2025-03-12 12:09:31 -07:00
committed by GitHub
parent eddd167f43
commit 6b4533e7cb
14 changed files with 279 additions and 46 deletions

View File

@@ -0,0 +1,38 @@
"""add identities for blocks
Revision ID: 167491cfb7a8
Revises: d211df879a5f
Create Date: 2025-03-07 17:51:24.843275
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "167491cfb7a8"
down_revision: Union[str, None] = "d211df879a5f"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"identities_blocks",
sa.Column("identity_id", sa.String(), nullable=False),
sa.Column("block_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(["block_id"], ["block.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["identity_id"], ["identities.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("identity_id", "block_id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("identities_blocks")
# ### end Alembic commands ###

View File

@@ -5,6 +5,7 @@ from letta.orm.block import Block
from letta.orm.blocks_agents import BlocksAgents
from letta.orm.file import FileMetadata
from letta.orm.identities_agents import IdentitiesAgents
from letta.orm.identities_blocks import IdentitiesBlocks
from letta.orm.identity import Identity
from letta.orm.job import Job
from letta.orm.job_messages import JobMessage

View File

@@ -12,6 +12,7 @@ from letta.schemas.block import Human, Persona
if TYPE_CHECKING:
from letta.orm import Organization
from letta.orm.identity import Identity
class Block(OrganizationMixin, SqlalchemyBase):
@@ -47,6 +48,13 @@ class Block(OrganizationMixin, SqlalchemyBase):
back_populates="core_memory",
doc="Agents associated with this block.",
)
identities: Mapped[List["Identity"]] = relationship(
"Identity",
secondary="identities_blocks",
lazy="selectin",
back_populates="blocks",
passive_deletes=True,
)
def to_pydantic(self) -> Type:
match self.label:

View File

@@ -0,0 +1,13 @@
from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm.base import Base
class IdentitiesBlocks(Base):
"""Identities may have one or many blocks associated with them."""
__tablename__ = "identities_blocks"
identity_id: Mapped[str] = mapped_column(String, ForeignKey("identities.id", ondelete="CASCADE"), primary_key=True)
block_id: Mapped[str] = mapped_column(String, ForeignKey("block.id", ondelete="CASCADE"), primary_key=True)

View File

@@ -40,12 +40,20 @@ class Identity(SqlalchemyBase, OrganizationMixin):
agents: Mapped[List["Agent"]] = relationship(
"Agent", secondary="identities_agents", lazy="selectin", passive_deletes=True, back_populates="identities"
)
blocks: Mapped[List["Block"]] = relationship(
"Block", secondary="identities_blocks", lazy="selectin", passive_deletes=True, back_populates="identities"
)
@property
def agent_ids(self) -> List[str]:
"""Get just the agent IDs without loading the full agent objects"""
return [agent.id for agent in self.agents]
@property
def block_ids(self) -> List[str]:
"""Get just the block IDs without loading the full agent objects"""
return [block.id for block in self.blocks]
def to_pydantic(self) -> PydanticIdentity:
state = {
"id": self.id,
@@ -54,6 +62,7 @@ class Identity(SqlalchemyBase, OrganizationMixin):
"identity_type": self.identity_type,
"project_id": self.project_id,
"agent_ids": self.agent_ids,
"block_ids": self.block_ids,
"organization_id": self.organization_id,
"properties": self.properties,
}

View File

@@ -69,7 +69,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
join_model: Optional[Base] = None,
join_conditions: Optional[Union[Tuple, List]] = None,
identifier_keys: Optional[List[str]] = None,
identifier_id: Optional[str] = None,
identity_id: Optional[str] = None,
**kwargs,
) -> List["SqlalchemyBase"]:
"""
@@ -148,9 +148,9 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
if identifier_keys and hasattr(cls, "identities"):
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
# given the identifier_id, we can find within the agents table any agents that have the identifier_id in their identity_ids
if identifier_id and hasattr(cls, "identities"):
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.id == identifier_id)
# given the identity_id, we can find within the agents table any agents that have the identity_id in their identity_ids
if identity_id and hasattr(cls, "identities"):
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.id == identity_id)
# Apply filtering logic from kwargs
for key, value in kwargs.items():

View File

@@ -46,6 +46,7 @@ class Identity(IdentityBase):
identity_type: IdentityType = Field(..., description="The type of the identity.")
project_id: Optional[str] = Field(None, description="The project id of the identity, if applicable.")
agent_ids: List[str] = Field(..., description="The IDs of the agents associated with the identity.")
block_ids: List[str] = Field(..., description="The IDs of the blocks associated with the identity.")
organization_id: Optional[str] = Field(None, description="The organization id of the user")
properties: List[IdentityProperty] = Field(default_factory=list, description="List of properties associated with the identity")
@@ -56,6 +57,7 @@ class IdentityCreate(LettaBase):
identity_type: IdentityType = Field(..., description="The type of the identity.")
project_id: Optional[str] = Field(None, description="The project id of the identity, if applicable.")
agent_ids: Optional[List[str]] = Field(None, description="The agent ids that are associated with the identity.")
block_ids: Optional[List[str]] = Field(None, description="The IDs of the blocks associated with the identity.")
properties: Optional[List[IdentityProperty]] = Field(None, description="List of properties associated with the identity.")
@@ -64,4 +66,5 @@ class IdentityUpdate(LettaBase):
name: Optional[str] = Field(None, description="The name of the identity.")
identity_type: Optional[IdentityType] = Field(None, description="The type of the identity.")
agent_ids: Optional[List[str]] = Field(None, description="The agent ids that are associated with the identity.")
block_ids: Optional[List[str]] = Field(None, description="The IDs of the blocks associated with the identity.")
properties: Optional[List[IdentityProperty]] = Field(None, description="List of properties associated with the identity.")

View File

@@ -53,7 +53,7 @@ def list_agents(
project_id: Optional[str] = Query(None, description="Search agents by project id"),
template_id: Optional[str] = Query(None, description="Search agents by template id"),
base_template_id: Optional[str] = Query(None, description="Search agents by base template id"),
identifier_id: Optional[str] = Query(None, description="Search agents by identifier id"),
identity_id: Optional[str] = Query(None, description="Search agents by identifier id"),
identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"),
):
"""
@@ -84,7 +84,7 @@ def list_agents(
tags=tags,
match_all_tags=match_all_tags,
identifier_keys=identifier_keys,
identifier_id=identifier_id,
identity_id=identity_id,
**kwargs,
)
return agents

View File

@@ -20,11 +20,15 @@ def list_blocks(
label: Optional[str] = Query(None, description="Labels to include (e.g. human, persona)"),
templates_only: bool = Query(True, description="Whether to include only templates"),
name: Optional[str] = Query(None, description="Name of the block"),
identity_id: Optional[str] = Query(None, description="Search agents by identifier id"),
identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"),
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.block_manager.get_blocks(actor=actor, label=label, is_template=templates_only, template_name=name)
return server.block_manager.get_blocks(
actor=actor, label=label, is_template=templates_only, template_name=name, identity_id=identity_id, identifier_keys=identifier_keys
)
@router.post("/", response_model=Block, operation_id="create_block")

View File

@@ -337,6 +337,7 @@ class AgentManager:
match_all_tags: bool = False,
query_text: Optional[str] = None,
identifier_keys: Optional[List[str]] = None,
identity_id: Optional[str] = None,
**kwargs,
) -> List[PydanticAgentState]:
"""
@@ -353,6 +354,7 @@ class AgentManager:
organization_id=actor.organization_id if actor else None,
query_text=query_text,
identifier_keys=identifier_keys,
identity_id=identity_id,
**kwargs,
)

View File

@@ -64,6 +64,8 @@ class BlockManager:
label: Optional[str] = None,
is_template: Optional[bool] = None,
template_name: Optional[str] = None,
identifier_keys: Optional[List[str]] = None,
identity_id: Optional[str] = None,
id: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
@@ -81,7 +83,14 @@ class BlockManager:
if id:
filters["id"] = id
blocks = BlockModel.list(db_session=session, after=after, limit=limit, **filters)
blocks = BlockModel.list(
db_session=session,
after=after,
limit=limit,
identifier_keys=identifier_keys,
identity_id=identity_id,
**filters,
)
return [block.to_pydantic() for block in blocks]

View File

@@ -5,6 +5,7 @@ from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import Session
from letta.orm.agent import Agent as AgentModel
from letta.orm.block import Block as BlockModel
from letta.orm.identity import Identity as IdentityModel
from letta.schemas.identity import Identity as PydanticIdentity
from letta.schemas.identity import IdentityCreate, IdentityType, IdentityUpdate
@@ -58,9 +59,24 @@ class IdentityManager:
@enforce_types
def create_identity(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity:
with self.session_maker() as session:
new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids"}, exclude_unset=True))
new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids", "block_ids"}, exclude_unset=True))
new_identity.organization_id = actor.organization_id
self._process_agent_relationship(session=session, identity=new_identity, agent_ids=identity.agent_ids, allow_partial=False)
self._process_relationship(
session=session,
identity=new_identity,
relationship_name="agents",
model_class=AgentModel,
item_ids=identity.agent_ids,
allow_partial=False,
)
self._process_relationship(
session=session,
identity=new_identity,
relationship_name="blocks",
model_class=BlockModel,
item_ids=identity.block_ids,
allow_partial=False,
)
new_identity.create(session, actor=actor)
return new_identity.to_pydantic()
@@ -124,9 +140,26 @@ class IdentityManager:
new_properties = existing_identity.properties + [prop.model_dump() for prop in identity.properties]
existing_identity.properties = new_properties
self._process_agent_relationship(
session=session, identity=existing_identity, agent_ids=identity.agent_ids, allow_partial=False, replace=replace
)
if identity.agent_ids is not None:
self._process_relationship(
session=session,
identity=existing_identity,
relationship_name="agents",
model_class=AgentModel,
item_ids=identity.agent_ids,
allow_partial=False,
replace=replace,
)
if identity.block_ids is not None:
self._process_relationship(
session=session,
identity=existing_identity,
relationship_name="blocks",
model_class=BlockModel,
item_ids=identity.block_ids,
allow_partial=False,
replace=replace,
)
existing_identity.update(session, actor=actor)
return existing_identity.to_pydantic()
@@ -141,26 +174,33 @@ class IdentityManager:
session.delete(identity)
session.commit()
def _process_agent_relationship(
self, session: Session, identity: IdentityModel, agent_ids: List[str], allow_partial=False, replace=True
def _process_relationship(
self,
session: Session,
identity: PydanticIdentity,
relationship_name: str,
model_class,
item_ids: List[str],
allow_partial=False,
replace=True,
):
current_relationship = getattr(identity, "agents", [])
if not agent_ids:
current_relationship = getattr(identity, relationship_name, [])
if not item_ids:
if replace:
setattr(identity, "agents", [])
setattr(identity, relationship_name, [])
return
# Retrieve models for the provided IDs
found_items = session.query(AgentModel).filter(AgentModel.id.in_(agent_ids)).all()
found_items = session.query(model_class).filter(model_class.id.in_(item_ids)).all()
# Validate all items are found if allow_partial is False
if not allow_partial and len(found_items) != len(agent_ids):
missing = set(agent_ids) - {item.id for item in found_items}
if not allow_partial and len(found_items) != len(item_ids):
missing = set(item_ids) - {item.id for item in found_items}
raise NoResultFound(f"Items not found in agents: {missing}")
if replace:
# Replace the relationship
setattr(identity, "agents", found_items)
setattr(identity, relationship_name, found_items)
else:
# Extend the relationship (only add new items)
current_ids = {item.id for item in current_relationship}

View File

@@ -42,7 +42,6 @@ from letta.schemas.user import User as PydanticUser
from letta.schemas.user import UserUpdate
from letta.server.server import SyncServer
from letta.services.block_manager import BlockManager
from letta.services.identity_manager import IdentityManager
from letta.services.organization_manager import OrganizationManager
from letta.settings import tool_settings
from tests.helpers.utils import comprehensive_agent_checks
@@ -2243,7 +2242,6 @@ def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, de
def test_create_and_upsert_identity(server: SyncServer, default_user):
identity_manager = IdentityManager()
identity_create = IdentityCreate(
identifier_key="1234",
name="caren",
@@ -2254,7 +2252,7 @@ def test_create_and_upsert_identity(server: SyncServer, default_user):
],
)
identity = identity_manager.create_identity(identity_create, actor=default_user)
identity = server.identity_manager.create_identity(identity_create, actor=default_user)
# Assertions to ensure the created identity matches the expected values
assert identity.identifier_key == identity_create.identifier_key
@@ -2265,48 +2263,46 @@ def test_create_and_upsert_identity(server: SyncServer, default_user):
assert identity.project_id == None
with pytest.raises(UniqueConstraintViolationError):
identity_manager.create_identity(
server.identity_manager.create_identity(
IdentityCreate(identifier_key="1234", name="sarah", identity_type=IdentityType.user),
actor=default_user,
)
identity_create.properties = [(IdentityProperty(key="age", value=29, type=IdentityPropertyType.number))]
identity = identity_manager.upsert_identity(identity_create, actor=default_user)
identity = server.identity_manager.upsert_identity(identity_create, actor=default_user)
identity = identity_manager.get_identity(identity_id=identity.id, actor=default_user)
identity = server.identity_manager.get_identity(identity_id=identity.id, actor=default_user)
assert len(identity.properties) == 1
assert identity.properties[0].key == "age"
assert identity.properties[0].value == 29
identity_manager.delete_identity(identity.id, actor=default_user)
server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user)
def test_get_identities(server, default_user):
identity_manager = IdentityManager()
# Create identities to retrieve later
user = identity_manager.create_identity(
user = server.identity_manager.create_identity(
IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user), actor=default_user
)
org = identity_manager.create_identity(
org = server.identity_manager.create_identity(
IdentityCreate(name="letta", identifier_key="0001", identity_type=IdentityType.org), actor=default_user
)
# Retrieve identities by different filters
all_identities = identity_manager.list_identities(actor=default_user)
all_identities = server.identity_manager.list_identities(actor=default_user)
assert len(all_identities) == 2
user_identities = identity_manager.list_identities(actor=default_user, identity_type=IdentityType.user)
user_identities = server.identity_manager.list_identities(actor=default_user, identity_type=IdentityType.user)
assert len(user_identities) == 1
assert user_identities[0].name == user.name
org_identities = identity_manager.list_identities(actor=default_user, identity_type=IdentityType.org)
org_identities = server.identity_manager.list_identities(actor=default_user, identity_type=IdentityType.org)
assert len(org_identities) == 1
assert org_identities[0].name == org.name
identity_manager.delete_identity(user.id, actor=default_user)
identity_manager.delete_identity(org.id, actor=default_user)
server.identity_manager.delete_identity(identity_id=user.id, actor=default_user)
server.identity_manager.delete_identity(identity_id=org.id, actor=default_user)
def test_update_identity(server: SyncServer, sarah_agent, charles_agent, default_user):
@@ -2333,7 +2329,7 @@ def test_update_identity(server: SyncServer, sarah_agent, charles_agent, default
agent_state = server.agent_manager.get_agent_by_id(agent_id=charles_agent.id, actor=default_user)
assert identity.id in agent_state.identity_ids
server.identity_manager.delete_identity(identity.id, actor=default_user)
server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user)
def test_attach_detach_identity_from_agent(server: SyncServer, sarah_agent, default_user):
@@ -2360,29 +2356,137 @@ def test_attach_detach_identity_from_agent(server: SyncServer, sarah_agent, defa
assert not identity.id in agent_state.identity_ids
def test_get_agents_for_identities(server: SyncServer, sarah_agent, charles_agent, default_user):
def test_get_set_agents_for_identities(server: SyncServer, sarah_agent, charles_agent, default_user):
identity = server.identity_manager.create_identity(
IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user, agent_ids=[sarah_agent.id, charles_agent.id]),
actor=default_user,
)
# Get the agents for identity id
agent_states = server.agent_manager.list_agents(identifier_id=identity.id, actor=default_user)
assert len(agent_states) == 2
agent_with_identity = server.create_agent(
CreateAgent(
memory_blocks=[],
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
identity_ids=[identity.id],
),
actor=default_user,
)
agent_without_identity = server.create_agent(
CreateAgent(
memory_blocks=[],
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
actor=default_user,
)
# Check both agents are in the list
# Get the agents for identity id
agent_states = server.agent_manager.list_agents(identity_id=identity.id, actor=default_user)
assert len(agent_states) == 3
# Check all agents are in the list
agent_state_ids = [a.id for a in agent_states]
assert sarah_agent.id in agent_state_ids
assert charles_agent.id in agent_state_ids
assert agent_with_identity.id in agent_state_ids
assert not agent_without_identity.id in agent_state_ids
# Get the agents for identifier key
agent_states = server.agent_manager.list_agents(identifier_keys=[identity.identifier_key], actor=default_user)
assert len(agent_states) == 2
assert len(agent_states) == 3
# Check both agents are in the list
# Check all agents are in the list
agent_state_ids = [a.id for a in agent_states]
assert sarah_agent.id in agent_state_ids
assert charles_agent.id in agent_state_ids
assert agent_with_identity.id in agent_state_ids
assert not agent_without_identity.id in agent_state_ids
# Delete new agents
server.agent_manager.delete_agent(agent_id=agent_with_identity.id, actor=default_user)
server.agent_manager.delete_agent(agent_id=agent_without_identity.id, actor=default_user)
# Get the agents for identity id
agent_states = server.agent_manager.list_agents(identity_id=identity.id, actor=default_user)
assert len(agent_states) == 2
# Check only initial agents are in the list
agent_state_ids = [a.id for a in agent_states]
assert sarah_agent.id in agent_state_ids
assert charles_agent.id in agent_state_ids
server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user)
def test_attach_detach_identity_from_block(server: SyncServer, default_block, default_user):
# Create an identity
identity = server.identity_manager.create_identity(
IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user, block_ids=[default_block.id]),
actor=default_user,
)
# Check that identity has been attached
blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user)
assert len(blocks) == 1 and blocks[0].id == default_block.id
# Now attempt to delete the identity
server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user)
# Verify that the identity was deleted
identities = server.identity_manager.list_identities(actor=default_user)
assert len(identities) == 0
# Check that block has been detached too
blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user)
assert len(blocks) == 0
def test_get_set_blocks_for_identities(server: SyncServer, default_block, default_user):
block_manager = BlockManager()
block_with_identity = block_manager.create_or_update_block(PydanticBlock(label="persona", value="Original Content"), actor=default_user)
block_without_identity = block_manager.create_or_update_block(PydanticBlock(label="user", value="Original Content"), actor=default_user)
identity = server.identity_manager.create_identity(
IdentityCreate(
name="caren", identifier_key="1234", identity_type=IdentityType.user, block_ids=[default_block.id, block_with_identity.id]
),
actor=default_user,
)
# Get the blocks for identity id
blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user)
assert len(blocks) == 2
# Check blocks are in the list
block_ids = [b.id for b in blocks]
assert default_block.id in block_ids
assert block_with_identity.id in block_ids
assert not block_without_identity.id in block_ids
# Get the blocks for identifier key
blocks = server.block_manager.get_blocks(identifier_keys=[identity.identifier_key], actor=default_user)
assert len(blocks) == 2
# Check blocks are in the list
block_ids = [b.id for b in blocks]
assert default_block.id in block_ids
assert block_with_identity.id in block_ids
assert not block_without_identity.id in block_ids
# Delete new agents
server.block_manager.delete_block(block_id=block_with_identity.id, actor=default_user)
server.block_manager.delete_block(block_id=block_without_identity.id, actor=default_user)
# Get the blocks for identity id
blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user)
assert len(blocks) == 1
# Check only initial block in the list
block_ids = [b.id for b in blocks]
assert default_block.id in block_ids
assert not block_with_identity.id in block_ids
assert not block_without_identity.id in block_ids
server.identity_manager.delete_identity(identity.id, actor=default_user)
# ======================================================================================================================

View File

@@ -380,6 +380,8 @@ def test_list_blocks(client, mock_sync_server):
label=None,
is_template=True,
template_name=None,
identity_id=None,
identifier_keys=None,
)