[Improvements] Use single query for Block Manager get_all_blocks_by_ids (#2485)

Co-authored-by: kyuds <kyuds@everspin.co.kr>
This commit is contained in:
Daniel Shin
2025-03-21 02:19:13 +09:00
committed by GitHub
parent 87ea2ef711
commit fd66fc248e
3 changed files with 71 additions and 14 deletions

View File

@@ -286,7 +286,45 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
Raises:
NoResultFound: if the object is not found
"""
logger.debug(f"Reading {cls.__name__} with ID: {identifier} with actor={actor}")
# this is ok because read_multiple will check if the
identifiers = [] if identifier is None else [identifier]
found = cls.read_multiple(db_session, identifiers, actor, access, access_type, **kwargs)
if len(found) == 0:
# for backwards compatibility.
conditions = []
if identifier:
conditions.append(f"id={identifier}")
if actor:
conditions.append(f"access level in {access} for {actor}")
if hasattr(cls, "is_deleted"):
conditions.append("is_deleted=False")
raise NoResultFound(f"{cls.__name__} not found with {', '.join(conditions if conditions else ['no conditions'])}")
return found[0]
@classmethod
@handle_db_timeout
def read_multiple(
cls,
db_session: "Session",
identifiers: List[str] = [],
actor: Optional["User"] = None,
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
access_type: AccessType = AccessType.ORGANIZATION,
**kwargs,
) -> List["SqlalchemyBase"]:
"""The primary accessor for ORM record(s)
Args:
db_session: the database session to use when retrieving the record
identifiers: a list of identifiers of the records to read, can be the id string or the UUID object for backwards compatibility
actor: if specified, results will be scoped only to records the user is able to access
access: if actor is specified, records will be filtered to the minimum permission level for the actor
kwargs: additional arguments to pass to the read, used for more complex objects
Returns:
The matching object
Raises:
NoResultFound: if the object is not found
"""
logger.debug(f"Reading {cls.__name__} with ID(s): {identifiers} with actor={actor}")
# Start the query
query = select(cls)
@@ -294,9 +332,9 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query_conditions = []
# If an identifier is provided, add it to the query conditions
if identifier is not None:
query = query.where(cls.id == identifier)
query_conditions.append(f"id='{identifier}'")
if len(identifiers) > 0:
query = query.where(cls.id.in_(identifiers))
query_conditions.append(f"id='{identifiers}'")
if kwargs:
query = query.filter_by(**kwargs)
@@ -309,12 +347,29 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
if hasattr(cls, "is_deleted"):
query = query.where(cls.is_deleted == False)
query_conditions.append("is_deleted=False")
if found := db_session.execute(query).scalar():
return found
results = db_session.execute(query).scalars().all()
if results: # if empty list a.k.a. no results
if len(identifiers) > 0:
# find which identifiers were not found
# only when identifier length is greater than 0 (so it was used in the actual query)
identifier_set = set(identifiers)
results_set = set(map(lambda obj: obj.id, results))
# we log a warning message if any of the queried IDs were not found.
# TODO: should we error out instead?
if identifier_set != results_set:
# Construct a detailed error message based on query conditions
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
logger.warning(
f"{cls.__name__} not found with {conditions_str}. Queried ids: {identifier_set}, Found ids: {results_set}"
)
return results
# Construct a detailed error message based on query conditions
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
raise NoResultFound(f"{cls.__name__} not found with {conditions_str}")
logger.warning(f"{cls.__name__} not found with {conditions_str}")
return []
@handle_db_timeout
def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":

View File

@@ -106,12 +106,14 @@ class BlockManager:
@enforce_types
def get_all_blocks_by_ids(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
# TODO: We can do this much more efficiently by listing, instead of executing individual queries per block_id
blocks = []
for block_id in block_ids:
block = self.get_block_by_id(block_id, actor=actor)
blocks.append(block)
return blocks
"""Retrieve blocks by their names."""
with self.session_maker() as session:
blocks = list(
map(lambda obj: obj.to_pydantic(), BlockModel.read_multiple(db_session=session, identifiers=block_ids, actor=actor))
)
# backwards compatibility. previous implementation added None for every block not found.
blocks.extend([None for _ in range(len(block_ids) - len(blocks))])
return blocks
@enforce_types
def add_default_blocks(self, actor: PydanticUser):

View File

@@ -117,7 +117,7 @@ def test_shared_blocks(client: LettaSDKClient):
)
assert (
"charles" in client.agents.blocks.retrieve(agent_id=agent_state2.id, block_label="human").value.lower()
), f"Shared block update failed {client.agents.blocks.retrieve(agent_id=agent_state2.id, block_label="human").value}"
), f"Shared block update failed {client.agents.blocks.retrieve(agent_id=agent_state2.id, block_label='human').value}"
# cleanup
client.agents.delete(agent_state1.id)