3375 lines
152 KiB
Python
3375 lines
152 KiB
Python
import asyncio
|
||
from datetime import datetime, timezone
|
||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple
|
||
from zoneinfo import ZoneInfo
|
||
|
||
import sqlalchemy as sa
|
||
from sqlalchemy import delete, func, insert, literal, or_, select, tuple_
|
||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||
|
||
from letta.constants import (
|
||
BASE_MEMORY_TOOLS,
|
||
BASE_MEMORY_TOOLS_V2,
|
||
BASE_MEMORY_TOOLS_V3,
|
||
BASE_SLEEPTIME_CHAT_TOOLS,
|
||
BASE_SLEEPTIME_TOOLS,
|
||
BASE_TOOLS,
|
||
BASE_VOICE_SLEEPTIME_CHAT_TOOLS,
|
||
BASE_VOICE_SLEEPTIME_TOOLS,
|
||
DEFAULT_CORE_MEMORY_SOURCE_CHAR_LIMIT,
|
||
DEFAULT_MAX_FILES_OPEN,
|
||
DEFAULT_TIMEZONE,
|
||
DEPRECATED_LETTA_TOOLS,
|
||
EXCLUDE_MODEL_KEYWORDS_FROM_BASE_TOOL_RULES,
|
||
FILES_TOOLS,
|
||
INCLUDE_MODEL_KEYWORDS_BASE_TOOL_RULES,
|
||
RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE,
|
||
)
|
||
from letta.errors import LettaAgentNotFoundError, LettaInvalidArgumentError
|
||
from letta.helpers import ToolRulesSolver
|
||
from letta.helpers.datetime_helpers import get_utc_time
|
||
from letta.llm_api.llm_client import LLMClient
|
||
from letta.log import get_logger
|
||
from letta.orm import (
|
||
Agent as AgentModel,
|
||
AgentsTags,
|
||
ArchivalPassage,
|
||
Block as BlockModel,
|
||
BlocksAgents,
|
||
Group as GroupModel,
|
||
GroupsAgents,
|
||
IdentitiesAgents,
|
||
Source as SourceModel,
|
||
SourcePassage,
|
||
SourcesAgents,
|
||
Tool as ToolModel,
|
||
ToolsAgents,
|
||
)
|
||
from letta.orm.errors import NoResultFound
|
||
from letta.orm.sandbox_config import AgentEnvironmentVariable, AgentEnvironmentVariable as AgentEnvironmentVariableModel
|
||
from letta.orm.sqlalchemy_base import AccessType
|
||
from letta.otel.tracing import trace_method
|
||
from letta.prompts.prompt_generator import PromptGenerator
|
||
from letta.schemas.agent import (
|
||
AgentRelationships,
|
||
AgentState as PydanticAgentState,
|
||
CreateAgent,
|
||
InternalTemplateAgentCreate,
|
||
UpdateAgent,
|
||
)
|
||
from letta.schemas.block import DEFAULT_BLOCKS, Block as PydanticBlock, BlockUpdate
|
||
from letta.schemas.embedding_config import EmbeddingConfig
|
||
from letta.schemas.enums import AgentType, PrimitiveType, ProviderType, TagMatchMode, ToolType, VectorDBProvider
|
||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||
from letta.schemas.group import Group as PydanticGroup, ManagerType
|
||
from letta.schemas.letta_stop_reason import StopReasonType
|
||
from letta.schemas.llm_config import LLMConfig
|
||
from letta.schemas.memory import ContextWindowOverview, Memory
|
||
from letta.schemas.message import Message, Message as PydanticMessage, MessageCreate, MessageUpdate
|
||
from letta.schemas.passage import Passage as PydanticPassage
|
||
from letta.schemas.secret import Secret
|
||
from letta.schemas.source import Source as PydanticSource
|
||
from letta.schemas.tool import Tool as PydanticTool
|
||
from letta.schemas.tool_rule import ContinueToolRule, RequiresApprovalToolRule, TerminalToolRule
|
||
from letta.schemas.user import User as PydanticUser
|
||
from letta.serialize_schemas import MarshmallowAgentSchema
|
||
from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema
|
||
from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema
|
||
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
|
||
from letta.server.db import db_registry
|
||
from letta.services.archive_manager import ArchiveManager
|
||
from letta.services.block_manager import BlockManager, validate_block_limit_constraint
|
||
from letta.services.context_window_calculator.context_window_calculator import ContextWindowCalculator
|
||
from letta.services.context_window_calculator.token_counter import create_token_counter
|
||
from letta.services.file_processor.chunker.line_chunker import LineChunker
|
||
from letta.services.files_agents_manager import FileAgentManager
|
||
from letta.services.helpers.agent_manager_helper import (
|
||
_apply_filters,
|
||
_apply_identity_filters,
|
||
_apply_pagination,
|
||
_apply_pagination_async,
|
||
_apply_relationship_filters,
|
||
_apply_tag_filter,
|
||
_process_relationship,
|
||
_process_relationship_async,
|
||
build_agent_passage_query,
|
||
build_passage_query,
|
||
build_source_passage_query,
|
||
calculate_base_tools,
|
||
calculate_multi_agent_tools,
|
||
check_supports_structured_output,
|
||
compile_system_message,
|
||
derive_system_message,
|
||
initialize_message_sequence,
|
||
initialize_message_sequence_async,
|
||
package_initial_message_sequence,
|
||
validate_agent_exists_async,
|
||
)
|
||
from letta.services.identity_manager import IdentityManager
|
||
from letta.services.message_manager import MessageManager
|
||
from letta.services.passage_manager import PassageManager
|
||
from letta.services.source_manager import SourceManager
|
||
from letta.services.tool_manager import ToolManager
|
||
from letta.settings import DatabaseChoice, model_settings, settings
|
||
from letta.utils import (
|
||
bounded_gather,
|
||
calculate_file_defaults_based_on_context_window,
|
||
decrypt_agent_secrets,
|
||
enforce_types,
|
||
united_diff,
|
||
)
|
||
from letta.validators import raise_on_invalid_id
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class AgentManager:
|
||
"""Manager class to handle business logic related to Agents."""
|
||
|
||
def __init__(self):
|
||
self.block_manager = BlockManager()
|
||
self.tool_manager = ToolManager()
|
||
self.source_manager = SourceManager()
|
||
self.message_manager = MessageManager()
|
||
self.passage_manager = PassageManager()
|
||
self.identity_manager = IdentityManager()
|
||
self.file_agent_manager = FileAgentManager()
|
||
self.archive_manager = ArchiveManager()
|
||
|
||
@staticmethod
|
||
def _should_exclude_model_from_base_tool_rules(model: str) -> bool:
|
||
"""Check if a model should be excluded from base tool rules based on model keywords."""
|
||
# First check if model contains any include keywords (overrides exclusion)
|
||
for include_keyword in INCLUDE_MODEL_KEYWORDS_BASE_TOOL_RULES:
|
||
if include_keyword in model:
|
||
return False
|
||
|
||
# Then check if model contains any exclude keywords
|
||
for exclude_keyword in EXCLUDE_MODEL_KEYWORDS_FROM_BASE_TOOL_RULES:
|
||
if exclude_keyword in model:
|
||
return True
|
||
|
||
return False
|
||
|
||
@staticmethod
|
||
def _resolve_tools(session, names: Set[str], ids: Set[str], org_id: str) -> Tuple[Dict[str, str], Dict[str, str]]:
|
||
"""
|
||
Bulk‑fetch all ToolModel rows matching either name ∈ names or id ∈ ids
|
||
(and scoped to this organization), and return two maps:
|
||
name_to_id, id_to_name.
|
||
Raises if any requested name or id was not found.
|
||
"""
|
||
stmt = select(ToolModel.id, ToolModel.name).where(
|
||
ToolModel.organization_id == org_id,
|
||
or_(
|
||
ToolModel.name.in_(names),
|
||
ToolModel.id.in_(ids),
|
||
),
|
||
)
|
||
rows = session.execute(stmt).all()
|
||
name_to_id = {name: tid for tid, name in rows}
|
||
id_to_name = {tid: name for tid, name in rows}
|
||
|
||
missing_names = names - set(name_to_id.keys())
|
||
missing_ids = ids - set(id_to_name.keys())
|
||
if missing_names:
|
||
raise ValueError(f"Tools not found by name: {missing_names}")
|
||
if missing_ids:
|
||
raise ValueError(f"Tools not found by id: {missing_ids}")
|
||
|
||
return name_to_id, id_to_name
|
||
|
||
@staticmethod
|
||
async def _resolve_tools_async(
|
||
session, names: Set[str], ids: Set[str], org_id: str, ignore_invalid_tools: bool = False
|
||
) -> Tuple[Dict[str, str], Dict[str, str], List[str]]:
|
||
"""
|
||
Bulk‑fetch all ToolModel rows matching either name ∈ names or id ∈ ids
|
||
(and scoped to this organization), and return two maps:
|
||
name_to_id, id_to_name.
|
||
Raises if any requested name or id was not found (unless ignore_invalid_tools is True).
|
||
|
||
Args:
|
||
session: Database session
|
||
names: Set of tool names to resolve
|
||
ids: Set of tool IDs to resolve
|
||
org_id: Organization ID for scoping
|
||
ignore_invalid_tools: If True, silently filters out missing tools instead of raising an error
|
||
"""
|
||
stmt = select(ToolModel.id, ToolModel.name, ToolModel.default_requires_approval).where(
|
||
ToolModel.organization_id == org_id,
|
||
or_(
|
||
ToolModel.name.in_(names),
|
||
ToolModel.id.in_(ids),
|
||
),
|
||
)
|
||
result = await session.execute(stmt)
|
||
rows = result.fetchall() # Use fetchall()
|
||
name_to_id = {row[1]: row[0] for row in rows} # row[1] is name, row[0] is id
|
||
id_to_name = {row[0]: row[1] for row in rows} # row[0] is id, row[1] is name
|
||
requires_approval = [row[1] for row in rows if row[2]] # row[1] is name, row[2] is default_requires_approval
|
||
|
||
missing_names = names - set(name_to_id.keys())
|
||
missing_ids = ids - set(id_to_name.keys())
|
||
|
||
if not ignore_invalid_tools:
|
||
# Original behavior: raise errors for missing tools
|
||
if missing_names:
|
||
raise ValueError(f"Tools not found by name: {missing_names}")
|
||
if missing_ids:
|
||
raise ValueError(f"Tools not found by id: {missing_ids}")
|
||
else:
|
||
# New behavior: log missing tools but don't raise errors
|
||
if missing_names or missing_ids:
|
||
logger = get_logger(__name__)
|
||
if missing_names:
|
||
logger.warning(f"Ignoring tools not found by name: {missing_names}")
|
||
if missing_ids:
|
||
logger.warning(f"Ignoring tools not found by id: {missing_ids}")
|
||
|
||
return name_to_id, id_to_name, requires_approval
|
||
|
||
@staticmethod
|
||
def _bulk_insert_pivot(session, table, rows: list[dict]):
|
||
if not rows:
|
||
return
|
||
|
||
dialect = session.bind.dialect.name
|
||
if dialect == "postgresql":
|
||
stmt = pg_insert(table).values(rows).on_conflict_do_nothing()
|
||
elif dialect == "sqlite":
|
||
stmt = sa.insert(table).values(rows).prefix_with("OR IGNORE")
|
||
else:
|
||
# fallback: filter out exact-duplicate dicts in Python
|
||
seen = set()
|
||
filtered = []
|
||
for row in rows:
|
||
key = tuple(sorted(row.items()))
|
||
if key not in seen:
|
||
seen.add(key)
|
||
filtered.append(row)
|
||
stmt = sa.insert(table).values(filtered)
|
||
|
||
session.execute(stmt)
|
||
|
||
@staticmethod
|
||
async def _bulk_insert_pivot_async(session, table, rows: list[dict]):
|
||
if not rows:
|
||
return
|
||
|
||
dialect = session.bind.dialect.name
|
||
if dialect == "postgresql":
|
||
stmt = pg_insert(table).values(rows).on_conflict_do_nothing()
|
||
elif dialect == "sqlite":
|
||
stmt = sa.insert(table).values(rows).prefix_with("OR IGNORE")
|
||
else:
|
||
# fallback: filter out exact-duplicate dicts in Python
|
||
seen = set()
|
||
filtered = []
|
||
for row in rows:
|
||
key = tuple(sorted(row.items()))
|
||
if key not in seen:
|
||
seen.add(key)
|
||
filtered.append(row)
|
||
stmt = sa.insert(table).values(filtered)
|
||
|
||
await session.execute(stmt)
|
||
|
||
@staticmethod
|
||
def _replace_pivot_rows(session, table, agent_id: str, rows: list[dict]):
|
||
"""
|
||
Replace all pivot rows for an agent with *exactly* the provided list.
|
||
Uses two bulk statements (DELETE + INSERT ... ON CONFLICT DO NOTHING).
|
||
"""
|
||
# delete all existing rows for this agent
|
||
session.execute(delete(table).where(table.c.agent_id == agent_id))
|
||
if rows:
|
||
AgentManager._bulk_insert_pivot(session, table, rows)
|
||
|
||
@staticmethod
|
||
async def _replace_pivot_rows_async(session, table, agent_id: str, rows: list[dict]):
|
||
"""
|
||
Replace all pivot rows for an agent atomically using MERGE pattern.
|
||
"""
|
||
dialect = session.bind.dialect.name
|
||
|
||
if dialect == "postgresql":
|
||
if rows:
|
||
# separate upsert and delete operations
|
||
stmt = pg_insert(table).values(rows)
|
||
stmt = stmt.on_conflict_do_nothing()
|
||
await session.execute(stmt)
|
||
|
||
# delete rows not in new set
|
||
pk_names = [c.name for c in table.primary_key.columns]
|
||
new_keys = [tuple(r[c] for c in pk_names) for r in rows]
|
||
await session.execute(
|
||
delete(table).where(table.c.agent_id == agent_id, ~tuple_(*[table.c[c] for c in pk_names]).in_(new_keys))
|
||
)
|
||
else:
|
||
# if no rows to insert, just delete all
|
||
await session.execute(delete(table).where(table.c.agent_id == agent_id))
|
||
|
||
elif dialect == "sqlite":
|
||
if rows:
|
||
stmt = sa.insert(table).values(rows).prefix_with("OR REPLACE")
|
||
await session.execute(stmt)
|
||
|
||
if rows:
|
||
primary_key_cols = [table.c[c.name] for c in table.primary_key.columns]
|
||
new_keys = [tuple(r[c.name] for c in table.primary_key.columns) for r in rows]
|
||
await session.execute(delete(table).where(table.c.agent_id == agent_id, ~tuple_(*primary_key_cols).in_(new_keys)))
|
||
else:
|
||
await session.execute(delete(table).where(table.c.agent_id == agent_id))
|
||
|
||
else:
|
||
# fallback: use original DELETE + INSERT pattern
|
||
await session.execute(delete(table).where(table.c.agent_id == agent_id))
|
||
if rows:
|
||
await AgentManager._bulk_insert_pivot_async(session, table, rows)
|
||
|
||
# ======================================================================================================================
|
||
# Basic CRUD operations
|
||
# ======================================================================================================================
|
||
|
||
@trace_method
|
||
async def create_agent_async(
|
||
self,
|
||
agent_create: CreateAgent,
|
||
actor: PydanticUser,
|
||
_test_only_force_id: Optional[str] = None,
|
||
_init_with_no_messages: bool = False,
|
||
ignore_invalid_tools: bool = False,
|
||
) -> PydanticAgentState:
|
||
# validate required configs
|
||
if not agent_create.llm_config:
|
||
raise ValueError("llm_config is required")
|
||
|
||
# For v1 agents, enforce sane defaults even when reasoning is omitted
|
||
if agent_create.agent_type == AgentType.letta_v1_agent:
|
||
# Claude 3.7/4 or OpenAI o1/o3/o4/gpt-5
|
||
default_reasoning = LLMConfig.is_anthropic_reasoning_model(agent_create.llm_config) or LLMConfig.is_openai_reasoning_model(
|
||
agent_create.llm_config
|
||
)
|
||
agent_create.llm_config = LLMConfig.apply_reasoning_setting_to_config(
|
||
agent_create.llm_config,
|
||
agent_create.reasoning if agent_create.reasoning is not None else default_reasoning,
|
||
agent_create.agent_type,
|
||
)
|
||
else:
|
||
if agent_create.reasoning is not None:
|
||
agent_create.llm_config = LLMConfig.apply_reasoning_setting_to_config(
|
||
agent_create.llm_config,
|
||
agent_create.reasoning,
|
||
agent_create.agent_type,
|
||
)
|
||
|
||
# blocks
|
||
block_ids = list(agent_create.block_ids or [])
|
||
if agent_create.memory_blocks:
|
||
pydantic_blocks = [PydanticBlock(**b.model_dump(to_orm=True)) for b in agent_create.memory_blocks]
|
||
|
||
# Inject a description for the default blocks if the user didn't specify them
|
||
# Used for `persona`, `human`, etc
|
||
default_blocks = {block.label: block for block in DEFAULT_BLOCKS}
|
||
for block in pydantic_blocks:
|
||
if block.label in default_blocks:
|
||
if block.description is None:
|
||
block.description = default_blocks[block.label].description
|
||
|
||
# Actually create the blocks
|
||
created_blocks = await self.block_manager.batch_create_blocks_async(
|
||
pydantic_blocks,
|
||
actor=actor,
|
||
)
|
||
block_ids.extend([blk.id for blk in created_blocks])
|
||
|
||
# tools
|
||
tool_names = set(agent_create.tools or [])
|
||
if agent_create.include_base_tools:
|
||
if agent_create.agent_type == AgentType.voice_sleeptime_agent:
|
||
tool_names |= set(BASE_VOICE_SLEEPTIME_TOOLS)
|
||
# NOTE: also overwrite initial message sequence to empty by default
|
||
if agent_create.initial_message_sequence is None:
|
||
agent_create.initial_message_sequence = []
|
||
elif agent_create.agent_type == AgentType.voice_convo_agent:
|
||
tool_names |= set(BASE_VOICE_SLEEPTIME_CHAT_TOOLS)
|
||
elif agent_create.agent_type == AgentType.sleeptime_agent:
|
||
tool_names |= set(BASE_SLEEPTIME_TOOLS)
|
||
# NOTE: also overwrite initial message sequence to empty by default
|
||
if agent_create.initial_message_sequence is None:
|
||
agent_create.initial_message_sequence = []
|
||
elif agent_create.enable_sleeptime:
|
||
tool_names |= set(BASE_SLEEPTIME_CHAT_TOOLS)
|
||
elif agent_create.agent_type == AgentType.memgpt_v2_agent:
|
||
tool_names |= calculate_base_tools(is_v2=True)
|
||
elif agent_create.agent_type == AgentType.react_agent:
|
||
pass # no default tools
|
||
elif agent_create.agent_type == AgentType.letta_v1_agent:
|
||
tool_names |= calculate_base_tools(is_v2=True)
|
||
# Remove `send_message` if it exists
|
||
tool_names.discard("send_message")
|
||
# NOTE: also overwriting inner_thoughts_in_kwargs to force False
|
||
agent_create.llm_config.put_inner_thoughts_in_kwargs = False
|
||
# NOTE: also overwrite initial message sequence to empty by default
|
||
if agent_create.initial_message_sequence is None:
|
||
agent_create.initial_message_sequence = []
|
||
# NOTE: default to no base tool rules unless explicitly provided
|
||
if not agent_create.tool_rules and agent_create.include_base_tool_rules is None:
|
||
agent_create.include_base_tool_rules = False
|
||
elif agent_create.agent_type == AgentType.workflow_agent:
|
||
pass # no default tools
|
||
else:
|
||
tool_names |= calculate_base_tools(is_v2=False)
|
||
if agent_create.include_multi_agent_tools:
|
||
tool_names |= calculate_multi_agent_tools()
|
||
|
||
supplied_ids = set(agent_create.tool_ids or [])
|
||
|
||
# Use folder_ids if provided, otherwise fall back to deprecated source_ids for backwards compatibility
|
||
source_ids = agent_create.folder_ids if agent_create.folder_ids else (agent_create.source_ids or [])
|
||
|
||
# Create default source if requested
|
||
if agent_create.include_default_source:
|
||
default_source = PydanticSource(
|
||
name=f"{agent_create.name} External Data Source",
|
||
embedding_config=agent_create.embedding_config,
|
||
)
|
||
created_source = await self.source_manager.create_source(default_source, actor)
|
||
source_ids.append(created_source.id)
|
||
|
||
identity_ids = agent_create.identity_ids or []
|
||
tag_values = agent_create.tags or []
|
||
|
||
# if the agent type is workflow, we set the autoclear to forced true
|
||
if agent_create.agent_type == AgentType.workflow_agent:
|
||
agent_create.message_buffer_autoclear = True
|
||
|
||
async with db_registry.async_session() as session:
|
||
async with session.begin():
|
||
# Note: This will need to be modified if _resolve_tools needs an async version
|
||
name_to_id, id_to_name, requires_approval = await self._resolve_tools_async(
|
||
session,
|
||
tool_names,
|
||
supplied_ids,
|
||
actor.organization_id,
|
||
ignore_invalid_tools=ignore_invalid_tools,
|
||
)
|
||
|
||
tool_ids = set(name_to_id.values()) | set(id_to_name.keys())
|
||
tool_names = set(name_to_id.keys()) # now canonical
|
||
tool_rules = list(agent_create.tool_rules or [])
|
||
|
||
# Override include_base_tool_rules to False if model matches exclusion keywords and include_base_tool_rules is not explicitly set to True
|
||
if (
|
||
(
|
||
self._should_exclude_model_from_base_tool_rules(agent_create.llm_config.model)
|
||
and agent_create.include_base_tool_rules is None
|
||
)
|
||
and agent_create.agent_type != AgentType.sleeptime_agent
|
||
) or agent_create.include_base_tool_rules is False:
|
||
agent_create.include_base_tool_rules = False
|
||
logger.info(f"Overriding include_base_tool_rules to False for model: {agent_create.llm_config.model}")
|
||
else:
|
||
agent_create.include_base_tool_rules = True
|
||
|
||
should_add_base_tool_rules = agent_create.include_base_tool_rules
|
||
if should_add_base_tool_rules:
|
||
for tn in tool_names:
|
||
if tn in {"send_message", "send_message_to_agent_async", "memory_finish_edits"}:
|
||
tool_rules.append(TerminalToolRule(tool_name=tn))
|
||
elif tn in (BASE_TOOLS + BASE_MEMORY_TOOLS + BASE_MEMORY_TOOLS_V2 + BASE_MEMORY_TOOLS_V3 + BASE_SLEEPTIME_TOOLS):
|
||
tool_rules.append(ContinueToolRule(tool_name=tn))
|
||
|
||
for tool_with_requires_approval in requires_approval:
|
||
tool_rules.append(RequiresApprovalToolRule(tool_name=tool_with_requires_approval))
|
||
|
||
if tool_rules:
|
||
check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=tool_rules)
|
||
|
||
new_agent = AgentModel(
|
||
name=agent_create.name,
|
||
system=derive_system_message(
|
||
agent_type=agent_create.agent_type,
|
||
enable_sleeptime=agent_create.enable_sleeptime,
|
||
system=agent_create.system,
|
||
),
|
||
agent_type=agent_create.agent_type,
|
||
llm_config=agent_create.llm_config,
|
||
embedding_config=agent_create.embedding_config,
|
||
compaction_settings=agent_create.compaction_settings,
|
||
organization_id=actor.organization_id,
|
||
description=agent_create.description,
|
||
metadata_=agent_create.metadata,
|
||
tool_rules=tool_rules,
|
||
hidden=agent_create.hidden,
|
||
project_id=agent_create.project_id,
|
||
template_id=agent_create.template_id,
|
||
base_template_id=agent_create.base_template_id,
|
||
message_buffer_autoclear=agent_create.message_buffer_autoclear,
|
||
enable_sleeptime=agent_create.enable_sleeptime,
|
||
response_format=agent_create.response_format,
|
||
created_by_id=actor.id,
|
||
last_updated_by_id=actor.id,
|
||
timezone=agent_create.timezone if agent_create.timezone else DEFAULT_TIMEZONE,
|
||
max_files_open=agent_create.max_files_open,
|
||
per_file_view_window_char_limit=agent_create.per_file_view_window_char_limit,
|
||
)
|
||
|
||
# Set template fields for InternalTemplateAgentCreate (similar to group creation)
|
||
if isinstance(agent_create, InternalTemplateAgentCreate):
|
||
new_agent.base_template_id = agent_create.base_template_id
|
||
new_agent.template_id = agent_create.template_id
|
||
new_agent.deployment_id = agent_create.deployment_id
|
||
new_agent.entity_id = agent_create.entity_id
|
||
|
||
if _test_only_force_id:
|
||
new_agent.id = _test_only_force_id
|
||
|
||
session.add(new_agent)
|
||
await session.flush()
|
||
aid = new_agent.id
|
||
|
||
# Note: These methods may need async versions if they perform database operations
|
||
await self._bulk_insert_pivot_async(
|
||
session,
|
||
ToolsAgents.__table__,
|
||
[{"agent_id": aid, "tool_id": tid} for tid in tool_ids],
|
||
)
|
||
|
||
if block_ids:
|
||
result = await session.execute(select(BlockModel.id, BlockModel.label).where(BlockModel.id.in_(block_ids)))
|
||
rows = [{"agent_id": aid, "block_id": bid, "block_label": lbl} for bid, lbl in result.all()]
|
||
await self._bulk_insert_pivot_async(session, BlocksAgents.__table__, rows)
|
||
|
||
await self._bulk_insert_pivot_async(
|
||
session,
|
||
SourcesAgents.__table__,
|
||
[{"agent_id": aid, "source_id": sid} for sid in source_ids],
|
||
)
|
||
await self._bulk_insert_pivot_async(
|
||
session,
|
||
AgentsTags.__table__,
|
||
[{"agent_id": aid, "tag": tag} for tag in tag_values],
|
||
)
|
||
await self._bulk_insert_pivot_async(
|
||
session,
|
||
IdentitiesAgents.__table__,
|
||
[{"agent_id": aid, "identity_id": iid} for iid in identity_ids],
|
||
)
|
||
|
||
env_rows = []
|
||
agent_secrets = agent_create.secrets or agent_create.tool_exec_environment_variables
|
||
|
||
if agent_secrets:
|
||
# Encrypt environment variable values concurrently (async to avoid blocking event loop)
|
||
secrets_dict = await Secret.from_plaintexts_async(agent_secrets)
|
||
env_rows = [
|
||
{
|
||
"agent_id": aid,
|
||
"key": key,
|
||
"value": "", # Empty string for NOT NULL constraint (deprecated, use value_enc)
|
||
"value_enc": secret.get_encrypted(),
|
||
"organization_id": actor.organization_id,
|
||
}
|
||
for key, secret in secrets_dict.items()
|
||
]
|
||
|
||
result = await session.execute(insert(AgentEnvironmentVariable).values(env_rows).returning(AgentEnvironmentVariable.id))
|
||
env_rows = [{**row, "id": env_var_id} for row, env_var_id in zip(env_rows, result.scalars().all())]
|
||
|
||
include_relationships = []
|
||
if tool_ids:
|
||
include_relationships.append("tools")
|
||
if source_ids:
|
||
include_relationships.append("sources")
|
||
if block_ids:
|
||
include_relationships.append("memory")
|
||
if identity_ids:
|
||
include_relationships.append("identity_ids")
|
||
if tag_values:
|
||
include_relationships.append("tags")
|
||
|
||
result = await new_agent.to_pydantic_async(include_relationships=include_relationships)
|
||
|
||
if agent_secrets and env_rows:
|
||
result.tool_exec_environment_variables = [AgentEnvironmentVariable(**row) for row in env_rows]
|
||
result.secrets = [AgentEnvironmentVariable(**row) for row in env_rows]
|
||
|
||
# initial message sequence (skip if _init_with_no_messages is True)
|
||
if not _init_with_no_messages:
|
||
init_messages = await self._generate_initial_message_sequence_async(
|
||
actor,
|
||
agent_state=result,
|
||
supplied_initial_message_sequence=agent_create.initial_message_sequence,
|
||
)
|
||
result.message_ids = [msg.id for msg in init_messages]
|
||
new_agent.message_ids = [msg.id for msg in init_messages]
|
||
await new_agent.update_async(session, no_refresh=True)
|
||
else:
|
||
init_messages = []
|
||
|
||
# Only create messages if we initialized with messages
|
||
if not _init_with_no_messages:
|
||
await self.message_manager.create_many_messages_async(
|
||
pydantic_msgs=init_messages, actor=actor, project_id=result.project_id, template_id=result.template_id
|
||
)
|
||
|
||
# Attach files from sources if this is a template-based creation
|
||
# Use the new agent's sources (already copied from template via source_ids)
|
||
if isinstance(agent_create, InternalTemplateAgentCreate) and source_ids:
|
||
try:
|
||
from letta.services.file_manager import FileManager
|
||
|
||
file_manager = FileManager()
|
||
|
||
# Get all files from the new agent's sources
|
||
all_files_metadata = []
|
||
for source_id in source_ids:
|
||
try:
|
||
files_in_source = await file_manager.list_files(
|
||
source_id=source_id,
|
||
actor=actor,
|
||
limit=1000,
|
||
)
|
||
all_files_metadata.extend(files_in_source)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to get files from source {source_id}: {e}")
|
||
|
||
if all_files_metadata:
|
||
try:
|
||
await self.file_agent_manager.attach_files_bulk(
|
||
agent_id=result.id,
|
||
files_metadata=all_files_metadata,
|
||
visible_content_map={}, # Empty map - content generated on-demand
|
||
actor=actor,
|
||
max_files_open=result.max_files_open or DEFAULT_MAX_FILES_OPEN,
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Failed to attach files: {e}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to attach files from sources: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
|
||
return result
|
||
|
||
@enforce_types
|
||
def _generate_initial_message_sequence(
|
||
self, actor: PydanticUser, agent_state: PydanticAgentState, supplied_initial_message_sequence: Optional[List[MessageCreate]] = None
|
||
) -> List[Message]:
|
||
init_messages = initialize_message_sequence(
|
||
agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True
|
||
)
|
||
if supplied_initial_message_sequence is not None:
|
||
# We always need the system prompt up front
|
||
system_message_obj = PydanticMessage.dict_to_message(
|
||
agent_id=agent_state.id,
|
||
model=agent_state.llm_config.model,
|
||
openai_message_dict=init_messages[0],
|
||
)
|
||
# Don't use anything else in the pregen sequence, instead use the provided sequence
|
||
init_messages = [system_message_obj]
|
||
init_messages.extend(
|
||
package_initial_message_sequence(
|
||
agent_state.id, supplied_initial_message_sequence, agent_state.llm_config.model, agent_state.timezone, actor
|
||
)
|
||
)
|
||
else:
|
||
init_messages = [
|
||
PydanticMessage.dict_to_message(agent_id=agent_state.id, model=agent_state.llm_config.model, openai_message_dict=msg)
|
||
for msg in init_messages
|
||
]
|
||
|
||
return init_messages
|
||
|
||
@enforce_types
|
||
async def _generate_initial_message_sequence_async(
|
||
self, actor: PydanticUser, agent_state: PydanticAgentState, supplied_initial_message_sequence: Optional[List[MessageCreate]] = None
|
||
) -> List[Message]:
|
||
init_messages = await initialize_message_sequence_async(
|
||
agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True
|
||
)
|
||
if supplied_initial_message_sequence is not None:
|
||
# We always need the system prompt up front
|
||
system_message_obj = PydanticMessage.dict_to_message(
|
||
agent_id=agent_state.id,
|
||
model=agent_state.llm_config.model,
|
||
openai_message_dict=init_messages[0],
|
||
)
|
||
# Don't use anything else in the pregen sequence, instead use the provided sequence
|
||
init_messages = [system_message_obj]
|
||
init_messages.extend(
|
||
package_initial_message_sequence(
|
||
agent_state.id, supplied_initial_message_sequence, agent_state.llm_config.model, agent_state.timezone, actor
|
||
)
|
||
)
|
||
else:
|
||
init_messages = [
|
||
PydanticMessage.dict_to_message(agent_id=agent_state.id, model=agent_state.llm_config.model, openai_message_dict=msg)
|
||
for msg in init_messages
|
||
]
|
||
|
||
return init_messages
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def append_initial_message_sequence_to_in_context_messages_async(
|
||
self, actor: PydanticUser, agent_state: PydanticAgentState, initial_message_sequence: Optional[List[MessageCreate]] = None
|
||
) -> PydanticAgentState:
|
||
init_messages = await self._generate_initial_message_sequence_async(actor, agent_state, initial_message_sequence)
|
||
return await self.append_to_in_context_messages_async(init_messages, agent_id=agent_state.id, actor=actor)
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@trace_method
|
||
async def update_agent_async(
|
||
self,
|
||
agent_id: str,
|
||
agent_update: UpdateAgent,
|
||
actor: PydanticUser,
|
||
) -> PydanticAgentState:
|
||
new_tools = set(agent_update.tool_ids or [])
|
||
# Use folder_ids if provided, otherwise fall back to deprecated source_ids for backwards compatibility
|
||
folder_ids_to_update = agent_update.folder_ids if agent_update.folder_ids is not None else agent_update.source_ids
|
||
new_sources = set(folder_ids_to_update or [])
|
||
new_blocks = set(agent_update.block_ids or [])
|
||
new_idents = set(agent_update.identity_ids or [])
|
||
new_tags = set(agent_update.tags or [])
|
||
|
||
async with db_registry.async_session() as session, session.begin():
|
||
agent: AgentModel = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
agent.updated_at = datetime.now(timezone.utc)
|
||
agent.last_updated_by_id = actor.id
|
||
|
||
if agent_update.reasoning is not None:
|
||
llm_config = agent_update.llm_config or agent.llm_config
|
||
agent_update.llm_config = LLMConfig.apply_reasoning_setting_to_config(
|
||
llm_config,
|
||
agent_update.reasoning,
|
||
agent.agent_type,
|
||
)
|
||
|
||
scalar_updates = {
|
||
"name": agent_update.name,
|
||
"system": agent_update.system,
|
||
"llm_config": agent_update.llm_config,
|
||
"embedding_config": agent_update.embedding_config,
|
||
"compaction_settings": agent_update.compaction_settings,
|
||
"message_ids": agent_update.message_ids,
|
||
"tool_rules": agent_update.tool_rules,
|
||
"description": agent_update.description,
|
||
"project_id": agent_update.project_id,
|
||
"template_id": agent_update.template_id,
|
||
"base_template_id": agent_update.base_template_id,
|
||
"message_buffer_autoclear": agent_update.message_buffer_autoclear,
|
||
"enable_sleeptime": agent_update.enable_sleeptime,
|
||
"response_format": agent_update.response_format,
|
||
"last_run_completion": agent_update.last_run_completion,
|
||
"last_run_duration_ms": agent_update.last_run_duration_ms,
|
||
"last_stop_reason": agent_update.last_stop_reason,
|
||
"timezone": agent_update.timezone,
|
||
"max_files_open": agent_update.max_files_open,
|
||
"per_file_view_window_char_limit": agent_update.per_file_view_window_char_limit,
|
||
}
|
||
for col, val in scalar_updates.items():
|
||
if val is not None:
|
||
setattr(agent, col, val)
|
||
|
||
if agent_update.metadata is not None:
|
||
agent.metadata_ = agent_update.metadata
|
||
|
||
aid = agent.id
|
||
|
||
if agent_update.tool_ids is not None:
|
||
await self._replace_pivot_rows_async(
|
||
session,
|
||
ToolsAgents.__table__,
|
||
aid,
|
||
[{"agent_id": aid, "tool_id": tid} for tid in new_tools],
|
||
)
|
||
session.expire(agent, ["tools"])
|
||
|
||
# Update sources if either folder_ids or source_ids (deprecated) is provided
|
||
if agent_update.folder_ids is not None or agent_update.source_ids is not None:
|
||
await self._replace_pivot_rows_async(
|
||
session,
|
||
SourcesAgents.__table__,
|
||
aid,
|
||
[{"agent_id": aid, "source_id": sid} for sid in new_sources],
|
||
)
|
||
session.expire(agent, ["sources"])
|
||
|
||
if agent_update.block_ids is not None:
|
||
rows = []
|
||
if new_blocks:
|
||
result = await session.execute(select(BlockModel.id, BlockModel.label).where(BlockModel.id.in_(new_blocks)))
|
||
label_map = {bid: lbl for bid, lbl in result.all()}
|
||
rows = [{"agent_id": aid, "block_id": bid, "block_label": label_map[bid]} for bid in new_blocks]
|
||
|
||
await self._replace_pivot_rows_async(session, BlocksAgents.__table__, aid, rows)
|
||
session.expire(agent, ["core_memory"])
|
||
|
||
if agent_update.identity_ids is not None:
|
||
await self._replace_pivot_rows_async(
|
||
session,
|
||
IdentitiesAgents.__table__,
|
||
aid,
|
||
[{"agent_id": aid, "identity_id": iid} for iid in new_idents],
|
||
)
|
||
session.expire(agent, ["identities"])
|
||
|
||
if agent_update.tags is not None:
|
||
await self._replace_pivot_rows_async(
|
||
session,
|
||
AgentsTags.__table__,
|
||
aid,
|
||
[{"agent_id": aid, "tag": tag} for tag in new_tags],
|
||
)
|
||
session.expire(agent, ["tags"])
|
||
|
||
agent_secrets = agent_update.secrets if agent_update.secrets is not None else agent_update.tool_exec_environment_variables
|
||
if agent_secrets is not None:
|
||
# Fetch existing environment variables to check if values changed
|
||
result = await session.execute(select(AgentEnvironmentVariable).where(AgentEnvironmentVariable.agent_id == aid))
|
||
existing_env_vars = {env.key: env for env in result.scalars().all()}
|
||
|
||
# TODO: do we need to delete each time or can we just upsert?
|
||
await session.execute(delete(AgentEnvironmentVariable).where(AgentEnvironmentVariable.agent_id == aid))
|
||
|
||
# Decrypt existing values to check for changes (async to avoid blocking)
|
||
existing_values: dict[str, str | None] = {}
|
||
for k, existing_env in existing_env_vars.items():
|
||
if existing_env.value_enc:
|
||
existing_secret = Secret.from_encrypted(existing_env.value_enc)
|
||
existing_values[k] = await existing_secret.get_plaintext_async()
|
||
else:
|
||
existing_values[k] = None
|
||
|
||
# Identify values that need encryption (new or changed)
|
||
to_encrypt = {
|
||
k: v
|
||
for k, v in agent_secrets.items()
|
||
if k not in existing_env_vars or existing_values.get(k) != v or not existing_env_vars[k].value_enc
|
||
}
|
||
|
||
# Batch encrypt new/changed values concurrently (async to avoid blocking event loop)
|
||
new_secrets = await Secret.from_plaintexts_async(to_encrypt) if to_encrypt else {}
|
||
|
||
# Build rows, reusing existing encrypted values where unchanged
|
||
env_rows = []
|
||
for k, v in agent_secrets.items():
|
||
if k in new_secrets:
|
||
# New or changed value - use newly encrypted value
|
||
value_enc = new_secrets[k].get_encrypted()
|
||
else:
|
||
# Value unchanged - reuse existing encrypted value
|
||
value_enc = existing_env_vars[k].value_enc
|
||
|
||
row = {
|
||
"agent_id": aid,
|
||
"key": k,
|
||
"value": "", # Empty string for NOT NULL constraint (deprecated, use value_enc)
|
||
"value_enc": value_enc,
|
||
"organization_id": agent.organization_id,
|
||
}
|
||
env_rows.append(row)
|
||
|
||
if env_rows:
|
||
await self._bulk_insert_pivot_async(session, AgentEnvironmentVariable.__table__, env_rows)
|
||
session.expire(agent, ["tool_exec_environment_variables"])
|
||
|
||
if agent_update.enable_sleeptime and agent_update.system is None:
|
||
agent.system = derive_system_message(
|
||
agent_type=agent.agent_type,
|
||
enable_sleeptime=agent_update.enable_sleeptime,
|
||
system=agent.system,
|
||
)
|
||
|
||
await session.flush()
|
||
await session.refresh(agent)
|
||
|
||
# Convert without decrypting to release DB connection before PBKDF2
|
||
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
|
||
|
||
# Decrypt secrets outside session
|
||
return (await decrypt_agent_secrets([agent_encrypted]))[0]
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def update_message_ids_async(
|
||
self,
|
||
agent_id: str,
|
||
message_ids: List[str],
|
||
actor: PydanticUser,
|
||
) -> None:
|
||
async with db_registry.async_session() as session:
|
||
query = select(AgentModel)
|
||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||
query = query.where(AgentModel.id == agent_id)
|
||
query = _apply_relationship_filters(query, include_relationships=[])
|
||
|
||
result = await session.execute(query)
|
||
agent = result.scalar_one_or_none()
|
||
|
||
agent.updated_at = datetime.now(timezone.utc)
|
||
agent.last_updated_by_id = actor.id
|
||
agent.message_ids = message_ids
|
||
|
||
await agent.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||
# context manager now handles commits
|
||
# await session.commit()
|
||
|
||
@trace_method
|
||
async def list_agents_async(
|
||
self,
|
||
actor: PydanticUser,
|
||
name: Optional[str] = None,
|
||
tags: Optional[List[str]] = None,
|
||
match_all_tags: bool = False,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
limit: Optional[int] = 50,
|
||
query_text: Optional[str] = None,
|
||
project_id: Optional[str] = None,
|
||
template_id: Optional[str] = None,
|
||
base_template_id: Optional[str] = None,
|
||
identity_id: Optional[str] = None,
|
||
identifier_keys: Optional[List[str]] = None,
|
||
include_relationships: Optional[List[str]] = None,
|
||
include: List[str] = [],
|
||
ascending: bool = True,
|
||
sort_by: Optional[str] = "created_at",
|
||
show_hidden_agents: Optional[bool] = None,
|
||
last_stop_reason: Optional[StopReasonType] = None,
|
||
) -> List[PydanticAgentState]:
|
||
"""
|
||
Retrieves agents with optimized filtering and optional field selection.
|
||
|
||
Args:
|
||
actor: The User requesting the list
|
||
name (Optional[str]): Filter by agent name.
|
||
tags (Optional[List[str]]): Filter agents by tags.
|
||
match_all_tags (bool): If True, only return agents that match ALL given tags.
|
||
before (Optional[str]): Cursor for pagination.
|
||
after (Optional[str]): Cursor for pagination.
|
||
limit (Optional[int]): Maximum number of agents to return.
|
||
query_text (Optional[str]): Search agents by name.
|
||
project_id (Optional[str]): Filter by project ID.
|
||
template_id (Optional[str]): Filter by template ID.
|
||
base_template_id (Optional[str]): Filter by base template ID.
|
||
identity_id (Optional[str]): Filter by identifier ID.
|
||
identifier_keys (Optional[List[str]]): Search agents by identifier keys.
|
||
include_relationships (Optional[List[str]]): List of fields to load for performance optimization.
|
||
ascending (bool): Sort agents in ascending order.
|
||
sort_by (Optional[str]): Sort agents by this field.
|
||
show_hidden_agents (bool): If True, include agents marked as hidden in the results.
|
||
last_stop_reason (Optional[str]): Filter by the agent's last stop reason (e.g., 'requires_approval', 'error').
|
||
|
||
Returns:
|
||
List[PydanticAgentState]: The filtered list of matching agents.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
query = select(AgentModel)
|
||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||
|
||
# Apply filters
|
||
query = _apply_filters(query, name, query_text, project_id, template_id, base_template_id, last_stop_reason)
|
||
query = _apply_identity_filters(query, identity_id, identifier_keys)
|
||
query = _apply_tag_filter(query, tags, match_all_tags)
|
||
query = _apply_relationship_filters(query, include_relationships, include)
|
||
|
||
# Apply hidden filter
|
||
if not show_hidden_agents:
|
||
query = query.where((AgentModel.hidden.is_(None)) | (AgentModel.hidden == False))
|
||
query = await _apply_pagination_async(query, before, after, session, ascending=ascending, sort_by=sort_by)
|
||
|
||
if limit:
|
||
query = query.limit(limit)
|
||
result = await session.execute(query)
|
||
agents = result.scalars().all()
|
||
|
||
# Convert to pydantic without decrypting (keeps encrypted values)
|
||
# This allows us to release the DB connection before expensive PBKDF2 operations
|
||
agents_encrypted = await bounded_gather(
|
||
[agent.to_pydantic_async(include_relationships=include_relationships, include=include, decrypt=False) for agent in agents]
|
||
)
|
||
|
||
# DB session released - now decrypt secrets outside session to prevent connection holding
|
||
return await decrypt_agent_secrets(agents_encrypted)
|
||
|
||
@trace_method
|
||
async def count_agents_async(
|
||
self,
|
||
actor: PydanticUser,
|
||
name: Optional[str] = None,
|
||
tags: Optional[List[str]] = None,
|
||
match_all_tags: bool = False,
|
||
query_text: Optional[str] = None,
|
||
project_id: Optional[str] = None,
|
||
template_id: Optional[str] = None,
|
||
base_template_id: Optional[str] = None,
|
||
identity_id: Optional[str] = None,
|
||
identifier_keys: Optional[List[str]] = None,
|
||
show_hidden_agents: Optional[bool] = None,
|
||
last_stop_reason: Optional[StopReasonType] = None,
|
||
) -> int:
|
||
"""
|
||
Count agents matching the specified filters using an efficient database-level COUNT query.
|
||
|
||
Args:
|
||
actor: The User requesting the count
|
||
name (Optional[str]): Filter by agent name.
|
||
tags (Optional[List[str]]): Filter agents by tags.
|
||
match_all_tags (bool): If True, only count agents that match ALL given tags.
|
||
query_text (Optional[str]): Search agents by name.
|
||
project_id (Optional[str]): Filter by project ID.
|
||
template_id (Optional[str]): Filter by template ID.
|
||
base_template_id (Optional[str]): Filter by base template ID.
|
||
identity_id (Optional[str]): Filter by identifier ID.
|
||
identifier_keys (Optional[List[str]]): Search agents by identifier keys.
|
||
show_hidden_agents (bool): If True, include agents marked as hidden in the results.
|
||
last_stop_reason (Optional[str]): Filter by the agent's last stop reason (e.g., 'requires_approval', 'error').
|
||
|
||
Returns:
|
||
int: The count of agents matching the filters.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
query = select(func.count()).select_from(AgentModel)
|
||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||
|
||
# Apply filters
|
||
query = _apply_filters(query, name, query_text, project_id, template_id, base_template_id, last_stop_reason)
|
||
query = _apply_identity_filters(query, identity_id, identifier_keys)
|
||
query = _apply_tag_filter(query, tags, match_all_tags)
|
||
|
||
# Apply hidden filter
|
||
if not show_hidden_agents:
|
||
query = query.where((AgentModel.hidden.is_(None)) | (AgentModel.hidden == False))
|
||
|
||
result = await session.execute(query)
|
||
return result.scalar_one()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def list_agents_matching_tags_async(
|
||
self,
|
||
actor: PydanticUser,
|
||
match_all: List[str],
|
||
match_some: List[str],
|
||
limit: Optional[int] = 50,
|
||
) -> List[PydanticAgentState]:
|
||
"""
|
||
Retrieves agents in the same organization that match all specified `match_all` tags
|
||
and at least one tag from `match_some`. The query is optimized for efficiency by
|
||
leveraging indexed filtering and aggregation.
|
||
|
||
Args:
|
||
actor (PydanticUser): The user requesting the agent list.
|
||
match_all (List[str]): Agents must have all these tags.
|
||
match_some (List[str]): Agents must have at least one of these tags.
|
||
limit (Optional[int]): Maximum number of agents to return.
|
||
|
||
Returns:
|
||
List[PydanticAgentState: The filtered list of matching agents.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
query = select(AgentModel).where(AgentModel.organization_id == actor.organization_id)
|
||
|
||
if match_all:
|
||
# Subquery to find agent IDs that contain all match_all tags
|
||
subquery = (
|
||
select(AgentsTags.agent_id)
|
||
.where(AgentsTags.tag.in_(match_all))
|
||
.group_by(AgentsTags.agent_id)
|
||
.having(func.count(AgentsTags.tag) == literal(len(match_all)))
|
||
)
|
||
query = query.where(AgentModel.id.in_(subquery))
|
||
|
||
if match_some:
|
||
# Ensures agents match at least one tag in match_some
|
||
query = query.join(AgentsTags).where(AgentsTags.tag.in_(match_some))
|
||
|
||
query = query.distinct(AgentModel.id).order_by(AgentModel.id).limit(limit)
|
||
result = await session.execute(query)
|
||
|
||
# Convert without decrypting to release DB connection before PBKDF2
|
||
agents_encrypted = await bounded_gather([agent.to_pydantic_async(decrypt=False) for agent in result.scalars()])
|
||
|
||
# Decrypt secrets outside session
|
||
return await decrypt_agent_secrets(agents_encrypted)
|
||
|
||
@trace_method
|
||
async def size_async(
|
||
self,
|
||
actor: PydanticUser,
|
||
) -> int:
|
||
"""
|
||
Get the total count of agents for the given user.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
return await AgentModel.size_async(db_session=session, actor=actor)
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@trace_method
|
||
async def get_agent_by_id_async(
|
||
self,
|
||
agent_id: str,
|
||
actor: PydanticUser,
|
||
include_relationships: Optional[List[str]] = None,
|
||
include: List[str] = [],
|
||
) -> PydanticAgentState:
|
||
"""Fetch an agent by its ID."""
|
||
|
||
try:
|
||
async with db_registry.async_session() as session:
|
||
query = select(AgentModel)
|
||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||
query = query.where(AgentModel.id == agent_id)
|
||
query = _apply_relationship_filters(query, include_relationships, include)
|
||
|
||
result = await session.execute(query)
|
||
agent = result.scalar_one_or_none()
|
||
|
||
if agent is None:
|
||
raise NoResultFound(f"Agent with ID {agent_id} not found")
|
||
|
||
# Convert without decrypting to release DB connection before PBKDF2
|
||
agent_encrypted = await agent.to_pydantic_async(include_relationships=include_relationships, include=include, decrypt=False)
|
||
|
||
# Decrypt secrets outside session
|
||
return (await decrypt_agent_secrets([agent_encrypted]))[0]
|
||
except NoResultFound:
|
||
# Re-raise NoResultFound without logging to preserve 404 handling
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Error fetching agent {agent_id}: {str(e)}")
|
||
raise
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def get_agents_by_ids_async(
|
||
self,
|
||
agent_ids: list[str],
|
||
actor: PydanticUser,
|
||
include_relationships: Optional[List[str]] = None,
|
||
) -> list[PydanticAgentState]:
|
||
"""Fetch a list of agents by their IDs."""
|
||
try:
|
||
async with db_registry.async_session() as session:
|
||
query = select(AgentModel)
|
||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||
query = query.where(AgentModel.id.in_(agent_ids))
|
||
query = _apply_relationship_filters(query, include_relationships)
|
||
|
||
result = await session.execute(query)
|
||
agents = result.scalars().all()
|
||
|
||
if not agents:
|
||
logger.warning(f"No agents found with IDs: {agent_ids}")
|
||
return []
|
||
|
||
# Convert without decrypting to release DB connection before PBKDF2
|
||
agents_encrypted = await bounded_gather(
|
||
[agent.to_pydantic_async(include_relationships=include_relationships, decrypt=False) for agent in agents]
|
||
)
|
||
|
||
# Decrypt secrets outside session
|
||
return await decrypt_agent_secrets(agents_encrypted)
|
||
except Exception as e:
|
||
logger.error(f"Error fetching agents with IDs {agent_ids}: {str(e)}")
|
||
raise
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@trace_method
|
||
async def get_agent_archive_ids_async(self, agent_id: str, actor: PydanticUser) -> List[str]:
|
||
"""Get all archive IDs associated with an agent."""
|
||
from letta.orm import ArchivesAgents
|
||
|
||
async with db_registry.async_session() as session:
|
||
# Direct query to archives_agents table for performance
|
||
query = select(ArchivesAgents.archive_id).where(ArchivesAgents.agent_id == agent_id)
|
||
result = await session.execute(query)
|
||
archive_ids = [row[0] for row in result.fetchall()]
|
||
return archive_ids
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@trace_method
|
||
async def validate_agent_exists_async(self, agent_id: str, actor: PydanticUser) -> None:
|
||
"""
|
||
Validate that an agent exists and user has access to it.
|
||
Lightweight method that doesn't load the full agent object.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to validate
|
||
actor: User performing the action
|
||
|
||
Raises:
|
||
LettaAgentNotFoundError: If agent doesn't exist or user doesn't have access
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
await validate_agent_exists_async(session, agent_id, actor)
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@trace_method
|
||
async def delete_agent_async(self, agent_id: str, actor: PydanticUser) -> None:
|
||
"""
|
||
Deletes an agent and its associated relationships.
|
||
Ensures proper permission checks and cascades where applicable.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to be deleted.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If agent doesn't exist
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# Retrieve the agent
|
||
logger.debug(f"Hard deleting Agent with ID: {agent_id} with actor={actor}")
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
agents_to_delete = [agent]
|
||
sleeptime_group_to_delete = None
|
||
|
||
# Delete sleeptime agent and group (TODO this is flimsy pls fix)
|
||
if agent.multi_agent_group:
|
||
participant_agent_ids = agent.multi_agent_group.agent_ids
|
||
if agent.multi_agent_group.manager_type in {ManagerType.sleeptime, ManagerType.voice_sleeptime} and participant_agent_ids:
|
||
for participant_agent_id in participant_agent_ids:
|
||
try:
|
||
sleeptime_agent = await AgentModel.read_async(db_session=session, identifier=participant_agent_id, actor=actor)
|
||
agents_to_delete.append(sleeptime_agent)
|
||
except NoResultFound:
|
||
pass # agent already deleted
|
||
sleeptime_agent_group = await GroupModel.read_async(
|
||
db_session=session, identifier=agent.multi_agent_group.id, actor=actor
|
||
)
|
||
sleeptime_group_to_delete = sleeptime_agent_group
|
||
|
||
try:
|
||
if sleeptime_group_to_delete is not None:
|
||
await session.delete(sleeptime_group_to_delete)
|
||
await session.commit()
|
||
for agent in agents_to_delete:
|
||
await session.delete(agent)
|
||
# context manager now handles commits
|
||
# await session.commit()
|
||
except Exception as e:
|
||
await session.rollback()
|
||
logger.exception(f"Failed to hard delete Agent with ID {agent_id}")
|
||
raise ValueError(f"Failed to hard delete Agent with ID {agent_id}: {e}")
|
||
else:
|
||
logger.debug(f"Agent with ID {agent_id} successfully hard deleted")
|
||
|
||
# ======================================================================================================================
|
||
# Per Agent Environment Variable Management
|
||
# ======================================================================================================================
|
||
|
||
# ======================================================================================================================
|
||
# In Context Messages Management
|
||
# ======================================================================================================================
|
||
# TODO: There are several assumptions here that are not explicitly checked
|
||
# TODO: 1) These message ids are valid
|
||
# TODO: 2) These messages are ordered from oldest to newest
|
||
# TODO: This can be fixed by having an actual relationship in the ORM for message_ids
|
||
# TODO: This can also be made more efficient, instead of getting, setting, we can do it all in one db session for one query.
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@trace_method
|
||
async def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
|
||
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||
return await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def get_system_message(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
|
||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||
return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor)
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@trace_method
|
||
async def get_system_message_async(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
|
||
agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor)
|
||
return await self.message_manager.get_message_by_id_async(message_id=agent.message_ids[0], actor=actor)
|
||
|
||
# TODO: This is duplicated below
|
||
# TODO: This is legacy code and should be cleaned up
|
||
# TODO: A lot of the memory "compilation" should be offset to a separate class
|
||
@enforce_types
|
||
@trace_method
|
||
def rebuild_system_prompt(self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True) -> PydanticAgentState:
|
||
"""Rebuilds the system message with the latest memory object and any shared memory block updates
|
||
|
||
Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object
|
||
|
||
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
|
||
"""
|
||
agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||
|
||
curr_system_message = self.get_system_message(
|
||
agent_id=agent_id, actor=actor
|
||
) # this is the system + memory bank, not just the system prompt
|
||
|
||
if curr_system_message is None:
|
||
logger.warning(f"No system message found for agent {agent_state.id} and user {actor}")
|
||
return agent_state
|
||
|
||
curr_system_message_openai = curr_system_message.to_openai_dict()
|
||
|
||
# note: we only update the system prompt if the core memory is changed
|
||
# this means that the archival/recall memory statistics may be someout out of date
|
||
curr_memory_str = agent_state.memory.compile(sources=agent_state.sources, llm_config=agent_state.llm_config)
|
||
if curr_memory_str in curr_system_message_openai["content"] and not force:
|
||
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
||
logger.debug(
|
||
f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild"
|
||
)
|
||
return agent_state
|
||
|
||
# If the memory didn't update, we probably don't want to update the timestamp inside
|
||
# For example, if we're doing a system prompt swap, this should probably be False
|
||
if update_timestamp:
|
||
memory_edit_timestamp = get_utc_time()
|
||
else:
|
||
# NOTE: a bit of a hack - we pull the timestamp from the message created_by
|
||
memory_edit_timestamp = curr_system_message.created_at
|
||
|
||
num_messages = self.message_manager.size(actor=actor, agent_id=agent_id)
|
||
num_archival_memories = self.passage_manager.size(actor=actor, agent_id=agent_id)
|
||
|
||
# update memory (TODO: potentially update recall/archival stats separately)
|
||
new_system_message_str = compile_system_message(
|
||
system_prompt=agent_state.system,
|
||
in_context_memory=agent_state.memory,
|
||
in_context_memory_last_edit=memory_edit_timestamp,
|
||
timezone=agent_state.timezone,
|
||
previous_message_count=num_messages - len(agent_state.message_ids),
|
||
archival_memory_size=num_archival_memories,
|
||
sources=agent_state.sources,
|
||
max_files_open=agent_state.max_files_open,
|
||
llm_config=agent_state.llm_config,
|
||
)
|
||
|
||
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
|
||
if len(diff) > 0: # there was a diff
|
||
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||
|
||
# Swap the system message out (only if there is a diff)
|
||
message = PydanticMessage.dict_to_message(
|
||
agent_id=agent_id,
|
||
model=agent_state.llm_config.model,
|
||
openai_message_dict={"role": "system", "content": new_system_message_str},
|
||
)
|
||
message = self.message_manager.update_message_by_id(
|
||
message_id=curr_system_message.id,
|
||
message_update=MessageUpdate(**message.model_dump()),
|
||
actor=actor,
|
||
)
|
||
return self.set_in_context_messages(agent_id=agent_id, message_ids=agent_state.message_ids, actor=actor)
|
||
else:
|
||
return agent_state
|
||
|
||
# Do not remove comment. (cliandy)
|
||
# TODO: This is probably one of the worst pieces of code I've ever written please rip up as you see wish
|
||
@enforce_types
|
||
@trace_method
|
||
async def rebuild_system_prompt_async(
|
||
self,
|
||
agent_id: str,
|
||
actor: PydanticUser,
|
||
force=False,
|
||
update_timestamp=True,
|
||
dry_run: bool = False,
|
||
) -> Tuple[PydanticAgentState, Optional[PydanticMessage], int, int]:
|
||
"""Rebuilds the system message with the latest memory object and any shared memory block updates
|
||
|
||
Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object
|
||
|
||
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
|
||
"""
|
||
num_messages = await self.message_manager.size_async(actor=actor, agent_id=agent_id)
|
||
num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id)
|
||
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory", "sources", "tools"], actor=actor)
|
||
|
||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||
|
||
if agent_state.message_ids == []:
|
||
curr_system_message = None
|
||
else:
|
||
curr_system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
|
||
|
||
if curr_system_message is None:
|
||
logger.warning(f"No system message found for agent {agent_state.id} and user {actor}")
|
||
return agent_state, curr_system_message, num_messages, num_archival_memories
|
||
|
||
curr_system_message_openai = curr_system_message.to_openai_dict()
|
||
|
||
# note: we only update the system prompt if the core memory is changed
|
||
# this means that the archival/recall memory statistics may be someout out of date
|
||
curr_memory_str = agent_state.memory.compile(
|
||
sources=agent_state.sources,
|
||
tool_usage_rules=tool_rules_solver.compile_tool_rule_prompts(),
|
||
max_files_open=agent_state.max_files_open,
|
||
llm_config=agent_state.llm_config,
|
||
)
|
||
if curr_memory_str in curr_system_message_openai["content"] and not force:
|
||
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
||
logger.debug(
|
||
f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild"
|
||
)
|
||
return agent_state, curr_system_message, num_messages, num_archival_memories
|
||
|
||
# If the memory didn't update, we probably don't want to update the timestamp inside
|
||
# For example, if we're doing a system prompt swap, this should probably be False
|
||
if update_timestamp:
|
||
memory_edit_timestamp = get_utc_time()
|
||
else:
|
||
# NOTE: a bit of a hack - we pull the timestamp from the message created_by
|
||
memory_edit_timestamp = curr_system_message.created_at
|
||
|
||
# update memory (TODO: potentially update recall/archival stats separately)
|
||
|
||
new_system_message_str = PromptGenerator.get_system_message_from_compiled_memory(
|
||
system_prompt=agent_state.system,
|
||
memory_with_sources=curr_memory_str,
|
||
in_context_memory_last_edit=memory_edit_timestamp,
|
||
timezone=agent_state.timezone,
|
||
previous_message_count=num_messages - len(agent_state.message_ids),
|
||
archival_memory_size=num_archival_memories,
|
||
)
|
||
|
||
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
|
||
if len(diff) > 0: # there was a diff
|
||
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||
|
||
# Swap the system message out (only if there is a diff)
|
||
temp_message = PydanticMessage.dict_to_message(
|
||
agent_id=agent_id,
|
||
model=agent_state.llm_config.model,
|
||
openai_message_dict={"role": "system", "content": new_system_message_str},
|
||
)
|
||
temp_message.id = curr_system_message.id
|
||
|
||
if not dry_run:
|
||
await self.message_manager.update_message_by_id_async(
|
||
message_id=curr_system_message.id,
|
||
message_update=MessageUpdate(**temp_message.model_dump()),
|
||
actor=actor,
|
||
project_id=agent_state.project_id,
|
||
)
|
||
else:
|
||
curr_system_message = temp_message
|
||
|
||
return agent_state, curr_system_message, num_messages, num_archival_memories
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
|
||
return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@trace_method
|
||
async def set_in_context_messages_async(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
|
||
return await self.update_agent_async(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||
new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message
|
||
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def trim_all_in_context_messages_except_system(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||
# TODO: How do we know this?
|
||
new_messages = [message_ids[0]] # 0 is system message
|
||
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||
new_messages = self.message_manager.create_many_messages(messages, actor=actor)
|
||
message_ids = [message_ids[0]] + [m.id for m in new_messages] + message_ids[1:]
|
||
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def append_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
messages = self.message_manager.create_many_messages(messages, actor=actor)
|
||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids or []
|
||
message_ids += [m.id for m in messages]
|
||
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def append_to_in_context_messages_async(
|
||
self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser
|
||
) -> PydanticAgentState:
|
||
agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||
messages = await self.message_manager.create_many_messages_async(
|
||
messages, actor=actor, project_id=agent.project_id, template_id=agent.template_id
|
||
)
|
||
message_ids = agent.message_ids or []
|
||
message_ids += [m.id for m in messages]
|
||
return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def reset_messages_async(
|
||
self, agent_id: str, actor: PydanticUser, add_default_initial_messages: bool = False, needs_agent_state: bool = True
|
||
) -> Optional[PydanticAgentState]:
|
||
"""
|
||
Clears all in-context messages for the specified agent except the original system message by:
|
||
1) Preserving the first message ID (original system message).
|
||
2) Updating the agent's message_ids to only contain the system message.
|
||
3) Optionally adding default initial messages after the system message.
|
||
|
||
Note: This only clears messages from the agent's context, it does not delete them from the database.
|
||
|
||
Args:
|
||
add_default_initial_messages: If true, adds the default initial messages after resetting.
|
||
agent_id (str): The ID of the agent whose messages will be reset.
|
||
actor (PydanticUser): The user performing this action.
|
||
needs_agent_state: If True, returns the updated agent state. If False, returns None (for performance optimization)
|
||
|
||
Returns:
|
||
Optional[PydanticAgentState]: The updated agent state with only the original system message preserved, or None if needs_agent_state=False.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
|
||
if not agent.message_ids or len(agent.message_ids) == 0:
|
||
logger.error(
|
||
f"Agent {agent_id} has no message_ids. Agent details: "
|
||
f"name={agent.name}, created_at={agent.created_at}, "
|
||
f"message_ids={agent.message_ids}, organization_id={actor.organization_id}"
|
||
)
|
||
raise ValueError(f"Agent {agent_id} has no message_ids - cannot preserve system message")
|
||
|
||
system_message_id = agent.message_ids[0]
|
||
agent.message_ids = [system_message_id]
|
||
await agent.update_async(db_session=session, actor=actor)
|
||
|
||
# Only convert to pydantic if we need to return it or add initial messages
|
||
if add_default_initial_messages or needs_agent_state:
|
||
agent_state = await agent.to_pydantic_async(include_relationships=["sources"] if add_default_initial_messages else None)
|
||
else:
|
||
agent_state = None
|
||
|
||
# Optionally add default initial messages after the system message
|
||
if add_default_initial_messages:
|
||
init_messages = await initialize_message_sequence_async(
|
||
agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True
|
||
)
|
||
# Skip index 0 (system message) since we preserved the original
|
||
non_system_messages = [
|
||
PydanticMessage.dict_to_message(
|
||
agent_id=agent_state.id,
|
||
model=agent_state.llm_config.model,
|
||
openai_message_dict=msg,
|
||
)
|
||
for msg in init_messages[1:]
|
||
]
|
||
return await self.append_to_in_context_messages_async(non_system_messages, agent_id=agent_state.id, actor=actor)
|
||
else:
|
||
return agent_state
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@trace_method
|
||
async def update_memory_if_changed_async(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState:
|
||
"""
|
||
Update internal memory object and system prompt if there have been modifications.
|
||
|
||
Args:
|
||
actor:
|
||
agent_id:
|
||
new_memory (Memory): the new memory object to compare to the current memory object
|
||
|
||
Returns:
|
||
modified (bool): whether the memory was updated
|
||
"""
|
||
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor, include_relationships=["memory", "sources"])
|
||
system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
|
||
temp_tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||
new_memory_str = new_memory.compile(
|
||
sources=agent_state.sources,
|
||
tool_usage_rules=temp_tool_rules_solver.compile_tool_rule_prompts(),
|
||
max_files_open=agent_state.max_files_open,
|
||
llm_config=agent_state.llm_config,
|
||
)
|
||
if new_memory_str not in system_message.content[0].text:
|
||
# update the blocks (LRW) in the DB
|
||
for label in new_memory.list_block_labels():
|
||
if label in agent_state.memory.list_block_labels():
|
||
# Block exists in both old and new memory - check if value changed
|
||
updated_value = new_memory.get_block(label).value
|
||
if updated_value != agent_state.memory.get_block(label).value:
|
||
# update the block if it's changed
|
||
block_id = agent_state.memory.get_block(label).id
|
||
await self.block_manager.update_block_async(
|
||
block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=actor
|
||
)
|
||
|
||
# Note: New blocks are already persisted in the creation methods,
|
||
# so we don't need to handle them here
|
||
|
||
# refresh memory from DB (using block ids from the new memory)
|
||
blocks = await self.block_manager.get_all_blocks_by_ids_async(block_ids=[b.id for b in new_memory.get_blocks()], actor=actor)
|
||
|
||
agent_state.memory = Memory(
|
||
blocks=blocks,
|
||
file_blocks=agent_state.memory.file_blocks,
|
||
agent_type=agent_state.agent_type,
|
||
)
|
||
|
||
# NOTE: don't do this since re-buildin the memory is handled at the start of the step
|
||
# rebuild memory - this records the last edited timestamp of the memory
|
||
# TODO: pass in update timestamp from block edit time
|
||
await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor)
|
||
|
||
return agent_state
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def refresh_memory_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
||
# TODO: This will NOT work for new blocks/file blocks added intra-step
|
||
block_ids = [b.id for b in agent_state.memory.blocks]
|
||
file_block_names = [b.label for b in agent_state.memory.file_blocks]
|
||
|
||
if block_ids:
|
||
blocks = await self.block_manager.get_all_blocks_by_ids_async(block_ids=[b.id for b in agent_state.memory.blocks], actor=actor)
|
||
agent_state.memory.blocks = [b for b in blocks if b is not None]
|
||
|
||
if file_block_names:
|
||
file_blocks = await self.file_agent_manager.get_all_file_blocks_by_name(
|
||
file_names=file_block_names,
|
||
agent_id=agent_state.id,
|
||
actor=actor,
|
||
per_file_view_window_char_limit=agent_state.per_file_view_window_char_limit,
|
||
)
|
||
agent_state.memory.file_blocks = [b for b in file_blocks if b is not None]
|
||
|
||
return agent_state
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def refresh_file_blocks(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
||
"""
|
||
Refresh the file blocks in an agent's memory with current file content.
|
||
|
||
This method synchronizes the agent's in-memory file blocks with the actual
|
||
file content from attached sources. It respects the per-file view window
|
||
limit to prevent excessive memory usage.
|
||
|
||
Args:
|
||
agent_state: The current agent state containing memory configuration
|
||
actor: The user performing this action (for permission checking)
|
||
|
||
Returns:
|
||
Updated agent state with refreshed file blocks
|
||
|
||
Important:
|
||
- File blocks are truncated based on per_file_view_window_char_limit
|
||
- None values are filtered out (files that couldn't be loaded)
|
||
- This does NOT persist changes to the database, only updates the state object
|
||
- Call this before agent interactions if files may have changed externally
|
||
"""
|
||
file_blocks = await self.file_agent_manager.list_files_for_agent(
|
||
agent_id=agent_state.id,
|
||
per_file_view_window_char_limit=agent_state.per_file_view_window_char_limit,
|
||
actor=actor,
|
||
return_as_blocks=True,
|
||
)
|
||
agent_state.memory.file_blocks = [b for b in file_blocks if b is not None]
|
||
return agent_state
|
||
|
||
# ======================================================================================================================
|
||
# Source Management
|
||
# ======================================================================================================================
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||
@trace_method
|
||
async def attach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
"""
|
||
Attaches a source to an agent.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to attach the source to
|
||
source_id: ID of the source to attach
|
||
actor: User performing the action
|
||
|
||
Raises:
|
||
NoResultFound: If either agent or source doesn't exist or actor lacks permission to access them
|
||
IntegrityError: If the source is already attached to the agent
|
||
"""
|
||
|
||
async with db_registry.async_session() as session:
|
||
# Verify both agent and source exist and user has permission to access them
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
|
||
# Verify the actor has permission to access the source
|
||
await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
|
||
|
||
# The _process_relationship helper already handles duplicate checking via unique constraint
|
||
await _process_relationship_async(
|
||
session=session,
|
||
agent=agent,
|
||
relationship_name="sources",
|
||
model_class=SourceModel,
|
||
item_ids=[source_id],
|
||
replace=False,
|
||
)
|
||
|
||
# Commit the changes
|
||
agent = await agent.update_async(session, actor=actor)
|
||
# TODO: This refresh is expensive. If we can find out which fields are needed, we can save cost by only refreshing those fields.
|
||
# or even better, not refresh at all.
|
||
|
||
# Convert without decrypting to release DB connection before PBKDF2
|
||
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
|
||
|
||
# Decrypt secrets outside session
|
||
return (await decrypt_agent_secrets([agent_encrypted]))[0]
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def append_system_message(self, agent_id: str, content: str, actor: PydanticUser):
|
||
"""
|
||
Append a system message to an agent's in-context message history.
|
||
|
||
This method is typically used during agent initialization to add system prompts,
|
||
instructions, or context that should be treated as system-level guidance.
|
||
Unlike user messages, system messages directly influence the agent's behavior
|
||
and understanding of its role.
|
||
|
||
Args:
|
||
agent_id: The ID of the agent to append the message to
|
||
content: The system message content (e.g., instructions, context, role definition)
|
||
actor: The user performing this action (for permission checking)
|
||
|
||
Side Effects:
|
||
- Creates a new Message object in the database
|
||
- Updates the agent's in_context_message_ids list
|
||
- The message becomes part of the agent's permanent context window
|
||
|
||
Note:
|
||
System messages consume tokens in the context window and cannot be
|
||
removed without rebuilding the agent's message history.
|
||
"""
|
||
|
||
# get the agent
|
||
agent = self.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||
message = PydanticMessage.dict_to_message(
|
||
agent_id=agent.id, model=agent.llm_config.model, openai_message_dict={"role": "system", "content": content}
|
||
)
|
||
|
||
# update agent in-context message IDs
|
||
self.append_to_in_context_messages(messages=[message], agent_id=agent_id, actor=actor)
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@trace_method
|
||
async def append_system_message_async(self, agent_id: str, content: str, actor: PydanticUser):
|
||
"""
|
||
Async version of append_system_message.
|
||
|
||
Append a system message to an agent's in-context message history.
|
||
See append_system_message for detailed documentation.
|
||
|
||
This async version is preferred for high-throughput scenarios or when
|
||
called within other async operations to avoid blocking the event loop.
|
||
"""
|
||
|
||
# get the agent
|
||
agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||
message = PydanticMessage.dict_to_message(
|
||
agent_id=agent.id, model=agent.llm_config.model, openai_message_dict={"role": "system", "content": content}
|
||
)
|
||
|
||
# update agent in-context message IDs
|
||
await self.append_to_in_context_messages_async(messages=[message], agent_id=agent_id, actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def list_attached_sources_async(
|
||
self,
|
||
agent_id: str,
|
||
actor: PydanticUser,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
limit: Optional[int] = None,
|
||
ascending: bool = False,
|
||
) -> List[PydanticSource]:
|
||
"""
|
||
Lists all sources attached to an agent with pagination.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to list sources for
|
||
actor: User performing the action
|
||
before: Source ID cursor for pagination. Returns sources that come before this source ID.
|
||
after: Source ID cursor for pagination. Returns sources that come after this source ID.
|
||
limit: Maximum number of sources to return.
|
||
ascending: Sort order by creation time.
|
||
|
||
Returns:
|
||
List[PydanticSource]: List of sources attached to the agent
|
||
|
||
Raises:
|
||
NoResultFound: If agent doesn't exist or user doesn't have access
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# Validate agent exists and user has access
|
||
await validate_agent_exists_async(session, agent_id, actor)
|
||
|
||
# Use raw SQL to efficiently fetch sources - much faster than lazy loading
|
||
# Fast query without relationship loading
|
||
query = (
|
||
select(SourceModel)
|
||
.join(SourcesAgents, SourceModel.id == SourcesAgents.source_id)
|
||
.where(
|
||
SourcesAgents.agent_id == agent_id,
|
||
SourceModel.organization_id == actor.organization_id,
|
||
SourceModel.is_deleted == False,
|
||
)
|
||
)
|
||
|
||
# Apply cursor-based pagination
|
||
if before:
|
||
query = query.where(SourceModel.id < before)
|
||
if after:
|
||
query = query.where(SourceModel.id > after)
|
||
|
||
# Apply sorting
|
||
if ascending:
|
||
query = query.order_by(SourceModel.created_at.asc(), SourceModel.id.asc())
|
||
else:
|
||
query = query.order_by(SourceModel.created_at.desc(), SourceModel.id.desc())
|
||
|
||
# Apply limit
|
||
if limit:
|
||
query = query.limit(limit)
|
||
|
||
result = await session.execute(query)
|
||
sources = result.scalars().all()
|
||
|
||
return [source.to_pydantic() for source in sources]
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||
@trace_method
|
||
async def detach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
"""
|
||
Detaches a source from an agent.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to detach the source from
|
||
source_id: ID of the source to detach
|
||
actor: User performing the action
|
||
|
||
Raises:
|
||
NoResultFound: If agent doesn't exist or user doesn't have access
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# Validate agent exists and user has access
|
||
await validate_agent_exists_async(session, agent_id, actor)
|
||
|
||
# Check if the source is actually attached to this agent using junction table
|
||
attachment_check_query = select(SourcesAgents).where(SourcesAgents.agent_id == agent_id, SourcesAgents.source_id == source_id)
|
||
attachment_result = await session.execute(attachment_check_query)
|
||
attachment = attachment_result.scalar_one_or_none()
|
||
|
||
if not attachment:
|
||
logger.warning(f"Attempted to remove unattached source id={source_id} from agent id={agent_id} by actor={actor}")
|
||
else:
|
||
# Delete the association directly from the junction table
|
||
delete_query = delete(SourcesAgents).where(SourcesAgents.agent_id == agent_id, SourcesAgents.source_id == source_id)
|
||
await session.execute(delete_query)
|
||
await session.commit()
|
||
|
||
# Get agent without loading relationships for return value
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
# TODO: This refresh is expensive. If we can find out which fields are needed, we can save cost by only refreshing those fields.
|
||
# or even better, not refresh at all.
|
||
|
||
# Convert without decrypting to release DB connection before PBKDF2
|
||
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
|
||
|
||
# Decrypt secrets outside session
|
||
return (await decrypt_agent_secrets([agent_encrypted]))[0]
|
||
|
||
# ======================================================================================================================
|
||
# Block management
|
||
# ======================================================================================================================
|
||
@enforce_types
|
||
@trace_method
|
||
async def get_block_with_label_async(
|
||
self,
|
||
agent_id: str,
|
||
block_label: str,
|
||
actor: PydanticUser,
|
||
) -> PydanticBlock:
|
||
"""Gets a block attached to an agent by its label."""
|
||
async with db_registry.async_session() as session:
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
for block in agent.core_memory:
|
||
if block.label == block_label:
|
||
return block.to_pydantic()
|
||
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def modify_block_by_label_async(
|
||
self,
|
||
agent_id: str,
|
||
block_label: str,
|
||
block_update: BlockUpdate,
|
||
actor: PydanticUser,
|
||
) -> PydanticBlock:
|
||
"""Gets a block attached to an agent by its label."""
|
||
async with db_registry.async_session() as session:
|
||
matched_block = None
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
for block in agent.core_memory:
|
||
if block.label == block_label:
|
||
matched_block = block
|
||
break
|
||
if not matched_block:
|
||
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
|
||
|
||
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||
|
||
# Validate limit constraints before updating
|
||
validate_block_limit_constraint(update_data, matched_block)
|
||
|
||
for key, value in update_data.items():
|
||
setattr(matched_block, key, value)
|
||
|
||
await matched_block.update_async(session, actor=actor)
|
||
return matched_block.to_pydantic()
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||
@trace_method
|
||
async def attach_block_async(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
"""Attaches a block to an agent. For sleeptime agents, also attaches to paired agents in the same group."""
|
||
async with db_registry.async_session() as session:
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
|
||
|
||
# Attach block to the main agent
|
||
agent.core_memory.append(block)
|
||
# await agent.update_async(session, actor=actor, no_commit=True)
|
||
await agent.update_async(session)
|
||
|
||
# If agent is part of a sleeptime group, attach block to the sleeptime_agent
|
||
if agent.multi_agent_group and agent.multi_agent_group.manager_type == ManagerType.sleeptime:
|
||
group = agent.multi_agent_group
|
||
# Find the sleeptime_agent in the group
|
||
for other_agent_id in group.agent_ids or []:
|
||
if other_agent_id != agent_id:
|
||
try:
|
||
other_agent = await AgentModel.read_async(db_session=session, identifier=other_agent_id, actor=actor)
|
||
if other_agent.agent_type == AgentType.sleeptime_agent and block not in other_agent.core_memory:
|
||
other_agent.core_memory.append(block)
|
||
# await other_agent.update_async(session, actor=actor, no_commit=True)
|
||
await other_agent.update_async(session, actor=actor)
|
||
except NoResultFound:
|
||
# Agent might not exist anymore, skip
|
||
continue
|
||
|
||
# TODO: @andy/caren
|
||
# TODO: Ideally we do two no commits on the update_async calls, and then commit here - but that errors for some reason?
|
||
# TODO: I have too many things rn so lets look at this later
|
||
# await session.commit()
|
||
|
||
# Convert without decrypting to release DB connection before PBKDF2
|
||
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
|
||
|
||
# Decrypt secrets outside session
|
||
return (await decrypt_agent_secrets([agent_encrypted]))[0]
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def detach_block_async(
|
||
self,
|
||
agent_id: str,
|
||
block_id: str,
|
||
actor: PydanticUser,
|
||
) -> PydanticAgentState:
|
||
"""Detaches a block from an agent."""
|
||
async with db_registry.async_session() as session:
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
original_length = len(agent.core_memory)
|
||
|
||
agent.core_memory = [b for b in agent.core_memory if b.id != block_id]
|
||
|
||
if len(agent.core_memory) == original_length:
|
||
raise NoResultFound(f"No block with id '{block_id}' found for agent '{agent_id}' with actor id: '{actor.id}'")
|
||
|
||
await agent.update_async(session, actor=actor)
|
||
|
||
# Convert without decrypting to release DB connection before PBKDF2
|
||
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
|
||
|
||
# Decrypt secrets outside session
|
||
return (await decrypt_agent_secrets([agent_encrypted]))[0]
|
||
|
||
# ======================================================================================================================
|
||
# Passage Management
|
||
# ======================================================================================================================
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def list_passages(
|
||
self,
|
||
actor: PydanticUser,
|
||
agent_id: Optional[str] = None,
|
||
file_id: Optional[str] = None,
|
||
limit: Optional[int] = 50,
|
||
query_text: Optional[str] = None,
|
||
start_date: Optional[datetime] = None,
|
||
end_date: Optional[datetime] = None,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
source_id: Optional[str] = None,
|
||
embed_query: bool = False,
|
||
ascending: bool = True,
|
||
embedding_config: Optional[EmbeddingConfig] = None,
|
||
agent_only: bool = False,
|
||
) -> List[PydanticPassage]:
|
||
"""Lists all passages attached to an agent."""
|
||
async with db_registry.async_session() as session:
|
||
main_query = await build_passage_query(
|
||
actor=actor,
|
||
agent_id=agent_id,
|
||
file_id=file_id,
|
||
query_text=query_text,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
before=before,
|
||
after=after,
|
||
source_id=source_id,
|
||
embed_query=embed_query,
|
||
ascending=ascending,
|
||
embedding_config=embedding_config,
|
||
agent_only=agent_only,
|
||
)
|
||
|
||
# Add limit (enforce default if not provided)
|
||
main_query = main_query.limit(limit)
|
||
|
||
# Execute query
|
||
result = await session.execute(main_query)
|
||
|
||
passages = []
|
||
for row in result:
|
||
data = dict(row._mapping)
|
||
if data.get("archive_id", None):
|
||
# This is an ArchivalPassage - remove source fields
|
||
data.pop("source_id", None)
|
||
data.pop("file_id", None)
|
||
data.pop("file_name", None)
|
||
passage = ArchivalPassage(**data)
|
||
elif data.get("source_id", None):
|
||
# This is a SourcePassage - remove archive field
|
||
data.pop("archive_id", None)
|
||
data.pop("agent_id", None) # For backward compatibility
|
||
passage = SourcePassage(**data)
|
||
else:
|
||
raise ValueError(f"Passage data is malformed, is neither ArchivalPassage nor SourcePassage {data}")
|
||
passages.append(passage)
|
||
|
||
return [p.to_pydantic() for p in passages]
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def list_passages_async(
|
||
self,
|
||
actor: PydanticUser,
|
||
agent_id: Optional[str] = None,
|
||
file_id: Optional[str] = None,
|
||
limit: Optional[int] = 50,
|
||
query_text: Optional[str] = None,
|
||
start_date: Optional[datetime] = None,
|
||
end_date: Optional[datetime] = None,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
source_id: Optional[str] = None,
|
||
embed_query: bool = False,
|
||
ascending: bool = True,
|
||
embedding_config: Optional[EmbeddingConfig] = None,
|
||
agent_only: bool = False,
|
||
) -> List[PydanticPassage]:
|
||
"""
|
||
DEPRECATED: Use query_source_passages_async or query_agent_passages_async instead.
|
||
This method is kept only for test compatibility and will be removed in a future version.
|
||
|
||
Lists all passages attached to an agent (combines both source and agent passages).
|
||
"""
|
||
import warnings
|
||
|
||
logger.warning(
|
||
"list_passages_async is deprecated. Use query_source_passages_async or query_agent_passages_async instead.",
|
||
stacklevel=2,
|
||
)
|
||
|
||
async with db_registry.async_session() as session:
|
||
main_query = await build_passage_query(
|
||
actor=actor,
|
||
agent_id=agent_id,
|
||
file_id=file_id,
|
||
query_text=query_text,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
before=before,
|
||
after=after,
|
||
source_id=source_id,
|
||
embed_query=embed_query,
|
||
ascending=ascending,
|
||
embedding_config=embedding_config,
|
||
agent_only=agent_only,
|
||
)
|
||
|
||
# Add limit (enforce default if not provided)
|
||
main_query = main_query.limit(limit)
|
||
|
||
# Execute query
|
||
result = await session.execute(main_query)
|
||
|
||
passages = []
|
||
for row in result:
|
||
data = dict(row._mapping)
|
||
if data.get("archive_id", None):
|
||
# This is an ArchivalPassage - remove source fields
|
||
data.pop("source_id", None)
|
||
data.pop("file_id", None)
|
||
data.pop("file_name", None)
|
||
passage = ArchivalPassage(**data)
|
||
elif data.get("source_id", None):
|
||
# This is a SourcePassage - remove archive field
|
||
data.pop("archive_id", None)
|
||
data.pop("agent_id", None) # For backward compatibility
|
||
passage = SourcePassage(**data)
|
||
else:
|
||
raise ValueError(f"Passage data is malformed, is neither ArchivalPassage nor SourcePassage {data}")
|
||
passages.append(passage)
|
||
|
||
return [p.to_pydantic() for p in passages]
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def query_source_passages_async(
|
||
self,
|
||
actor: PydanticUser,
|
||
agent_id: Optional[str] = None,
|
||
file_id: Optional[str] = None,
|
||
limit: Optional[int] = 50,
|
||
query_text: Optional[str] = None,
|
||
start_date: Optional[datetime] = None,
|
||
end_date: Optional[datetime] = None,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
source_id: Optional[str] = None,
|
||
embed_query: bool = False,
|
||
ascending: bool = True,
|
||
embedding_config: Optional[EmbeddingConfig] = None,
|
||
) -> List[PydanticPassage]:
|
||
"""Lists all passages attached to an agent."""
|
||
async with db_registry.async_session() as session:
|
||
main_query = await build_source_passage_query(
|
||
actor=actor,
|
||
agent_id=agent_id,
|
||
file_id=file_id,
|
||
query_text=query_text,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
before=before,
|
||
after=after,
|
||
source_id=source_id,
|
||
embed_query=embed_query,
|
||
ascending=ascending,
|
||
embedding_config=embedding_config,
|
||
)
|
||
|
||
# Add limit (enforce default if not provided)
|
||
main_query = main_query.limit(limit)
|
||
|
||
# Execute query
|
||
result = await session.execute(main_query)
|
||
|
||
# Get ORM objects directly using scalars()
|
||
passages = result.scalars().all()
|
||
|
||
# Convert to Pydantic models
|
||
return [p.to_pydantic() for p in passages]
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def query_agent_passages_async(
|
||
self,
|
||
actor: PydanticUser,
|
||
agent_id: Optional[str] = None,
|
||
archive_id: Optional[str] = None,
|
||
limit: Optional[int] = 50,
|
||
query_text: Optional[str] = None,
|
||
start_date: Optional[datetime] = None,
|
||
end_date: Optional[datetime] = None,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
embed_query: bool = False,
|
||
ascending: bool = True,
|
||
embedding_config: Optional[EmbeddingConfig] = None,
|
||
tags: Optional[List[str]] = None,
|
||
tag_match_mode: Optional[TagMatchMode] = None,
|
||
) -> List[Tuple[PydanticPassage, float, dict]]:
|
||
"""Lists all passages attached to an agent."""
|
||
# Check if we should use Turbopuffer for vector search
|
||
# Support searching by either agent_id or archive_id directly
|
||
if embed_query and query_text and embedding_config:
|
||
target_archive_id = None
|
||
|
||
if agent_id:
|
||
# Get archive IDs for the agent
|
||
archive_ids = await self.get_agent_archive_ids_async(agent_id=agent_id, actor=actor)
|
||
|
||
if archive_ids:
|
||
# TODO: Remove this restriction once we support multiple archives with mixed vector DB providers
|
||
if len(archive_ids) > 1:
|
||
raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported for vector search")
|
||
target_archive_id = archive_ids[0]
|
||
elif archive_id:
|
||
# Use the provided archive_id directly
|
||
target_archive_id = archive_id
|
||
|
||
if target_archive_id:
|
||
# Get archive to check vector_db_provider
|
||
archive = await self.archive_manager.get_archive_by_id_async(archive_id=target_archive_id, actor=actor)
|
||
|
||
# Use Turbopuffer for vector search if archive is configured for TPUF
|
||
if archive.vector_db_provider == VectorDBProvider.TPUF:
|
||
from letta.helpers.tpuf_client import TurbopufferClient
|
||
from letta.llm_api.llm_client import LLMClient
|
||
|
||
# Generate embedding for query
|
||
embedding_client = LLMClient.create(
|
||
provider_type=embedding_config.embedding_endpoint_type,
|
||
actor=actor,
|
||
)
|
||
embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
|
||
query_embedding = embeddings[0]
|
||
|
||
# Query Turbopuffer - use hybrid search when text is available
|
||
tpuf_client = TurbopufferClient()
|
||
# use hybrid search to combine vector and full-text search
|
||
passages_with_scores = await tpuf_client.query_passages(
|
||
archive_id=target_archive_id,
|
||
query_text=query_text, # pass text for potential hybrid search
|
||
search_mode="hybrid", # use hybrid mode for better results
|
||
top_k=limit,
|
||
tags=tags,
|
||
tag_match_mode=tag_match_mode or TagMatchMode.ANY,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
actor=actor,
|
||
)
|
||
|
||
# Return full tuples with metadata
|
||
return passages_with_scores
|
||
|
||
# Fall back to SQL-based search for non-vector queries or NATIVE archives
|
||
async with db_registry.async_session() as session:
|
||
main_query = await build_agent_passage_query(
|
||
actor=actor,
|
||
agent_id=agent_id,
|
||
archive_id=archive_id,
|
||
query_text=query_text,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
before=before,
|
||
after=after,
|
||
embed_query=embed_query,
|
||
ascending=ascending,
|
||
embedding_config=embedding_config,
|
||
)
|
||
|
||
# Add limit
|
||
if limit:
|
||
main_query = main_query.limit(limit)
|
||
|
||
# Execute query
|
||
result = await session.execute(main_query)
|
||
|
||
# Get ORM objects directly using scalars()
|
||
passages = result.scalars().all()
|
||
|
||
# Convert to Pydantic models
|
||
pydantic_passages = [p.to_pydantic() for p in passages]
|
||
|
||
# TODO: Integrate tag filtering directly into the SQL query for better performance.
|
||
# Currently using post-filtering which is less efficient but simpler to implement.
|
||
# Future optimization: Add JOIN with passage_tags table and WHERE clause for tag filtering.
|
||
if tags:
|
||
filtered_passages = []
|
||
for passage in pydantic_passages:
|
||
if passage.tags:
|
||
passage_tags = set(passage.tags)
|
||
query_tags = set(tags)
|
||
|
||
if tag_match_mode == TagMatchMode.ALL:
|
||
# ALL mode: passage must have all query tags
|
||
if query_tags.issubset(passage_tags):
|
||
filtered_passages.append(passage)
|
||
else:
|
||
# ANY mode (default): passage must have at least one query tag
|
||
if query_tags.intersection(passage_tags):
|
||
filtered_passages.append(passage)
|
||
|
||
# Return as tuples with empty metadata for SQL path
|
||
return [(p, 0.0, {}) for p in filtered_passages]
|
||
|
||
# Return as tuples with empty metadata for SQL path
|
||
return [(p, 0.0, {}) for p in pydantic_passages]
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def search_agent_archival_memory_async(
|
||
self,
|
||
agent_id: str,
|
||
actor: PydanticUser,
|
||
query: str,
|
||
tags: Optional[List[str]] = None,
|
||
tag_match_mode: Literal["any", "all"] = "any",
|
||
top_k: Optional[int] = None,
|
||
start_datetime: Optional[str] = None,
|
||
end_datetime: Optional[str] = None,
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
Search archival memory using semantic (embedding-based) search with optional temporal filtering.
|
||
|
||
This is a shared method used by both the agent tool and API endpoint to ensure consistent behavior.
|
||
|
||
Args:
|
||
agent_id: ID of the agent whose archival memory to search
|
||
actor: User performing the search
|
||
query: String to search for using semantic similarity
|
||
tags: Optional list of tags to filter search results
|
||
tag_match_mode: How to match tags - "any" or "all"
|
||
top_k: Maximum number of results to return
|
||
start_datetime: Filter results after this datetime (ISO 8601 format)
|
||
end_datetime: Filter results before this datetime (ISO 8601 format)
|
||
|
||
Returns:
|
||
List of formatted results with relevance metadata
|
||
"""
|
||
# Handle empty or whitespace-only queries
|
||
if not query or not query.strip():
|
||
return []
|
||
|
||
# Get the agent to access timezone and embedding config
|
||
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||
|
||
# Parse datetime parameters if provided
|
||
start_date = None
|
||
end_date = None
|
||
|
||
if start_datetime:
|
||
try:
|
||
# Try parsing as full datetime first (with time)
|
||
start_date = datetime.fromisoformat(start_datetime)
|
||
except ValueError:
|
||
try:
|
||
# Fall back to date-only format
|
||
start_date = datetime.strptime(start_datetime, "%Y-%m-%d")
|
||
# Set to beginning of day
|
||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||
except ValueError:
|
||
raise ValueError(
|
||
f"Invalid start_datetime format: {start_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)"
|
||
)
|
||
|
||
# Apply agent's timezone if datetime is naive
|
||
if start_date.tzinfo is None and agent_state.timezone:
|
||
tz = ZoneInfo(agent_state.timezone)
|
||
start_date = start_date.replace(tzinfo=tz)
|
||
|
||
if end_datetime:
|
||
try:
|
||
# Try parsing as full datetime first (with time)
|
||
end_date = datetime.fromisoformat(end_datetime)
|
||
except ValueError:
|
||
try:
|
||
# Fall back to date-only format
|
||
end_date = datetime.strptime(end_datetime, "%Y-%m-%d")
|
||
# Set to end of day for end dates
|
||
end_date = end_date.replace(hour=23, minute=59, second=59, microsecond=999999)
|
||
except ValueError:
|
||
raise ValueError(f"Invalid end_datetime format: {end_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)")
|
||
|
||
# Apply agent's timezone if datetime is naive
|
||
if end_date.tzinfo is None and agent_state.timezone:
|
||
tz = ZoneInfo(agent_state.timezone)
|
||
end_date = end_date.replace(tzinfo=tz)
|
||
|
||
# Convert string to TagMatchMode enum
|
||
tag_mode = TagMatchMode.ANY if tag_match_mode == "any" else TagMatchMode.ALL
|
||
|
||
# Get results using existing passage query method
|
||
limit = top_k if top_k is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||
passages_with_metadata = await self.query_agent_passages_async(
|
||
actor=actor,
|
||
agent_id=agent_id,
|
||
query_text=query,
|
||
limit=limit,
|
||
embedding_config=agent_state.embedding_config,
|
||
embed_query=True,
|
||
tags=tags,
|
||
tag_match_mode=tag_mode,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
)
|
||
|
||
# Format results to include tags with friendly timestamps and relevance metadata
|
||
formatted_results = []
|
||
for passage, score, metadata in passages_with_metadata:
|
||
# Format timestamp in agent's timezone if available
|
||
timestamp = passage.created_at
|
||
if timestamp and agent_state.timezone:
|
||
try:
|
||
# Convert to agent's timezone
|
||
tz = ZoneInfo(agent_state.timezone)
|
||
local_time = timestamp.astimezone(tz)
|
||
# Format as ISO string with timezone
|
||
formatted_timestamp = local_time.isoformat()
|
||
except Exception:
|
||
# Fallback to ISO format if timezone conversion fails
|
||
formatted_timestamp = str(timestamp)
|
||
else:
|
||
# Use ISO format if no timezone is set
|
||
formatted_timestamp = str(timestamp) if timestamp else "Unknown"
|
||
|
||
result_dict = {"timestamp": formatted_timestamp, "content": passage.text, "tags": passage.tags or []}
|
||
|
||
# Add relevance metadata if available
|
||
if metadata:
|
||
relevance_info = {
|
||
k: v
|
||
for k, v in {
|
||
"rrf_score": metadata.get("combined_score"),
|
||
"vector_rank": metadata.get("vector_rank"),
|
||
"fts_rank": metadata.get("fts_rank"),
|
||
}.items()
|
||
if v is not None
|
||
}
|
||
|
||
if relevance_info: # Only add if we have metadata
|
||
result_dict["relevance"] = relevance_info
|
||
|
||
formatted_results.append(result_dict)
|
||
|
||
return formatted_results
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def passage_size(
|
||
self,
|
||
actor: PydanticUser,
|
||
agent_id: Optional[str] = None,
|
||
file_id: Optional[str] = None,
|
||
query_text: Optional[str] = None,
|
||
start_date: Optional[datetime] = None,
|
||
end_date: Optional[datetime] = None,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
source_id: Optional[str] = None,
|
||
embed_query: bool = False,
|
||
ascending: bool = True,
|
||
embedding_config: Optional[EmbeddingConfig] = None,
|
||
agent_only: bool = False,
|
||
) -> int:
|
||
"""Returns the count of passages matching the given criteria."""
|
||
async with db_registry.async_session() as session:
|
||
main_query = await build_passage_query(
|
||
actor=actor,
|
||
agent_id=agent_id,
|
||
file_id=file_id,
|
||
query_text=query_text,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
before=before,
|
||
after=after,
|
||
source_id=source_id,
|
||
embed_query=embed_query,
|
||
ascending=ascending,
|
||
embedding_config=embedding_config,
|
||
agent_only=agent_only,
|
||
)
|
||
|
||
# Convert to count query
|
||
count_query = select(func.count()).select_from(main_query.subquery())
|
||
return (await session.scalar(count_query)) or 0
|
||
|
||
@enforce_types
|
||
async def passage_size_async(
|
||
self,
|
||
actor: PydanticUser,
|
||
agent_id: Optional[str] = None,
|
||
file_id: Optional[str] = None,
|
||
query_text: Optional[str] = None,
|
||
start_date: Optional[datetime] = None,
|
||
end_date: Optional[datetime] = None,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
source_id: Optional[str] = None,
|
||
embed_query: bool = False,
|
||
ascending: bool = True,
|
||
embedding_config: Optional[EmbeddingConfig] = None,
|
||
agent_only: bool = False,
|
||
) -> int:
|
||
async with db_registry.async_session() as session:
|
||
main_query = await build_passage_query(
|
||
actor=actor,
|
||
agent_id=agent_id,
|
||
file_id=file_id,
|
||
query_text=query_text,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
before=before,
|
||
after=after,
|
||
source_id=source_id,
|
||
embed_query=embed_query,
|
||
ascending=ascending,
|
||
embedding_config=embedding_config,
|
||
agent_only=agent_only,
|
||
)
|
||
|
||
# Convert to count query
|
||
count_query = select(func.count()).select_from(main_query.subquery())
|
||
return (await session.execute(count_query)).scalar() or 0
|
||
|
||
# ======================================================================================================================
|
||
# Tool Management
|
||
# ======================================================================================================================
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@raise_on_invalid_id(param_name="tool_id", expected_prefix=PrimitiveType.TOOL)
|
||
@trace_method
|
||
async def attach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None:
|
||
"""
|
||
Attaches a tool to an agent.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to attach the tool to.
|
||
tool_id: ID of the tool to attach.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If the agent or tool is not found.
|
||
|
||
Returns:
|
||
PydanticAgentState: The updated agent state.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# Verify the agent exists and user has permission to access it
|
||
await validate_agent_exists_async(session, agent_id, actor)
|
||
|
||
# verify tool exists and belongs to organization in a single query with the insert
|
||
# first, check if tool exists with correct organization
|
||
tool_check_query = select(ToolModel.name, ToolModel.default_requires_approval).where(
|
||
ToolModel.id == tool_id, ToolModel.organization_id == actor.organization_id
|
||
)
|
||
result = await session.execute(tool_check_query)
|
||
tool_rows = result.fetchall()
|
||
|
||
if len(tool_rows) == 0:
|
||
raise NoResultFound(f"Tool with id={tool_id} not found in organization={actor.organization_id}")
|
||
tool_name, default_requires_approval = tool_rows[0]
|
||
|
||
# use postgresql on conflict or mysql on duplicate key update for atomic operation
|
||
if settings.letta_pg_uri_no_default:
|
||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||
|
||
insert_stmt = pg_insert(ToolsAgents).values(agent_id=agent_id, tool_id=tool_id)
|
||
# on conflict do nothing - silently ignore if already exists
|
||
insert_stmt = insert_stmt.on_conflict_do_nothing(index_elements=["agent_id", "tool_id"])
|
||
result = await session.execute(insert_stmt)
|
||
if result.rowcount == 0:
|
||
logger.info(f"Tool id={tool_id} is already attached to agent id={agent_id}")
|
||
else:
|
||
# for sqlite/mysql, check then insert
|
||
existing_query = (
|
||
select(func.count()).select_from(ToolsAgents).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id == tool_id)
|
||
)
|
||
existing_result = await session.execute(existing_query)
|
||
if existing_result.scalar() == 0:
|
||
insert_stmt = insert(ToolsAgents).values(agent_id=agent_id, tool_id=tool_id)
|
||
await session.execute(insert_stmt)
|
||
else:
|
||
logger.info(f"Tool id={tool_id} is already attached to agent id={agent_id}")
|
||
|
||
if default_requires_approval:
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
existing_rules = [rule for rule in agent.tool_rules if rule.tool_name == tool_name and rule.type == "requires_approval"]
|
||
if len(existing_rules) == 0:
|
||
# Create a new list to ensure SQLAlchemy detects the change
|
||
# This is critical for JSON columns - modifying in place doesn't trigger change detection
|
||
tool_rules = list(agent.tool_rules) if agent.tool_rules else []
|
||
tool_rules.append(RequiresApprovalToolRule(tool_name=tool_name))
|
||
agent.tool_rules = tool_rules
|
||
session.add(agent)
|
||
|
||
# context manager now handles commits
|
||
# await session.commit()
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@trace_method
|
||
async def bulk_attach_tools_async(self, agent_id: str, tool_ids: List[str], actor: PydanticUser) -> None:
|
||
"""
|
||
Efficiently attaches multiple tools to an agent in a single operation.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to attach the tools to.
|
||
tool_ids: List of tool IDs to attach.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If the agent or any tool is not found.
|
||
"""
|
||
if not tool_ids:
|
||
# no tools to attach, nothing to do
|
||
return
|
||
|
||
async with db_registry.async_session() as session:
|
||
# Verify the agent exists and user has permission to access it
|
||
await validate_agent_exists_async(session, agent_id, actor)
|
||
|
||
# verify all tools exist and belong to organization in a single query
|
||
tool_check_query = select(func.count(ToolModel.id)).where(
|
||
ToolModel.id.in_(tool_ids), ToolModel.organization_id == actor.organization_id
|
||
)
|
||
tool_result = await session.execute(tool_check_query)
|
||
found_count = tool_result.scalar()
|
||
|
||
if found_count != len(tool_ids):
|
||
# find which tools are missing for better error message
|
||
existing_query = select(ToolModel.id).where(ToolModel.id.in_(tool_ids), ToolModel.organization_id == actor.organization_id)
|
||
existing_result = await session.execute(existing_query)
|
||
existing_ids = {row[0] for row in existing_result}
|
||
missing_ids = set(tool_ids) - existing_ids
|
||
raise NoResultFound(f"Tools with ids={missing_ids} not found in organization={actor.organization_id}")
|
||
|
||
if settings.letta_pg_uri_no_default:
|
||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||
|
||
# prepare bulk values
|
||
values = [{"agent_id": agent_id, "tool_id": tool_id} for tool_id in tool_ids]
|
||
|
||
# bulk insert with on conflict do nothing
|
||
insert_stmt = pg_insert(ToolsAgents).values(values)
|
||
insert_stmt = insert_stmt.on_conflict_do_nothing(index_elements=["agent_id", "tool_id"])
|
||
result = await session.execute(insert_stmt)
|
||
logger.info(
|
||
f"Attached {result.rowcount} new tools to agent {agent_id} (skipped {len(tool_ids) - result.rowcount} already attached)"
|
||
)
|
||
else:
|
||
# for sqlite/mysql, first check which tools are already attached
|
||
existing_query = select(ToolsAgents.tool_id).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id.in_(tool_ids))
|
||
existing_result = await session.execute(existing_query)
|
||
already_attached = {row[0] for row in existing_result}
|
||
|
||
# only insert tools that aren't already attached
|
||
new_tool_ids = [tid for tid in tool_ids if tid not in already_attached]
|
||
|
||
if new_tool_ids:
|
||
# bulk insert new attachments
|
||
values = [{"agent_id": agent_id, "tool_id": tool_id} for tool_id in new_tool_ids]
|
||
insert_stmt = insert(ToolsAgents).values(values)
|
||
await session.execute(insert_stmt)
|
||
logger.info(
|
||
f"Attached {len(new_tool_ids)} new tools to agent {agent_id} (skipped {len(already_attached)} already attached)"
|
||
)
|
||
else:
|
||
logger.info(f"All {len(tool_ids)} tools already attached to agent {agent_id}")
|
||
|
||
# context manager now handles commits
|
||
# await session.commit()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def attach_missing_files_tools_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
||
"""
|
||
Attaches missing core file tools to an agent.
|
||
|
||
Args:
|
||
agent_state: The current agent state with tools already loaded.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If the agent or tool is not found.
|
||
|
||
Returns:
|
||
PydanticAgentState: The updated agent state.
|
||
"""
|
||
# get current file tools attached to the agent
|
||
attached_file_tool_names = {tool.name for tool in agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE}
|
||
|
||
# determine which file tools are missing
|
||
missing_tool_names = set(FILES_TOOLS) - attached_file_tool_names
|
||
|
||
if not missing_tool_names:
|
||
# agent already has all file tools
|
||
return agent_state
|
||
|
||
# get full tool objects for all missing file tools in one query
|
||
async with db_registry.async_session() as session:
|
||
query = select(ToolModel).where(
|
||
ToolModel.name.in_(missing_tool_names),
|
||
ToolModel.organization_id == actor.organization_id,
|
||
ToolModel.tool_type == ToolType.LETTA_FILES_CORE,
|
||
)
|
||
result = await session.execute(query)
|
||
found_tool_models = result.scalars().all()
|
||
|
||
if not found_tool_models:
|
||
logger.warning(f"No file tools found for organization {actor.organization_id}. Expected tools: {missing_tool_names}")
|
||
return agent_state
|
||
|
||
# convert to pydantic tools
|
||
found_tools = [tool.to_pydantic() for tool in found_tool_models]
|
||
found_tool_names = {tool.name for tool in found_tools}
|
||
|
||
# log if any expected tools weren't found
|
||
still_missing = missing_tool_names - found_tool_names
|
||
if still_missing:
|
||
logger.warning(f"File tools {still_missing} not found in organization {actor.organization_id}")
|
||
|
||
# extract tool IDs for bulk attach
|
||
tool_ids_to_attach = [tool.id for tool in found_tools]
|
||
|
||
# bulk attach all found file tools
|
||
await self.bulk_attach_tools_async(agent_id=agent_state.id, tool_ids=tool_ids_to_attach, actor=actor)
|
||
|
||
# create a shallow copy with updated tools list to avoid modifying input
|
||
agent_state_dict = agent_state.model_dump()
|
||
agent_state_dict["tools"] = agent_state.tools + found_tools
|
||
|
||
return PydanticAgentState(**agent_state_dict)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def detach_all_files_tools_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
||
"""
|
||
Detach all core file tools from an agent.
|
||
|
||
Args:
|
||
agent_state: The current agent state with tools already loaded.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If the agent is not found.
|
||
|
||
Returns:
|
||
PydanticAgentState: The updated agent state.
|
||
"""
|
||
# extract file tool IDs directly from agent_state.tools
|
||
file_tool_ids = [tool.id for tool in agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE]
|
||
|
||
if not file_tool_ids:
|
||
# no file tools to detach
|
||
return agent_state
|
||
|
||
# bulk detach all file tools in one operation
|
||
await self.bulk_detach_tools_async(agent_id=agent_state.id, tool_ids=file_tool_ids, actor=actor)
|
||
|
||
# create a shallow copy with updated tools list to avoid modifying input
|
||
agent_state_dict = agent_state.model_dump()
|
||
agent_state_dict["tools"] = [tool for tool in agent_state.tools if tool.tool_type != ToolType.LETTA_FILES_CORE]
|
||
|
||
return PydanticAgentState(**agent_state_dict)
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@raise_on_invalid_id(param_name="tool_id", expected_prefix=PrimitiveType.TOOL)
|
||
@trace_method
|
||
async def detach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None:
|
||
"""
|
||
Detaches a tool from an agent.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to detach the tool from.
|
||
tool_id: ID of the tool to detach.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If the agent is not found.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# Verify the agent exists and user has permission to access it
|
||
await validate_agent_exists_async(session, agent_id, actor)
|
||
|
||
# Delete the association directly - if it doesn't exist, rowcount will be 0
|
||
delete_query = delete(ToolsAgents).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id == tool_id)
|
||
result = await session.execute(delete_query)
|
||
|
||
if result.rowcount == 0:
|
||
logger.warning(f"Attempted to remove unattached tool id={tool_id} from agent id={agent_id} by actor={actor}")
|
||
else:
|
||
logger.debug(f"Detached tool id={tool_id} from agent id={agent_id}")
|
||
|
||
# context manager now handles commits
|
||
# await session.commit()
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@trace_method
|
||
async def bulk_detach_tools_async(self, agent_id: str, tool_ids: List[str], actor: PydanticUser) -> None:
|
||
"""
|
||
Efficiently detaches multiple tools from an agent in a single operation.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to detach the tools from.
|
||
tool_ids: List of tool IDs to detach.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If the agent is not found.
|
||
"""
|
||
if not tool_ids:
|
||
# no tools to detach, nothing to do
|
||
return
|
||
|
||
async with db_registry.async_session() as session:
|
||
# Verify the agent exists and user has permission to access it
|
||
await validate_agent_exists_async(session, agent_id, actor)
|
||
|
||
# Delete all associations in a single query
|
||
delete_query = delete(ToolsAgents).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id.in_(tool_ids))
|
||
result = await session.execute(delete_query)
|
||
|
||
detached_count = result.rowcount
|
||
if detached_count == 0:
|
||
logger.warning(f"No tools from list {tool_ids} were attached to agent id={agent_id}")
|
||
elif detached_count < len(tool_ids):
|
||
logger.info(f"Detached {detached_count} tools from agent {agent_id} ({len(tool_ids) - detached_count} were not attached)")
|
||
else:
|
||
logger.info(f"Detached all {detached_count} tools from agent {agent_id}")
|
||
|
||
# context manager now handles commits
|
||
# await session.commit()
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@trace_method
|
||
async def modify_approvals_async(self, agent_id: str, tool_name: str, requires_approval: bool, actor: PydanticUser) -> None:
|
||
def is_target_rule(rule):
|
||
return rule.tool_name == tool_name and rule.type == "requires_approval"
|
||
|
||
async with db_registry.async_session() as session:
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
existing_rules = [rule for rule in agent.tool_rules if is_target_rule(rule)]
|
||
|
||
if len(existing_rules) == 1 and not requires_approval:
|
||
tool_rules = [rule for rule in agent.tool_rules if not is_target_rule(rule)]
|
||
elif len(existing_rules) == 0 and requires_approval:
|
||
# Create a new list to ensure SQLAlchemy detects the change
|
||
# This is critical for JSON columns - modifying in place doesn't trigger change detection
|
||
tool_rules = list(agent.tool_rules) if agent.tool_rules else []
|
||
tool_rules.append(RequiresApprovalToolRule(tool_name=tool_name))
|
||
else:
|
||
tool_rules = None
|
||
|
||
if tool_rules is None:
|
||
return
|
||
|
||
agent.tool_rules = tool_rules
|
||
session.add(agent)
|
||
# context manager now handles commits
|
||
# await session.commit()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def list_attached_tools_async(
|
||
self,
|
||
agent_id: str,
|
||
actor: PydanticUser,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
limit: Optional[int] = None,
|
||
ascending: bool = False,
|
||
) -> List[PydanticTool]:
|
||
"""
|
||
List all tools attached to an agent (async version with optimized performance).
|
||
Uses direct SQL queries to avoid SqlAlchemyBase overhead.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to list tools for.
|
||
actor: User performing the action.
|
||
before: Tool ID cursor for pagination. Returns tools that come before this tool ID.
|
||
after: Tool ID cursor for pagination. Returns tools that come after this tool ID.
|
||
limit: Maximum number of tools to return.
|
||
ascending: Sort order by creation time.
|
||
|
||
Returns:
|
||
List[PydanticTool]: List of tools attached to the agent.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# lightweight check for agent access
|
||
await validate_agent_exists_async(session, agent_id, actor)
|
||
|
||
# direct query for tools via join - much more performant
|
||
query = (
|
||
select(ToolModel)
|
||
.join(ToolsAgents, ToolModel.id == ToolsAgents.tool_id)
|
||
.where(ToolsAgents.agent_id == agent_id, ToolModel.organization_id == actor.organization_id)
|
||
)
|
||
|
||
# Apply cursor-based pagination
|
||
if before:
|
||
query = query.where(ToolModel.id < before)
|
||
if after:
|
||
query = query.where(ToolModel.id > after)
|
||
|
||
# Apply sorting
|
||
if ascending:
|
||
query = query.order_by(ToolModel.created_at.asc())
|
||
else:
|
||
query = query.order_by(ToolModel.created_at.desc())
|
||
|
||
# Apply limit
|
||
if limit:
|
||
query = query.limit(limit)
|
||
|
||
result = await session.execute(query)
|
||
tools = result.scalars().all()
|
||
return [tool.to_pydantic() for tool in tools]
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def list_agent_blocks_async(
|
||
self,
|
||
agent_id: str,
|
||
actor: PydanticUser,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
limit: Optional[int] = None,
|
||
ascending: bool = False,
|
||
) -> List[PydanticBlock]:
|
||
"""
|
||
List all blocks for a specific agent with pagination.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to find blocks for.
|
||
actor: User performing the action.
|
||
before: Block ID cursor for pagination. Returns blocks that come before this block ID.
|
||
after: Block ID cursor for pagination. Returns blocks that come after this block ID.
|
||
limit: Maximum number of blocks to return.
|
||
ascending: Sort order by creation time.
|
||
|
||
Returns:
|
||
List[PydanticBlock]: List of blocks for the agent.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# First verify agent exists and user has access
|
||
await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
|
||
# Build query to get blocks for this agent with pagination
|
||
query = (
|
||
select(BlockModel)
|
||
.join(BlocksAgents, BlockModel.id == BlocksAgents.block_id)
|
||
.where(BlocksAgents.agent_id == agent_id, BlockModel.organization_id == actor.organization_id)
|
||
)
|
||
|
||
# Apply cursor-based pagination
|
||
if before:
|
||
query = query.where(BlockModel.id < before)
|
||
if after:
|
||
query = query.where(BlockModel.id > after)
|
||
|
||
# Apply sorting - use id instead of created_at for core memory blocks
|
||
if ascending:
|
||
query = query.order_by(BlockModel.id.asc())
|
||
else:
|
||
query = query.order_by(BlockModel.id.desc())
|
||
|
||
# Apply limit
|
||
if limit:
|
||
query = query.limit(limit)
|
||
|
||
result = await session.execute(query)
|
||
blocks = result.scalars().all()
|
||
|
||
return [block.to_pydantic() for block in blocks]
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def list_groups_async(
|
||
self,
|
||
agent_id: str,
|
||
actor: PydanticUser,
|
||
manager_type: Optional[str] = None,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
limit: Optional[int] = None,
|
||
ascending: bool = False,
|
||
) -> List[PydanticGroup]:
|
||
"""
|
||
List all groups that contain the specified agent.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to find groups for.
|
||
actor: User performing the action.
|
||
manager_type: Optional manager type to filter by.
|
||
before: Group ID cursor for pagination. Returns groups that come before this group ID.
|
||
after: Group ID cursor for pagination. Returns groups that come after this group ID.
|
||
limit: Maximum number of groups to return.
|
||
ascending: Sort order by creation time.
|
||
|
||
Returns:
|
||
List[PydanticGroup]: List of groups containing the agent.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
query = (
|
||
select(GroupModel)
|
||
.join(GroupsAgents, GroupModel.id == GroupsAgents.group_id)
|
||
.where(GroupsAgents.agent_id == agent_id, GroupModel.organization_id == actor.organization_id)
|
||
)
|
||
|
||
if manager_type:
|
||
query = query.where(GroupModel.manager_type == manager_type)
|
||
|
||
# Apply cursor-based pagination
|
||
if before:
|
||
query = query.where(GroupModel.id < before)
|
||
if after:
|
||
query = query.where(GroupModel.id > after)
|
||
|
||
# Apply sorting
|
||
if ascending:
|
||
query = query.order_by(GroupModel.created_at.asc())
|
||
else:
|
||
query = query.order_by(GroupModel.created_at.desc())
|
||
|
||
# Apply limit
|
||
if limit:
|
||
query = query.limit(limit)
|
||
|
||
result = await session.execute(query)
|
||
groups = result.scalars().all()
|
||
return [group.to_pydantic() for group in groups]
|
||
|
||
# ======================================================================================================================
|
||
# File Management
|
||
# ======================================================================================================================
|
||
async def insert_file_into_context_windows(
|
||
self,
|
||
source_id: str,
|
||
file_metadata_with_content: PydanticFileMetadata,
|
||
actor: PydanticUser,
|
||
agent_states: Optional[List[PydanticAgentState]] = None,
|
||
) -> List[PydanticAgentState]:
|
||
"""
|
||
Insert the uploaded document into the context window of all agents
|
||
attached to the given source.
|
||
"""
|
||
agent_states = agent_states or await self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
|
||
|
||
# Return early
|
||
if not agent_states:
|
||
return []
|
||
|
||
logger.info(f"Inserting document into context window for source: {source_id}")
|
||
logger.info(f"Attached agents: {[a.id for a in agent_states]}")
|
||
|
||
# Generate visible content for the file
|
||
line_chunker = LineChunker()
|
||
content_lines = line_chunker.chunk_text(file_metadata=file_metadata_with_content)
|
||
visible_content = "\n".join(content_lines)
|
||
visible_content_map = {file_metadata_with_content.file_name: visible_content}
|
||
|
||
all_closed_files: List[str] = []
|
||
|
||
for agent_state in agent_states:
|
||
# To avoid exhausting the db connection pool when many agents are attached,
|
||
# perform the operations sequentially instead of concurrently.
|
||
closed_for_agent = await self.file_agent_manager.attach_files_bulk(
|
||
agent_id=agent_state.id,
|
||
files_metadata=[file_metadata_with_content],
|
||
visible_content_map=visible_content_map,
|
||
actor=actor,
|
||
max_files_open=agent_state.max_files_open,
|
||
)
|
||
all_closed_files.extend(closed_for_agent)
|
||
|
||
# Log if any files were closed
|
||
closed_files = all_closed_files
|
||
if closed_files:
|
||
logger.info(f"LRU eviction closed {len(closed_files)} files during bulk attach: {closed_files}")
|
||
|
||
return agent_states
|
||
|
||
async def insert_files_into_context_window(
|
||
self, agent_state: PydanticAgentState, file_metadata_with_content: List[PydanticFileMetadata], actor: PydanticUser
|
||
) -> None:
|
||
"""
|
||
Insert the uploaded documents into the context window of an agent
|
||
attached to the given source.
|
||
"""
|
||
logger.info(f"Inserting {len(file_metadata_with_content)} documents into context window for agent_state: {agent_state.id}")
|
||
|
||
# Generate visible content for each file
|
||
line_chunker = LineChunker()
|
||
visible_content_map = {}
|
||
for i, file_metadata in enumerate(file_metadata_with_content):
|
||
content_lines = line_chunker.chunk_text(file_metadata=file_metadata)
|
||
visible_content_map[file_metadata.file_name] = "\n".join(content_lines)
|
||
|
||
# Yield to event loop every 100 files to prevent saturation
|
||
if i > 0 and i % 100 == 0:
|
||
await asyncio.sleep(0)
|
||
|
||
# Use bulk attach to avoid race conditions and duplicate LRU eviction decisions
|
||
closed_files = await self.file_agent_manager.attach_files_bulk(
|
||
agent_id=agent_state.id,
|
||
files_metadata=file_metadata_with_content,
|
||
visible_content_map=visible_content_map,
|
||
actor=actor,
|
||
max_files_open=agent_state.max_files_open,
|
||
)
|
||
|
||
if closed_files:
|
||
logger.info(f"LRU eviction closed {len(closed_files)} files during bulk insert: {closed_files}")
|
||
|
||
# ======================================================================================================================
|
||
# Tag Management
|
||
# ======================================================================================================================
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def list_tags_async(
|
||
self,
|
||
actor: PydanticUser,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
limit: Optional[int] = 50,
|
||
query_text: Optional[str] = None,
|
||
ascending: bool = True,
|
||
) -> List[str]:
|
||
"""
|
||
Get all tags a user has created, ordered alphabetically.
|
||
|
||
Args:
|
||
actor: User performing the action.
|
||
before: Cursor for backward pagination (tags before this tag).
|
||
after: Cursor for forward pagination (tags after this tag).
|
||
limit: Maximum number of tags to return (default: 50).
|
||
query_text: Filter tags by text search.
|
||
ascending: Sort order - True for alphabetical, False for reverse (default: True).
|
||
|
||
Returns:
|
||
List[str]: List of all tags matching the criteria.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# Build the query using select() for async SQLAlchemy
|
||
query = (
|
||
select(AgentsTags.tag)
|
||
.join(AgentModel, AgentModel.id == AgentsTags.agent_id)
|
||
.where(AgentModel.organization_id == actor.organization_id)
|
||
.distinct()
|
||
)
|
||
|
||
if query_text:
|
||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||
# PostgreSQL: Use ILIKE for case-insensitive search
|
||
query = query.where(AgentsTags.tag.ilike(f"%{query_text}%"))
|
||
else:
|
||
# SQLite: Use LIKE with LOWER for case-insensitive search
|
||
query = query.where(func.lower(AgentsTags.tag).like(func.lower(f"%{query_text}%")))
|
||
|
||
# Handle pagination cursors
|
||
if after:
|
||
if ascending:
|
||
query = query.where(AgentsTags.tag > after)
|
||
else:
|
||
query = query.where(AgentsTags.tag < after)
|
||
|
||
if before:
|
||
if ascending:
|
||
query = query.where(AgentsTags.tag < before)
|
||
else:
|
||
query = query.where(AgentsTags.tag > before)
|
||
|
||
# Apply ordering based on ascending parameter
|
||
if ascending:
|
||
query = query.order_by(AgentsTags.tag.asc())
|
||
else:
|
||
query = query.order_by(AgentsTags.tag.desc())
|
||
|
||
query = query.limit(limit)
|
||
|
||
# Execute the query asynchronously
|
||
result = await session.execute(query)
|
||
# Extract the tag values from the result
|
||
results = [row[0] for row in result.all()]
|
||
return results
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@trace_method
|
||
async def get_agent_files_config_async(self, agent_id: str, actor: PydanticUser) -> Tuple[int, int]:
|
||
"""Get per_file_view_window_char_limit and max_files_open for an agent.
|
||
|
||
This is a performant query that only fetches the specific fields needed.
|
||
|
||
Args:
|
||
agent_id: The ID of the agent
|
||
actor: The user making the request
|
||
|
||
Returns:
|
||
Tuple of per_file_view_window_char_limit, max_files_open values
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
result = await session.execute(
|
||
select(AgentModel.per_file_view_window_char_limit, AgentModel.max_files_open)
|
||
.where(AgentModel.id == agent_id)
|
||
.where(AgentModel.organization_id == actor.organization_id)
|
||
.where(AgentModel.is_deleted == False)
|
||
)
|
||
row = result.one_or_none()
|
||
|
||
if row is None:
|
||
raise ValueError(f"Agent {agent_id} not found")
|
||
|
||
per_file_limit, max_files = row[0], row[1]
|
||
|
||
# Handle None values by calculating defaults based on context window
|
||
if per_file_limit is None or max_files is None:
|
||
# Get the agent's model context window to calculate appropriate defaults
|
||
model_result = await session.execute(
|
||
select(AgentModel.llm_config)
|
||
.where(AgentModel.id == agent_id)
|
||
.where(AgentModel.organization_id == actor.organization_id)
|
||
.where(AgentModel.is_deleted == False)
|
||
)
|
||
model_row = model_result.one_or_none()
|
||
context_window = model_row[0].context_window if model_row and model_row[0] else None
|
||
|
||
default_max_files, default_per_file_limit = calculate_file_defaults_based_on_context_window(context_window)
|
||
|
||
# Use calculated defaults for None values
|
||
if per_file_limit is None:
|
||
per_file_limit = default_per_file_limit
|
||
if max_files is None:
|
||
max_files = default_max_files
|
||
|
||
# FINAL fallback: ensure neither is None (should never happen, but just in case)
|
||
if per_file_limit is None:
|
||
per_file_limit = DEFAULT_CORE_MEMORY_SOURCE_CHAR_LIMIT
|
||
if max_files is None:
|
||
max_files = DEFAULT_MAX_FILES_OPEN
|
||
|
||
return per_file_limit, max_files
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@trace_method
|
||
async def get_agent_max_files_open_async(self, agent_id: str, actor: PydanticUser) -> int:
|
||
"""Get max_files_open for an agent.
|
||
|
||
This is a performant query that only fetches the specific field needed.
|
||
|
||
Args:
|
||
agent_id: The ID of the agent
|
||
actor: The user making the request
|
||
|
||
Returns:
|
||
max_files_open value
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
result = await session.execute(
|
||
select(AgentModel.max_files_open)
|
||
.where(AgentModel.id == agent_id)
|
||
.where(AgentModel.organization_id == actor.organization_id)
|
||
.where(AgentModel.is_deleted == False)
|
||
)
|
||
row = result.scalar_one_or_none()
|
||
|
||
if row is None:
|
||
raise ValueError(f"Agent {agent_id} not found")
|
||
|
||
return row
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@trace_method
|
||
async def get_agent_per_file_view_window_char_limit_async(self, agent_id: str, actor: PydanticUser) -> int:
|
||
"""Get per_file_view_window_char_limit for an agent.
|
||
|
||
This is a performant query that only fetches the specific field needed.
|
||
|
||
Args:
|
||
agent_id: The ID of the agent
|
||
actor: The user making the request
|
||
|
||
Returns:
|
||
per_file_view_window_char_limit value
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
result = await session.execute(
|
||
select(AgentModel.per_file_view_window_char_limit)
|
||
.where(AgentModel.id == agent_id)
|
||
.where(AgentModel.organization_id == actor.organization_id)
|
||
.where(AgentModel.is_deleted == False)
|
||
)
|
||
row = result.scalar_one_or_none()
|
||
|
||
if row is None:
|
||
raise ValueError(f"Agent {agent_id} not found")
|
||
|
||
return row
|
||
|
||
@enforce_types
|
||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||
@trace_method
|
||
async def get_context_window(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview:
|
||
agent_state, system_message, num_messages, num_archival_memories = await self.rebuild_system_prompt_async(
|
||
agent_id=agent_id, actor=actor, force=True, dry_run=True
|
||
)
|
||
calculator = ContextWindowCalculator()
|
||
|
||
# Create the appropriate token counter based on model configuration
|
||
token_counter = create_token_counter(
|
||
model_endpoint_type=agent_state.llm_config.model_endpoint_type,
|
||
model=agent_state.llm_config.model,
|
||
actor=actor,
|
||
agent_id=agent_id,
|
||
)
|
||
|
||
try:
|
||
result = await calculator.calculate_context_window(
|
||
agent_state=agent_state,
|
||
actor=actor,
|
||
token_counter=token_counter,
|
||
message_manager=self.message_manager,
|
||
system_message_compiled=system_message,
|
||
num_archival_memories=num_archival_memories,
|
||
num_messages=num_messages,
|
||
)
|
||
except Exception as e:
|
||
raise e
|
||
|
||
return result
|