feat: make get blocks async (#2162)
This commit is contained in:
@@ -15,19 +15,26 @@ router = APIRouter(prefix="/blocks", tags=["blocks"])
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Block], operation_id="list_blocks")
|
||||
def list_blocks(
|
||||
async def list_blocks(
|
||||
# query parameters
|
||||
label: Optional[str] = Query(None, description="Labels to include (e.g. human, persona)"),
|
||||
templates_only: bool = Query(False, description="Whether to include only templates"),
|
||||
name: Optional[str] = Query(None, description="Name of the block"),
|
||||
identity_id: Optional[str] = Query(None, description="Search agents by identifier id"),
|
||||
identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"),
|
||||
limit: Optional[int] = Query(50, description="Number of blocks to return"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.block_manager.get_blocks(
|
||||
actor=actor, label=label, is_template=templates_only, template_name=name, identity_id=identity_id, identifier_keys=identifier_keys
|
||||
return await server.block_manager.get_blocks_async(
|
||||
actor=actor,
|
||||
label=label,
|
||||
is_template=templates_only,
|
||||
template_name=name,
|
||||
identity_id=identity_id,
|
||||
identifier_keys=identifier_keys,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -117,6 +117,68 @@ class BlockManager:
|
||||
|
||||
return [block.to_pydantic() for block in blocks]
|
||||
|
||||
@enforce_types
|
||||
async def get_blocks_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
label: Optional[str] = None,
|
||||
is_template: Optional[bool] = None,
|
||||
template_name: Optional[str] = None,
|
||||
identity_id: Optional[str] = None,
|
||||
identifier_keys: Optional[List[str]] = None,
|
||||
limit: Optional[int] = 50,
|
||||
) -> List[PydanticBlock]:
|
||||
"""Async version of get_blocks method. Retrieve blocks based on various optional filters."""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import noload
|
||||
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Start with a basic query
|
||||
query = select(BlockModel)
|
||||
|
||||
# Explicitly avoid loading relationships
|
||||
query = query.options(noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups))
|
||||
|
||||
# Apply access control
|
||||
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||||
|
||||
# Add filters
|
||||
query = query.where(BlockModel.organization_id == actor.organization_id)
|
||||
if label:
|
||||
query = query.where(BlockModel.label == label)
|
||||
|
||||
if is_template is not None:
|
||||
query = query.where(BlockModel.is_template == is_template)
|
||||
|
||||
if template_name:
|
||||
query = query.where(BlockModel.template_name == template_name)
|
||||
|
||||
if identifier_keys:
|
||||
query = (
|
||||
query.join(BlockModel.identities)
|
||||
.filter(BlockModel.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
|
||||
.distinct(BlockModel.id)
|
||||
)
|
||||
|
||||
if identity_id:
|
||||
query = (
|
||||
query.join(BlockModel.identities)
|
||||
.filter(BlockModel.identities.property.mapper.class_.id == identity_id)
|
||||
.distinct(BlockModel.id)
|
||||
)
|
||||
|
||||
# Add limit
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
# Execute the query
|
||||
result = await session.execute(query)
|
||||
blocks = result.scalars().all()
|
||||
|
||||
return [block.to_pydantic() for block in blocks]
|
||||
|
||||
@enforce_types
|
||||
def get_block_by_id(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
|
||||
"""Retrieve a block by its name."""
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
@@ -628,6 +629,14 @@ def letta_batch_job(server: SyncServer, default_user) -> Job:
|
||||
return server.job_manager.create_job(BatchJob(user_id=default_user.id), actor=default_user)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop(request):
|
||||
"""Create an instance of the default event loop for each test case."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Basic
|
||||
# ======================================================================================================================
|
||||
@@ -1717,7 +1726,8 @@ def test_refresh_memory(server: SyncServer, default_user):
|
||||
assert len(agent.memory.blocks) == 0
|
||||
|
||||
|
||||
async def test_refresh_memory_async(server: SyncServer, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_memory_async(server: SyncServer, default_user, event_loop):
|
||||
block = server.block_manager.create_or_update_block(
|
||||
PydanticBlock(
|
||||
label="test",
|
||||
@@ -1726,13 +1736,21 @@ async def test_refresh_memory_async(server: SyncServer, default_user):
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
block_human = server.block_manager.create_or_update_block(
|
||||
PydanticBlock(
|
||||
label="human",
|
||||
value="name: caren",
|
||||
limit=1000,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
agent = server.agent_manager.create_agent(
|
||||
CreateAgent(
|
||||
name="test",
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
block_ids=[block.id],
|
||||
block_ids=[block.id, block_human.id],
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -1743,10 +1761,10 @@ async def test_refresh_memory_async(server: SyncServer, default_user):
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
assert len(agent.memory.blocks) == 1
|
||||
assert len(agent.memory.blocks) == 2
|
||||
agent = await server.agent_manager.refresh_memory_async(agent_state=agent, actor=default_user)
|
||||
assert len(agent.memory.blocks) == 1
|
||||
assert agent.memory.blocks[0].value == "test2"
|
||||
assert len(agent.memory.blocks) == 2
|
||||
assert any([block.value == "test2" for block in agent.memory.blocks])
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
@@ -2583,7 +2601,8 @@ def test_create_block(server: SyncServer, default_user):
|
||||
assert block.organization_id == default_user.organization_id
|
||||
|
||||
|
||||
def test_get_blocks(server, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_blocks(server, default_user, event_loop):
|
||||
block_manager = BlockManager()
|
||||
|
||||
# Create blocks to retrieve later
|
||||
@@ -2591,19 +2610,20 @@ def test_get_blocks(server, default_user):
|
||||
block_manager.create_or_update_block(PydanticBlock(label="persona", value="Block 2"), actor=default_user)
|
||||
|
||||
# Retrieve blocks by different filters
|
||||
all_blocks = block_manager.get_blocks(actor=default_user)
|
||||
all_blocks = await block_manager.get_blocks_async(actor=default_user)
|
||||
assert len(all_blocks) == 2
|
||||
|
||||
human_blocks = block_manager.get_blocks(actor=default_user, label="human")
|
||||
human_blocks = await block_manager.get_blocks_async(actor=default_user, label="human")
|
||||
assert len(human_blocks) == 1
|
||||
assert human_blocks[0].label == "human"
|
||||
|
||||
persona_blocks = block_manager.get_blocks(actor=default_user, label="persona")
|
||||
persona_blocks = await block_manager.get_blocks_async(actor=default_user, label="persona")
|
||||
assert len(persona_blocks) == 1
|
||||
assert persona_blocks[0].label == "persona"
|
||||
|
||||
|
||||
def test_get_blocks_comprehensive(server, default_user, other_user_different_org):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_blocks_comprehensive(server, default_user, other_user_different_org, event_loop):
|
||||
def random_label(prefix="label"):
|
||||
return f"{prefix}_{''.join(random.choices(string.ascii_lowercase, k=6))}"
|
||||
|
||||
@@ -2629,7 +2649,7 @@ def test_get_blocks_comprehensive(server, default_user, other_user_different_org
|
||||
other_user_blocks.append((label, value))
|
||||
|
||||
# Check default_user sees only their blocks
|
||||
retrieved_default_blocks = block_manager.get_blocks(actor=default_user)
|
||||
retrieved_default_blocks = await block_manager.get_blocks_async(actor=default_user)
|
||||
assert len(retrieved_default_blocks) == 10
|
||||
retrieved_labels = {b.label for b in retrieved_default_blocks}
|
||||
for label, value in default_user_blocks:
|
||||
@@ -2637,13 +2657,13 @@ def test_get_blocks_comprehensive(server, default_user, other_user_different_org
|
||||
|
||||
# Check individual filtering for default_user
|
||||
for label, value in default_user_blocks:
|
||||
filtered = block_manager.get_blocks(actor=default_user, label=label)
|
||||
filtered = await block_manager.get_blocks_async(actor=default_user, label=label)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0].label == label
|
||||
assert filtered[0].value == value
|
||||
|
||||
# Check other_user sees only their blocks
|
||||
retrieved_other_blocks = block_manager.get_blocks(actor=other_user_different_org)
|
||||
retrieved_other_blocks = await block_manager.get_blocks_async(actor=other_user_different_org)
|
||||
assert len(retrieved_other_blocks) == 3
|
||||
retrieved_labels = {b.label for b in retrieved_other_blocks}
|
||||
for label, value in other_user_blocks:
|
||||
@@ -2651,11 +2671,11 @@ def test_get_blocks_comprehensive(server, default_user, other_user_different_org
|
||||
|
||||
# Other user shouldn't see default_user's blocks
|
||||
for label, _ in default_user_blocks:
|
||||
assert block_manager.get_blocks(actor=other_user_different_org, label=label) == []
|
||||
assert (await block_manager.get_blocks_async(actor=other_user_different_org, label=label)) == []
|
||||
|
||||
# Default user shouldn't see other_user's blocks
|
||||
for label, _ in other_user_blocks:
|
||||
assert block_manager.get_blocks(actor=default_user, label=label) == []
|
||||
assert (await block_manager.get_blocks_async(actor=default_user, label=label)) == []
|
||||
|
||||
|
||||
def test_update_block(server: SyncServer, default_user):
|
||||
@@ -2667,7 +2687,7 @@ def test_update_block(server: SyncServer, default_user):
|
||||
block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user)
|
||||
|
||||
# Retrieve the updated block
|
||||
updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0]
|
||||
updated_block = block_manager.get_block_by_id(actor=default_user, block_id=block.id)
|
||||
|
||||
# Assertions to verify the update
|
||||
assert updated_block.value == "Updated Content"
|
||||
@@ -2690,7 +2710,7 @@ def test_update_block_limit(server: SyncServer, default_user):
|
||||
block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user)
|
||||
|
||||
# Retrieve the updated block and validate the update
|
||||
updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0]
|
||||
updated_block = block_manager.get_block_by_id(actor=default_user, block_id=block.id)
|
||||
|
||||
assert updated_block.value == "Updated Content" * 2000
|
||||
assert updated_block.description == "Updated description"
|
||||
@@ -2707,11 +2727,12 @@ def test_update_block_limit_does_not_reset(server: SyncServer, default_user):
|
||||
block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user)
|
||||
|
||||
# Retrieve the updated block and validate the update
|
||||
updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0]
|
||||
updated_block = block_manager.get_block_by_id(actor=default_user, block_id=block.id)
|
||||
assert updated_block.value == new_content
|
||||
|
||||
|
||||
def test_delete_block(server: SyncServer, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_block(server: SyncServer, default_user, event_loop):
|
||||
block_manager = BlockManager()
|
||||
|
||||
# Create and delete a block
|
||||
@@ -2719,11 +2740,12 @@ def test_delete_block(server: SyncServer, default_user):
|
||||
block_manager.delete_block(block_id=block.id, actor=default_user)
|
||||
|
||||
# Verify that the block was deleted
|
||||
blocks = block_manager.get_blocks(actor=default_user)
|
||||
blocks = await block_manager.get_blocks_async(actor=default_user)
|
||||
assert len(blocks) == 0
|
||||
|
||||
|
||||
def test_delete_block_detaches_from_agent(server: SyncServer, sarah_agent, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_block_detaches_from_agent(server: SyncServer, sarah_agent, default_user, event_loop):
|
||||
# Create and delete a block
|
||||
block = server.block_manager.create_or_update_block(PydanticBlock(label="human", value="Sample content"), actor=default_user)
|
||||
agent_state = server.agent_manager.attach_block(agent_id=sarah_agent.id, block_id=block.id, actor=default_user)
|
||||
@@ -2735,7 +2757,7 @@ def test_delete_block_detaches_from_agent(server: SyncServer, sarah_agent, defau
|
||||
server.block_manager.delete_block(block_id=block.id, actor=default_user)
|
||||
|
||||
# Verify that the block was deleted
|
||||
blocks = server.block_manager.get_blocks(actor=default_user)
|
||||
blocks = await server.block_manager.get_blocks_async(actor=default_user)
|
||||
assert len(blocks) == 0
|
||||
|
||||
# Check that block has been detached too
|
||||
@@ -2763,7 +2785,8 @@ def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, de
|
||||
assert charles_agent.id in agent_state_ids
|
||||
|
||||
|
||||
def test_batch_create_multiple_blocks(server: SyncServer, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_create_multiple_blocks(server: SyncServer, default_user, event_loop):
|
||||
block_manager = BlockManager()
|
||||
num_blocks = 10
|
||||
|
||||
@@ -2788,7 +2811,7 @@ def test_batch_create_multiple_blocks(server: SyncServer, default_user):
|
||||
assert blk.id is not None
|
||||
|
||||
# Confirm all created blocks exist in the full list from get_blocks
|
||||
all_labels = {blk.label for blk in block_manager.get_blocks(actor=default_user)}
|
||||
all_labels = {blk.label for blk in await block_manager.get_blocks_async(actor=default_user)}
|
||||
expected_labels = {f"batch_label_{i}" for i in range(num_blocks)}
|
||||
assert expected_labels.issubset(all_labels)
|
||||
|
||||
@@ -2819,7 +2842,7 @@ def test_bulk_update_skips_missing_and_truncates_then_returns_none(server: SyncS
|
||||
assert "truncating" in caplog.text
|
||||
|
||||
# confirm the value was truncated to `limit` characters
|
||||
reloaded = mgr.get_blocks(actor=default_user, id=b.id)[0]
|
||||
reloaded = mgr.get_block_by_id(actor=default_user, block_id=b.id)
|
||||
assert len(reloaded.value) == 5
|
||||
assert reloaded.value == long_val[:5]
|
||||
|
||||
@@ -2864,11 +2887,11 @@ def test_bulk_update_respects_org_scoping(server: SyncServer, default_user: Pyda
|
||||
mgr.bulk_update_block_values(updates, actor=default_user)
|
||||
|
||||
# mine should be updated...
|
||||
reloaded_mine = mgr.get_blocks(actor=default_user, id=mine.id)[0]
|
||||
reloaded_mine = mgr.get_block_by_id(actor=default_user, block_id=mine.id)
|
||||
assert reloaded_mine.value == "updated-mine"
|
||||
|
||||
# ...theirs should remain untouched
|
||||
reloaded_theirs = mgr.get_blocks(actor=other_user_different_org, id=theirs.id)[0]
|
||||
reloaded_theirs = mgr.get_block_by_id(actor=other_user_different_org, block_id=theirs.id)
|
||||
assert reloaded_theirs.value == "theirs"
|
||||
|
||||
# warning should mention skipping the other-org ID
|
||||
@@ -3625,7 +3648,8 @@ def test_get_set_agents_for_identities(server: SyncServer, sarah_agent, charles_
|
||||
server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user)
|
||||
|
||||
|
||||
def test_attach_detach_identity_from_block(server: SyncServer, default_block, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_detach_identity_from_block(server: SyncServer, default_block, default_user, event_loop):
|
||||
# Create an identity
|
||||
identity = server.identity_manager.create_identity(
|
||||
IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user, block_ids=[default_block.id]),
|
||||
@@ -3633,7 +3657,7 @@ def test_attach_detach_identity_from_block(server: SyncServer, default_block, de
|
||||
)
|
||||
|
||||
# Check that identity has been attached
|
||||
blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user)
|
||||
blocks = await server.block_manager.get_blocks_async(identity_id=identity.id, actor=default_user)
|
||||
assert len(blocks) == 1 and blocks[0].id == default_block.id
|
||||
|
||||
# Now attempt to delete the identity
|
||||
@@ -3644,11 +3668,12 @@ def test_attach_detach_identity_from_block(server: SyncServer, default_block, de
|
||||
assert len(identities) == 0
|
||||
|
||||
# Check that block has been detached too
|
||||
blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user)
|
||||
blocks = await server.block_manager.get_blocks_async(identity_id=identity.id, actor=default_user)
|
||||
assert len(blocks) == 0
|
||||
|
||||
|
||||
def test_get_set_blocks_for_identities(server: SyncServer, default_block, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_set_blocks_for_identities(server: SyncServer, default_block, default_user, event_loop):
|
||||
block_manager = BlockManager()
|
||||
block_with_identity = block_manager.create_or_update_block(PydanticBlock(label="persona", value="Original Content"), actor=default_user)
|
||||
block_without_identity = block_manager.create_or_update_block(PydanticBlock(label="user", value="Original Content"), actor=default_user)
|
||||
@@ -3660,7 +3685,7 @@ def test_get_set_blocks_for_identities(server: SyncServer, default_block, defaul
|
||||
)
|
||||
|
||||
# Get the blocks for identity id
|
||||
blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user)
|
||||
blocks = await server.block_manager.get_blocks_async(identity_id=identity.id, actor=default_user)
|
||||
assert len(blocks) == 2
|
||||
|
||||
# Check blocks are in the list
|
||||
@@ -3670,7 +3695,7 @@ def test_get_set_blocks_for_identities(server: SyncServer, default_block, defaul
|
||||
assert not block_without_identity.id in block_ids
|
||||
|
||||
# Get the blocks for identifier key
|
||||
blocks = server.block_manager.get_blocks(identifier_keys=[identity.identifier_key], actor=default_user)
|
||||
blocks = await server.block_manager.get_blocks_async(identifier_keys=[identity.identifier_key], actor=default_user)
|
||||
assert len(blocks) == 2
|
||||
|
||||
# Check blocks are in the list
|
||||
@@ -3684,7 +3709,7 @@ def test_get_set_blocks_for_identities(server: SyncServer, default_block, defaul
|
||||
server.block_manager.delete_block(block_id=block_without_identity.id, actor=default_user)
|
||||
|
||||
# Get the blocks for identity id
|
||||
blocks = server.block_manager.get_blocks(identity_id=identity.id, actor=default_user)
|
||||
blocks = await server.block_manager.get_blocks_async(identity_id=identity.id, actor=default_user)
|
||||
assert len(blocks) == 1
|
||||
|
||||
# Check only initial block in the list
|
||||
|
||||
Reference in New Issue
Block a user