diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 3b45c6ee..61b3c9d4 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -286,7 +286,45 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): Raises: NoResultFound: if the object is not found """ - logger.debug(f"Reading {cls.__name__} with ID: {identifier} with actor={actor}") + # this is ok because read_multiple will check if the + identifiers = [] if identifier is None else [identifier] + found = cls.read_multiple(db_session, identifiers, actor, access, access_type, **kwargs) + if len(found) == 0: + # for backwards compatibility. + conditions = [] + if identifier: + conditions.append(f"id={identifier}") + if actor: + conditions.append(f"access level in {access} for {actor}") + if hasattr(cls, "is_deleted"): + conditions.append("is_deleted=False") + raise NoResultFound(f"{cls.__name__} not found with {', '.join(conditions if conditions else ['no conditions'])}") + return found[0] + + @classmethod + @handle_db_timeout + def read_multiple( + cls, + db_session: "Session", + identifiers: List[str] = [], + actor: Optional["User"] = None, + access: Optional[List[Literal["read", "write", "admin"]]] = ["read"], + access_type: AccessType = AccessType.ORGANIZATION, + **kwargs, + ) -> List["SqlalchemyBase"]: + """The primary accessor for ORM record(s) + Args: + db_session: the database session to use when retrieving the record + identifiers: a list of identifiers of the records to read, can be the id string or the UUID object for backwards compatibility + actor: if specified, results will be scoped only to records the user is able to access + access: if actor is specified, records will be filtered to the minimum permission level for the actor + kwargs: additional arguments to pass to the read, used for more complex objects + Returns: + The matching object + Raises: + NoResultFound: if the object is not found + """ + logger.debug(f"Reading {cls.__name__} with ID(s): {identifiers} with actor={actor}") # Start the query query = select(cls) @@ -294,9 +332,9 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): query_conditions = [] # If an identifier is provided, add it to the query conditions - if identifier is not None: - query = query.where(cls.id == identifier) - query_conditions.append(f"id='{identifier}'") + if len(identifiers) > 0: + query = query.where(cls.id.in_(identifiers)) + query_conditions.append(f"id='{identifiers}'") if kwargs: query = query.filter_by(**kwargs) @@ -309,12 +347,29 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): if hasattr(cls, "is_deleted"): query = query.where(cls.is_deleted == False) query_conditions.append("is_deleted=False") - if found := db_session.execute(query).scalar(): - return found + + results = db_session.execute(query).scalars().all() + if results: # if empty list a.k.a. no results + if len(identifiers) > 0: + # find which identifiers were not found + # only when identifier length is greater than 0 (so it was used in the actual query) + identifier_set = set(identifiers) + results_set = set(map(lambda obj: obj.id, results)) + + # we log a warning message if any of the queried IDs were not found. + # TODO: should we error out instead? + if identifier_set != results_set: + # Construct a detailed error message based on query conditions + conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions" + logger.warning( + f"{cls.__name__} not found with {conditions_str}. Queried ids: {identifier_set}, Found ids: {results_set}" + ) + return results # Construct a detailed error message based on query conditions conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions" - raise NoResultFound(f"{cls.__name__} not found with {conditions_str}") + logger.warning(f"{cls.__name__} not found with {conditions_str}") + return [] @handle_db_timeout def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase": diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index ff9b8507..cc09c0b1 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -106,12 +106,14 @@ class BlockManager: @enforce_types def get_all_blocks_by_ids(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]: - # TODO: We can do this much more efficiently by listing, instead of executing individual queries per block_id - blocks = [] - for block_id in block_ids: - block = self.get_block_by_id(block_id, actor=actor) - blocks.append(block) - return blocks + """Retrieve blocks by their names.""" + with self.session_maker() as session: + blocks = list( + map(lambda obj: obj.to_pydantic(), BlockModel.read_multiple(db_session=session, identifiers=block_ids, actor=actor)) + ) + # backwards compatibility. previous implementation added None for every block not found. + blocks.extend([None for _ in range(len(block_ids) - len(blocks))]) + return blocks @enforce_types def add_default_blocks(self, actor: PydanticUser): diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index a85a0e09..e21c0492 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -117,7 +117,7 @@ def test_shared_blocks(client: LettaSDKClient): ) assert ( "charles" in client.agents.blocks.retrieve(agent_id=agent_state2.id, block_label="human").value.lower() - ), f"Shared block update failed {client.agents.blocks.retrieve(agent_id=agent_state2.id, block_label="human").value}" + ), f"Shared block update failed {client.agents.blocks.retrieve(agent_id=agent_state2.id, block_label='human').value}" # cleanup client.agents.delete(agent_state1.id)