perf: agent tags optimization (#2454)

This commit is contained in:
Andy Li
2025-05-27 15:14:07 -07:00
committed by GitHub
parent 7f8c2f9366
commit 9640dbf09b
2 changed files with 6 additions and 42 deletions

View File

@@ -61,8 +61,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query_text: Optional[str] = None,
query_embedding: Optional[List[float]] = None,
ascending: bool = True,
tags: Optional[List[str]] = None,
match_all_tags: bool = False,
actor: Optional["User"] = None,
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
access_type: AccessType = AccessType.ORGANIZATION,
@@ -86,8 +84,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query_text: Text to search for
query_embedding: Vector to search for similar embeddings
ascending: Sort direction
tags: List of tags to filter by
match_all_tags: If True, return items matching all tags. If False, match any tag.
**kwargs: Additional filters to apply
"""
if start_date and end_date and start_date > end_date:
@@ -123,8 +119,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query_text=query_text,
query_embedding=query_embedding,
ascending=ascending,
tags=tags,
match_all_tags=match_all_tags,
actor=actor,
access=access,
access_type=access_type,
@@ -162,8 +156,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query_text: Optional[str] = None,
query_embedding: Optional[List[float]] = None,
ascending: bool = True,
tags: Optional[List[str]] = None,
match_all_tags: bool = False,
actor: Optional["User"] = None,
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
access_type: AccessType = AccessType.ORGANIZATION,
@@ -189,8 +181,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query_text: Text to search for
query_embedding: Vector to search for similar embeddings
ascending: Sort direction
tags: List of tags to filter by
match_all_tags: If True, return items matching all tags. If False, match any tag.
**kwargs: Additional filters to apply
"""
if start_date and end_date and start_date > end_date:
@@ -226,8 +216,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query_text=query_text,
query_embedding=query_embedding,
ascending=ascending,
tags=tags,
match_all_tags=match_all_tags,
actor=actor,
access=access,
access_type=access_type,
@@ -263,8 +251,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query_text: Optional[str] = None,
query_embedding: Optional[List[float]] = None,
ascending: bool = True,
tags: Optional[List[str]] = None,
match_all_tags: bool = False,
actor: Optional["User"] = None,
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
access_type: AccessType = AccessType.ORGANIZATION,
@@ -286,28 +272,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
if actor:
query = cls.apply_access_predicate(query, actor, access, access_type)
# Handle tag filtering if the model has tags
if tags and hasattr(cls, "tags"):
query = select(cls)
if match_all_tags:
# Match ALL tags - use subqueries
subquery = (
select(cls.tags.property.mapper.class_.agent_id)
.where(cls.tags.property.mapper.class_.tag.in_(tags))
.group_by(cls.tags.property.mapper.class_.agent_id)
.having(func.count() == len(tags))
)
query = query.filter(cls.id.in_(subquery))
else:
# Match ANY tag - use join and filter
query = (
query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).distinct(cls.id).order_by(cls.id)
) # Deduplicate results
# select distinct primary key
query = query.distinct(cls.id).order_by(cls.id)
if identifier_keys and hasattr(cls, "identities"):
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))

View File

@@ -2,6 +2,7 @@ import datetime
from typing import List, Literal, Optional
from sqlalchemy import and_, asc, desc, func, literal, or_, select
from sqlalchemy.sql.expression import exists
from letta import system
from letta.constants import IN_CONTEXT_MEMORY_KEYWORD, STRUCTURED_OUTPUT_MODELS
@@ -504,14 +505,13 @@ def _apply_tag_filter(query, tags: Optional[List[str]], match_all_tags: bool):
Returns:
The modified query with tag filters applied.
"""
if tags:
# Build a subquery to select agent IDs that have the specified tags.
subquery = select(AgentsTags.agent_id).where(AgentsTags.tag.in_(tags)).group_by(AgentsTags.agent_id)
# If all tags must match, add a HAVING clause to ensure the count of tags equals the number provided.
if match_all_tags:
subquery = subquery.having(func.count(AgentsTags.tag) == literal(len(tags)))
# Filter the main query to include only agents present in the subquery.
query = query.where(AgentModel.id.in_(subquery))
for tag in tags:
query = query.filter(exists().where((AgentsTags.agent_id == AgentModel.id) & (AgentsTags.tag == tag)))
else:
query = query.where(exists().where((AgentsTags.agent_id == AgentModel.id) & (AgentsTags.tag.in_(tags))))
return query