diff --git a/letta/server/rest_api/routers/v1/blocks.py b/letta/server/rest_api/routers/v1/blocks.py index ef43d9b6..4a9ea8da 100644 --- a/letta/server/rest_api/routers/v1/blocks.py +++ b/letta/server/rest_api/routers/v1/blocks.py @@ -15,19 +15,26 @@ router = APIRouter(prefix="/blocks", tags=["blocks"]) @router.get("/", response_model=List[Block], operation_id="list_blocks") -def list_blocks( +async def list_blocks( # query parameters label: Optional[str] = Query(None, description="Labels to include (e.g. human, persona)"), templates_only: bool = Query(False, description="Whether to include only templates"), name: Optional[str] = Query(None, description="Name of the block"), identity_id: Optional[str] = Query(None, description="Search agents by identifier id"), identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"), + limit: Optional[int] = Query(50, description="Number of blocks to return"), server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.block_manager.get_blocks( - actor=actor, label=label, is_template=templates_only, template_name=name, identity_id=identity_id, identifier_keys=identifier_keys + return await server.block_manager.get_blocks_async( + actor=actor, + label=label, + is_template=templates_only, + template_name=name, + identity_id=identity_id, + identifier_keys=identifier_keys, + limit=limit, ) diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index 63e95060..30450f01 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -117,6 +117,68 @@ class BlockManager: return [block.to_pydantic() for block in blocks] + @enforce_types + async def get_blocks_async( + self, + actor: PydanticUser, + label: Optional[str] = None, + is_template: Optional[bool] = None, + template_name: Optional[str] = None, + identity_id: Optional[str] = None, + identifier_keys: Optional[List[str]] = None, + limit: Optional[int] = 50, + ) -> List[PydanticBlock]: + """Async version of get_blocks method. Retrieve blocks based on various optional filters.""" + from sqlalchemy import select + from sqlalchemy.orm import noload + + from letta.orm.sqlalchemy_base import AccessType + + async with db_registry.async_session() as session: + # Start with a basic query + query = select(BlockModel) + + # Explicitly avoid loading relationships + query = query.options(noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups)) + + # Apply access control + query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION) + + # Add filters + query = query.where(BlockModel.organization_id == actor.organization_id) + if label: + query = query.where(BlockModel.label == label) + + if is_template is not None: + query = query.where(BlockModel.is_template == is_template) + + if template_name: + query = query.where(BlockModel.template_name == template_name) + + if identifier_keys: + query = ( + query.join(BlockModel.identities) + .filter(BlockModel.identities.property.mapper.class_.identifier_key.in_(identifier_keys)) + .distinct(BlockModel.id) + ) + + if identity_id: + query = ( + query.join(BlockModel.identities) + .filter(BlockModel.identities.property.mapper.class_.id == identity_id) + .distinct(BlockModel.id) + ) + + # Add limit + if limit: + query = query.limit(limit) + + # Execute the query + result = await session.execute(query) + blocks = result.scalars().all() + + return [block.to_pydantic() for block in blocks] + @enforce_types def get_block_by_id(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]: """Retrieve a block by its name.""" diff --git a/tests/test_managers.py b/tests/test_managers.py index bff28191..afb5353c 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -1,3 +1,4 @@ +import asyncio import logging import os import random @@ -628,6 +629,14 @@ def letta_batch_job(server: SyncServer, default_user) -> Job: return server.job_manager.create_job(BatchJob(user_id=default_user.id), actor=default_user) +@pytest.fixture(scope="session") +def event_loop(request): + """Create an instance of the default event loop for each test case.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + # ====================================================================================================================== # AgentManager Tests - Basic # ====================================================================================================================== @@ -1717,7 +1726,8 @@ def test_refresh_memory(server: SyncServer, default_user): assert len(agent.memory.blocks) == 0 -async def test_refresh_memory_async(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_refresh_memory_async(server: SyncServer, default_user, event_loop): block = server.block_manager.create_or_update_block( PydanticBlock( label="test", @@ -1726,13 +1736,21 @@ async def test_refresh_memory_async(server: SyncServer, default_user): ), actor=default_user, ) + block_human = server.block_manager.create_or_update_block( + PydanticBlock( + label="human", + value="name: caren", + limit=1000, + ), + actor=default_user, + ) agent = server.agent_manager.create_agent( CreateAgent( name="test", llm_config=LLMConfig.default_config("gpt-4o-mini"), embedding_config=EmbeddingConfig.default_config(provider="openai"), include_base_tools=False, - block_ids=[block.id], + block_ids=[block.id, block_human.id], ), actor=default_user, ) @@ -1743,10 +1761,10 @@ async def test_refresh_memory_async(server: SyncServer, default_user): ), actor=default_user, ) - assert len(agent.memory.blocks) == 1 + assert len(agent.memory.blocks) == 2 agent = await server.agent_manager.refresh_memory_async(agent_state=agent, actor=default_user) - assert len(agent.memory.blocks) == 1 - assert agent.memory.blocks[0].value == "test2" + assert len(agent.memory.blocks) == 2 + assert any([block.value == "test2" for block in agent.memory.blocks]) # ====================================================================================================================== @@ -2583,7 +2601,8 @@ def test_create_block(server: SyncServer, default_user): assert block.organization_id == default_user.organization_id -def test_get_blocks(server, default_user): +@pytest.mark.asyncio +async def test_get_blocks(server, default_user, event_loop): block_manager = BlockManager() # Create blocks to retrieve later @@ -2591,19 +2610,20 @@ def test_get_blocks(server, default_user): block_manager.create_or_update_block(PydanticBlock(label="persona", value="Block 2"), actor=default_user) # Retrieve blocks by different filters - all_blocks = block_manager.get_blocks(actor=default_user) + all_blocks = await block_manager.get_blocks_async(actor=default_user) assert len(all_blocks) == 2 - human_blocks = block_manager.get_blocks(actor=default_user, label="human") + human_blocks = await block_manager.get_blocks_async(actor=default_user, label="human") assert len(human_blocks) == 1 assert human_blocks[0].label == "human" - persona_blocks = block_manager.get_blocks(actor=default_user, label="persona") + persona_blocks = await block_manager.get_blocks_async(actor=default_user, label="persona") assert len(persona_blocks) == 1 assert persona_blocks[0].label == "persona" -def test_get_blocks_comprehensive(server, default_user, other_user_different_org): +@pytest.mark.asyncio +async def test_get_blocks_comprehensive(server, default_user, other_user_different_org, event_loop): def random_label(prefix="label"): return f"{prefix}_{''.join(random.choices(string.ascii_lowercase, k=6))}" @@ -2629,7 +2649,7 @@ def test_get_blocks_comprehensive(server, default_user, other_user_different_org other_user_blocks.append((label, value)) # Check default_user sees only their blocks - retrieved_default_blocks = block_manager.get_blocks(actor=default_user) + retrieved_default_blocks = await block_manager.get_blocks_async(actor=default_user) assert len(retrieved_default_blocks) == 10 retrieved_labels = {b.label for b in retrieved_default_blocks} for label, value in default_user_blocks: @@ -2637,13 +2657,13 @@ def test_get_blocks_comprehensive(server, default_user, other_user_different_org # Check individual filtering for default_user for label, value in default_user_blocks: - filtered = block_manager.get_blocks(actor=default_user, label=label) + filtered = await block_manager.get_blocks_async(actor=default_user, label=label) assert len(filtered) == 1 assert filtered[0].label == label assert filtered[0].value == value # Check other_user sees only their blocks - retrieved_other_blocks = block_manager.get_blocks(actor=other_user_different_org) + retrieved_other_blocks = await block_manager.get_blocks_async(actor=other_user_different_org) assert len(retrieved_other_blocks) == 3 retrieved_labels = {b.label for b in retrieved_other_blocks} for label, value in other_user_blocks: @@ -2651,11 +2671,11 @@ def test_get_blocks_comprehensive(server, default_user, other_user_different_org # Other user shouldn't see default_user's blocks for label, _ in default_user_blocks: - assert block_manager.get_blocks(actor=other_user_different_org, label=label) == [] + assert (await block_manager.get_blocks_async(actor=other_user_different_org, label=label)) == [] # Default user shouldn't see other_user's blocks for label, _ in other_user_blocks: - assert block_manager.get_blocks(actor=default_user, label=label) == [] + assert (await block_manager.get_blocks_async(actor=default_user, label=label)) == [] def test_update_block(server: SyncServer, default_user): @@ -2667,7 +2687,7 @@ def test_update_block(server: SyncServer, default_user): block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user) # Retrieve the updated block - updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0] + updated_block = block_manager.get_block_by_id(actor=default_user, block_id=block.id) # Assertions to verify the update assert updated_block.value == "Updated Content" @@ -2690,7 +2710,7 @@ def test_update_block_limit(server: SyncServer, default_user): block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user) # Retrieve the updated block and validate the update - updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0] + updated_block = block_manager.get_block_by_id(actor=default_user, block_id=block.id) assert updated_block.value == "Updated Content" * 2000 assert updated_block.description == "Updated description" @@ -2707,11 +2727,12 @@ def test_update_block_limit_does_not_reset(server: SyncServer, default_user): block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user) # Retrieve the updated block and validate the update - updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0] + updated_block = block_manager.get_block_by_id(actor=default_user, block_id=block.id) assert updated_block.value == new_content -def test_delete_block(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_delete_block(server: SyncServer, default_user, event_loop): block_manager = BlockManager() # Create and delete a block @@ -2719,11 +2740,12 @@ def test_delete_block(server: SyncServer, default_user): block_manager.delete_block(block_id=block.id, actor=default_user) # Verify that the block was deleted - blocks = block_manager.get_blocks(actor=default_user) + blocks = await block_manager.get_blocks_async(actor=default_user) assert len(blocks) == 0 -def test_delete_block_detaches_from_agent(server: SyncServer, sarah_agent, default_user): +@pytest.mark.asyncio +async def test_delete_block_detaches_from_agent(server: SyncServer, sarah_agent, default_user, event_loop): # Create and delete a block block = server.block_manager.create_or_update_block(PydanticBlock(label="human", value="Sample content"), actor=default_user) agent_state = server.agent_manager.attach_block(agent_id=sarah_agent.id, block_id=block.id, actor=default_user) @@ -2735,7 +2757,7 @@ def test_delete_block_detaches_from_agent(server: SyncServer, sarah_agent, defau server.block_manager.delete_block(block_id=block.id, actor=default_user) # Verify that the block was deleted - blocks = server.block_manager.get_blocks(actor=default_user) + blocks = await server.block_manager.get_blocks_async(actor=default_user) assert len(blocks) == 0 # Check that block has been detached too @@ -2763,7 +2785,8 @@ def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, de assert charles_agent.id in agent_state_ids -def test_batch_create_multiple_blocks(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_batch_create_multiple_blocks(server: SyncServer, default_user, event_loop): block_manager = BlockManager() num_blocks = 10 @@ -2788,7 +2811,7 @@ def test_batch_create_multiple_blocks(server: SyncServer, default_user): assert blk.id is not None # Confirm all created blocks exist in the full list from get_blocks - all_labels = {blk.label for blk in block_manager.get_blocks(actor=default_user)} + all_labels = {blk.label for blk in await block_manager.get_blocks_async(actor=default_user)} expected_labels = {f"batch_label_{i}" for i in range(num_blocks)} assert expected_labels.issubset(all_labels) @@ -2819,7 +2842,7 @@ def test_bulk_update_skips_missing_and_truncates_then_returns_none(server: SyncS assert "truncating" in caplog.text # confirm the value was truncated to `limit` characters - reloaded = mgr.get_blocks(actor=default_user, id=b.id)[0] + reloaded = mgr.get_block_by_id(actor=default_user, block_id=b.id) assert len(reloaded.value) == 5 assert reloaded.value == long_val[:5] @@ -2864,11 +2887,11 @@ def test_bulk_update_respects_org_scoping(server: SyncServer, default_user: Pyda mgr.bulk_update_block_values(updates, actor=default_user) # mine should be updated... - reloaded_mine = mgr.get_blocks(actor=default_user, id=mine.id)[0] + reloaded_mine = mgr.get_block_by_id(actor=default_user, block_id=mine.id) assert reloaded_mine.value == "updated-mine" # ...theirs should remain untouched - reloaded_theirs = mgr.get_blocks(actor=other_user_different_org, id=theirs.id)[0] + reloaded_theirs = mgr.get_block_by_id(actor=other_user_different_org, block_id=theirs.id) assert reloaded_theirs.value == "theirs" # warning should mention skipping the other-org ID @@ -3625,7 +3648,8 @@ def test_get_set_agents_for_identities(server: SyncServer, sarah_agent, charles_ server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user) -def test_attach_detach_identity_from_block(server: SyncServer, default_block, default_user): +@pytest.mark.asyncio +async def test_attach_detach_identity_from_block(server: SyncServer, default_block, default_user, event_loop): # Create an identity identity = server.identity_manager.create_identity( IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user, block_ids=[default_block.id]), @@ -3633,7 +3657,7 @@ def test_attach_detach_identity_from_block(server: SyncServer, default_block, de ) # Check that identity has been attached - blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user) + blocks = await server.block_manager.get_blocks_async(identity_id=identity.id, actor=default_user) assert len(blocks) == 1 and blocks[0].id == default_block.id # Now attempt to delete the identity @@ -3644,11 +3668,12 @@ def test_attach_detach_identity_from_block(server: SyncServer, default_block, de assert len(identities) == 0 # Check that block has been detached too - blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user) + blocks = await server.block_manager.get_blocks_async(identity_id=identity.id, actor=default_user) assert len(blocks) == 0 -def test_get_set_blocks_for_identities(server: SyncServer, default_block, default_user): +@pytest.mark.asyncio +async def test_get_set_blocks_for_identities(server: SyncServer, default_block, default_user, event_loop): block_manager = BlockManager() block_with_identity = block_manager.create_or_update_block(PydanticBlock(label="persona", value="Original Content"), actor=default_user) block_without_identity = block_manager.create_or_update_block(PydanticBlock(label="user", value="Original Content"), actor=default_user) @@ -3660,7 +3685,7 @@ def test_get_set_blocks_for_identities(server: SyncServer, default_block, defaul ) # Get the blocks for identity id - blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user) + blocks = await server.block_manager.get_blocks_async(identity_id=identity.id, actor=default_user) assert len(blocks) == 2 # Check blocks are in the list @@ -3670,7 +3695,7 @@ def test_get_set_blocks_for_identities(server: SyncServer, default_block, defaul assert not block_without_identity.id in block_ids # Get the blocks for identifier key - blocks = server.block_manager.get_blocks(identifier_keys=[identity.identifier_key], actor=default_user) + blocks = await server.block_manager.get_blocks_async(identifier_keys=[identity.identifier_key], actor=default_user) assert len(blocks) == 2 # Check blocks are in the list @@ -3684,7 +3709,7 @@ def test_get_set_blocks_for_identities(server: SyncServer, default_block, defaul server.block_manager.delete_block(block_id=block_without_identity.id, actor=default_user) # Get the blocks for identity id - blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user) + blocks = await server.block_manager.get_blocks_async(identity_id=identity.id, actor=default_user) assert len(blocks) == 1 # Check only initial block in the list