Files
letta-server/letta/services/helpers/agent_manager_helper.py
Christina Tong c8ae02a1fb feat(core): sort agents by updated_at [LET-7771] (#9730)
feat(core): sort agents by last_updated_at
2026-03-03 18:34:02 -08:00

1320 lines
53 KiB
Python

import uuid
from datetime import datetime
from typing import List, Literal, Optional, Set
from letta.log import get_logger
from letta.schemas.letta_stop_reason import StopReasonType
logger = get_logger(__name__)
import numpy as np
from sqlalchemy import Select, and_, asc, desc, func, literal, nulls_last, or_, select, union_all
from sqlalchemy.orm import noload
from sqlalchemy.sql.expression import exists
from letta import system
from letta.constants import (
BASE_MEMORY_TOOLS,
BASE_MEMORY_TOOLS_V2,
BASE_TOOLS,
DEPRECATED_LETTA_TOOLS,
IN_CONTEXT_MEMORY_KEYWORD,
LOCAL_ONLY_MULTI_AGENT_TOOLS,
MAX_EMBEDDING_DIM,
MULTI_AGENT_TOOLS,
STRUCTURED_OUTPUT_MODELS,
)
from letta.errors import LettaAgentNotFoundError
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_local_time
from letta.llm_api.llm_client import LLMClient
from letta.orm.agent import Agent as AgentModel
from letta.orm.agents_tags import AgentsTags
from letta.orm.archives_agents import ArchivesAgents
from letta.orm.errors import NoResultFound
from letta.orm.identity import Identity
from letta.orm.passage import ArchivalPassage, SourcePassage
from letta.orm.sources_agents import SourcesAgents
from letta.otel.tracing import trace_method
from letta.prompts import gpt_system
from letta.prompts.prompt_generator import PromptGenerator
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import AgentType, MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.memory import Memory
from letta.schemas.message import Message, MessageCreate, ToolReturn
from letta.schemas.tool_rule import ToolRule
from letta.schemas.user import User
from letta.settings import DatabaseChoice, settings
from letta.system import get_initial_boot_messages, get_login_event, package_function_response
# Static methods
@trace_method
def _process_relationship(
session, agent: "AgentModel", relationship_name: str, model_class, item_ids: List[str], allow_partial=False, replace=True
):
"""
Generalized function to handle relationships like tools, sources, and blocks using item IDs.
Args:
session: The database session.
agent: The AgentModel instance.
relationship_name: The name of the relationship attribute (e.g., 'tools', 'sources').
model_class: The ORM class corresponding to the related items.
item_ids: List of IDs to set or update.
allow_partial: If True, allows missing items without raising errors.
replace: If True, replaces the entire relationship; otherwise, extends it.
Raises:
ValueError: If `allow_partial` is False and some IDs are missing.
"""
current_relationship = getattr(agent, relationship_name, [])
if not item_ids:
if replace:
setattr(agent, relationship_name, [])
return
# Retrieve models for the provided IDs
found_items = session.query(model_class).filter(model_class.id.in_(item_ids)).all()
# Validate all items are found if allow_partial is False
if not allow_partial and len(found_items) != len(item_ids):
missing = set(item_ids) - {item.id for item in found_items}
raise NoResultFound(f"Items not found in {relationship_name}: {missing}")
if replace:
# Replace the relationship
setattr(agent, relationship_name, found_items)
else:
# Extend the relationship (only add new items)
current_ids = {item.id for item in current_relationship}
new_items = [item for item in found_items if item.id not in current_ids]
current_relationship.extend(new_items)
@trace_method
async def _process_relationship_async(
session, agent: "AgentModel", relationship_name: str, model_class, item_ids: List[str], allow_partial=False, replace=True
):
"""
Generalized function to handle relationships like tools, sources, and blocks using item IDs.
Args:
session: The database session.
agent: The AgentModel instance.
relationship_name: The name of the relationship attribute (e.g., 'tools', 'sources').
model_class: The ORM class corresponding to the related items.
item_ids: List of IDs to set or update.
allow_partial: If True, allows missing items without raising errors.
replace: If True, replaces the entire relationship; otherwise, extends it.
Raises:
ValueError: If `allow_partial` is False and some IDs are missing.
"""
current_relationship = getattr(agent, relationship_name, [])
if not item_ids:
if replace:
setattr(agent, relationship_name, [])
return
# Retrieve models for the provided IDs
result = await session.execute(select(model_class).where(model_class.id.in_(item_ids)))
found_items = result.scalars().all()
# Validate all items are found if allow_partial is False
if not allow_partial and len(found_items) != len(item_ids):
missing = set(item_ids) - {item.id for item in found_items}
raise NoResultFound(f"Items not found in {relationship_name}: {missing}")
if replace:
# Replace the relationship
setattr(agent, relationship_name, found_items)
else:
# Extend the relationship (only add new items)
current_ids = {item.id for item in current_relationship}
new_items = [item for item in found_items if item.id not in current_ids]
current_relationship.extend(new_items)
def _process_tags(agent: "AgentModel", tags: List[str], replace=True):
"""
Handles tags for an agent.
Args:
agent: The AgentModel instance.
tags: List of tags to set or update.
replace: If True, replaces all tags; otherwise, extends them.
"""
if not tags:
if replace:
agent.tags = []
return
# Ensure tags are unique and prepare for replacement/extension
new_tags = {AgentsTags(agent_id=agent.id, tag=tag) for tag in set(tags)}
if replace:
agent.tags = list(new_tags)
else:
existing_tags = {t.tag for t in agent.tags}
agent.tags.extend([tag for tag in new_tags if tag.tag not in existing_tags])
def derive_system_message(agent_type: AgentType, enable_sleeptime: Optional[bool] = None, system: Optional[str] = None) -> str:
"""
Derive the appropriate system message based on agent type and configuration.
This function determines which system prompt template to use based on the
agent's type and whether sleeptime functionality is enabled. If a custom
system message is provided, it returns that instead.
Args:
agent_type: The type of agent (e.g., memgpt_agent, sleeptime_agent, react_agent)
enable_sleeptime: Whether sleeptime tools should be available (affects prompt choice)
system: Optional custom system message to use instead of defaults
Returns:
The system message string appropriate for the agent configuration
Raises:
ValueError: If an invalid or unsupported agent type is provided
"""
if system is None:
# TODO: don't hardcode
if agent_type == AgentType.voice_convo_agent:
system = gpt_system.get_system_text("voice_chat")
elif agent_type == AgentType.voice_sleeptime_agent:
system = gpt_system.get_system_text("voice_sleeptime")
# MemGPT v1, both w/ and w/o sleeptime
elif agent_type == AgentType.memgpt_agent and not enable_sleeptime:
system = gpt_system.get_system_text("memgpt_v2_chat")
elif agent_type == AgentType.memgpt_agent and enable_sleeptime:
# NOTE: same as the chat one, since the chat one says that you "may" have the tools
system = gpt_system.get_system_text("memgpt_v2_chat")
# MemGPT v2, both w/ and w/o sleeptime
elif agent_type == AgentType.memgpt_v2_agent and not enable_sleeptime:
system = gpt_system.get_system_text("memgpt_v2_chat")
elif agent_type == AgentType.memgpt_v2_agent and enable_sleeptime:
# NOTE: same as the chat one, since the chat one says that you "may" have the tools
system = gpt_system.get_system_text("memgpt_v2_chat")
# Sleeptime
elif agent_type == AgentType.sleeptime_agent:
# v2 drops references to specific blocks, and instead relies on the block description injections
system = gpt_system.get_system_text("sleeptime_v2")
# ReAct
elif agent_type == AgentType.react_agent:
system = gpt_system.get_system_text("react")
# Letta v1
elif agent_type == AgentType.letta_v1_agent:
system = gpt_system.get_system_text("letta_v1")
# Workflow
elif agent_type == AgentType.workflow_agent:
system = gpt_system.get_system_text("workflow")
else:
raise ValueError(f"Invalid agent type: {agent_type}")
return system
class PreserveMapping(dict):
"""Used to preserve (do not modify) undefined variables in the system prompt"""
def __missing__(self, key):
return "{" + key + "}"
def safe_format(template: str, variables: dict) -> str:
"""
Safely formats a template string, preserving empty {} and {unknown_vars}
while substituting known variables.
If we simply use {} in format_map, it'll be treated as a positional field
"""
# First escape any empty {} by doubling them
escaped = template.replace("{}", "{{}}")
# Now use format_map with our custom mapping
return escaped.format_map(PreserveMapping(variables))
@trace_method
def compile_system_message(
system_prompt: str,
in_context_memory: Memory,
in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory?
timezone: str,
user_defined_variables: Optional[dict] = None,
append_icm_if_missing: bool = True,
template_format: Literal["f-string", "mustache"] = "f-string",
previous_message_count: int = 0,
archival_memory_size: int | None = 0,
tool_rules_solver: Optional[ToolRulesSolver] = None,
sources: Optional[List] = None,
max_files_open: Optional[int] = None,
llm_config: Optional[object] = None,
) -> str:
"""Prepare the final/full system message that will be fed into the LLM API
The base system message may be templated, in which case we need to render the variables.
The following are reserved variables:
- CORE_MEMORY: the in-context memory of the LLM
"""
# Add tool rule constraints if available
tool_constraint_block = None
if tool_rules_solver is not None:
tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts()
if user_defined_variables is not None:
# TODO eventually support the user defining their own variables to inject
raise NotImplementedError
else:
variables = {}
# Add the protected memory variable
if IN_CONTEXT_MEMORY_KEYWORD in variables:
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
else:
# TODO should this all put into the memory.__repr__ function?
memory_metadata_string = PromptGenerator.compile_memory_metadata_block(
memory_edit_timestamp=in_context_memory_last_edit,
previous_message_count=previous_message_count,
archival_memory_size=archival_memory_size or 0,
timezone=timezone,
)
memory_with_sources = in_context_memory.compile(
tool_usage_rules=tool_constraint_block, sources=sources, max_files_open=max_files_open, llm_config=llm_config
)
full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string
# Add to the variables list to inject
variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string
if template_format == "f-string":
memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}"
# Catch the special case where the system prompt is unformatted
if append_icm_if_missing:
if memory_variable_string not in system_prompt:
# In this case, append it to the end to make sure memory is still injected
# logger.warning(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead")
system_prompt += "\n\n" + memory_variable_string
# render the variables using the built-in templater
try:
if user_defined_variables:
formatted_prompt = safe_format(system_prompt, variables)
else:
formatted_prompt = system_prompt.replace(memory_variable_string, full_memory_string)
except Exception as e:
raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}")
else:
# TODO support for mustache
raise NotImplementedError(template_format)
return formatted_prompt
@trace_method
def initialize_message_sequence(
agent_state: AgentState,
memory_edit_timestamp: Optional[datetime] = None,
include_initial_boot_message: bool = True,
previous_message_count: int = 0,
archival_memory_size: int = 0,
) -> List[dict]:
if memory_edit_timestamp is None:
memory_edit_timestamp = get_local_time()
full_system_message = compile_system_message(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
timezone=agent_state.timezone,
user_defined_variables=None,
append_icm_if_missing=True,
previous_message_count=previous_message_count,
archival_memory_size=archival_memory_size,
sources=agent_state.sources,
max_files_open=agent_state.max_files_open,
)
first_user_message = get_login_event(agent_state.timezone) # event letting Letta know the user just logged in
if include_initial_boot_message:
llm_config = agent_state.llm_config
uuid_str = str(uuid.uuid4())
# Some LMStudio models (e.g. ministral) require the tool call ID to be 9 alphanumeric characters
tool_call_id = uuid_str[:9] if llm_config.provider_name == "lmstudio_openai" else uuid_str
if agent_state.agent_type == AgentType.sleeptime_agent:
initial_boot_messages = []
elif llm_config.model is not None and "gpt-3.5" in llm_config.model:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35", agent_state.timezone, tool_call_id)
else:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message", agent_state.timezone, tool_call_id)
# Some LMStudio models (e.g. meta-llama-3.1) require the user message before any tool calls
if llm_config.provider_name == "lmstudio_openai":
messages = [
{"role": "system", "content": full_system_message},
{"role": "user", "content": first_user_message},
*initial_boot_messages,
]
else:
messages = [
{"role": "system", "content": full_system_message},
*initial_boot_messages,
{"role": "user", "content": first_user_message},
]
else:
messages = [
{"role": "system", "content": full_system_message},
{"role": "user", "content": first_user_message},
]
return messages
@trace_method
async def initialize_message_sequence_async(
agent_state: AgentState,
memory_edit_timestamp: Optional[datetime] = None,
include_initial_boot_message: bool = True,
previous_message_count: int = 0,
archival_memory_size: int = 0,
) -> List[dict]:
if memory_edit_timestamp is None:
memory_edit_timestamp = get_local_time()
full_system_message = await PromptGenerator.compile_system_message_async(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
timezone=agent_state.timezone,
user_defined_variables=None,
append_icm_if_missing=True,
previous_message_count=previous_message_count,
archival_memory_size=archival_memory_size,
sources=agent_state.sources,
max_files_open=agent_state.max_files_open,
)
first_user_message = get_login_event(agent_state.timezone) # event letting Letta know the user just logged in
if agent_state.agent_type == AgentType.letta_v1_agent:
return [{"role": "system", "content": full_system_message}]
if include_initial_boot_message:
llm_config = agent_state.llm_config
uuid_str = str(uuid.uuid4())
# Some LMStudio models (e.g. ministral) require the tool call ID to be 9 alphanumeric characters
tool_call_id = uuid_str[:9] if llm_config.provider_name == "lmstudio_openai" else uuid_str
if agent_state.agent_type == AgentType.sleeptime_agent or agent_state.agent_type == AgentType.letta_v1_agent:
initial_boot_messages = []
elif llm_config.model is not None and "gpt-3.5" in llm_config.model:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35", agent_state.timezone, tool_call_id)
else:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message", agent_state.timezone, tool_call_id)
# Some LMStudio models (e.g. meta-llama-3.1) require the user message before any tool calls
if llm_config.provider_name == "lmstudio_openai":
messages = [
{"role": "system", "content": full_system_message},
{"role": "user", "content": first_user_message},
*initial_boot_messages,
]
else:
messages = [
{"role": "system", "content": full_system_message},
*initial_boot_messages,
{"role": "user", "content": first_user_message},
]
else:
messages = [
{"role": "system", "content": full_system_message},
{"role": "user", "content": first_user_message},
]
return messages
def package_initial_message_sequence(
agent_id: str, initial_message_sequence: List[MessageCreate], model: str, timezone: str, actor: User
) -> List[Message]:
# create the agent object
init_messages = []
for message_create in initial_message_sequence:
if message_create.role == MessageRole.user:
packed_message = system.package_user_message(
user_message=message_create.content,
timezone=timezone,
)
init_messages.append(
Message(
role=message_create.role,
content=[TextContent(text=packed_message)],
name=message_create.name,
agent_id=agent_id,
model=model,
)
)
elif message_create.role == MessageRole.system:
packed_message = system.package_system_message(
system_message=message_create.content,
timezone=timezone,
)
init_messages.append(
Message(
role=message_create.role,
content=[TextContent(text=packed_message)],
name=message_create.name,
agent_id=agent_id,
model=model,
)
)
elif message_create.role == MessageRole.assistant:
# append tool call to send_message
import json
import uuid
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall as OpenAIToolCall,
Function as OpenAIFunction,
)
from letta.constants import DEFAULT_MESSAGE_TOOL
tool_call_id = str(uuid.uuid4())
init_messages.append(
Message(
role=MessageRole.assistant,
content=None,
name=message_create.name,
agent_id=agent_id,
model=model,
tool_calls=[
OpenAIToolCall(
id=tool_call_id,
type="function",
function=OpenAIFunction(name=DEFAULT_MESSAGE_TOOL, arguments=json.dumps({"message": message_create.content})),
)
],
)
)
# add tool return
function_response = package_function_response(True, "None", timezone)
init_messages.append(
Message(
role=MessageRole.tool,
content=[TextContent(text=function_response)],
name=message_create.name,
agent_id=agent_id,
model=model,
tool_call_id=tool_call_id,
tool_returns=[
ToolReturn(
tool_call_id=tool_call_id,
status="success",
func_response=function_response,
)
],
)
)
else:
# TODO: add tool call and tool return
raise ValueError(f"Invalid message role: {message_create.role}")
return init_messages
def check_supports_structured_output(model: str, tool_rules: List[ToolRule]) -> bool:
if model not in STRUCTURED_OUTPUT_MODELS:
if len(ToolRulesSolver(tool_rules=tool_rules).init_tool_rules) > 1:
raise ValueError("Multiple initial tools are not supported for non-structured models. Please use only one initial tool rule.")
return False
else:
return True
def _cursor_filter(sort_col, id_col, ref_sort_col, ref_id, forward: bool, nulls_last: bool = False):
"""
Returns a SQLAlchemy filter expression for cursor-based pagination.
If `forward` is True, returns records after the reference.
If `forward` is False, returns records before the reference.
Handles NULL values in the sort column properly when nulls_last is True.
"""
if not nulls_last:
# Simple case: no special NULL handling needed
if forward:
return or_(
sort_col > ref_sort_col,
and_(sort_col == ref_sort_col, id_col > ref_id),
)
else:
return or_(
sort_col < ref_sort_col,
and_(sort_col == ref_sort_col, id_col < ref_id),
)
# Handle nulls_last case
# TODO: add tests to check if this works for ascending order but nulls are stil last?
if ref_sort_col is None:
# Reference cursor is at a NULL value
if forward:
# Moving forward (e.g. previous) from NULL: either other NULLs with greater IDs or non-NULLs
return or_(and_(sort_col.is_(None), id_col > ref_id), sort_col.isnot(None))
else:
# Moving backward (e.g. next) from NULL: NULLs with smaller IDs
return and_(sort_col.is_(None), id_col < ref_id)
else:
# Reference cursor is at a non-NULL value
if forward:
# Moving forward (e.g. previous) from non-NULL: only greater non-NULL values
# (NULLs are at the end, so we don't include them when moving forward from non-NULL)
return and_(sort_col.isnot(None), or_(sort_col > ref_sort_col, and_(sort_col == ref_sort_col, id_col > ref_id)))
else:
# Moving backward (e.g. next) from non-NULL: smaller non-NULL values or NULLs
return or_(sort_col.is_(None), or_(sort_col < ref_sort_col, and_(sort_col == ref_sort_col, id_col < ref_id)))
def _apply_pagination(
query, before: Optional[str], after: Optional[str], session, ascending: bool = True, sort_by: str = "created_at"
) -> any:
# Determine the sort column
if sort_by == "last_run_completion":
sort_column = AgentModel.last_run_completion
sort_nulls_last = True # TODO: handle this as a query param eventually
elif sort_by == "updated_at":
sort_column = AgentModel.updated_at
sort_nulls_last = False
else:
sort_column = AgentModel.created_at
sort_nulls_last = False
if after:
result = session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == after)).first()
if result:
after_sort_value, after_id = result
query = query.where(
_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending, nulls_last=sort_nulls_last)
)
if before:
result = session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == before)).first()
if result:
before_sort_value, before_id = result
query = query.where(
_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending, nulls_last=sort_nulls_last)
)
# Apply ordering
order_fn = asc if ascending else desc
query = query.order_by(nulls_last(order_fn(sort_column)) if sort_nulls_last else order_fn(sort_column), order_fn(AgentModel.id))
return query
async def _apply_pagination_async(
query, before: Optional[str], after: Optional[str], session, ascending: bool = True, sort_by: str = "created_at"
) -> any:
# Determine the sort column
if sort_by == "last_run_completion":
sort_column = AgentModel.last_run_completion
sort_nulls_last = True # TODO: handle this as a query param eventually
elif sort_by == "updated_at":
sort_column = AgentModel.updated_at
sort_nulls_last = False
else:
sort_column = AgentModel.created_at
sort_nulls_last = False
if after:
result = (await session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == after))).first()
if result:
after_sort_value, after_id = result
# SQLite does not support as granular timestamping, so we need to round the timestamp
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(after_sort_value, datetime):
after_sort_value = after_sort_value.strftime("%Y-%m-%d %H:%M:%S")
query = query.where(
_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending, nulls_last=sort_nulls_last)
)
if before:
result = (await session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == before))).first()
if result:
before_sort_value, before_id = result
# SQLite does not support as granular timestamping, so we need to round the timestamp
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(before_sort_value, datetime):
before_sort_value = before_sort_value.strftime("%Y-%m-%d %H:%M:%S")
query = query.where(
_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending, nulls_last=sort_nulls_last)
)
# Apply ordering
order_fn = asc if ascending else desc
query = query.order_by(nulls_last(order_fn(sort_column)) if sort_nulls_last else order_fn(sort_column), order_fn(AgentModel.id))
return query
def _apply_tag_filter(query, tags: Optional[List[str]], match_all_tags: bool):
"""
Apply tag-based filtering to the agent query.
This helper function creates a subquery that groups agent IDs by their tags.
If `match_all_tags` is True, it filters agents that have all of the specified tags.
Otherwise, it filters agents that have any of the tags.
Args:
query: The SQLAlchemy query object to be modified.
tags (Optional[List[str]]): A list of tags to filter agents.
match_all_tags (bool): If True, only return agents that match all provided tags.
Returns:
The modified query with tag filters applied.
"""
if tags:
if match_all_tags:
for tag in tags:
query = query.filter(exists().where((AgentsTags.agent_id == AgentModel.id) & (AgentsTags.tag == tag)))
else:
query = query.where(exists().where((AgentsTags.agent_id == AgentModel.id) & (AgentsTags.tag.in_(tags))))
return query
def _apply_identity_filters(query, identity_id: Optional[str], identifier_keys: Optional[List[str]]):
"""
Apply identity-related filters to the agent query.
This helper function joins the identities relationship and filters the agents based on
a specific identity ID and/or a list of identifier keys.
Args:
query: The SQLAlchemy query object to be modified.
identity_id (Optional[str]): The identity ID to filter by.
identifier_keys (Optional[List[str]]): A list of identifier keys to filter agents.
Returns:
The modified query with identity filters applied.
"""
# Join the identities relationship and filter by a specific identity ID.
if identity_id:
query = query.join(AgentModel.identities).where(Identity.id == identity_id)
# Join the identities relationship and filter by a set of identifier keys.
if identifier_keys:
query = query.join(AgentModel.identities).where(Identity.identifier_key.in_(identifier_keys))
return query
def _apply_filters(
query,
name: Optional[str],
query_text: Optional[str],
project_id: Optional[str],
template_id: Optional[str],
base_template_id: Optional[str],
last_stop_reason: Optional[StopReasonType] = None,
):
"""
Apply basic filtering criteria to the agent query.
This helper function adds WHERE clauses based on provided parameters such as
exact name, partial name match (using ILIKE), project ID, template ID, base template ID,
and last stop reason.
Args:
query: The SQLAlchemy query object to be modified.
name (Optional[str]): Exact name to filter by.
query_text (Optional[str]): Partial text to search in the agent's name (case-insensitive).
project_id (Optional[str]): Filter for agents belonging to a specific project.
template_id (Optional[str]): Filter for agents using a specific template.
base_template_id (Optional[str]): Filter for agents using a specific base template.
last_stop_reason (Optional[StopReasonType]): Filter for agents by their last stop reason (e.g., 'requires_approval', 'error').
Returns:
The modified query with the applied filters.
"""
# Filter by exact agent name if provided.
if name:
query = query.where(AgentModel.name == name)
# Apply a case-insensitive partial match for the agent's name.
if query_text:
if settings.database_engine is DatabaseChoice.POSTGRES:
# PostgreSQL: Use ILIKE for case-insensitive search
query = query.where(AgentModel.name.ilike(f"%{query_text}%"))
else:
# SQLite: Use LIKE with LOWER for case-insensitive search
query = query.where(func.lower(AgentModel.name).like(func.lower(f"%{query_text}%")))
# Filter agents by project ID.
if project_id:
query = query.where(AgentModel.project_id == project_id)
# Filter agents by template ID.
if template_id:
query = query.where(AgentModel.template_id == template_id)
# Filter agents by base template ID.
if base_template_id:
query = query.where(AgentModel.base_template_id == base_template_id)
# Filter agents by last stop reason.
if last_stop_reason:
query = query.where(AgentModel.last_stop_reason == last_stop_reason)
return query
def _apply_relationship_filters(
query,
include_relationships: Optional[List[str]] = None,
include: Optional[List[str]] = None,
):
# legacy include_relationships
if include_relationships is None and not include:
return query
column_names = get_column_names_from_includes_params(include_relationships, include)
relationships = [
"core_memory",
"file_agents",
"identities",
"tool_exec_environment_variables",
"tools",
"sources",
"tags",
"multi_agent_group",
]
for rel in relationships:
if rel not in column_names:
query = query.options(noload(getattr(AgentModel, rel)))
return query
def get_column_names_from_includes_params(
include_relationships: Optional[List[str]] = None, includes: Optional[List[str]] = None
) -> Set[str]:
include_mapping = {
"agent.blocks": ["core_memory", "file_agents", "tags"],
"agent.identities": ["identities"],
"agent.managed_group": ["multi_agent_group"],
"agent.secrets": ["tool_exec_environment_variables"],
"agent.sources": ["sources"],
"agent.tags": ["tags"],
"agent.tools": ["tools"],
# legacy
"memory": ["core_memory", "file_agents", "tags"],
"identity_ids": ["identities"],
"multi_agent_group": ["multi_agent_group"],
"tool_exec_environment_variables": ["tool_exec_environment_variables"],
"secrets": ["tool_exec_environment_variables"],
"sources": ["sources"],
"tags": ["tags"],
"tools": ["tools"],
}
column_names = set()
if includes:
for include in includes:
column_names.update(include_mapping.get(include, []))
else:
for include_relationship in include_relationships:
column_names.update(include_mapping.get(include_relationship, []))
return column_names
async def build_passage_query(
actor: User,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
query_text: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
before: Optional[str] = None,
after: Optional[str] = None,
source_id: Optional[str] = None,
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False,
) -> Select:
"""Helper function to build the base passage query with all filters applied.
Supports both before and after pagination across merged source and agent passages.
Returns the query before any limit or count operations are applied.
"""
embedded_text = None
if embed_query:
assert embedding_config is not None, "embedding_config must be specified for vector search"
assert query_text is not None, "query_text must be specified for vector search"
# Use the new LLMClient for embeddings
embedding_client = LLMClient.create(
provider_type=embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
embedded_text = np.array(embeddings[0])
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
# Start with base query for source passages
source_passages = None
if not agent_only: # Include source passages
if agent_id is not None:
source_passages = (
select(
SourcePassage.file_name,
SourcePassage.id,
SourcePassage.text,
SourcePassage.embedding_config,
SourcePassage.metadata_,
SourcePassage.embedding,
SourcePassage.created_at,
SourcePassage.updated_at,
SourcePassage.is_deleted,
SourcePassage._created_by_id,
SourcePassage._last_updated_by_id,
SourcePassage.organization_id,
SourcePassage.file_id,
SourcePassage.source_id,
literal(None).label("archive_id"),
)
.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
.where(SourcesAgents.agent_id == agent_id)
.where(SourcePassage.organization_id == actor.organization_id)
)
else:
source_passages = select(
SourcePassage.file_name,
SourcePassage.id,
SourcePassage.text,
SourcePassage.embedding_config,
SourcePassage.metadata_,
SourcePassage.embedding,
SourcePassage.created_at,
SourcePassage.updated_at,
SourcePassage.is_deleted,
SourcePassage._created_by_id,
SourcePassage._last_updated_by_id,
SourcePassage.organization_id,
SourcePassage.file_id,
SourcePassage.source_id,
literal(None).label("archive_id"),
).where(SourcePassage.organization_id == actor.organization_id)
if source_id:
source_passages = source_passages.where(SourcePassage.source_id == source_id)
if file_id:
source_passages = source_passages.where(SourcePassage.file_id == file_id)
# Add agent passages query
agent_passages = None
if agent_id is not None:
agent_passages = (
select(
literal(None).label("file_name"),
ArchivalPassage.id,
ArchivalPassage.text,
ArchivalPassage.embedding_config,
ArchivalPassage.metadata_,
ArchivalPassage.embedding,
ArchivalPassage.created_at,
ArchivalPassage.updated_at,
ArchivalPassage.is_deleted,
ArchivalPassage._created_by_id,
ArchivalPassage._last_updated_by_id,
ArchivalPassage.organization_id,
literal(None).label("file_id"),
literal(None).label("source_id"),
ArchivalPassage.archive_id,
)
.join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id)
.where(ArchivesAgents.agent_id == agent_id)
.where(ArchivalPassage.organization_id == actor.organization_id)
)
# Combine queries
if source_passages is not None and agent_passages is not None:
combined_query = union_all(source_passages, agent_passages).cte("combined_passages")
elif agent_passages is not None:
combined_query = agent_passages.cte("combined_passages")
elif source_passages is not None:
combined_query = source_passages.cte("combined_passages")
else:
raise ValueError("No passages found")
# Build main query from combined CTE
main_query = select(combined_query)
# Apply filters
if start_date:
main_query = main_query.where(combined_query.c.created_at >= start_date)
if end_date:
main_query = main_query.where(combined_query.c.created_at <= end_date)
if source_id:
main_query = main_query.where(combined_query.c.source_id == source_id)
if file_id:
main_query = main_query.where(combined_query.c.file_id == file_id)
# Vector search
if embedded_text:
if settings.database_engine is DatabaseChoice.POSTGRES:
# PostgreSQL with pgvector
main_query = main_query.order_by(combined_query.c.embedding.cosine_distance(embedded_text).asc())
else:
# SQLite with custom vector type
from letta.orm.sqlite_functions import adapt_array
query_embedding_binary = adapt_array(embedded_text)
main_query = main_query.order_by(
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
combined_query.c.created_at.asc() if ascending else combined_query.c.created_at.desc(),
combined_query.c.id.asc(),
)
else:
if query_text:
main_query = main_query.where(func.lower(combined_query.c.text).contains(func.lower(query_text)))
# Handle pagination
if before or after:
# Create reference CTEs
if before:
before_ref = select(combined_query.c.created_at, combined_query.c.id).where(combined_query.c.id == before).cte("before_ref")
if after:
after_ref = select(combined_query.c.created_at, combined_query.c.id).where(combined_query.c.id == after).cte("after_ref")
if before and after:
# Window-based query (get records between before and after)
main_query = main_query.where(
or_(
combined_query.c.created_at < select(before_ref.c.created_at).scalar_subquery(),
and_(
combined_query.c.created_at == select(before_ref.c.created_at).scalar_subquery(),
combined_query.c.id < select(before_ref.c.id).scalar_subquery(),
),
)
)
main_query = main_query.where(
or_(
combined_query.c.created_at > select(after_ref.c.created_at).scalar_subquery(),
and_(
combined_query.c.created_at == select(after_ref.c.created_at).scalar_subquery(),
combined_query.c.id > select(after_ref.c.id).scalar_subquery(),
),
)
)
else:
# Pure pagination (only before or only after)
if before:
main_query = main_query.where(
or_(
combined_query.c.created_at < select(before_ref.c.created_at).scalar_subquery(),
and_(
combined_query.c.created_at == select(before_ref.c.created_at).scalar_subquery(),
combined_query.c.id < select(before_ref.c.id).scalar_subquery(),
),
)
)
if after:
main_query = main_query.where(
or_(
combined_query.c.created_at > select(after_ref.c.created_at).scalar_subquery(),
and_(
combined_query.c.created_at == select(after_ref.c.created_at).scalar_subquery(),
combined_query.c.id > select(after_ref.c.id).scalar_subquery(),
),
)
)
# Add ordering if not already ordered by similarity
if not embed_query:
if ascending:
main_query = main_query.order_by(
combined_query.c.created_at.asc(),
combined_query.c.id.asc(),
)
else:
main_query = main_query.order_by(
combined_query.c.created_at.desc(),
combined_query.c.id.asc(),
)
return main_query
async def build_source_passage_query(
actor: User,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
query_text: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
before: Optional[str] = None,
after: Optional[str] = None,
source_id: Optional[str] = None,
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
) -> Select:
"""Build query for source passages with all filters applied."""
# Handle embedding for vector search
embedded_text = None
if embed_query:
assert embedding_config is not None, "embedding_config must be specified for vector search"
assert query_text is not None, "query_text must be specified for vector search"
# Use the new LLMClient for embeddings
embedding_client = LLMClient.create(
provider_type=embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
embedded_text = np.array(embeddings[0])
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
# Base query for source passages - use noload to prevent lazy loading which can block the event loop
query = select(SourcePassage).options(noload(SourcePassage.organization)).where(SourcePassage.organization_id == actor.organization_id)
# If agent_id is specified, join with SourcesAgents to get only passages linked to that agent
if agent_id is not None:
query = query.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
query = query.where(SourcesAgents.agent_id == agent_id)
# Apply filters
if source_id:
query = query.where(SourcePassage.source_id == source_id)
if file_id:
query = query.where(SourcePassage.file_id == file_id)
if start_date:
query = query.where(SourcePassage.created_at >= start_date)
if end_date:
query = query.where(SourcePassage.created_at <= end_date)
# Handle text search or vector search
if embedded_text:
if settings.database_engine is DatabaseChoice.POSTGRES:
# PostgreSQL with pgvector
query = query.order_by(SourcePassage.embedding.cosine_distance(embedded_text).asc())
else:
# SQLite with custom vector type
from letta.orm.sqlite_functions import adapt_array
query_embedding_binary = adapt_array(embedded_text)
query = query.order_by(
func.cosine_distance(SourcePassage.embedding, query_embedding_binary).asc(),
SourcePassage.created_at.asc() if ascending else SourcePassage.created_at.desc(),
SourcePassage.id.asc(),
)
else:
if query_text:
query = query.where(func.lower(SourcePassage.text).contains(func.lower(query_text)))
# Handle pagination
if before or after:
if before:
# Get the reference record
before_subq = select(SourcePassage.created_at, SourcePassage.id).where(SourcePassage.id == before).subquery()
query = query.where(
or_(
SourcePassage.created_at < before_subq.c.created_at,
and_(
SourcePassage.created_at == before_subq.c.created_at,
SourcePassage.id < before_subq.c.id,
),
)
)
if after:
# Get the reference record
after_subq = select(SourcePassage.created_at, SourcePassage.id).where(SourcePassage.id == after).subquery()
query = query.where(
or_(
SourcePassage.created_at > after_subq.c.created_at,
and_(
SourcePassage.created_at == after_subq.c.created_at,
SourcePassage.id > after_subq.c.id,
),
)
)
# Apply ordering if not already ordered by similarity
if not embed_query:
if ascending:
query = query.order_by(SourcePassage.created_at.asc(), SourcePassage.id.asc())
else:
query = query.order_by(SourcePassage.created_at.desc(), SourcePassage.id.asc())
return query
async def build_agent_passage_query(
actor: User,
agent_id: Optional[str] = None,
archive_id: Optional[str] = None,
query_text: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
before: Optional[str] = None,
after: Optional[str] = None,
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
) -> Select:
"""Build query for agent/archive passages with all filters applied.
Can provide agent_id, archive_id, both, or neither (org-wide search).
If both are provided, agent_id takes precedence.
"""
# Handle embedding for vector search
# If embed_query is True but no embedding_config, fall through to text search
embedded_text = None
if embed_query and embedding_config is not None:
assert query_text is not None, "query_text must be specified for vector search"
# Use the new LLMClient for embeddings
embedding_client = LLMClient.create(
provider_type=embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
embedded_text = np.array(embeddings[0])
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
# Base query for passages - use noload to prevent lazy loading which can block the event loop
if agent_id:
query = (
select(ArchivalPassage)
.options(noload(ArchivalPassage.organization), noload(ArchivalPassage.passage_tags))
.join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id)
.where(ArchivesAgents.agent_id == agent_id, ArchivalPassage.organization_id == actor.organization_id)
)
elif archive_id:
query = (
select(ArchivalPassage)
.options(noload(ArchivalPassage.organization), noload(ArchivalPassage.passage_tags))
.where(ArchivalPassage.archive_id == archive_id, ArchivalPassage.organization_id == actor.organization_id)
)
else:
query = (
select(ArchivalPassage)
.options(noload(ArchivalPassage.organization), noload(ArchivalPassage.passage_tags))
.where(ArchivalPassage.organization_id == actor.organization_id)
)
# Apply filters
if start_date:
query = query.where(ArchivalPassage.created_at >= start_date)
if end_date:
query = query.where(ArchivalPassage.created_at <= end_date)
# Handle text search or vector search
if embedded_text:
if settings.database_engine is DatabaseChoice.POSTGRES:
# PostgreSQL with pgvector
query = query.order_by(ArchivalPassage.embedding.cosine_distance(embedded_text).asc())
else:
# SQLite with custom vector type
from letta.orm.sqlite_functions import adapt_array
query_embedding_binary = adapt_array(embedded_text)
query = query.order_by(
func.cosine_distance(ArchivalPassage.embedding, query_embedding_binary).asc(),
ArchivalPassage.created_at.asc() if ascending else ArchivalPassage.created_at.desc(),
ArchivalPassage.id.asc(),
)
else:
if query_text:
query = query.where(func.lower(ArchivalPassage.text).contains(func.lower(query_text)))
# Handle pagination
if before or after:
if before:
# Get the reference record
before_subq = select(ArchivalPassage.created_at, ArchivalPassage.id).where(ArchivalPassage.id == before).subquery()
query = query.where(
or_(
ArchivalPassage.created_at < before_subq.c.created_at,
and_(
ArchivalPassage.created_at == before_subq.c.created_at,
ArchivalPassage.id < before_subq.c.id,
),
)
)
if after:
# Get the reference record
after_subq = select(ArchivalPassage.created_at, ArchivalPassage.id).where(ArchivalPassage.id == after).subquery()
query = query.where(
or_(
ArchivalPassage.created_at > after_subq.c.created_at,
and_(
ArchivalPassage.created_at == after_subq.c.created_at,
ArchivalPassage.id > after_subq.c.id,
),
)
)
# Apply ordering if not already ordered by similarity
if not embed_query:
if ascending:
query = query.order_by(ArchivalPassage.created_at.asc(), ArchivalPassage.id.asc())
else:
query = query.order_by(ArchivalPassage.created_at.desc(), ArchivalPassage.id.asc())
return query
def calculate_base_tools(is_v2: bool) -> Set[str]:
if is_v2:
return (set(BASE_TOOLS) - set(DEPRECATED_LETTA_TOOLS)) | set(BASE_MEMORY_TOOLS_V2)
else:
return (set(BASE_TOOLS) - set(DEPRECATED_LETTA_TOOLS)) | set(BASE_MEMORY_TOOLS)
def calculate_multi_agent_tools() -> Set[str]:
"""Calculate multi-agent tools, excluding local-only tools in production environment."""
if settings.environment == "prod":
return set(MULTI_AGENT_TOOLS) - set(LOCAL_ONLY_MULTI_AGENT_TOOLS)
else:
return set(MULTI_AGENT_TOOLS)
@trace_method
async def validate_agent_exists_async(session, agent_id: str, actor: User) -> None:
"""
Validate that an agent exists and user has access to it using raw SQL for efficiency.
Args:
session: Database session
agent_id: ID of the agent to validate
actor: User performing the action
Raises:
NoResultFound: If agent doesn't exist or user doesn't have access
"""
agent_exists_query = select(
exists().where(and_(AgentModel.id == agent_id, AgentModel.organization_id == actor.organization_id, AgentModel.is_deleted == False))
)
result = await session.execute(agent_exists_query)
if not result.scalar():
raise LettaAgentNotFoundError(f"Agent with ID {agent_id} not found")