feat: make get blocks async (#2162)

This commit is contained in:
cthomas
2025-05-13 15:09:25 -07:00
committed by GitHub
parent d133ca248f
commit a279126fb6
3 changed files with 131 additions and 37 deletions

View File

@@ -15,19 +15,26 @@ router = APIRouter(prefix="/blocks", tags=["blocks"])
@router.get("/", response_model=List[Block], operation_id="list_blocks")
def list_blocks(
async def list_blocks(
# query parameters
label: Optional[str] = Query(None, description="Labels to include (e.g. human, persona)"),
templates_only: bool = Query(False, description="Whether to include only templates"),
name: Optional[str] = Query(None, description="Name of the block"),
identity_id: Optional[str] = Query(None, description="Search agents by identifier id"),
identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"),
limit: Optional[int] = Query(50, description="Number of blocks to return"),
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.block_manager.get_blocks(
actor=actor, label=label, is_template=templates_only, template_name=name, identity_id=identity_id, identifier_keys=identifier_keys
return await server.block_manager.get_blocks_async(
actor=actor,
label=label,
is_template=templates_only,
template_name=name,
identity_id=identity_id,
identifier_keys=identifier_keys,
limit=limit,
)

View File

@@ -117,6 +117,68 @@ class BlockManager:
return [block.to_pydantic() for block in blocks]
@enforce_types
async def get_blocks_async(
self,
actor: PydanticUser,
label: Optional[str] = None,
is_template: Optional[bool] = None,
template_name: Optional[str] = None,
identity_id: Optional[str] = None,
identifier_keys: Optional[List[str]] = None,
limit: Optional[int] = 50,
) -> List[PydanticBlock]:
"""Async version of get_blocks method. Retrieve blocks based on various optional filters."""
from sqlalchemy import select
from sqlalchemy.orm import noload
from letta.orm.sqlalchemy_base import AccessType
async with db_registry.async_session() as session:
# Start with a basic query
query = select(BlockModel)
# Explicitly avoid loading relationships
query = query.options(noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups))
# Apply access control
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
# Add filters
query = query.where(BlockModel.organization_id == actor.organization_id)
if label:
query = query.where(BlockModel.label == label)
if is_template is not None:
query = query.where(BlockModel.is_template == is_template)
if template_name:
query = query.where(BlockModel.template_name == template_name)
if identifier_keys:
query = (
query.join(BlockModel.identities)
.filter(BlockModel.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
.distinct(BlockModel.id)
)
if identity_id:
query = (
query.join(BlockModel.identities)
.filter(BlockModel.identities.property.mapper.class_.id == identity_id)
.distinct(BlockModel.id)
)
# Add limit
if limit:
query = query.limit(limit)
# Execute the query
result = await session.execute(query)
blocks = result.scalars().all()
return [block.to_pydantic() for block in blocks]
@enforce_types
def get_block_by_id(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
"""Retrieve a block by its name."""

View File

@@ -1,3 +1,4 @@
import asyncio
import logging
import os
import random
@@ -628,6 +629,14 @@ def letta_batch_job(server: SyncServer, default_user) -> Job:
return server.job_manager.create_job(BatchJob(user_id=default_user.id), actor=default_user)
@pytest.fixture(scope="session")
def event_loop(request):
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
# ======================================================================================================================
# AgentManager Tests - Basic
# ======================================================================================================================
@@ -1717,7 +1726,8 @@ def test_refresh_memory(server: SyncServer, default_user):
assert len(agent.memory.blocks) == 0
async def test_refresh_memory_async(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_refresh_memory_async(server: SyncServer, default_user, event_loop):
block = server.block_manager.create_or_update_block(
PydanticBlock(
label="test",
@@ -1726,13 +1736,21 @@ async def test_refresh_memory_async(server: SyncServer, default_user):
),
actor=default_user,
)
block_human = server.block_manager.create_or_update_block(
PydanticBlock(
label="human",
value="name: caren",
limit=1000,
),
actor=default_user,
)
agent = server.agent_manager.create_agent(
CreateAgent(
name="test",
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
include_base_tools=False,
block_ids=[block.id],
block_ids=[block.id, block_human.id],
),
actor=default_user,
)
@@ -1743,10 +1761,10 @@ async def test_refresh_memory_async(server: SyncServer, default_user):
),
actor=default_user,
)
assert len(agent.memory.blocks) == 1
assert len(agent.memory.blocks) == 2
agent = await server.agent_manager.refresh_memory_async(agent_state=agent, actor=default_user)
assert len(agent.memory.blocks) == 1
assert agent.memory.blocks[0].value == "test2"
assert len(agent.memory.blocks) == 2
assert any([block.value == "test2" for block in agent.memory.blocks])
# ======================================================================================================================
@@ -2583,7 +2601,8 @@ def test_create_block(server: SyncServer, default_user):
assert block.organization_id == default_user.organization_id
def test_get_blocks(server, default_user):
@pytest.mark.asyncio
async def test_get_blocks(server, default_user, event_loop):
block_manager = BlockManager()
# Create blocks to retrieve later
@@ -2591,19 +2610,20 @@ def test_get_blocks(server, default_user):
block_manager.create_or_update_block(PydanticBlock(label="persona", value="Block 2"), actor=default_user)
# Retrieve blocks by different filters
all_blocks = block_manager.get_blocks(actor=default_user)
all_blocks = await block_manager.get_blocks_async(actor=default_user)
assert len(all_blocks) == 2
human_blocks = block_manager.get_blocks(actor=default_user, label="human")
human_blocks = await block_manager.get_blocks_async(actor=default_user, label="human")
assert len(human_blocks) == 1
assert human_blocks[0].label == "human"
persona_blocks = block_manager.get_blocks(actor=default_user, label="persona")
persona_blocks = await block_manager.get_blocks_async(actor=default_user, label="persona")
assert len(persona_blocks) == 1
assert persona_blocks[0].label == "persona"
def test_get_blocks_comprehensive(server, default_user, other_user_different_org):
@pytest.mark.asyncio
async def test_get_blocks_comprehensive(server, default_user, other_user_different_org, event_loop):
def random_label(prefix="label"):
return f"{prefix}_{''.join(random.choices(string.ascii_lowercase, k=6))}"
@@ -2629,7 +2649,7 @@ def test_get_blocks_comprehensive(server, default_user, other_user_different_org
other_user_blocks.append((label, value))
# Check default_user sees only their blocks
retrieved_default_blocks = block_manager.get_blocks(actor=default_user)
retrieved_default_blocks = await block_manager.get_blocks_async(actor=default_user)
assert len(retrieved_default_blocks) == 10
retrieved_labels = {b.label for b in retrieved_default_blocks}
for label, value in default_user_blocks:
@@ -2637,13 +2657,13 @@ def test_get_blocks_comprehensive(server, default_user, other_user_different_org
# Check individual filtering for default_user
for label, value in default_user_blocks:
filtered = block_manager.get_blocks(actor=default_user, label=label)
filtered = await block_manager.get_blocks_async(actor=default_user, label=label)
assert len(filtered) == 1
assert filtered[0].label == label
assert filtered[0].value == value
# Check other_user sees only their blocks
retrieved_other_blocks = block_manager.get_blocks(actor=other_user_different_org)
retrieved_other_blocks = await block_manager.get_blocks_async(actor=other_user_different_org)
assert len(retrieved_other_blocks) == 3
retrieved_labels = {b.label for b in retrieved_other_blocks}
for label, value in other_user_blocks:
@@ -2651,11 +2671,11 @@ def test_get_blocks_comprehensive(server, default_user, other_user_different_org
# Other user shouldn't see default_user's blocks
for label, _ in default_user_blocks:
assert block_manager.get_blocks(actor=other_user_different_org, label=label) == []
assert (await block_manager.get_blocks_async(actor=other_user_different_org, label=label)) == []
# Default user shouldn't see other_user's blocks
for label, _ in other_user_blocks:
assert block_manager.get_blocks(actor=default_user, label=label) == []
assert (await block_manager.get_blocks_async(actor=default_user, label=label)) == []
def test_update_block(server: SyncServer, default_user):
@@ -2667,7 +2687,7 @@ def test_update_block(server: SyncServer, default_user):
block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user)
# Retrieve the updated block
updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0]
updated_block = block_manager.get_block_by_id(actor=default_user, block_id=block.id)
# Assertions to verify the update
assert updated_block.value == "Updated Content"
@@ -2690,7 +2710,7 @@ def test_update_block_limit(server: SyncServer, default_user):
block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user)
# Retrieve the updated block and validate the update
updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0]
updated_block = block_manager.get_block_by_id(actor=default_user, block_id=block.id)
assert updated_block.value == "Updated Content" * 2000
assert updated_block.description == "Updated description"
@@ -2707,11 +2727,12 @@ def test_update_block_limit_does_not_reset(server: SyncServer, default_user):
block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user)
# Retrieve the updated block and validate the update
updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0]
updated_block = block_manager.get_block_by_id(actor=default_user, block_id=block.id)
assert updated_block.value == new_content
def test_delete_block(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_delete_block(server: SyncServer, default_user, event_loop):
block_manager = BlockManager()
# Create and delete a block
@@ -2719,11 +2740,12 @@ def test_delete_block(server: SyncServer, default_user):
block_manager.delete_block(block_id=block.id, actor=default_user)
# Verify that the block was deleted
blocks = block_manager.get_blocks(actor=default_user)
blocks = await block_manager.get_blocks_async(actor=default_user)
assert len(blocks) == 0
def test_delete_block_detaches_from_agent(server: SyncServer, sarah_agent, default_user):
@pytest.mark.asyncio
async def test_delete_block_detaches_from_agent(server: SyncServer, sarah_agent, default_user, event_loop):
# Create and delete a block
block = server.block_manager.create_or_update_block(PydanticBlock(label="human", value="Sample content"), actor=default_user)
agent_state = server.agent_manager.attach_block(agent_id=sarah_agent.id, block_id=block.id, actor=default_user)
@@ -2735,7 +2757,7 @@ def test_delete_block_detaches_from_agent(server: SyncServer, sarah_agent, defau
server.block_manager.delete_block(block_id=block.id, actor=default_user)
# Verify that the block was deleted
blocks = server.block_manager.get_blocks(actor=default_user)
blocks = await server.block_manager.get_blocks_async(actor=default_user)
assert len(blocks) == 0
# Check that block has been detached too
@@ -2763,7 +2785,8 @@ def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, de
assert charles_agent.id in agent_state_ids
def test_batch_create_multiple_blocks(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_batch_create_multiple_blocks(server: SyncServer, default_user, event_loop):
block_manager = BlockManager()
num_blocks = 10
@@ -2788,7 +2811,7 @@ def test_batch_create_multiple_blocks(server: SyncServer, default_user):
assert blk.id is not None
# Confirm all created blocks exist in the full list from get_blocks
all_labels = {blk.label for blk in block_manager.get_blocks(actor=default_user)}
all_labels = {blk.label for blk in await block_manager.get_blocks_async(actor=default_user)}
expected_labels = {f"batch_label_{i}" for i in range(num_blocks)}
assert expected_labels.issubset(all_labels)
@@ -2819,7 +2842,7 @@ def test_bulk_update_skips_missing_and_truncates_then_returns_none(server: SyncS
assert "truncating" in caplog.text
# confirm the value was truncated to `limit` characters
reloaded = mgr.get_blocks(actor=default_user, id=b.id)[0]
reloaded = mgr.get_block_by_id(actor=default_user, block_id=b.id)
assert len(reloaded.value) == 5
assert reloaded.value == long_val[:5]
@@ -2864,11 +2887,11 @@ def test_bulk_update_respects_org_scoping(server: SyncServer, default_user: Pyda
mgr.bulk_update_block_values(updates, actor=default_user)
# mine should be updated...
reloaded_mine = mgr.get_blocks(actor=default_user, id=mine.id)[0]
reloaded_mine = mgr.get_block_by_id(actor=default_user, block_id=mine.id)
assert reloaded_mine.value == "updated-mine"
# ...theirs should remain untouched
reloaded_theirs = mgr.get_blocks(actor=other_user_different_org, id=theirs.id)[0]
reloaded_theirs = mgr.get_block_by_id(actor=other_user_different_org, block_id=theirs.id)
assert reloaded_theirs.value == "theirs"
# warning should mention skipping the other-org ID
@@ -3625,7 +3648,8 @@ def test_get_set_agents_for_identities(server: SyncServer, sarah_agent, charles_
server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user)
def test_attach_detach_identity_from_block(server: SyncServer, default_block, default_user):
@pytest.mark.asyncio
async def test_attach_detach_identity_from_block(server: SyncServer, default_block, default_user, event_loop):
# Create an identity
identity = server.identity_manager.create_identity(
IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user, block_ids=[default_block.id]),
@@ -3633,7 +3657,7 @@ def test_attach_detach_identity_from_block(server: SyncServer, default_block, de
)
# Check that identity has been attached
blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user)
blocks = await server.block_manager.get_blocks_async(identity_id=identity.id, actor=default_user)
assert len(blocks) == 1 and blocks[0].id == default_block.id
# Now attempt to delete the identity
@@ -3644,11 +3668,12 @@ def test_attach_detach_identity_from_block(server: SyncServer, default_block, de
assert len(identities) == 0
# Check that block has been detached too
blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user)
blocks = await server.block_manager.get_blocks_async(identity_id=identity.id, actor=default_user)
assert len(blocks) == 0
def test_get_set_blocks_for_identities(server: SyncServer, default_block, default_user):
@pytest.mark.asyncio
async def test_get_set_blocks_for_identities(server: SyncServer, default_block, default_user, event_loop):
block_manager = BlockManager()
block_with_identity = block_manager.create_or_update_block(PydanticBlock(label="persona", value="Original Content"), actor=default_user)
block_without_identity = block_manager.create_or_update_block(PydanticBlock(label="user", value="Original Content"), actor=default_user)
@@ -3660,7 +3685,7 @@ def test_get_set_blocks_for_identities(server: SyncServer, default_block, defaul
)
# Get the blocks for identity id
blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user)
blocks = await server.block_manager.get_blocks_async(identity_id=identity.id, actor=default_user)
assert len(blocks) == 2
# Check blocks are in the list
@@ -3670,7 +3695,7 @@ def test_get_set_blocks_for_identities(server: SyncServer, default_block, defaul
assert not block_without_identity.id in block_ids
# Get the blocks for identifier key
blocks = server.block_manager.get_blocks(identifier_keys=[identity.identifier_key], actor=default_user)
blocks = await server.block_manager.get_blocks_async(identifier_keys=[identity.identifier_key], actor=default_user)
assert len(blocks) == 2
# Check blocks are in the list
@@ -3684,7 +3709,7 @@ def test_get_set_blocks_for_identities(server: SyncServer, default_block, defaul
server.block_manager.delete_block(block_id=block_without_identity.id, actor=default_user)
# Get the blocks for identity id
blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user)
blocks = await server.block_manager.get_blocks_async(identity_id=identity.id, actor=default_user)
assert len(blocks) == 1
# Check only initial block in the list