1028 lines
42 KiB
Python
1028 lines
42 KiB
Python
from datetime import datetime
|
|
from typing import List, Literal, Optional
|
|
|
|
import numpy as np
|
|
from sqlalchemy import Select, and_, asc, desc, func, literal, nulls_last, or_, select, union_all
|
|
from sqlalchemy.sql.expression import exists
|
|
|
|
from letta import system
|
|
from letta.constants import IN_CONTEXT_MEMORY_KEYWORD, MAX_EMBEDDING_DIM, STRUCTURED_OUTPUT_MODELS
|
|
from letta.embeddings import embedding_model
|
|
from letta.helpers import ToolRulesSolver
|
|
from letta.helpers.datetime_helpers import get_local_time, get_local_time_fast
|
|
from letta.orm import AgentPassage, SourcePassage, SourcesAgents
|
|
from letta.orm.agent import Agent as AgentModel
|
|
from letta.orm.agents_tags import AgentsTags
|
|
from letta.orm.errors import NoResultFound
|
|
from letta.orm.identity import Identity
|
|
from letta.orm.sqlite_functions import adapt_array
|
|
from letta.otel.tracing import trace_method
|
|
from letta.prompts import gpt_system
|
|
from letta.schemas.agent import AgentState, AgentType
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.enums import MessageRole
|
|
from letta.schemas.letta_message_content import TextContent
|
|
from letta.schemas.memory import Memory
|
|
from letta.schemas.message import Message, MessageCreate
|
|
from letta.schemas.tool_rule import ToolRule
|
|
from letta.schemas.user import User
|
|
from letta.settings import settings
|
|
from letta.system import get_initial_boot_messages, get_login_event, package_function_response
|
|
|
|
|
|
# Static methods
|
|
@trace_method
|
|
def _process_relationship(
|
|
session, agent: AgentModel, relationship_name: str, model_class, item_ids: List[str], allow_partial=False, replace=True
|
|
):
|
|
"""
|
|
Generalized function to handle relationships like tools, sources, and blocks using item IDs.
|
|
|
|
Args:
|
|
session: The database session.
|
|
agent: The AgentModel instance.
|
|
relationship_name: The name of the relationship attribute (e.g., 'tools', 'sources').
|
|
model_class: The ORM class corresponding to the related items.
|
|
item_ids: List of IDs to set or update.
|
|
allow_partial: If True, allows missing items without raising errors.
|
|
replace: If True, replaces the entire relationship; otherwise, extends it.
|
|
|
|
Raises:
|
|
ValueError: If `allow_partial` is False and some IDs are missing.
|
|
"""
|
|
current_relationship = getattr(agent, relationship_name, [])
|
|
if not item_ids:
|
|
if replace:
|
|
setattr(agent, relationship_name, [])
|
|
return
|
|
|
|
# Retrieve models for the provided IDs
|
|
found_items = session.query(model_class).filter(model_class.id.in_(item_ids)).all()
|
|
|
|
# Validate all items are found if allow_partial is False
|
|
if not allow_partial and len(found_items) != len(item_ids):
|
|
missing = set(item_ids) - {item.id for item in found_items}
|
|
raise NoResultFound(f"Items not found in {relationship_name}: {missing}")
|
|
|
|
if replace:
|
|
# Replace the relationship
|
|
setattr(agent, relationship_name, found_items)
|
|
else:
|
|
# Extend the relationship (only add new items)
|
|
current_ids = {item.id for item in current_relationship}
|
|
new_items = [item for item in found_items if item.id not in current_ids]
|
|
current_relationship.extend(new_items)
|
|
|
|
|
|
@trace_method
|
|
async def _process_relationship_async(
|
|
session, agent: AgentModel, relationship_name: str, model_class, item_ids: List[str], allow_partial=False, replace=True
|
|
):
|
|
"""
|
|
Generalized function to handle relationships like tools, sources, and blocks using item IDs.
|
|
|
|
Args:
|
|
session: The database session.
|
|
agent: The AgentModel instance.
|
|
relationship_name: The name of the relationship attribute (e.g., 'tools', 'sources').
|
|
model_class: The ORM class corresponding to the related items.
|
|
item_ids: List of IDs to set or update.
|
|
allow_partial: If True, allows missing items without raising errors.
|
|
replace: If True, replaces the entire relationship; otherwise, extends it.
|
|
|
|
Raises:
|
|
ValueError: If `allow_partial` is False and some IDs are missing.
|
|
"""
|
|
current_relationship = getattr(agent, relationship_name, [])
|
|
if not item_ids:
|
|
if replace:
|
|
setattr(agent, relationship_name, [])
|
|
return
|
|
|
|
# Retrieve models for the provided IDs
|
|
result = await session.execute(select(model_class).where(model_class.id.in_(item_ids)))
|
|
found_items = result.scalars().all()
|
|
|
|
# Validate all items are found if allow_partial is False
|
|
if not allow_partial and len(found_items) != len(item_ids):
|
|
missing = set(item_ids) - {item.id for item in found_items}
|
|
raise NoResultFound(f"Items not found in {relationship_name}: {missing}")
|
|
|
|
if replace:
|
|
# Replace the relationship
|
|
setattr(agent, relationship_name, found_items)
|
|
else:
|
|
# Extend the relationship (only add new items)
|
|
current_ids = {item.id for item in current_relationship}
|
|
new_items = [item for item in found_items if item.id not in current_ids]
|
|
current_relationship.extend(new_items)
|
|
|
|
|
|
def _process_tags(agent: AgentModel, tags: List[str], replace=True):
|
|
"""
|
|
Handles tags for an agent.
|
|
|
|
Args:
|
|
agent: The AgentModel instance.
|
|
tags: List of tags to set or update.
|
|
replace: If True, replaces all tags; otherwise, extends them.
|
|
"""
|
|
if not tags:
|
|
if replace:
|
|
agent.tags = []
|
|
return
|
|
|
|
# Ensure tags are unique and prepare for replacement/extension
|
|
new_tags = {AgentsTags(agent_id=agent.id, tag=tag) for tag in set(tags)}
|
|
if replace:
|
|
agent.tags = list(new_tags)
|
|
else:
|
|
existing_tags = {t.tag for t in agent.tags}
|
|
agent.tags.extend([tag for tag in new_tags if tag.tag not in existing_tags])
|
|
|
|
|
|
def derive_system_message(agent_type: AgentType, enable_sleeptime: Optional[bool] = None, system: Optional[str] = None):
|
|
if system is None:
|
|
# TODO: don't hardcode
|
|
|
|
if agent_type == AgentType.voice_convo_agent:
|
|
system = gpt_system.get_system_text("voice_chat")
|
|
|
|
elif agent_type == AgentType.voice_sleeptime_agent:
|
|
system = gpt_system.get_system_text("voice_sleeptime")
|
|
|
|
# MemGPT v1, both w/ and w/o sleeptime
|
|
elif agent_type == AgentType.memgpt_agent and not enable_sleeptime:
|
|
system = gpt_system.get_system_text("memgpt_v2_chat")
|
|
elif agent_type == AgentType.memgpt_agent and enable_sleeptime:
|
|
# NOTE: same as the chat one, since the chat one says that you "may" have the tools
|
|
system = gpt_system.get_system_text("memgpt_v2_chat")
|
|
|
|
# MemGPT v2, both w/ and w/o sleeptime
|
|
elif agent_type == AgentType.memgpt_v2_agent and not enable_sleeptime:
|
|
system = gpt_system.get_system_text("memgpt_v2_chat")
|
|
elif agent_type == AgentType.memgpt_v2_agent and enable_sleeptime:
|
|
# NOTE: same as the chat one, since the chat one says that you "may" have the tools
|
|
system = gpt_system.get_system_text("memgpt_v2_chat")
|
|
|
|
# Sleeptime
|
|
elif agent_type == AgentType.sleeptime_agent:
|
|
# v2 drops references to specific blocks, and instead relies on the block description injections
|
|
system = gpt_system.get_system_text("sleeptime_v2")
|
|
|
|
else:
|
|
raise ValueError(f"Invalid agent type: {agent_type}")
|
|
|
|
return system
|
|
|
|
|
|
# TODO: This code is kind of wonky and deserves a rewrite
|
|
def compile_memory_metadata_block(
|
|
memory_edit_timestamp: datetime,
|
|
previous_message_count: int = 0,
|
|
archival_memory_size: int = 0,
|
|
) -> str:
|
|
# Put the timestamp in the local timezone (mimicking get_local_time())
|
|
timestamp_str = memory_edit_timestamp.astimezone().strftime("%Y-%m-%d %I:%M:%S %p %Z%z").strip()
|
|
|
|
# Create a metadata block of info so the agent knows about the metadata of out-of-context memories
|
|
memory_metadata_block = "\n".join(
|
|
[
|
|
"<memory_metadata>",
|
|
f"- The current time is: {get_local_time_fast()}",
|
|
f"- Memory blocks were last modified: {timestamp_str}",
|
|
f"- {previous_message_count} previous messages between you and the user are stored in recall memory (use tools to access them)",
|
|
f"- {archival_memory_size} total memories you created are stored in archival memory (use tools to access them)",
|
|
"</memory_metadata>",
|
|
]
|
|
)
|
|
return memory_metadata_block
|
|
|
|
|
|
class PreserveMapping(dict):
|
|
"""Used to preserve (do not modify) undefined variables in the system prompt"""
|
|
|
|
def __missing__(self, key):
|
|
return "{" + key + "}"
|
|
|
|
|
|
def safe_format(template: str, variables: dict) -> str:
|
|
"""
|
|
Safely formats a template string, preserving empty {} and {unknown_vars}
|
|
while substituting known variables.
|
|
|
|
If we simply use {} in format_map, it'll be treated as a positional field
|
|
"""
|
|
# First escape any empty {} by doubling them
|
|
escaped = template.replace("{}", "{{}}")
|
|
|
|
# Now use format_map with our custom mapping
|
|
return escaped.format_map(PreserveMapping(variables))
|
|
|
|
|
|
def compile_system_message(
|
|
system_prompt: str,
|
|
in_context_memory: Memory,
|
|
in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory?
|
|
user_defined_variables: Optional[dict] = None,
|
|
append_icm_if_missing: bool = True,
|
|
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
|
previous_message_count: int = 0,
|
|
archival_memory_size: int = 0,
|
|
tool_rules_solver: Optional[ToolRulesSolver] = None,
|
|
) -> str:
|
|
"""Prepare the final/full system message that will be fed into the LLM API
|
|
|
|
The base system message may be templated, in which case we need to render the variables.
|
|
|
|
The following are reserved variables:
|
|
- CORE_MEMORY: the in-context memory of the LLM
|
|
"""
|
|
# Add tool rule constraints if available
|
|
if tool_rules_solver is not None:
|
|
tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts()
|
|
if tool_constraint_block: # There may not be any depending on if there are tool rules attached
|
|
in_context_memory.blocks.append(tool_constraint_block)
|
|
|
|
if user_defined_variables is not None:
|
|
# TODO eventually support the user defining their own variables to inject
|
|
raise NotImplementedError
|
|
else:
|
|
variables = {}
|
|
|
|
# Add the protected memory variable
|
|
if IN_CONTEXT_MEMORY_KEYWORD in variables:
|
|
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
|
|
else:
|
|
# TODO should this all put into the memory.__repr__ function?
|
|
memory_metadata_string = compile_memory_metadata_block(
|
|
memory_edit_timestamp=in_context_memory_last_edit,
|
|
previous_message_count=previous_message_count,
|
|
archival_memory_size=archival_memory_size,
|
|
)
|
|
full_memory_string = in_context_memory.compile() + "\n\n" + memory_metadata_string
|
|
|
|
# Add to the variables list to inject
|
|
variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string
|
|
|
|
if template_format == "f-string":
|
|
memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}"
|
|
# Catch the special case where the system prompt is unformatted
|
|
if append_icm_if_missing:
|
|
if memory_variable_string not in system_prompt:
|
|
# In this case, append it to the end to make sure memory is still injected
|
|
# warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead")
|
|
system_prompt += "\n\n" + memory_variable_string
|
|
|
|
# render the variables using the built-in templater
|
|
try:
|
|
if user_defined_variables:
|
|
formatted_prompt = safe_format(system_prompt, variables)
|
|
else:
|
|
formatted_prompt = system_prompt.replace(memory_variable_string, full_memory_string)
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}")
|
|
|
|
else:
|
|
# TODO support for mustache and jinja2
|
|
raise NotImplementedError(template_format)
|
|
|
|
return formatted_prompt
|
|
|
|
|
|
def initialize_message_sequence(
|
|
agent_state: AgentState,
|
|
memory_edit_timestamp: Optional[datetime] = None,
|
|
include_initial_boot_message: bool = True,
|
|
previous_message_count: int = 0,
|
|
archival_memory_size: int = 0,
|
|
) -> List[dict]:
|
|
if memory_edit_timestamp is None:
|
|
memory_edit_timestamp = get_local_time()
|
|
|
|
full_system_message = compile_system_message(
|
|
system_prompt=agent_state.system,
|
|
in_context_memory=agent_state.memory,
|
|
in_context_memory_last_edit=memory_edit_timestamp,
|
|
user_defined_variables=None,
|
|
append_icm_if_missing=True,
|
|
previous_message_count=previous_message_count,
|
|
archival_memory_size=archival_memory_size,
|
|
)
|
|
first_user_message = get_login_event() # event letting Letta know the user just logged in
|
|
|
|
if include_initial_boot_message:
|
|
if agent_state.agent_type == AgentType.sleeptime_agent:
|
|
initial_boot_messages = []
|
|
elif agent_state.llm_config.model is not None and "gpt-3.5" in agent_state.llm_config.model:
|
|
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35")
|
|
else:
|
|
initial_boot_messages = get_initial_boot_messages("startup_with_send_message")
|
|
messages = (
|
|
[
|
|
{"role": "system", "content": full_system_message},
|
|
]
|
|
+ initial_boot_messages
|
|
+ [
|
|
{"role": "user", "content": first_user_message},
|
|
]
|
|
)
|
|
|
|
else:
|
|
messages = [
|
|
{"role": "system", "content": full_system_message},
|
|
{"role": "user", "content": first_user_message},
|
|
]
|
|
|
|
return messages
|
|
|
|
|
|
def package_initial_message_sequence(
|
|
agent_id: str, initial_message_sequence: List[MessageCreate], model: str, actor: User
|
|
) -> List[Message]:
|
|
# create the agent object
|
|
init_messages = []
|
|
for message_create in initial_message_sequence:
|
|
|
|
if message_create.role == MessageRole.user:
|
|
packed_message = system.package_user_message(
|
|
user_message=message_create.content,
|
|
)
|
|
init_messages.append(
|
|
Message(
|
|
role=message_create.role,
|
|
content=[TextContent(text=packed_message)],
|
|
name=message_create.name,
|
|
organization_id=actor.organization_id,
|
|
agent_id=agent_id,
|
|
model=model,
|
|
)
|
|
)
|
|
elif message_create.role == MessageRole.system:
|
|
packed_message = system.package_system_message(
|
|
system_message=message_create.content,
|
|
)
|
|
init_messages.append(
|
|
Message(
|
|
role=message_create.role,
|
|
content=[TextContent(text=packed_message)],
|
|
name=message_create.name,
|
|
organization_id=actor.organization_id,
|
|
agent_id=agent_id,
|
|
model=model,
|
|
)
|
|
)
|
|
elif message_create.role == MessageRole.assistant:
|
|
# append tool call to send_message
|
|
import json
|
|
import uuid
|
|
|
|
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
|
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
|
|
|
from letta.constants import DEFAULT_MESSAGE_TOOL
|
|
|
|
tool_call_id = str(uuid.uuid4())
|
|
init_messages.append(
|
|
Message(
|
|
role=MessageRole.assistant,
|
|
content=None,
|
|
name=message_create.name,
|
|
organization_id=actor.organization_id,
|
|
agent_id=agent_id,
|
|
model=model,
|
|
tool_calls=[
|
|
OpenAIToolCall(
|
|
id=tool_call_id,
|
|
type="function",
|
|
function=OpenAIFunction(name=DEFAULT_MESSAGE_TOOL, arguments=json.dumps({"message": message_create.content})),
|
|
)
|
|
],
|
|
)
|
|
)
|
|
|
|
# add tool return
|
|
function_response = package_function_response(True, "None")
|
|
init_messages.append(
|
|
Message(
|
|
role=MessageRole.tool,
|
|
content=[TextContent(text=function_response)],
|
|
name=message_create.name,
|
|
organization_id=actor.organization_id,
|
|
agent_id=agent_id,
|
|
model=model,
|
|
tool_call_id=tool_call_id,
|
|
)
|
|
)
|
|
else:
|
|
# TODO: add tool call and tool return
|
|
raise ValueError(f"Invalid message role: {message_create.role}")
|
|
|
|
return init_messages
|
|
|
|
|
|
def check_supports_structured_output(model: str, tool_rules: List[ToolRule]) -> bool:
|
|
if model not in STRUCTURED_OUTPUT_MODELS:
|
|
if len(ToolRulesSolver(tool_rules=tool_rules).init_tool_rules) > 1:
|
|
raise ValueError("Multiple initial tools are not supported for non-structured models. Please use only one initial tool rule.")
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
|
|
def _cursor_filter(sort_col, id_col, ref_sort_col, ref_id, forward: bool, nulls_last: bool = False):
|
|
"""
|
|
Returns a SQLAlchemy filter expression for cursor-based pagination.
|
|
|
|
If `forward` is True, returns records after the reference.
|
|
If `forward` is False, returns records before the reference.
|
|
|
|
Handles NULL values in the sort column properly when nulls_last is True.
|
|
"""
|
|
if not nulls_last:
|
|
# Simple case: no special NULL handling needed
|
|
if forward:
|
|
return or_(
|
|
sort_col > ref_sort_col,
|
|
and_(sort_col == ref_sort_col, id_col > ref_id),
|
|
)
|
|
else:
|
|
return or_(
|
|
sort_col < ref_sort_col,
|
|
and_(sort_col == ref_sort_col, id_col < ref_id),
|
|
)
|
|
|
|
# Handle nulls_last case
|
|
# TODO: add tests to check if this works for ascending order but nulls are stil last?
|
|
if ref_sort_col is None:
|
|
# Reference cursor is at a NULL value
|
|
if forward:
|
|
# Moving forward (e.g. previous) from NULL: either other NULLs with greater IDs or non-NULLs
|
|
return or_(and_(sort_col.is_(None), id_col > ref_id), sort_col.isnot(None))
|
|
else:
|
|
# Moving backward (e.g. next) from NULL: NULLs with smaller IDs
|
|
return and_(sort_col.is_(None), id_col < ref_id)
|
|
else:
|
|
# Reference cursor is at a non-NULL value
|
|
if forward:
|
|
# Moving forward (e.g. previous) from non-NULL: only greater non-NULL values
|
|
# (NULLs are at the end, so we don't include them when moving forward from non-NULL)
|
|
return and_(sort_col.isnot(None), or_(sort_col > ref_sort_col, and_(sort_col == ref_sort_col, id_col > ref_id)))
|
|
else:
|
|
# Moving backward (e.g. next) from non-NULL: smaller non-NULL values or NULLs
|
|
return or_(sort_col.is_(None), or_(sort_col < ref_sort_col, and_(sort_col == ref_sort_col, id_col < ref_id)))
|
|
|
|
|
|
def _apply_pagination(
|
|
query, before: Optional[str], after: Optional[str], session, ascending: bool = True, sort_by: str = "created_at"
|
|
) -> any:
|
|
# Determine the sort column
|
|
if sort_by == "last_run_completion":
|
|
sort_column = AgentModel.last_run_completion
|
|
sort_nulls_last = True # TODO: handle this as a query param eventually
|
|
else:
|
|
sort_column = AgentModel.created_at
|
|
sort_nulls_last = False
|
|
|
|
if after:
|
|
result = session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == after)).first()
|
|
if result:
|
|
after_sort_value, after_id = result
|
|
query = query.where(
|
|
_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending, nulls_last=sort_nulls_last)
|
|
)
|
|
|
|
if before:
|
|
result = session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == before)).first()
|
|
if result:
|
|
before_sort_value, before_id = result
|
|
query = query.where(
|
|
_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending, nulls_last=sort_nulls_last)
|
|
)
|
|
|
|
# Apply ordering
|
|
order_fn = asc if ascending else desc
|
|
query = query.order_by(nulls_last(order_fn(sort_column)) if sort_nulls_last else order_fn(sort_column), order_fn(AgentModel.id))
|
|
return query
|
|
|
|
|
|
async def _apply_pagination_async(
|
|
query, before: Optional[str], after: Optional[str], session, ascending: bool = True, sort_by: str = "created_at"
|
|
) -> any:
|
|
# Determine the sort column
|
|
if sort_by == "last_run_completion":
|
|
sort_column = AgentModel.last_run_completion
|
|
sort_nulls_last = True # TODO: handle this as a query param eventually
|
|
else:
|
|
sort_column = AgentModel.created_at
|
|
sort_nulls_last = False
|
|
|
|
if after:
|
|
result = (await session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == after))).first()
|
|
if result:
|
|
after_sort_value, after_id = result
|
|
query = query.where(
|
|
_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending, nulls_last=sort_nulls_last)
|
|
)
|
|
|
|
if before:
|
|
result = (await session.execute(select(sort_column, AgentModel.id).where(AgentModel.id == before))).first()
|
|
if result:
|
|
before_sort_value, before_id = result
|
|
query = query.where(
|
|
_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending, nulls_last=sort_nulls_last)
|
|
)
|
|
|
|
# Apply ordering
|
|
order_fn = asc if ascending else desc
|
|
query = query.order_by(nulls_last(order_fn(sort_column)) if sort_nulls_last else order_fn(sort_column), order_fn(AgentModel.id))
|
|
return query
|
|
|
|
|
|
def _apply_tag_filter(query, tags: Optional[List[str]], match_all_tags: bool):
|
|
"""
|
|
Apply tag-based filtering to the agent query.
|
|
|
|
This helper function creates a subquery that groups agent IDs by their tags.
|
|
If `match_all_tags` is True, it filters agents that have all of the specified tags.
|
|
Otherwise, it filters agents that have any of the tags.
|
|
|
|
Args:
|
|
query: The SQLAlchemy query object to be modified.
|
|
tags (Optional[List[str]]): A list of tags to filter agents.
|
|
match_all_tags (bool): If True, only return agents that match all provided tags.
|
|
|
|
Returns:
|
|
The modified query with tag filters applied.
|
|
"""
|
|
|
|
if tags:
|
|
if match_all_tags:
|
|
for tag in tags:
|
|
query = query.filter(exists().where((AgentsTags.agent_id == AgentModel.id) & (AgentsTags.tag == tag)))
|
|
else:
|
|
query = query.where(exists().where((AgentsTags.agent_id == AgentModel.id) & (AgentsTags.tag.in_(tags))))
|
|
return query
|
|
|
|
|
|
def _apply_identity_filters(query, identity_id: Optional[str], identifier_keys: Optional[List[str]]):
|
|
"""
|
|
Apply identity-related filters to the agent query.
|
|
|
|
This helper function joins the identities relationship and filters the agents based on
|
|
a specific identity ID and/or a list of identifier keys.
|
|
|
|
Args:
|
|
query: The SQLAlchemy query object to be modified.
|
|
identity_id (Optional[str]): The identity ID to filter by.
|
|
identifier_keys (Optional[List[str]]): A list of identifier keys to filter agents.
|
|
|
|
Returns:
|
|
The modified query with identity filters applied.
|
|
"""
|
|
# Join the identities relationship and filter by a specific identity ID.
|
|
if identity_id:
|
|
query = query.join(AgentModel.identities).where(Identity.id == identity_id)
|
|
# Join the identities relationship and filter by a set of identifier keys.
|
|
if identifier_keys:
|
|
query = query.join(AgentModel.identities).where(Identity.identifier_key.in_(identifier_keys))
|
|
return query
|
|
|
|
|
|
def _apply_filters(
|
|
query,
|
|
name: Optional[str],
|
|
query_text: Optional[str],
|
|
project_id: Optional[str],
|
|
template_id: Optional[str],
|
|
base_template_id: Optional[str],
|
|
):
|
|
"""
|
|
Apply basic filtering criteria to the agent query.
|
|
|
|
This helper function adds WHERE clauses based on provided parameters such as
|
|
exact name, partial name match (using ILIKE), project ID, template ID, and base template ID.
|
|
|
|
Args:
|
|
query: The SQLAlchemy query object to be modified.
|
|
name (Optional[str]): Exact name to filter by.
|
|
query_text (Optional[str]): Partial text to search in the agent's name (case-insensitive).
|
|
project_id (Optional[str]): Filter for agents belonging to a specific project.
|
|
template_id (Optional[str]): Filter for agents using a specific template.
|
|
base_template_id (Optional[str]): Filter for agents using a specific base template.
|
|
|
|
Returns:
|
|
The modified query with the applied filters.
|
|
"""
|
|
# Filter by exact agent name if provided.
|
|
if name:
|
|
query = query.where(AgentModel.name == name)
|
|
# Apply a case-insensitive partial match for the agent's name.
|
|
if query_text:
|
|
query = query.where(AgentModel.name.ilike(f"%{query_text}%"))
|
|
# Filter agents by project ID.
|
|
if project_id:
|
|
query = query.where(AgentModel.project_id == project_id)
|
|
# Filter agents by template ID.
|
|
if template_id:
|
|
query = query.where(AgentModel.template_id == template_id)
|
|
# Filter agents by base template ID.
|
|
if base_template_id:
|
|
query = query.where(AgentModel.base_template_id == base_template_id)
|
|
return query
|
|
|
|
|
|
def build_passage_query(
|
|
actor: User,
|
|
agent_id: Optional[str] = None,
|
|
file_id: Optional[str] = None,
|
|
query_text: Optional[str] = None,
|
|
start_date: Optional[datetime] = None,
|
|
end_date: Optional[datetime] = None,
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
source_id: Optional[str] = None,
|
|
embed_query: bool = False,
|
|
ascending: bool = True,
|
|
embedding_config: Optional[EmbeddingConfig] = None,
|
|
agent_only: bool = False,
|
|
) -> Select:
|
|
"""Helper function to build the base passage query with all filters applied.
|
|
Supports both before and after pagination across merged source and agent passages.
|
|
|
|
Returns the query before any limit or count operations are applied.
|
|
"""
|
|
embedded_text = None
|
|
if embed_query:
|
|
assert embedding_config is not None, "embedding_config must be specified for vector search"
|
|
assert query_text is not None, "query_text must be specified for vector search"
|
|
embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
|
|
embedded_text = np.array(embedded_text)
|
|
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
|
|
|
# Start with base query for source passages
|
|
source_passages = None
|
|
if not agent_only: # Include source passages
|
|
if agent_id is not None:
|
|
source_passages = (
|
|
select(
|
|
SourcePassage.file_name,
|
|
SourcePassage.id,
|
|
SourcePassage.text,
|
|
SourcePassage.embedding_config,
|
|
SourcePassage.metadata_,
|
|
SourcePassage.embedding,
|
|
SourcePassage.created_at,
|
|
SourcePassage.updated_at,
|
|
SourcePassage.is_deleted,
|
|
SourcePassage._created_by_id,
|
|
SourcePassage._last_updated_by_id,
|
|
SourcePassage.organization_id,
|
|
SourcePassage.file_id,
|
|
SourcePassage.source_id,
|
|
literal(None).label("agent_id"),
|
|
)
|
|
.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
|
|
.where(SourcesAgents.agent_id == agent_id)
|
|
.where(SourcePassage.organization_id == actor.organization_id)
|
|
)
|
|
else:
|
|
source_passages = select(
|
|
SourcePassage.file_name,
|
|
SourcePassage.id,
|
|
SourcePassage.text,
|
|
SourcePassage.embedding_config,
|
|
SourcePassage.metadata_,
|
|
SourcePassage.embedding,
|
|
SourcePassage.created_at,
|
|
SourcePassage.updated_at,
|
|
SourcePassage.is_deleted,
|
|
SourcePassage._created_by_id,
|
|
SourcePassage._last_updated_by_id,
|
|
SourcePassage.organization_id,
|
|
SourcePassage.file_id,
|
|
SourcePassage.source_id,
|
|
literal(None).label("agent_id"),
|
|
).where(SourcePassage.organization_id == actor.organization_id)
|
|
|
|
if source_id:
|
|
source_passages = source_passages.where(SourcePassage.source_id == source_id)
|
|
if file_id:
|
|
source_passages = source_passages.where(SourcePassage.file_id == file_id)
|
|
|
|
# Add agent passages query
|
|
agent_passages = None
|
|
if agent_id is not None:
|
|
agent_passages = (
|
|
select(
|
|
literal(None).label("file_name"),
|
|
AgentPassage.id,
|
|
AgentPassage.text,
|
|
AgentPassage.embedding_config,
|
|
AgentPassage.metadata_,
|
|
AgentPassage.embedding,
|
|
AgentPassage.created_at,
|
|
AgentPassage.updated_at,
|
|
AgentPassage.is_deleted,
|
|
AgentPassage._created_by_id,
|
|
AgentPassage._last_updated_by_id,
|
|
AgentPassage.organization_id,
|
|
literal(None).label("file_id"),
|
|
literal(None).label("source_id"),
|
|
AgentPassage.agent_id,
|
|
)
|
|
.where(AgentPassage.agent_id == agent_id)
|
|
.where(AgentPassage.organization_id == actor.organization_id)
|
|
)
|
|
|
|
# Combine queries
|
|
if source_passages is not None and agent_passages is not None:
|
|
combined_query = union_all(source_passages, agent_passages).cte("combined_passages")
|
|
elif agent_passages is not None:
|
|
combined_query = agent_passages.cte("combined_passages")
|
|
elif source_passages is not None:
|
|
combined_query = source_passages.cte("combined_passages")
|
|
else:
|
|
raise ValueError("No passages found")
|
|
|
|
# Build main query from combined CTE
|
|
main_query = select(combined_query)
|
|
|
|
# Apply filters
|
|
if start_date:
|
|
main_query = main_query.where(combined_query.c.created_at >= start_date)
|
|
if end_date:
|
|
main_query = main_query.where(combined_query.c.created_at <= end_date)
|
|
if source_id:
|
|
main_query = main_query.where(combined_query.c.source_id == source_id)
|
|
if file_id:
|
|
main_query = main_query.where(combined_query.c.file_id == file_id)
|
|
|
|
# Vector search
|
|
if embedded_text:
|
|
if settings.letta_pg_uri_no_default:
|
|
# PostgreSQL with pgvector
|
|
main_query = main_query.order_by(combined_query.c.embedding.cosine_distance(embedded_text).asc())
|
|
else:
|
|
# SQLite with custom vector type
|
|
query_embedding_binary = adapt_array(embedded_text)
|
|
main_query = main_query.order_by(
|
|
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
|
|
combined_query.c.created_at.asc() if ascending else combined_query.c.created_at.desc(),
|
|
combined_query.c.id.asc(),
|
|
)
|
|
else:
|
|
if query_text:
|
|
main_query = main_query.where(func.lower(combined_query.c.text).contains(func.lower(query_text)))
|
|
|
|
# Handle pagination
|
|
if before or after:
|
|
# Create reference CTEs
|
|
if before:
|
|
before_ref = select(combined_query.c.created_at, combined_query.c.id).where(combined_query.c.id == before).cte("before_ref")
|
|
if after:
|
|
after_ref = select(combined_query.c.created_at, combined_query.c.id).where(combined_query.c.id == after).cte("after_ref")
|
|
|
|
if before and after:
|
|
# Window-based query (get records between before and after)
|
|
main_query = main_query.where(
|
|
or_(
|
|
combined_query.c.created_at < select(before_ref.c.created_at).scalar_subquery(),
|
|
and_(
|
|
combined_query.c.created_at == select(before_ref.c.created_at).scalar_subquery(),
|
|
combined_query.c.id < select(before_ref.c.id).scalar_subquery(),
|
|
),
|
|
)
|
|
)
|
|
main_query = main_query.where(
|
|
or_(
|
|
combined_query.c.created_at > select(after_ref.c.created_at).scalar_subquery(),
|
|
and_(
|
|
combined_query.c.created_at == select(after_ref.c.created_at).scalar_subquery(),
|
|
combined_query.c.id > select(after_ref.c.id).scalar_subquery(),
|
|
),
|
|
)
|
|
)
|
|
else:
|
|
# Pure pagination (only before or only after)
|
|
if before:
|
|
main_query = main_query.where(
|
|
or_(
|
|
combined_query.c.created_at < select(before_ref.c.created_at).scalar_subquery(),
|
|
and_(
|
|
combined_query.c.created_at == select(before_ref.c.created_at).scalar_subquery(),
|
|
combined_query.c.id < select(before_ref.c.id).scalar_subquery(),
|
|
),
|
|
)
|
|
)
|
|
if after:
|
|
main_query = main_query.where(
|
|
or_(
|
|
combined_query.c.created_at > select(after_ref.c.created_at).scalar_subquery(),
|
|
and_(
|
|
combined_query.c.created_at == select(after_ref.c.created_at).scalar_subquery(),
|
|
combined_query.c.id > select(after_ref.c.id).scalar_subquery(),
|
|
),
|
|
)
|
|
)
|
|
|
|
# Add ordering if not already ordered by similarity
|
|
if not embed_query:
|
|
if ascending:
|
|
main_query = main_query.order_by(
|
|
combined_query.c.created_at.asc(),
|
|
combined_query.c.id.asc(),
|
|
)
|
|
else:
|
|
main_query = main_query.order_by(
|
|
combined_query.c.created_at.desc(),
|
|
combined_query.c.id.asc(),
|
|
)
|
|
|
|
return main_query
|
|
|
|
|
|
def build_source_passage_query(
|
|
actor: User,
|
|
agent_id: Optional[str] = None,
|
|
file_id: Optional[str] = None,
|
|
query_text: Optional[str] = None,
|
|
start_date: Optional[datetime] = None,
|
|
end_date: Optional[datetime] = None,
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
source_id: Optional[str] = None,
|
|
embed_query: bool = False,
|
|
ascending: bool = True,
|
|
embedding_config: Optional[EmbeddingConfig] = None,
|
|
) -> Select:
|
|
"""Build query for source passages with all filters applied."""
|
|
|
|
# Handle embedding for vector search
|
|
embedded_text = None
|
|
if embed_query:
|
|
assert embedding_config is not None, "embedding_config must be specified for vector search"
|
|
assert query_text is not None, "query_text must be specified for vector search"
|
|
embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
|
|
embedded_text = np.array(embedded_text)
|
|
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
|
|
|
# Base query for source passages
|
|
query = select(SourcePassage).where(SourcePassage.organization_id == actor.organization_id)
|
|
|
|
# If agent_id is specified, join with SourcesAgents to get only passages linked to that agent
|
|
if agent_id is not None:
|
|
query = query.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
|
|
query = query.where(SourcesAgents.agent_id == agent_id)
|
|
|
|
# Apply filters
|
|
if source_id:
|
|
query = query.where(SourcePassage.source_id == source_id)
|
|
if file_id:
|
|
query = query.where(SourcePassage.file_id == file_id)
|
|
if start_date:
|
|
query = query.where(SourcePassage.created_at >= start_date)
|
|
if end_date:
|
|
query = query.where(SourcePassage.created_at <= end_date)
|
|
|
|
# Handle text search or vector search
|
|
if embedded_text:
|
|
if settings.letta_pg_uri_no_default:
|
|
# PostgreSQL with pgvector
|
|
query = query.order_by(SourcePassage.embedding.cosine_distance(embedded_text).asc())
|
|
else:
|
|
# SQLite with custom vector type
|
|
query_embedding_binary = adapt_array(embedded_text)
|
|
query = query.order_by(
|
|
func.cosine_distance(SourcePassage.embedding, query_embedding_binary).asc(),
|
|
SourcePassage.created_at.asc() if ascending else SourcePassage.created_at.desc(),
|
|
SourcePassage.id.asc(),
|
|
)
|
|
else:
|
|
if query_text:
|
|
query = query.where(func.lower(SourcePassage.text).contains(func.lower(query_text)))
|
|
|
|
# Handle pagination
|
|
if before or after:
|
|
if before:
|
|
# Get the reference record
|
|
before_subq = select(SourcePassage.created_at, SourcePassage.id).where(SourcePassage.id == before).subquery()
|
|
query = query.where(
|
|
or_(
|
|
SourcePassage.created_at < before_subq.c.created_at,
|
|
and_(
|
|
SourcePassage.created_at == before_subq.c.created_at,
|
|
SourcePassage.id < before_subq.c.id,
|
|
),
|
|
)
|
|
)
|
|
|
|
if after:
|
|
# Get the reference record
|
|
after_subq = select(SourcePassage.created_at, SourcePassage.id).where(SourcePassage.id == after).subquery()
|
|
query = query.where(
|
|
or_(
|
|
SourcePassage.created_at > after_subq.c.created_at,
|
|
and_(
|
|
SourcePassage.created_at == after_subq.c.created_at,
|
|
SourcePassage.id > after_subq.c.id,
|
|
),
|
|
)
|
|
)
|
|
|
|
# Apply ordering if not already ordered by similarity
|
|
if not embed_query:
|
|
if ascending:
|
|
query = query.order_by(SourcePassage.created_at.asc(), SourcePassage.id.asc())
|
|
else:
|
|
query = query.order_by(SourcePassage.created_at.desc(), SourcePassage.id.asc())
|
|
|
|
return query
|
|
|
|
|
|
def build_agent_passage_query(
|
|
actor: User,
|
|
agent_id: str, # Required for agent passages
|
|
query_text: Optional[str] = None,
|
|
start_date: Optional[datetime] = None,
|
|
end_date: Optional[datetime] = None,
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
embed_query: bool = False,
|
|
ascending: bool = True,
|
|
embedding_config: Optional[EmbeddingConfig] = None,
|
|
) -> Select:
|
|
"""Build query for agent passages with all filters applied."""
|
|
|
|
# Handle embedding for vector search
|
|
embedded_text = None
|
|
if embed_query:
|
|
assert embedding_config is not None, "embedding_config must be specified for vector search"
|
|
assert query_text is not None, "query_text must be specified for vector search"
|
|
embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
|
|
embedded_text = np.array(embedded_text)
|
|
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
|
|
|
# Base query for agent passages
|
|
query = select(AgentPassage).where(AgentPassage.agent_id == agent_id, AgentPassage.organization_id == actor.organization_id)
|
|
|
|
# Apply filters
|
|
if start_date:
|
|
query = query.where(AgentPassage.created_at >= start_date)
|
|
if end_date:
|
|
query = query.where(AgentPassage.created_at <= end_date)
|
|
|
|
# Handle text search or vector search
|
|
if embedded_text:
|
|
if settings.letta_pg_uri_no_default:
|
|
# PostgreSQL with pgvector
|
|
query = query.order_by(AgentPassage.embedding.cosine_distance(embedded_text).asc())
|
|
else:
|
|
# SQLite with custom vector type
|
|
query_embedding_binary = adapt_array(embedded_text)
|
|
query = query.order_by(
|
|
func.cosine_distance(AgentPassage.embedding, query_embedding_binary).asc(),
|
|
AgentPassage.created_at.asc() if ascending else AgentPassage.created_at.desc(),
|
|
AgentPassage.id.asc(),
|
|
)
|
|
else:
|
|
if query_text:
|
|
query = query.where(func.lower(AgentPassage.text).contains(func.lower(query_text)))
|
|
|
|
# Handle pagination
|
|
if before or after:
|
|
if before:
|
|
# Get the reference record
|
|
before_subq = select(AgentPassage.created_at, AgentPassage.id).where(AgentPassage.id == before).subquery()
|
|
query = query.where(
|
|
or_(
|
|
AgentPassage.created_at < before_subq.c.created_at,
|
|
and_(
|
|
AgentPassage.created_at == before_subq.c.created_at,
|
|
AgentPassage.id < before_subq.c.id,
|
|
),
|
|
)
|
|
)
|
|
|
|
if after:
|
|
# Get the reference record
|
|
after_subq = select(AgentPassage.created_at, AgentPassage.id).where(AgentPassage.id == after).subquery()
|
|
query = query.where(
|
|
or_(
|
|
AgentPassage.created_at > after_subq.c.created_at,
|
|
and_(
|
|
AgentPassage.created_at == after_subq.c.created_at,
|
|
AgentPassage.id > after_subq.c.id,
|
|
),
|
|
)
|
|
)
|
|
|
|
# Apply ordering if not already ordered by similarity
|
|
if not embed_query:
|
|
if ascending:
|
|
query = query.order_by(AgentPassage.created_at.asc(), AgentPassage.id.asc())
|
|
else:
|
|
query = query.order_by(AgentPassage.created_at.desc(), AgentPassage.id.asc())
|
|
|
|
return query
|