From 825ad3d9d71b2b92bfbdef7d49ac32c7d18138d4 Mon Sep 17 00:00:00 2001 From: cthomas Date: Wed, 5 Mar 2025 16:22:20 -0800 Subject: [PATCH] chore: add identities tests (#1204) --- letta/services/identity_manager.py | 8 +- tests/test_managers.py | 150 +++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 1 deletion(-) diff --git a/letta/services/identity_manager.py b/letta/services/identity_manager.py index 45e77d26..42efa191 100644 --- a/letta/services/identity_manager.py +++ b/letta/services/identity_manager.py @@ -78,7 +78,13 @@ class IdentityManager: if existing_identity is None: return self.create_identity(identity=identity, actor=actor) else: - identity_update = IdentityUpdate(name=identity.name, identity_type=identity.identity_type, agent_ids=identity.agent_ids) + identity_update = IdentityUpdate( + name=identity.name, + identifier_key=identity.identifier_key, + identity_type=identity.identity_type, + agent_ids=identity.agent_ids, + properties=identity.properties, + ) return self._update_identity( session=session, existing_identity=existing_identity, identity=identity_update, actor=actor, replace=True ) diff --git a/tests/test_managers.py b/tests/test_managers.py index 334e72ed..ba4f0f93 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -21,6 +21,7 @@ from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import JobStatus, MessageRole from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate from letta.schemas.file import FileMetadata as PydanticFileMetadata +from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate from letta.schemas.job import Job as PydanticJob from letta.schemas.job import JobUpdate, LettaRequestConfig from letta.schemas.llm_config import LLMConfig @@ -40,6 +41,7 @@ from letta.schemas.user import User as PydanticUser from letta.schemas.user import UserUpdate from letta.server.server import SyncServer from letta.services.block_manager import BlockManager +from letta.services.identity_manager import IdentityManager from letta.services.organization_manager import OrganizationManager from letta.settings import tool_settings from tests.helpers.utils import comprehensive_agent_checks @@ -2075,6 +2077,154 @@ def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, de assert charles_agent.id in agent_state_ids +# ====================================================================================================================== +# Identity Manager Tests +# ====================================================================================================================== + + +def test_create_and_upsert_identity(server: SyncServer, default_user): + identity_manager = IdentityManager() + identity_create = IdentityCreate( + identifier_key="1234", + name="caren", + identity_type=IdentityType.user, + properties=[ + IdentityProperty(key="email", value="caren@letta.com", type=IdentityPropertyType.string), + IdentityProperty(key="age", value=28, type=IdentityPropertyType.number), + ], + ) + + identity = identity_manager.create_identity(identity_create, actor=default_user) + + # Assertions to ensure the created identity matches the expected values + assert identity.identifier_key == identity_create.identifier_key + assert identity.name == identity_create.name + assert identity.identity_type == identity_create.identity_type + assert identity.properties == identity_create.properties + assert identity.agent_ids == [] + assert identity.project_id == None + + with pytest.raises(UniqueConstraintViolationError): + identity_manager.create_identity( + IdentityCreate(identifier_key="1234", name="sarah", identity_type=IdentityType.user), + actor=default_user, + ) + + identity_create.properties = [(IdentityProperty(key="age", value=29, type=IdentityPropertyType.number))] + + identity = identity_manager.upsert_identity(identity_create, actor=default_user) + + identity = identity_manager.get_identity(identity_id=identity.id, actor=default_user) + assert len(identity.properties) == 1 + assert identity.properties[0].key == "age" + assert identity.properties[0].value == 29 + + identity_manager.delete_identity(identity.id, actor=default_user) + + +def test_get_identities(server, default_user): + identity_manager = IdentityManager() + + # Create identities to retrieve later + user = identity_manager.create_identity( + IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user), actor=default_user + ) + org = identity_manager.create_identity( + IdentityCreate(name="letta", identifier_key="0001", identity_type=IdentityType.org), actor=default_user + ) + + # Retrieve identities by different filters + all_identities = identity_manager.list_identities(actor=default_user) + assert len(all_identities) == 2 + + user_identities = identity_manager.list_identities(actor=default_user, identity_type=IdentityType.user) + assert len(user_identities) == 1 + assert user_identities[0].name == user.name + + org_identities = identity_manager.list_identities(actor=default_user, identity_type=IdentityType.org) + assert len(org_identities) == 1 + assert org_identities[0].name == org.name + + identity_manager.delete_identity(user.id, actor=default_user) + identity_manager.delete_identity(org.id, actor=default_user) + + +def test_update_identity(server: SyncServer, sarah_agent, charles_agent, default_user): + identity = server.identity_manager.create_identity( + IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user), actor=default_user + ) + + # Update identity fields + update_data = IdentityUpdate( + agent_ids=[sarah_agent.id, charles_agent.id], + properties=[IdentityProperty(key="email", value="caren@letta.com", type=IdentityPropertyType.string)], + ) + server.identity_manager.update_identity(identity_id=identity.id, identity=update_data, actor=default_user) + + # Retrieve the updated identity + updated_identity = server.identity_manager.get_identity(identity_id=identity.id, actor=default_user) + + # Assertions to verify the update + assert updated_identity.agent_ids.sort() == update_data.agent_ids.sort() + assert updated_identity.properties == update_data.properties + + agent_state = server.agent_manager.get_agent_by_id(agent_id=sarah_agent.id, actor=default_user) + assert identity.id in agent_state.identity_ids + agent_state = server.agent_manager.get_agent_by_id(agent_id=charles_agent.id, actor=default_user) + assert identity.id in agent_state.identity_ids + + server.identity_manager.delete_identity(identity.id, actor=default_user) + + +def test_attach_detach_identity_from_agent(server: SyncServer, sarah_agent, default_user): + # Create an identity + identity = server.identity_manager.create_identity( + IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user), actor=default_user + ) + agent_state = server.agent_manager.update_agent( + agent_id=sarah_agent.id, agent_update=UpdateAgent(identity_ids=[identity.id]), actor=default_user + ) + + # Check that identity has been attached + assert identity.id in agent_state.identity_ids + + # Now attempt to delete the identity + server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user) + + # Verify that the identity was deleted + identities = server.identity_manager.list_identities(actor=default_user) + assert len(identities) == 0 + + # Check that block has been detached too + agent_state = server.agent_manager.get_agent_by_id(agent_id=sarah_agent.id, actor=default_user) + assert not identity.id in agent_state.identity_ids + + +def test_get_agents_for_identities(server: SyncServer, sarah_agent, charles_agent, default_user): + identity = server.identity_manager.create_identity( + IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user, agent_ids=[sarah_agent.id, charles_agent.id]), + actor=default_user, + ) + + # Get the agents for identity id + agent_states = server.agent_manager.list_agents(identifier_id=identity.id, actor=default_user) + assert len(agent_states) == 2 + + # Check both agents are in the list + agent_state_ids = [a.id for a in agent_states] + assert sarah_agent.id in agent_state_ids + assert charles_agent.id in agent_state_ids + + # Get the agents for identifier key + agent_states = server.agent_manager.list_agents(identifier_keys=[identity.identifier_key], actor=default_user) + assert len(agent_states) == 2 + + # Check both agents are in the list + agent_state_ids = [a.id for a in agent_states] + assert sarah_agent.id in agent_state_ids + assert charles_agent.id in agent_state_ids + + # ====================================================================================================================== # SourceManager Tests - Sources # ======================================================================================================================