perf: agent tags optimization (#2454)
This commit is contained in:
@@ -61,8 +61,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
query_text: Optional[str] = None,
|
||||
query_embedding: Optional[List[float]] = None,
|
||||
ascending: bool = True,
|
||||
tags: Optional[List[str]] = None,
|
||||
match_all_tags: bool = False,
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
@@ -86,8 +84,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
query_text: Text to search for
|
||||
query_embedding: Vector to search for similar embeddings
|
||||
ascending: Sort direction
|
||||
tags: List of tags to filter by
|
||||
match_all_tags: If True, return items matching all tags. If False, match any tag.
|
||||
**kwargs: Additional filters to apply
|
||||
"""
|
||||
if start_date and end_date and start_date > end_date:
|
||||
@@ -123,8 +119,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
query_text=query_text,
|
||||
query_embedding=query_embedding,
|
||||
ascending=ascending,
|
||||
tags=tags,
|
||||
match_all_tags=match_all_tags,
|
||||
actor=actor,
|
||||
access=access,
|
||||
access_type=access_type,
|
||||
@@ -162,8 +156,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
query_text: Optional[str] = None,
|
||||
query_embedding: Optional[List[float]] = None,
|
||||
ascending: bool = True,
|
||||
tags: Optional[List[str]] = None,
|
||||
match_all_tags: bool = False,
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
@@ -189,8 +181,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
query_text: Text to search for
|
||||
query_embedding: Vector to search for similar embeddings
|
||||
ascending: Sort direction
|
||||
tags: List of tags to filter by
|
||||
match_all_tags: If True, return items matching all tags. If False, match any tag.
|
||||
**kwargs: Additional filters to apply
|
||||
"""
|
||||
if start_date and end_date and start_date > end_date:
|
||||
@@ -226,8 +216,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
query_text=query_text,
|
||||
query_embedding=query_embedding,
|
||||
ascending=ascending,
|
||||
tags=tags,
|
||||
match_all_tags=match_all_tags,
|
||||
actor=actor,
|
||||
access=access,
|
||||
access_type=access_type,
|
||||
@@ -263,8 +251,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
query_text: Optional[str] = None,
|
||||
query_embedding: Optional[List[float]] = None,
|
||||
ascending: bool = True,
|
||||
tags: Optional[List[str]] = None,
|
||||
match_all_tags: bool = False,
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
@@ -286,28 +272,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
if actor:
|
||||
query = cls.apply_access_predicate(query, actor, access, access_type)
|
||||
|
||||
# Handle tag filtering if the model has tags
|
||||
if tags and hasattr(cls, "tags"):
|
||||
query = select(cls)
|
||||
|
||||
if match_all_tags:
|
||||
# Match ALL tags - use subqueries
|
||||
subquery = (
|
||||
select(cls.tags.property.mapper.class_.agent_id)
|
||||
.where(cls.tags.property.mapper.class_.tag.in_(tags))
|
||||
.group_by(cls.tags.property.mapper.class_.agent_id)
|
||||
.having(func.count() == len(tags))
|
||||
)
|
||||
query = query.filter(cls.id.in_(subquery))
|
||||
else:
|
||||
# Match ANY tag - use join and filter
|
||||
query = (
|
||||
query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).distinct(cls.id).order_by(cls.id)
|
||||
) # Deduplicate results
|
||||
|
||||
# select distinct primary key
|
||||
query = query.distinct(cls.id).order_by(cls.id)
|
||||
|
||||
if identifier_keys and hasattr(cls, "identities"):
|
||||
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import datetime
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from sqlalchemy import and_, asc, desc, func, literal, or_, select
|
||||
from sqlalchemy.sql.expression import exists
|
||||
|
||||
from letta import system
|
||||
from letta.constants import IN_CONTEXT_MEMORY_KEYWORD, STRUCTURED_OUTPUT_MODELS
|
||||
@@ -504,14 +505,13 @@ def _apply_tag_filter(query, tags: Optional[List[str]], match_all_tags: bool):
|
||||
Returns:
|
||||
The modified query with tag filters applied.
|
||||
"""
|
||||
|
||||
if tags:
|
||||
# Build a subquery to select agent IDs that have the specified tags.
|
||||
subquery = select(AgentsTags.agent_id).where(AgentsTags.tag.in_(tags)).group_by(AgentsTags.agent_id)
|
||||
# If all tags must match, add a HAVING clause to ensure the count of tags equals the number provided.
|
||||
if match_all_tags:
|
||||
subquery = subquery.having(func.count(AgentsTags.tag) == literal(len(tags)))
|
||||
# Filter the main query to include only agents present in the subquery.
|
||||
query = query.where(AgentModel.id.in_(subquery))
|
||||
for tag in tags:
|
||||
query = query.filter(exists().where((AgentsTags.agent_id == AgentModel.id) & (AgentsTags.tag == tag)))
|
||||
else:
|
||||
query = query.where(exists().where((AgentsTags.agent_id == AgentModel.id) & (AgentsTags.tag.in_(tags))))
|
||||
return query
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user