[Improvements] Use single query for Block Manager get_all_blocks_by_ids (#2485)
Co-authored-by: kyuds <kyuds@everspin.co.kr>
This commit is contained in:
@@ -286,7 +286,45 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
Raises:
|
||||
NoResultFound: if the object is not found
|
||||
"""
|
||||
logger.debug(f"Reading {cls.__name__} with ID: {identifier} with actor={actor}")
|
||||
# this is ok because read_multiple will check if the
|
||||
identifiers = [] if identifier is None else [identifier]
|
||||
found = cls.read_multiple(db_session, identifiers, actor, access, access_type, **kwargs)
|
||||
if len(found) == 0:
|
||||
# for backwards compatibility.
|
||||
conditions = []
|
||||
if identifier:
|
||||
conditions.append(f"id={identifier}")
|
||||
if actor:
|
||||
conditions.append(f"access level in {access} for {actor}")
|
||||
if hasattr(cls, "is_deleted"):
|
||||
conditions.append("is_deleted=False")
|
||||
raise NoResultFound(f"{cls.__name__} not found with {', '.join(conditions if conditions else ['no conditions'])}")
|
||||
return found[0]
|
||||
|
||||
@classmethod
|
||||
@handle_db_timeout
|
||||
def read_multiple(
|
||||
cls,
|
||||
db_session: "Session",
|
||||
identifiers: List[str] = [],
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
**kwargs,
|
||||
) -> List["SqlalchemyBase"]:
|
||||
"""The primary accessor for ORM record(s)
|
||||
Args:
|
||||
db_session: the database session to use when retrieving the record
|
||||
identifiers: a list of identifiers of the records to read, can be the id string or the UUID object for backwards compatibility
|
||||
actor: if specified, results will be scoped only to records the user is able to access
|
||||
access: if actor is specified, records will be filtered to the minimum permission level for the actor
|
||||
kwargs: additional arguments to pass to the read, used for more complex objects
|
||||
Returns:
|
||||
The matching object
|
||||
Raises:
|
||||
NoResultFound: if the object is not found
|
||||
"""
|
||||
logger.debug(f"Reading {cls.__name__} with ID(s): {identifiers} with actor={actor}")
|
||||
|
||||
# Start the query
|
||||
query = select(cls)
|
||||
@@ -294,9 +332,9 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
query_conditions = []
|
||||
|
||||
# If an identifier is provided, add it to the query conditions
|
||||
if identifier is not None:
|
||||
query = query.where(cls.id == identifier)
|
||||
query_conditions.append(f"id='{identifier}'")
|
||||
if len(identifiers) > 0:
|
||||
query = query.where(cls.id.in_(identifiers))
|
||||
query_conditions.append(f"id='{identifiers}'")
|
||||
|
||||
if kwargs:
|
||||
query = query.filter_by(**kwargs)
|
||||
@@ -309,12 +347,29 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
if hasattr(cls, "is_deleted"):
|
||||
query = query.where(cls.is_deleted == False)
|
||||
query_conditions.append("is_deleted=False")
|
||||
if found := db_session.execute(query).scalar():
|
||||
return found
|
||||
|
||||
results = db_session.execute(query).scalars().all()
|
||||
if results: # if empty list a.k.a. no results
|
||||
if len(identifiers) > 0:
|
||||
# find which identifiers were not found
|
||||
# only when identifier length is greater than 0 (so it was used in the actual query)
|
||||
identifier_set = set(identifiers)
|
||||
results_set = set(map(lambda obj: obj.id, results))
|
||||
|
||||
# we log a warning message if any of the queried IDs were not found.
|
||||
# TODO: should we error out instead?
|
||||
if identifier_set != results_set:
|
||||
# Construct a detailed error message based on query conditions
|
||||
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
|
||||
logger.warning(
|
||||
f"{cls.__name__} not found with {conditions_str}. Queried ids: {identifier_set}, Found ids: {results_set}"
|
||||
)
|
||||
return results
|
||||
|
||||
# Construct a detailed error message based on query conditions
|
||||
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
|
||||
raise NoResultFound(f"{cls.__name__} not found with {conditions_str}")
|
||||
logger.warning(f"{cls.__name__} not found with {conditions_str}")
|
||||
return []
|
||||
|
||||
@handle_db_timeout
|
||||
def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
||||
|
||||
@@ -106,12 +106,14 @@ class BlockManager:
|
||||
|
||||
@enforce_types
|
||||
def get_all_blocks_by_ids(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
|
||||
# TODO: We can do this much more efficiently by listing, instead of executing individual queries per block_id
|
||||
blocks = []
|
||||
for block_id in block_ids:
|
||||
block = self.get_block_by_id(block_id, actor=actor)
|
||||
blocks.append(block)
|
||||
return blocks
|
||||
"""Retrieve blocks by their names."""
|
||||
with self.session_maker() as session:
|
||||
blocks = list(
|
||||
map(lambda obj: obj.to_pydantic(), BlockModel.read_multiple(db_session=session, identifiers=block_ids, actor=actor))
|
||||
)
|
||||
# backwards compatibility. previous implementation added None for every block not found.
|
||||
blocks.extend([None for _ in range(len(block_ids) - len(blocks))])
|
||||
return blocks
|
||||
|
||||
@enforce_types
|
||||
def add_default_blocks(self, actor: PydanticUser):
|
||||
|
||||
@@ -117,7 +117,7 @@ def test_shared_blocks(client: LettaSDKClient):
|
||||
)
|
||||
assert (
|
||||
"charles" in client.agents.blocks.retrieve(agent_id=agent_state2.id, block_label="human").value.lower()
|
||||
), f"Shared block update failed {client.agents.blocks.retrieve(agent_id=agent_state2.id, block_label="human").value}"
|
||||
), f"Shared block update failed {client.agents.blocks.retrieve(agent_id=agent_state2.id, block_label='human').value}"
|
||||
|
||||
# cleanup
|
||||
client.agents.delete(agent_state1.id)
|
||||
|
||||
Reference in New Issue
Block a user