diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 3df9bb5f..c629d283 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -61,8 +61,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): query_text: Optional[str] = None, query_embedding: Optional[List[float]] = None, ascending: bool = True, - tags: Optional[List[str]] = None, - match_all_tags: bool = False, actor: Optional["User"] = None, access: Optional[List[Literal["read", "write", "admin"]]] = ["read"], access_type: AccessType = AccessType.ORGANIZATION, @@ -86,8 +84,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): query_text: Text to search for query_embedding: Vector to search for similar embeddings ascending: Sort direction - tags: List of tags to filter by - match_all_tags: If True, return items matching all tags. If False, match any tag. **kwargs: Additional filters to apply """ if start_date and end_date and start_date > end_date: @@ -123,8 +119,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): query_text=query_text, query_embedding=query_embedding, ascending=ascending, - tags=tags, - match_all_tags=match_all_tags, actor=actor, access=access, access_type=access_type, @@ -162,8 +156,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): query_text: Optional[str] = None, query_embedding: Optional[List[float]] = None, ascending: bool = True, - tags: Optional[List[str]] = None, - match_all_tags: bool = False, actor: Optional["User"] = None, access: Optional[List[Literal["read", "write", "admin"]]] = ["read"], access_type: AccessType = AccessType.ORGANIZATION, @@ -189,8 +181,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): query_text: Text to search for query_embedding: Vector to search for similar embeddings ascending: Sort direction - tags: List of tags to filter by - match_all_tags: If True, return items matching all tags. If False, match any tag. **kwargs: Additional filters to apply """ if start_date and end_date and start_date > end_date: @@ -226,8 +216,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): query_text=query_text, query_embedding=query_embedding, ascending=ascending, - tags=tags, - match_all_tags=match_all_tags, actor=actor, access=access, access_type=access_type, @@ -263,8 +251,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): query_text: Optional[str] = None, query_embedding: Optional[List[float]] = None, ascending: bool = True, - tags: Optional[List[str]] = None, - match_all_tags: bool = False, actor: Optional["User"] = None, access: Optional[List[Literal["read", "write", "admin"]]] = ["read"], access_type: AccessType = AccessType.ORGANIZATION, @@ -286,28 +272,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): if actor: query = cls.apply_access_predicate(query, actor, access, access_type) - # Handle tag filtering if the model has tags - if tags and hasattr(cls, "tags"): - query = select(cls) - - if match_all_tags: - # Match ALL tags - use subqueries - subquery = ( - select(cls.tags.property.mapper.class_.agent_id) - .where(cls.tags.property.mapper.class_.tag.in_(tags)) - .group_by(cls.tags.property.mapper.class_.agent_id) - .having(func.count() == len(tags)) - ) - query = query.filter(cls.id.in_(subquery)) - else: - # Match ANY tag - use join and filter - query = ( - query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).distinct(cls.id).order_by(cls.id) - ) # Deduplicate results - - # select distinct primary key - query = query.distinct(cls.id).order_by(cls.id) - if identifier_keys and hasattr(cls, "identities"): query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys)) diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index fd4058ba..0b20e31b 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -2,6 +2,7 @@ import datetime from typing import List, Literal, Optional from sqlalchemy import and_, asc, desc, func, literal, or_, select +from sqlalchemy.sql.expression import exists from letta import system from letta.constants import IN_CONTEXT_MEMORY_KEYWORD, STRUCTURED_OUTPUT_MODELS @@ -504,14 +505,13 @@ def _apply_tag_filter(query, tags: Optional[List[str]], match_all_tags: bool): Returns: The modified query with tag filters applied. """ + if tags: - # Build a subquery to select agent IDs that have the specified tags. - subquery = select(AgentsTags.agent_id).where(AgentsTags.tag.in_(tags)).group_by(AgentsTags.agent_id) - # If all tags must match, add a HAVING clause to ensure the count of tags equals the number provided. if match_all_tags: - subquery = subquery.having(func.count(AgentsTags.tag) == literal(len(tags))) - # Filter the main query to include only agents present in the subquery. - query = query.where(AgentModel.id.in_(subquery)) + for tag in tags: + query = query.filter(exists().where((AgentsTags.agent_id == AgentModel.id) & (AgentsTags.tag == tag))) + else: + query = query.where(exists().where((AgentsTags.agent_id == AgentModel.id) & (AgentsTags.tag.in_(tags)))) return query