3743 lines
166 KiB
Python
3743 lines
166 KiB
Python
import asyncio
|
||
from datetime import datetime, timezone
|
||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple
|
||
from zoneinfo import ZoneInfo
|
||
|
||
import sqlalchemy as sa
|
||
from sqlalchemy import delete, func, insert, literal, or_, select, tuple_
|
||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||
|
||
from letta.constants import (
|
||
BASE_MEMORY_TOOLS,
|
||
BASE_MEMORY_TOOLS_V2,
|
||
BASE_SLEEPTIME_CHAT_TOOLS,
|
||
BASE_SLEEPTIME_TOOLS,
|
||
BASE_TOOLS,
|
||
BASE_VOICE_SLEEPTIME_CHAT_TOOLS,
|
||
BASE_VOICE_SLEEPTIME_TOOLS,
|
||
DEFAULT_CORE_MEMORY_SOURCE_CHAR_LIMIT,
|
||
DEFAULT_MAX_FILES_OPEN,
|
||
DEFAULT_TIMEZONE,
|
||
DEPRECATED_LETTA_TOOLS,
|
||
EXCLUDE_MODEL_KEYWORDS_FROM_BASE_TOOL_RULES,
|
||
FILES_TOOLS,
|
||
INCLUDE_MODEL_KEYWORDS_BASE_TOOL_RULES,
|
||
RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE,
|
||
)
|
||
from letta.helpers import ToolRulesSolver
|
||
from letta.helpers.datetime_helpers import get_utc_time
|
||
from letta.llm_api.llm_client import LLMClient
|
||
from letta.log import get_logger
|
||
from letta.orm import (
|
||
Agent as AgentModel,
|
||
AgentsTags,
|
||
ArchivalPassage,
|
||
Block as BlockModel,
|
||
BlocksAgents,
|
||
Group as GroupModel,
|
||
GroupsAgents,
|
||
IdentitiesAgents,
|
||
Source as SourceModel,
|
||
SourcePassage,
|
||
SourcesAgents,
|
||
Tool as ToolModel,
|
||
ToolsAgents,
|
||
)
|
||
from letta.orm.errors import NoResultFound
|
||
from letta.orm.sandbox_config import AgentEnvironmentVariable, AgentEnvironmentVariable as AgentEnvironmentVariableModel
|
||
from letta.orm.sqlalchemy_base import AccessType
|
||
from letta.otel.tracing import trace_method
|
||
from letta.prompts.prompt_generator import PromptGenerator
|
||
from letta.schemas.agent import (
|
||
AgentState as PydanticAgentState,
|
||
AgentType,
|
||
CreateAgent,
|
||
InternalTemplateAgentCreate,
|
||
UpdateAgent,
|
||
get_prompt_template_for_agent_type,
|
||
)
|
||
from letta.schemas.block import DEFAULT_BLOCKS, Block as PydanticBlock, BlockUpdate
|
||
from letta.schemas.embedding_config import EmbeddingConfig
|
||
from letta.schemas.enums import ProviderType, TagMatchMode, ToolType, VectorDBProvider
|
||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||
from letta.schemas.group import Group as PydanticGroup, ManagerType
|
||
from letta.schemas.llm_config import LLMConfig
|
||
from letta.schemas.memory import ContextWindowOverview, Memory
|
||
from letta.schemas.message import Message, Message as PydanticMessage, MessageCreate, MessageUpdate
|
||
from letta.schemas.passage import Passage as PydanticPassage
|
||
from letta.schemas.source import Source as PydanticSource
|
||
from letta.schemas.tool import Tool as PydanticTool
|
||
from letta.schemas.tool_rule import ContinueToolRule, RequiresApprovalToolRule, TerminalToolRule
|
||
from letta.schemas.user import User as PydanticUser
|
||
from letta.serialize_schemas import MarshmallowAgentSchema
|
||
from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema
|
||
from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema
|
||
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
|
||
from letta.server.db import db_registry
|
||
from letta.services.archive_manager import ArchiveManager
|
||
from letta.services.block_manager import BlockManager
|
||
from letta.services.context_window_calculator.context_window_calculator import ContextWindowCalculator
|
||
from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, TiktokenCounter
|
||
from letta.services.file_processor.chunker.line_chunker import LineChunker
|
||
from letta.services.files_agents_manager import FileAgentManager
|
||
from letta.services.helpers.agent_manager_helper import (
|
||
_apply_filters,
|
||
_apply_identity_filters,
|
||
_apply_pagination,
|
||
_apply_pagination_async,
|
||
_apply_relationship_filters,
|
||
_apply_tag_filter,
|
||
_process_relationship,
|
||
_process_relationship_async,
|
||
build_agent_passage_query,
|
||
build_passage_query,
|
||
build_source_passage_query,
|
||
calculate_base_tools,
|
||
calculate_multi_agent_tools,
|
||
check_supports_structured_output,
|
||
compile_system_message,
|
||
derive_system_message,
|
||
initialize_message_sequence,
|
||
initialize_message_sequence_async,
|
||
package_initial_message_sequence,
|
||
validate_agent_exists_async,
|
||
)
|
||
from letta.services.identity_manager import IdentityManager
|
||
from letta.services.message_manager import MessageManager
|
||
from letta.services.passage_manager import PassageManager
|
||
from letta.services.source_manager import SourceManager
|
||
from letta.services.tool_manager import ToolManager
|
||
from letta.settings import DatabaseChoice, settings
|
||
from letta.utils import calculate_file_defaults_based_on_context_window, enforce_types, united_diff
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class AgentManager:
|
||
"""Manager class to handle business logic related to Agents."""
|
||
|
||
def __init__(self):
|
||
self.block_manager = BlockManager()
|
||
self.tool_manager = ToolManager()
|
||
self.source_manager = SourceManager()
|
||
self.message_manager = MessageManager()
|
||
self.passage_manager = PassageManager()
|
||
self.identity_manager = IdentityManager()
|
||
self.file_agent_manager = FileAgentManager()
|
||
self.archive_manager = ArchiveManager()
|
||
|
||
@staticmethod
|
||
def _should_exclude_model_from_base_tool_rules(model: str) -> bool:
|
||
"""Check if a model should be excluded from base tool rules based on model keywords."""
|
||
# First check if model contains any include keywords (overrides exclusion)
|
||
for include_keyword in INCLUDE_MODEL_KEYWORDS_BASE_TOOL_RULES:
|
||
if include_keyword in model:
|
||
return False
|
||
|
||
# Then check if model contains any exclude keywords
|
||
for exclude_keyword in EXCLUDE_MODEL_KEYWORDS_FROM_BASE_TOOL_RULES:
|
||
if exclude_keyword in model:
|
||
return True
|
||
|
||
return False
|
||
|
||
@staticmethod
|
||
def _resolve_tools(session, names: Set[str], ids: Set[str], org_id: str) -> Tuple[Dict[str, str], Dict[str, str]]:
|
||
"""
|
||
Bulk‑fetch all ToolModel rows matching either name ∈ names or id ∈ ids
|
||
(and scoped to this organization), and return two maps:
|
||
name_to_id, id_to_name.
|
||
Raises if any requested name or id was not found.
|
||
"""
|
||
stmt = select(ToolModel.id, ToolModel.name).where(
|
||
ToolModel.organization_id == org_id,
|
||
or_(
|
||
ToolModel.name.in_(names),
|
||
ToolModel.id.in_(ids),
|
||
),
|
||
)
|
||
rows = session.execute(stmt).all()
|
||
name_to_id = {name: tid for tid, name in rows}
|
||
id_to_name = {tid: name for tid, name in rows}
|
||
|
||
missing_names = names - set(name_to_id.keys())
|
||
missing_ids = ids - set(id_to_name.keys())
|
||
if missing_names:
|
||
raise ValueError(f"Tools not found by name: {missing_names}")
|
||
if missing_ids:
|
||
raise ValueError(f"Tools not found by id: {missing_ids}")
|
||
|
||
return name_to_id, id_to_name
|
||
|
||
@staticmethod
|
||
async def _resolve_tools_async(
|
||
session, names: Set[str], ids: Set[str], org_id: str
|
||
) -> Tuple[Dict[str, str], Dict[str, str], List[str]]:
|
||
"""
|
||
Bulk‑fetch all ToolModel rows matching either name ∈ names or id ∈ ids
|
||
(and scoped to this organization), and return two maps:
|
||
name_to_id, id_to_name.
|
||
Raises if any requested name or id was not found.
|
||
"""
|
||
stmt = select(ToolModel.id, ToolModel.name, ToolModel.default_requires_approval).where(
|
||
ToolModel.organization_id == org_id,
|
||
or_(
|
||
ToolModel.name.in_(names),
|
||
ToolModel.id.in_(ids),
|
||
),
|
||
)
|
||
result = await session.execute(stmt)
|
||
rows = result.fetchall() # Use fetchall()
|
||
name_to_id = {row[1]: row[0] for row in rows} # row[1] is name, row[0] is id
|
||
id_to_name = {row[0]: row[1] for row in rows} # row[0] is id, row[1] is name
|
||
requires_approval = [row[1] for row in rows if row[2]] # row[1] is name, row[2] is default_requires_approval
|
||
|
||
missing_names = names - set(name_to_id.keys())
|
||
missing_ids = ids - set(id_to_name.keys())
|
||
if missing_names:
|
||
raise ValueError(f"Tools not found by name: {missing_names}")
|
||
if missing_ids:
|
||
raise ValueError(f"Tools not found by id: {missing_ids}")
|
||
|
||
return name_to_id, id_to_name, requires_approval
|
||
|
||
@staticmethod
|
||
def _bulk_insert_pivot(session, table, rows: list[dict]):
|
||
if not rows:
|
||
return
|
||
|
||
dialect = session.bind.dialect.name
|
||
if dialect == "postgresql":
|
||
stmt = pg_insert(table).values(rows).on_conflict_do_nothing()
|
||
elif dialect == "sqlite":
|
||
stmt = sa.insert(table).values(rows).prefix_with("OR IGNORE")
|
||
else:
|
||
# fallback: filter out exact-duplicate dicts in Python
|
||
seen = set()
|
||
filtered = []
|
||
for row in rows:
|
||
key = tuple(sorted(row.items()))
|
||
if key not in seen:
|
||
seen.add(key)
|
||
filtered.append(row)
|
||
stmt = sa.insert(table).values(filtered)
|
||
|
||
session.execute(stmt)
|
||
|
||
@staticmethod
|
||
async def _bulk_insert_pivot_async(session, table, rows: list[dict]):
|
||
if not rows:
|
||
return
|
||
|
||
dialect = session.bind.dialect.name
|
||
if dialect == "postgresql":
|
||
stmt = pg_insert(table).values(rows).on_conflict_do_nothing()
|
||
elif dialect == "sqlite":
|
||
stmt = sa.insert(table).values(rows).prefix_with("OR IGNORE")
|
||
else:
|
||
# fallback: filter out exact-duplicate dicts in Python
|
||
seen = set()
|
||
filtered = []
|
||
for row in rows:
|
||
key = tuple(sorted(row.items()))
|
||
if key not in seen:
|
||
seen.add(key)
|
||
filtered.append(row)
|
||
stmt = sa.insert(table).values(filtered)
|
||
|
||
await session.execute(stmt)
|
||
|
||
@staticmethod
|
||
def _replace_pivot_rows(session, table, agent_id: str, rows: list[dict]):
|
||
"""
|
||
Replace all pivot rows for an agent with *exactly* the provided list.
|
||
Uses two bulk statements (DELETE + INSERT ... ON CONFLICT DO NOTHING).
|
||
"""
|
||
# delete all existing rows for this agent
|
||
session.execute(delete(table).where(table.c.agent_id == agent_id))
|
||
if rows:
|
||
AgentManager._bulk_insert_pivot(session, table, rows)
|
||
|
||
@staticmethod
|
||
async def _replace_pivot_rows_async(session, table, agent_id: str, rows: list[dict]):
|
||
"""
|
||
Replace all pivot rows for an agent atomically using MERGE pattern.
|
||
"""
|
||
dialect = session.bind.dialect.name
|
||
|
||
if dialect == "postgresql":
|
||
if rows:
|
||
# separate upsert and delete operations
|
||
stmt = pg_insert(table).values(rows)
|
||
stmt = stmt.on_conflict_do_nothing()
|
||
await session.execute(stmt)
|
||
|
||
# delete rows not in new set
|
||
pk_names = [c.name for c in table.primary_key.columns]
|
||
new_keys = [tuple(r[c] for c in pk_names) for r in rows]
|
||
await session.execute(
|
||
delete(table).where(table.c.agent_id == agent_id, ~tuple_(*[table.c[c] for c in pk_names]).in_(new_keys))
|
||
)
|
||
else:
|
||
# if no rows to insert, just delete all
|
||
await session.execute(delete(table).where(table.c.agent_id == agent_id))
|
||
|
||
elif dialect == "sqlite":
|
||
if rows:
|
||
stmt = sa.insert(table).values(rows).prefix_with("OR REPLACE")
|
||
await session.execute(stmt)
|
||
|
||
if rows:
|
||
primary_key_cols = [table.c[c.name] for c in table.primary_key.columns]
|
||
new_keys = [tuple(r[c.name] for c in table.primary_key.columns) for r in rows]
|
||
await session.execute(delete(table).where(table.c.agent_id == agent_id, ~tuple_(*primary_key_cols).in_(new_keys)))
|
||
else:
|
||
await session.execute(delete(table).where(table.c.agent_id == agent_id))
|
||
|
||
else:
|
||
# fallback: use original DELETE + INSERT pattern
|
||
await session.execute(delete(table).where(table.c.agent_id == agent_id))
|
||
if rows:
|
||
await AgentManager._bulk_insert_pivot_async(session, table, rows)
|
||
|
||
# ======================================================================================================================
|
||
# Basic CRUD operations
|
||
# ======================================================================================================================
|
||
@trace_method
|
||
def create_agent(self, agent_create: CreateAgent, actor: PydanticUser, _test_only_force_id: Optional[str] = None) -> PydanticAgentState:
|
||
# validate required configs
|
||
if not agent_create.llm_config or not agent_create.embedding_config:
|
||
raise ValueError("llm_config and embedding_config are required")
|
||
|
||
# blocks
|
||
block_ids = list(agent_create.block_ids or [])
|
||
if agent_create.memory_blocks:
|
||
pydantic_blocks = [PydanticBlock(**b.model_dump(to_orm=True)) for b in agent_create.memory_blocks]
|
||
created_blocks = self.block_manager.batch_create_blocks(
|
||
pydantic_blocks,
|
||
actor=actor,
|
||
)
|
||
block_ids.extend([blk.id for blk in created_blocks])
|
||
|
||
# tools
|
||
tool_names = set(agent_create.tools or [])
|
||
if agent_create.include_base_tools:
|
||
if agent_create.agent_type == AgentType.voice_sleeptime_agent:
|
||
tool_names |= set(BASE_VOICE_SLEEPTIME_TOOLS)
|
||
elif agent_create.agent_type == AgentType.voice_convo_agent:
|
||
tool_names |= set(BASE_VOICE_SLEEPTIME_CHAT_TOOLS)
|
||
elif agent_create.agent_type == AgentType.sleeptime_agent:
|
||
tool_names |= set(BASE_SLEEPTIME_TOOLS)
|
||
elif agent_create.enable_sleeptime:
|
||
tool_names |= set(BASE_SLEEPTIME_CHAT_TOOLS)
|
||
elif agent_create.agent_type == AgentType.memgpt_v2_agent:
|
||
tool_names |= calculate_base_tools(is_v2=True)
|
||
elif agent_create.agent_type == AgentType.react_agent:
|
||
pass # no default tools
|
||
elif agent_create.agent_type == AgentType.workflow_agent:
|
||
pass # no default tools
|
||
else:
|
||
tool_names |= calculate_base_tools(is_v2=False)
|
||
if agent_create.include_multi_agent_tools:
|
||
tool_names |= calculate_multi_agent_tools()
|
||
|
||
supplied_ids = set(agent_create.tool_ids or [])
|
||
|
||
source_ids = agent_create.source_ids or []
|
||
identity_ids = agent_create.identity_ids or []
|
||
tag_values = agent_create.tags or []
|
||
|
||
with db_registry.session() as session:
|
||
with session.begin():
|
||
name_to_id, id_to_name = self._resolve_tools(
|
||
session,
|
||
tool_names,
|
||
supplied_ids,
|
||
actor.organization_id,
|
||
)
|
||
|
||
tool_ids = set(name_to_id.values()) | set(id_to_name.keys())
|
||
tool_names = set(name_to_id.keys()) # now canonical
|
||
|
||
tool_rules = list(agent_create.tool_rules or [])
|
||
|
||
# Override include_base_tool_rules to False if model matches exclusion keywords and include_base_tool_rules is not explicitly set to True
|
||
if (
|
||
(
|
||
self._should_exclude_model_from_base_tool_rules(agent_create.llm_config.model)
|
||
and agent_create.include_base_tool_rules is None
|
||
)
|
||
and agent_create.agent_type != AgentType.sleeptime_agent
|
||
) or agent_create.include_base_tool_rules is False:
|
||
agent_create.include_base_tool_rules = False
|
||
logger.info(f"Overriding include_base_tool_rules to False for model: {agent_create.llm_config.model}")
|
||
else:
|
||
agent_create.include_base_tool_rules = True
|
||
|
||
should_add_base_tool_rules = agent_create.include_base_tool_rules
|
||
if should_add_base_tool_rules:
|
||
for tn in tool_names:
|
||
if tn in {"send_message", "send_message_to_agent_async", "memory_finish_edits"}:
|
||
tool_rules.append(TerminalToolRule(tool_name=tn))
|
||
elif tn in (BASE_TOOLS + BASE_MEMORY_TOOLS + BASE_MEMORY_TOOLS_V2 + BASE_SLEEPTIME_TOOLS):
|
||
tool_rules.append(ContinueToolRule(tool_name=tn))
|
||
|
||
if tool_rules:
|
||
check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=tool_rules)
|
||
|
||
new_agent = AgentModel(
|
||
name=agent_create.name,
|
||
system=derive_system_message(
|
||
agent_type=agent_create.agent_type,
|
||
enable_sleeptime=agent_create.enable_sleeptime,
|
||
system=agent_create.system,
|
||
),
|
||
hidden=agent_create.hidden,
|
||
agent_type=agent_create.agent_type,
|
||
llm_config=agent_create.llm_config,
|
||
embedding_config=agent_create.embedding_config,
|
||
organization_id=actor.organization_id,
|
||
description=agent_create.description,
|
||
metadata_=agent_create.metadata,
|
||
tool_rules=tool_rules,
|
||
project_id=agent_create.project_id,
|
||
template_id=agent_create.template_id,
|
||
base_template_id=agent_create.base_template_id,
|
||
message_buffer_autoclear=agent_create.message_buffer_autoclear,
|
||
enable_sleeptime=agent_create.enable_sleeptime,
|
||
response_format=agent_create.response_format,
|
||
created_by_id=actor.id,
|
||
last_updated_by_id=actor.id,
|
||
timezone=agent_create.timezone,
|
||
max_files_open=agent_create.max_files_open,
|
||
per_file_view_window_char_limit=agent_create.per_file_view_window_char_limit,
|
||
)
|
||
|
||
# Set template fields for InternalTemplateAgentCreate (similar to group creation)
|
||
if isinstance(agent_create, InternalTemplateAgentCreate):
|
||
new_agent.base_template_id = agent_create.base_template_id
|
||
new_agent.template_id = agent_create.template_id
|
||
new_agent.deployment_id = agent_create.deployment_id
|
||
new_agent.entity_id = agent_create.entity_id
|
||
|
||
if _test_only_force_id:
|
||
new_agent.id = _test_only_force_id
|
||
|
||
session.add(new_agent)
|
||
session.flush()
|
||
aid = new_agent.id
|
||
|
||
# Note: These methods may need async versions if they perform database operations
|
||
self._bulk_insert_pivot(
|
||
session,
|
||
ToolsAgents.__table__,
|
||
[{"agent_id": aid, "tool_id": tid} for tid in tool_ids],
|
||
)
|
||
|
||
if block_ids:
|
||
result = session.execute(select(BlockModel.id, BlockModel.label).where(BlockModel.id.in_(block_ids)))
|
||
rows = [{"agent_id": aid, "block_id": bid, "block_label": lbl} for bid, lbl in result.all()]
|
||
self._bulk_insert_pivot(session, BlocksAgents.__table__, rows)
|
||
|
||
self._bulk_insert_pivot(
|
||
session,
|
||
SourcesAgents.__table__,
|
||
[{"agent_id": aid, "source_id": sid} for sid in source_ids],
|
||
)
|
||
self._bulk_insert_pivot(
|
||
session,
|
||
AgentsTags.__table__,
|
||
[{"agent_id": aid, "tag": tag} for tag in tag_values],
|
||
)
|
||
self._bulk_insert_pivot(
|
||
session,
|
||
IdentitiesAgents.__table__,
|
||
[{"agent_id": aid, "identity_id": iid} for iid in identity_ids],
|
||
)
|
||
|
||
if agent_create.tool_exec_environment_variables:
|
||
env_rows = [
|
||
{
|
||
"agent_id": aid,
|
||
"key": key,
|
||
"value": val,
|
||
"organization_id": actor.organization_id,
|
||
}
|
||
for key, val in agent_create.tool_exec_environment_variables.items()
|
||
]
|
||
session.execute(insert(AgentEnvironmentVariable).values(env_rows))
|
||
|
||
# initial message sequence
|
||
init_messages = self._generate_initial_message_sequence(
|
||
actor,
|
||
agent_state=new_agent.to_pydantic(include_relationships={"memory"}),
|
||
supplied_initial_message_sequence=agent_create.initial_message_sequence,
|
||
)
|
||
new_agent.message_ids = [msg.id for msg in init_messages]
|
||
|
||
session.refresh(new_agent)
|
||
|
||
# Using the synchronous version since we don't have an async version yet
|
||
# If you implement an async version of create_many_messages, you can switch to that
|
||
self.message_manager.create_many_messages(pydantic_msgs=init_messages, actor=actor)
|
||
return new_agent.to_pydantic()
|
||
|
||
@trace_method
|
||
async def create_agent_async(
|
||
self,
|
||
agent_create: CreateAgent,
|
||
actor: PydanticUser,
|
||
_test_only_force_id: Optional[str] = None,
|
||
_init_with_no_messages: bool = False,
|
||
) -> PydanticAgentState:
|
||
# validate required configs
|
||
if not agent_create.llm_config or not agent_create.embedding_config:
|
||
raise ValueError("llm_config and embedding_config are required")
|
||
|
||
if agent_create.reasoning is not None:
|
||
agent_create.llm_config = LLMConfig.apply_reasoning_setting_to_config(agent_create.llm_config, agent_create.reasoning)
|
||
|
||
# blocks
|
||
block_ids = list(agent_create.block_ids or [])
|
||
if agent_create.memory_blocks:
|
||
pydantic_blocks = [PydanticBlock(**b.model_dump(to_orm=True)) for b in agent_create.memory_blocks]
|
||
|
||
# Inject a description for the default blocks if the user didn't specify them
|
||
# Used for `persona`, `human`, etc
|
||
default_blocks = {block.label: block for block in DEFAULT_BLOCKS}
|
||
for block in pydantic_blocks:
|
||
if block.label in default_blocks:
|
||
if block.description is None:
|
||
block.description = default_blocks[block.label].description
|
||
|
||
# Actually create the blocks
|
||
created_blocks = await self.block_manager.batch_create_blocks_async(
|
||
pydantic_blocks,
|
||
actor=actor,
|
||
)
|
||
block_ids.extend([blk.id for blk in created_blocks])
|
||
|
||
# tools
|
||
tool_names = set(agent_create.tools or [])
|
||
if agent_create.include_base_tools:
|
||
if agent_create.agent_type == AgentType.voice_sleeptime_agent:
|
||
tool_names |= set(BASE_VOICE_SLEEPTIME_TOOLS)
|
||
elif agent_create.agent_type == AgentType.voice_convo_agent:
|
||
tool_names |= set(BASE_VOICE_SLEEPTIME_CHAT_TOOLS)
|
||
elif agent_create.agent_type == AgentType.sleeptime_agent:
|
||
tool_names |= set(BASE_SLEEPTIME_TOOLS)
|
||
elif agent_create.enable_sleeptime:
|
||
tool_names |= set(BASE_SLEEPTIME_CHAT_TOOLS)
|
||
elif agent_create.agent_type == AgentType.memgpt_v2_agent:
|
||
tool_names |= calculate_base_tools(is_v2=True)
|
||
elif agent_create.agent_type == AgentType.react_agent:
|
||
pass # no default tools
|
||
elif agent_create.agent_type == AgentType.workflow_agent:
|
||
pass # no default tools
|
||
else:
|
||
tool_names |= calculate_base_tools(is_v2=False)
|
||
if agent_create.include_multi_agent_tools:
|
||
tool_names |= calculate_multi_agent_tools()
|
||
|
||
# take out the deprecated tool names
|
||
tool_names.difference_update(set(DEPRECATED_LETTA_TOOLS))
|
||
|
||
supplied_ids = set(agent_create.tool_ids or [])
|
||
|
||
source_ids = agent_create.source_ids or []
|
||
|
||
# Create default source if requested
|
||
if agent_create.include_default_source:
|
||
default_source = PydanticSource(
|
||
name=f"{agent_create.name} External Data Source",
|
||
embedding_config=agent_create.embedding_config,
|
||
)
|
||
created_source = await self.source_manager.create_source(default_source, actor)
|
||
source_ids.append(created_source.id)
|
||
|
||
identity_ids = agent_create.identity_ids or []
|
||
tag_values = agent_create.tags or []
|
||
|
||
# if the agent type is workflow, we set the autoclear to forced true
|
||
if agent_create.agent_type == AgentType.workflow_agent:
|
||
agent_create.message_buffer_autoclear = True
|
||
|
||
async with db_registry.async_session() as session:
|
||
async with session.begin():
|
||
# Note: This will need to be modified if _resolve_tools needs an async version
|
||
name_to_id, id_to_name, requires_approval = await self._resolve_tools_async(
|
||
session,
|
||
tool_names,
|
||
supplied_ids,
|
||
actor.organization_id,
|
||
)
|
||
|
||
tool_ids = set(name_to_id.values()) | set(id_to_name.keys())
|
||
tool_names = set(name_to_id.keys()) # now canonical
|
||
tool_rules = list(agent_create.tool_rules or [])
|
||
|
||
# Override include_base_tool_rules to False if model matches exclusion keywords and include_base_tool_rules is not explicitly set to True
|
||
if (
|
||
(
|
||
self._should_exclude_model_from_base_tool_rules(agent_create.llm_config.model)
|
||
and agent_create.include_base_tool_rules is None
|
||
)
|
||
and agent_create.agent_type != AgentType.sleeptime_agent
|
||
) or agent_create.include_base_tool_rules is False:
|
||
agent_create.include_base_tool_rules = False
|
||
logger.info(f"Overriding include_base_tool_rules to False for model: {agent_create.llm_config.model}")
|
||
else:
|
||
agent_create.include_base_tool_rules = True
|
||
|
||
should_add_base_tool_rules = agent_create.include_base_tool_rules
|
||
if should_add_base_tool_rules:
|
||
for tn in tool_names:
|
||
if tn in {"send_message", "send_message_to_agent_async", "memory_finish_edits"}:
|
||
tool_rules.append(TerminalToolRule(tool_name=tn))
|
||
elif tn in (BASE_TOOLS + BASE_MEMORY_TOOLS + BASE_MEMORY_TOOLS_V2 + BASE_SLEEPTIME_TOOLS):
|
||
tool_rules.append(ContinueToolRule(tool_name=tn))
|
||
|
||
for tool_with_requires_approval in requires_approval:
|
||
tool_rules.append(RequiresApprovalToolRule(tool_name=tool_with_requires_approval))
|
||
|
||
if tool_rules:
|
||
check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=tool_rules)
|
||
|
||
new_agent = AgentModel(
|
||
name=agent_create.name,
|
||
system=derive_system_message(
|
||
agent_type=agent_create.agent_type,
|
||
enable_sleeptime=agent_create.enable_sleeptime,
|
||
system=agent_create.system,
|
||
),
|
||
agent_type=agent_create.agent_type,
|
||
llm_config=agent_create.llm_config,
|
||
embedding_config=agent_create.embedding_config,
|
||
organization_id=actor.organization_id,
|
||
description=agent_create.description,
|
||
metadata_=agent_create.metadata,
|
||
tool_rules=tool_rules,
|
||
hidden=agent_create.hidden,
|
||
project_id=agent_create.project_id,
|
||
template_id=agent_create.template_id,
|
||
base_template_id=agent_create.base_template_id,
|
||
message_buffer_autoclear=agent_create.message_buffer_autoclear,
|
||
enable_sleeptime=agent_create.enable_sleeptime,
|
||
response_format=agent_create.response_format,
|
||
created_by_id=actor.id,
|
||
last_updated_by_id=actor.id,
|
||
timezone=agent_create.timezone if agent_create.timezone else DEFAULT_TIMEZONE,
|
||
max_files_open=agent_create.max_files_open,
|
||
per_file_view_window_char_limit=agent_create.per_file_view_window_char_limit,
|
||
)
|
||
|
||
# Set template fields for InternalTemplateAgentCreate (similar to group creation)
|
||
if isinstance(agent_create, InternalTemplateAgentCreate):
|
||
new_agent.base_template_id = agent_create.base_template_id
|
||
new_agent.template_id = agent_create.template_id
|
||
new_agent.deployment_id = agent_create.deployment_id
|
||
new_agent.entity_id = agent_create.entity_id
|
||
|
||
if _test_only_force_id:
|
||
new_agent.id = _test_only_force_id
|
||
|
||
session.add(new_agent)
|
||
await session.flush()
|
||
aid = new_agent.id
|
||
|
||
# Note: These methods may need async versions if they perform database operations
|
||
await self._bulk_insert_pivot_async(
|
||
session,
|
||
ToolsAgents.__table__,
|
||
[{"agent_id": aid, "tool_id": tid} for tid in tool_ids],
|
||
)
|
||
|
||
if block_ids:
|
||
result = await session.execute(select(BlockModel.id, BlockModel.label).where(BlockModel.id.in_(block_ids)))
|
||
rows = [{"agent_id": aid, "block_id": bid, "block_label": lbl} for bid, lbl in result.all()]
|
||
await self._bulk_insert_pivot_async(session, BlocksAgents.__table__, rows)
|
||
|
||
await self._bulk_insert_pivot_async(
|
||
session,
|
||
SourcesAgents.__table__,
|
||
[{"agent_id": aid, "source_id": sid} for sid in source_ids],
|
||
)
|
||
await self._bulk_insert_pivot_async(
|
||
session,
|
||
AgentsTags.__table__,
|
||
[{"agent_id": aid, "tag": tag} for tag in tag_values],
|
||
)
|
||
await self._bulk_insert_pivot_async(
|
||
session,
|
||
IdentitiesAgents.__table__,
|
||
[{"agent_id": aid, "identity_id": iid} for iid in identity_ids],
|
||
)
|
||
|
||
env_rows = []
|
||
if agent_create.tool_exec_environment_variables:
|
||
env_rows = [
|
||
{
|
||
"agent_id": aid,
|
||
"key": key,
|
||
"value": val,
|
||
"organization_id": actor.organization_id,
|
||
}
|
||
for key, val in agent_create.tool_exec_environment_variables.items()
|
||
]
|
||
result = await session.execute(insert(AgentEnvironmentVariable).values(env_rows).returning(AgentEnvironmentVariable.id))
|
||
env_rows = [{**row, "id": env_var_id} for row, env_var_id in zip(env_rows, result.scalars().all())]
|
||
|
||
include_relationships = []
|
||
if tool_ids:
|
||
include_relationships.append("tools")
|
||
if source_ids:
|
||
include_relationships.append("sources")
|
||
if block_ids:
|
||
include_relationships.append("memory")
|
||
if identity_ids:
|
||
include_relationships.append("identity_ids")
|
||
if tag_values:
|
||
include_relationships.append("tags")
|
||
|
||
result = await new_agent.to_pydantic_async(include_relationships=include_relationships)
|
||
|
||
if agent_create.tool_exec_environment_variables and env_rows:
|
||
result.tool_exec_environment_variables = [AgentEnvironmentVariable(**row) for row in env_rows]
|
||
|
||
# initial message sequence (skip if _init_with_no_messages is True)
|
||
if not _init_with_no_messages:
|
||
init_messages = await self._generate_initial_message_sequence_async(
|
||
actor,
|
||
agent_state=result,
|
||
supplied_initial_message_sequence=agent_create.initial_message_sequence,
|
||
)
|
||
result.message_ids = [msg.id for msg in init_messages]
|
||
new_agent.message_ids = [msg.id for msg in init_messages]
|
||
await new_agent.update_async(session, no_refresh=True)
|
||
else:
|
||
init_messages = []
|
||
|
||
# Only create messages if we initialized with messages
|
||
if not _init_with_no_messages:
|
||
await self.message_manager.create_many_messages_async(
|
||
pydantic_msgs=init_messages, actor=actor, embedding_config=result.embedding_config
|
||
)
|
||
return result
|
||
|
||
@enforce_types
|
||
def _generate_initial_message_sequence(
|
||
self, actor: PydanticUser, agent_state: PydanticAgentState, supplied_initial_message_sequence: Optional[List[MessageCreate]] = None
|
||
) -> List[Message]:
|
||
init_messages = initialize_message_sequence(
|
||
agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True
|
||
)
|
||
if supplied_initial_message_sequence is not None:
|
||
# We always need the system prompt up front
|
||
system_message_obj = PydanticMessage.dict_to_message(
|
||
agent_id=agent_state.id,
|
||
model=agent_state.llm_config.model,
|
||
openai_message_dict=init_messages[0],
|
||
)
|
||
# Don't use anything else in the pregen sequence, instead use the provided sequence
|
||
init_messages = [system_message_obj]
|
||
init_messages.extend(
|
||
package_initial_message_sequence(
|
||
agent_state.id, supplied_initial_message_sequence, agent_state.llm_config.model, agent_state.timezone, actor
|
||
)
|
||
)
|
||
else:
|
||
init_messages = [
|
||
PydanticMessage.dict_to_message(agent_id=agent_state.id, model=agent_state.llm_config.model, openai_message_dict=msg)
|
||
for msg in init_messages
|
||
]
|
||
|
||
return init_messages
|
||
|
||
@enforce_types
|
||
async def _generate_initial_message_sequence_async(
|
||
self, actor: PydanticUser, agent_state: PydanticAgentState, supplied_initial_message_sequence: Optional[List[MessageCreate]] = None
|
||
) -> List[Message]:
|
||
init_messages = await initialize_message_sequence_async(
|
||
agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True
|
||
)
|
||
if supplied_initial_message_sequence is not None:
|
||
# We always need the system prompt up front
|
||
system_message_obj = PydanticMessage.dict_to_message(
|
||
agent_id=agent_state.id,
|
||
model=agent_state.llm_config.model,
|
||
openai_message_dict=init_messages[0],
|
||
)
|
||
# Don't use anything else in the pregen sequence, instead use the provided sequence
|
||
init_messages = [system_message_obj]
|
||
init_messages.extend(
|
||
package_initial_message_sequence(
|
||
agent_state.id, supplied_initial_message_sequence, agent_state.llm_config.model, agent_state.timezone, actor
|
||
)
|
||
)
|
||
else:
|
||
init_messages = [
|
||
PydanticMessage.dict_to_message(agent_id=agent_state.id, model=agent_state.llm_config.model, openai_message_dict=msg)
|
||
for msg in init_messages
|
||
]
|
||
|
||
return init_messages
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def append_initial_message_sequence_to_in_context_messages(
|
||
self, actor: PydanticUser, agent_state: PydanticAgentState, initial_message_sequence: Optional[List[MessageCreate]] = None
|
||
) -> PydanticAgentState:
|
||
init_messages = self._generate_initial_message_sequence(actor, agent_state, initial_message_sequence)
|
||
return self.append_to_in_context_messages(init_messages, agent_id=agent_state.id, actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def append_initial_message_sequence_to_in_context_messages_async(
|
||
self, actor: PydanticUser, agent_state: PydanticAgentState, initial_message_sequence: Optional[List[MessageCreate]] = None
|
||
) -> PydanticAgentState:
|
||
init_messages = await self._generate_initial_message_sequence_async(actor, agent_state, initial_message_sequence)
|
||
return await self.append_to_in_context_messages_async(init_messages, agent_id=agent_state.id, actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def update_agent(
|
||
self,
|
||
agent_id: str,
|
||
agent_update: UpdateAgent,
|
||
actor: PydanticUser,
|
||
) -> PydanticAgentState:
|
||
new_tools = set(agent_update.tool_ids or [])
|
||
new_sources = set(agent_update.source_ids or [])
|
||
new_blocks = set(agent_update.block_ids or [])
|
||
new_idents = set(agent_update.identity_ids or [])
|
||
new_tags = set(agent_update.tags or [])
|
||
|
||
with db_registry.session() as session, session.begin():
|
||
agent: AgentModel = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
agent.updated_at = datetime.now(timezone.utc)
|
||
agent.last_updated_by_id = actor.id
|
||
|
||
scalar_updates = {
|
||
"name": agent_update.name,
|
||
"system": agent_update.system,
|
||
"llm_config": agent_update.llm_config,
|
||
"embedding_config": agent_update.embedding_config,
|
||
"message_ids": agent_update.message_ids,
|
||
"tool_rules": agent_update.tool_rules,
|
||
"description": agent_update.description,
|
||
"project_id": agent_update.project_id,
|
||
"template_id": agent_update.template_id,
|
||
"base_template_id": agent_update.base_template_id,
|
||
"message_buffer_autoclear": agent_update.message_buffer_autoclear,
|
||
"enable_sleeptime": agent_update.enable_sleeptime,
|
||
"response_format": agent_update.response_format,
|
||
"last_run_completion": agent_update.last_run_completion,
|
||
"last_run_duration_ms": agent_update.last_run_duration_ms,
|
||
"max_files_open": agent_update.max_files_open,
|
||
"per_file_view_window_char_limit": agent_update.per_file_view_window_char_limit,
|
||
"timezone": agent_update.timezone,
|
||
}
|
||
for col, val in scalar_updates.items():
|
||
if val is not None:
|
||
setattr(agent, col, val)
|
||
|
||
if agent_update.metadata is not None:
|
||
agent.metadata_ = agent_update.metadata
|
||
|
||
aid = agent.id
|
||
|
||
if agent_update.tool_ids is not None:
|
||
self._replace_pivot_rows(
|
||
session,
|
||
ToolsAgents.__table__,
|
||
aid,
|
||
[{"agent_id": aid, "tool_id": tid} for tid in new_tools],
|
||
)
|
||
session.expire(agent, ["tools"])
|
||
|
||
if agent_update.source_ids is not None:
|
||
self._replace_pivot_rows(
|
||
session,
|
||
SourcesAgents.__table__,
|
||
aid,
|
||
[{"agent_id": aid, "source_id": sid} for sid in new_sources],
|
||
)
|
||
session.expire(agent, ["sources"])
|
||
|
||
if agent_update.block_ids is not None:
|
||
rows = []
|
||
if new_blocks:
|
||
label_map = {
|
||
bid: lbl
|
||
for bid, lbl in session.execute(select(BlockModel.id, BlockModel.label).where(BlockModel.id.in_(new_blocks)))
|
||
}
|
||
rows = [{"agent_id": aid, "block_id": bid, "block_label": label_map[bid]} for bid in new_blocks]
|
||
|
||
self._replace_pivot_rows(session, BlocksAgents.__table__, aid, rows)
|
||
session.expire(agent, ["core_memory"])
|
||
|
||
if agent_update.identity_ids is not None:
|
||
self._replace_pivot_rows(
|
||
session,
|
||
IdentitiesAgents.__table__,
|
||
aid,
|
||
[{"agent_id": aid, "identity_id": iid} for iid in new_idents],
|
||
)
|
||
session.expire(agent, ["identities"])
|
||
|
||
if agent_update.tags is not None:
|
||
self._replace_pivot_rows(
|
||
session,
|
||
AgentsTags.__table__,
|
||
aid,
|
||
[{"agent_id": aid, "tag": tag} for tag in new_tags],
|
||
)
|
||
session.expire(agent, ["tags"])
|
||
|
||
if agent_update.tool_exec_environment_variables is not None:
|
||
session.execute(delete(AgentEnvironmentVariable).where(AgentEnvironmentVariable.agent_id == aid))
|
||
env_rows = [
|
||
{
|
||
"agent_id": aid,
|
||
"key": k,
|
||
"value": v,
|
||
"organization_id": agent.organization_id,
|
||
}
|
||
for k, v in agent_update.tool_exec_environment_variables.items()
|
||
]
|
||
if env_rows:
|
||
self._bulk_insert_pivot(session, AgentEnvironmentVariable.__table__, env_rows)
|
||
session.expire(agent, ["tool_exec_environment_variables"])
|
||
|
||
if agent_update.enable_sleeptime and agent_update.system is None:
|
||
agent.system = derive_system_message(
|
||
agent_type=agent.agent_type,
|
||
enable_sleeptime=agent_update.enable_sleeptime,
|
||
system=agent.system,
|
||
)
|
||
|
||
session.flush()
|
||
session.refresh(agent)
|
||
|
||
return agent.to_pydantic()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def update_agent_async(
|
||
self,
|
||
agent_id: str,
|
||
agent_update: UpdateAgent,
|
||
actor: PydanticUser,
|
||
) -> PydanticAgentState:
|
||
new_tools = set(agent_update.tool_ids or [])
|
||
new_sources = set(agent_update.source_ids or [])
|
||
new_blocks = set(agent_update.block_ids or [])
|
||
new_idents = set(agent_update.identity_ids or [])
|
||
new_tags = set(agent_update.tags or [])
|
||
|
||
async with db_registry.async_session() as session, session.begin():
|
||
agent: AgentModel = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
agent.updated_at = datetime.now(timezone.utc)
|
||
agent.last_updated_by_id = actor.id
|
||
|
||
if agent_update.reasoning is not None:
|
||
llm_config = agent_update.llm_config or agent.llm_config
|
||
agent_update.llm_config = LLMConfig.apply_reasoning_setting_to_config(llm_config, agent_update.reasoning)
|
||
|
||
scalar_updates = {
|
||
"name": agent_update.name,
|
||
"system": agent_update.system,
|
||
"llm_config": agent_update.llm_config,
|
||
"embedding_config": agent_update.embedding_config,
|
||
"message_ids": agent_update.message_ids,
|
||
"tool_rules": agent_update.tool_rules,
|
||
"description": agent_update.description,
|
||
"project_id": agent_update.project_id,
|
||
"template_id": agent_update.template_id,
|
||
"base_template_id": agent_update.base_template_id,
|
||
"message_buffer_autoclear": agent_update.message_buffer_autoclear,
|
||
"enable_sleeptime": agent_update.enable_sleeptime,
|
||
"response_format": agent_update.response_format,
|
||
"last_run_completion": agent_update.last_run_completion,
|
||
"last_run_duration_ms": agent_update.last_run_duration_ms,
|
||
"timezone": agent_update.timezone,
|
||
"max_files_open": agent_update.max_files_open,
|
||
"per_file_view_window_char_limit": agent_update.per_file_view_window_char_limit,
|
||
}
|
||
for col, val in scalar_updates.items():
|
||
if val is not None:
|
||
setattr(agent, col, val)
|
||
|
||
if agent_update.metadata is not None:
|
||
agent.metadata_ = agent_update.metadata
|
||
|
||
aid = agent.id
|
||
|
||
if agent_update.tool_ids is not None:
|
||
await self._replace_pivot_rows_async(
|
||
session,
|
||
ToolsAgents.__table__,
|
||
aid,
|
||
[{"agent_id": aid, "tool_id": tid} for tid in new_tools],
|
||
)
|
||
session.expire(agent, ["tools"])
|
||
|
||
if agent_update.source_ids is not None:
|
||
await self._replace_pivot_rows_async(
|
||
session,
|
||
SourcesAgents.__table__,
|
||
aid,
|
||
[{"agent_id": aid, "source_id": sid} for sid in new_sources],
|
||
)
|
||
session.expire(agent, ["sources"])
|
||
|
||
if agent_update.block_ids is not None:
|
||
rows = []
|
||
if new_blocks:
|
||
result = await session.execute(select(BlockModel.id, BlockModel.label).where(BlockModel.id.in_(new_blocks)))
|
||
label_map = {bid: lbl for bid, lbl in result.all()}
|
||
rows = [{"agent_id": aid, "block_id": bid, "block_label": label_map[bid]} for bid in new_blocks]
|
||
|
||
await self._replace_pivot_rows_async(session, BlocksAgents.__table__, aid, rows)
|
||
session.expire(agent, ["core_memory"])
|
||
|
||
if agent_update.identity_ids is not None:
|
||
await self._replace_pivot_rows_async(
|
||
session,
|
||
IdentitiesAgents.__table__,
|
||
aid,
|
||
[{"agent_id": aid, "identity_id": iid} for iid in new_idents],
|
||
)
|
||
session.expire(agent, ["identities"])
|
||
|
||
if agent_update.tags is not None:
|
||
await self._replace_pivot_rows_async(
|
||
session,
|
||
AgentsTags.__table__,
|
||
aid,
|
||
[{"agent_id": aid, "tag": tag} for tag in new_tags],
|
||
)
|
||
session.expire(agent, ["tags"])
|
||
|
||
if agent_update.tool_exec_environment_variables is not None:
|
||
await session.execute(delete(AgentEnvironmentVariable).where(AgentEnvironmentVariable.agent_id == aid))
|
||
env_rows = [
|
||
{
|
||
"agent_id": aid,
|
||
"key": k,
|
||
"value": v,
|
||
"organization_id": agent.organization_id,
|
||
}
|
||
for k, v in agent_update.tool_exec_environment_variables.items()
|
||
]
|
||
if env_rows:
|
||
await self._bulk_insert_pivot_async(session, AgentEnvironmentVariable.__table__, env_rows)
|
||
session.expire(agent, ["tool_exec_environment_variables"])
|
||
|
||
if agent_update.enable_sleeptime and agent_update.system is None:
|
||
agent.system = derive_system_message(
|
||
agent_type=agent.agent_type,
|
||
enable_sleeptime=agent_update.enable_sleeptime,
|
||
system=agent.system,
|
||
)
|
||
|
||
await session.flush()
|
||
await session.refresh(agent)
|
||
|
||
return await agent.to_pydantic_async()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def update_message_ids_async(
|
||
self,
|
||
agent_id: str,
|
||
message_ids: List[str],
|
||
actor: PydanticUser,
|
||
) -> None:
|
||
async with db_registry.async_session() as session:
|
||
query = select(AgentModel)
|
||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||
query = query.where(AgentModel.id == agent_id)
|
||
query = _apply_relationship_filters(query, include_relationships=[])
|
||
|
||
result = await session.execute(query)
|
||
agent = result.scalar_one_or_none()
|
||
|
||
agent.updated_at = datetime.now(timezone.utc)
|
||
agent.last_updated_by_id = actor.id
|
||
agent.message_ids = message_ids
|
||
|
||
await agent.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||
await session.commit()
|
||
|
||
# TODO: Make this general and think about how to roll this into sqlalchemybase
|
||
@trace_method
|
||
def list_agents(
|
||
self,
|
||
actor: PydanticUser,
|
||
name: Optional[str] = None,
|
||
tags: Optional[List[str]] = None,
|
||
match_all_tags: bool = False,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
limit: Optional[int] = 50,
|
||
query_text: Optional[str] = None,
|
||
project_id: Optional[str] = None,
|
||
template_id: Optional[str] = None,
|
||
base_template_id: Optional[str] = None,
|
||
identity_id: Optional[str] = None,
|
||
identifier_keys: Optional[List[str]] = None,
|
||
include_relationships: Optional[List[str]] = None,
|
||
ascending: bool = True,
|
||
sort_by: Optional[str] = "created_at",
|
||
) -> List[PydanticAgentState]:
|
||
"""
|
||
Retrieves agents with optimized filtering and optional field selection.
|
||
|
||
Args:
|
||
actor: The User requesting the list
|
||
name (Optional[str]): Filter by agent name.
|
||
tags (Optional[List[str]]): Filter agents by tags.
|
||
match_all_tags (bool): If True, only return agents that match ALL given tags.
|
||
before (Optional[str]): Cursor for pagination.
|
||
after (Optional[str]): Cursor for pagination.
|
||
limit (Optional[int]): Maximum number of agents to return.
|
||
query_text (Optional[str]): Search agents by name.
|
||
project_id (Optional[str]): Filter by project ID.
|
||
template_id (Optional[str]): Filter by template ID.
|
||
base_template_id (Optional[str]): Filter by base template ID.
|
||
identity_id (Optional[str]): Filter by identifier ID.
|
||
identifier_keys (Optional[List[str]]): Search agents by identifier keys.
|
||
include_relationships (Optional[List[str]]): List of fields to load for performance optimization.
|
||
ascending
|
||
|
||
Returns:
|
||
List[PydanticAgentState]: The filtered list of matching agents.
|
||
"""
|
||
with db_registry.session() as session:
|
||
query = select(AgentModel).distinct(AgentModel.created_at, AgentModel.id)
|
||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||
|
||
# Apply filters
|
||
query = _apply_filters(query, name, query_text, project_id, template_id, base_template_id)
|
||
query = _apply_identity_filters(query, identity_id, identifier_keys)
|
||
query = _apply_tag_filter(query, tags, match_all_tags)
|
||
query = _apply_pagination(query, before, after, session, ascending=ascending, sort_by=sort_by)
|
||
|
||
if limit:
|
||
query = query.limit(limit)
|
||
|
||
result = session.execute(query)
|
||
agents = result.scalars().all()
|
||
return [agent.to_pydantic(include_relationships=include_relationships) for agent in agents]
|
||
|
||
@trace_method
|
||
async def list_agents_async(
|
||
self,
|
||
actor: PydanticUser,
|
||
name: Optional[str] = None,
|
||
tags: Optional[List[str]] = None,
|
||
match_all_tags: bool = False,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
limit: Optional[int] = 50,
|
||
query_text: Optional[str] = None,
|
||
project_id: Optional[str] = None,
|
||
template_id: Optional[str] = None,
|
||
base_template_id: Optional[str] = None,
|
||
identity_id: Optional[str] = None,
|
||
identifier_keys: Optional[List[str]] = None,
|
||
include_relationships: Optional[List[str]] = None,
|
||
ascending: bool = True,
|
||
sort_by: Optional[str] = "created_at",
|
||
show_hidden_agents: Optional[bool] = None,
|
||
) -> List[PydanticAgentState]:
|
||
"""
|
||
Retrieves agents with optimized filtering and optional field selection.
|
||
|
||
Args:
|
||
actor: The User requesting the list
|
||
name (Optional[str]): Filter by agent name.
|
||
tags (Optional[List[str]]): Filter agents by tags.
|
||
match_all_tags (bool): If True, only return agents that match ALL given tags.
|
||
before (Optional[str]): Cursor for pagination.
|
||
after (Optional[str]): Cursor for pagination.
|
||
limit (Optional[int]): Maximum number of agents to return.
|
||
query_text (Optional[str]): Search agents by name.
|
||
project_id (Optional[str]): Filter by project ID.
|
||
template_id (Optional[str]): Filter by template ID.
|
||
base_template_id (Optional[str]): Filter by base template ID.
|
||
identity_id (Optional[str]): Filter by identifier ID.
|
||
identifier_keys (Optional[List[str]]): Search agents by identifier keys.
|
||
include_relationships (Optional[List[str]]): List of fields to load for performance optimization.
|
||
ascending (bool): Sort agents in ascending order.
|
||
sort_by (Optional[str]): Sort agents by this field.
|
||
show_hidden_agents (bool): If True, include agents marked as hidden in the results.
|
||
|
||
Returns:
|
||
List[PydanticAgentState]: The filtered list of matching agents.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
query = select(AgentModel)
|
||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||
|
||
# Apply filters
|
||
query = _apply_filters(query, name, query_text, project_id, template_id, base_template_id)
|
||
query = _apply_identity_filters(query, identity_id, identifier_keys)
|
||
query = _apply_tag_filter(query, tags, match_all_tags)
|
||
query = _apply_relationship_filters(query, include_relationships)
|
||
|
||
# Apply hidden filter
|
||
if not show_hidden_agents:
|
||
query = query.where((AgentModel.hidden.is_(None)) | (AgentModel.hidden == False))
|
||
query = await _apply_pagination_async(query, before, after, session, ascending=ascending, sort_by=sort_by)
|
||
|
||
if limit:
|
||
query = query.limit(limit)
|
||
result = await session.execute(query)
|
||
agents = result.scalars().all()
|
||
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents])
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def list_agents_matching_tags(
|
||
self,
|
||
actor: PydanticUser,
|
||
match_all: List[str],
|
||
match_some: List[str],
|
||
limit: Optional[int] = 50,
|
||
) -> List[PydanticAgentState]:
|
||
"""
|
||
Retrieves agents in the same organization that match all specified `match_all` tags
|
||
and at least one tag from `match_some`. The query is optimized for efficiency by
|
||
leveraging indexed filtering and aggregation.
|
||
|
||
Args:
|
||
actor (PydanticUser): The user requesting the agent list.
|
||
match_all (List[str]): Agents must have all these tags.
|
||
match_some (List[str]): Agents must have at least one of these tags.
|
||
limit (Optional[int]): Maximum number of agents to return.
|
||
|
||
Returns:
|
||
List[PydanticAgentState: The filtered list of matching agents.
|
||
"""
|
||
with db_registry.session() as session:
|
||
query = select(AgentModel).where(AgentModel.organization_id == actor.organization_id)
|
||
|
||
if match_all:
|
||
# Subquery to find agent IDs that contain all match_all tags
|
||
subquery = (
|
||
select(AgentsTags.agent_id)
|
||
.where(AgentsTags.tag.in_(match_all))
|
||
.group_by(AgentsTags.agent_id)
|
||
.having(func.count(AgentsTags.tag) == literal(len(match_all)))
|
||
)
|
||
query = query.where(AgentModel.id.in_(subquery))
|
||
|
||
if match_some:
|
||
# Ensures agents match at least one tag in match_some
|
||
query = query.join(AgentsTags).where(AgentsTags.tag.in_(match_some))
|
||
|
||
query = query.distinct(AgentModel.id).order_by(AgentModel.id).limit(limit)
|
||
|
||
return list(session.execute(query).scalars())
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def list_agents_matching_tags_async(
|
||
self,
|
||
actor: PydanticUser,
|
||
match_all: List[str],
|
||
match_some: List[str],
|
||
limit: Optional[int] = 50,
|
||
) -> List[PydanticAgentState]:
|
||
"""
|
||
Retrieves agents in the same organization that match all specified `match_all` tags
|
||
and at least one tag from `match_some`. The query is optimized for efficiency by
|
||
leveraging indexed filtering and aggregation.
|
||
|
||
Args:
|
||
actor (PydanticUser): The user requesting the agent list.
|
||
match_all (List[str]): Agents must have all these tags.
|
||
match_some (List[str]): Agents must have at least one of these tags.
|
||
limit (Optional[int]): Maximum number of agents to return.
|
||
|
||
Returns:
|
||
List[PydanticAgentState: The filtered list of matching agents.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
query = select(AgentModel).where(AgentModel.organization_id == actor.organization_id)
|
||
|
||
if match_all:
|
||
# Subquery to find agent IDs that contain all match_all tags
|
||
subquery = (
|
||
select(AgentsTags.agent_id)
|
||
.where(AgentsTags.tag.in_(match_all))
|
||
.group_by(AgentsTags.agent_id)
|
||
.having(func.count(AgentsTags.tag) == literal(len(match_all)))
|
||
)
|
||
query = query.where(AgentModel.id.in_(subquery))
|
||
|
||
if match_some:
|
||
# Ensures agents match at least one tag in match_some
|
||
query = query.join(AgentsTags).where(AgentsTags.tag.in_(match_some))
|
||
|
||
query = query.distinct(AgentModel.id).order_by(AgentModel.id).limit(limit)
|
||
result = await session.execute(query)
|
||
return await asyncio.gather(*[agent.to_pydantic_async() for agent in result.scalars()])
|
||
|
||
@trace_method
|
||
def size(
|
||
self,
|
||
actor: PydanticUser,
|
||
) -> int:
|
||
"""
|
||
Get the total count of agents for the given user.
|
||
"""
|
||
with db_registry.session() as session:
|
||
return AgentModel.size(db_session=session, actor=actor)
|
||
|
||
@trace_method
|
||
async def size_async(
|
||
self,
|
||
actor: PydanticUser,
|
||
) -> int:
|
||
"""
|
||
Get the total count of agents for the given user.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
return await AgentModel.size_async(db_session=session, actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
"""Fetch an agent by its ID."""
|
||
with db_registry.session() as session:
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
return agent.to_pydantic()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def get_agent_by_id_async(
|
||
self,
|
||
agent_id: str,
|
||
actor: PydanticUser,
|
||
include_relationships: Optional[List[str]] = None,
|
||
) -> PydanticAgentState:
|
||
"""Fetch an agent by its ID."""
|
||
async with db_registry.async_session() as session:
|
||
try:
|
||
query = select(AgentModel)
|
||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||
query = query.where(AgentModel.id == agent_id)
|
||
query = _apply_relationship_filters(query, include_relationships)
|
||
|
||
result = await session.execute(query)
|
||
agent = result.scalar_one_or_none()
|
||
|
||
if agent is None:
|
||
raise NoResultFound(f"Agent with ID {agent_id} not found")
|
||
|
||
return await agent.to_pydantic_async(include_relationships=include_relationships)
|
||
except Exception as e:
|
||
logger.error(f"Error fetching agent {agent_id}: {str(e)}")
|
||
raise
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def get_agents_by_ids_async(
|
||
self,
|
||
agent_ids: list[str],
|
||
actor: PydanticUser,
|
||
include_relationships: Optional[List[str]] = None,
|
||
) -> list[PydanticAgentState]:
|
||
"""Fetch a list of agents by their IDs."""
|
||
async with db_registry.async_session() as session:
|
||
try:
|
||
query = select(AgentModel)
|
||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||
query = query.where(AgentModel.id.in_(agent_ids))
|
||
query = _apply_relationship_filters(query, include_relationships)
|
||
|
||
result = await session.execute(query)
|
||
agents = result.scalars().all()
|
||
|
||
if not agents:
|
||
logger.warning(f"No agents found with IDs: {agent_ids}")
|
||
return []
|
||
|
||
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents])
|
||
except Exception as e:
|
||
logger.error(f"Error fetching agents with IDs {agent_ids}: {str(e)}")
|
||
raise
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState:
|
||
"""Fetch an agent by its ID."""
|
||
with db_registry.session() as session:
|
||
agent = AgentModel.read(db_session=session, name=agent_name, actor=actor)
|
||
return agent.to_pydantic()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def get_agent_archive_ids_async(self, agent_id: str, actor: PydanticUser) -> List[str]:
|
||
"""Get all archive IDs associated with an agent."""
|
||
from letta.orm import ArchivesAgents
|
||
|
||
async with db_registry.async_session() as session:
|
||
# Direct query to archives_agents table for performance
|
||
query = select(ArchivesAgents.archive_id).where(ArchivesAgents.agent_id == agent_id)
|
||
result = await session.execute(query)
|
||
archive_ids = [row[0] for row in result.fetchall()]
|
||
return archive_ids
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def delete_agent(self, agent_id: str, actor: PydanticUser) -> None:
|
||
"""
|
||
Deletes an agent and its associated relationships.
|
||
Ensures proper permission checks and cascades where applicable.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to be deleted.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If agent doesn't exist
|
||
"""
|
||
with db_registry.session() as session:
|
||
# Retrieve the agent
|
||
logger.debug(f"Hard deleting Agent with ID: {agent_id} with actor={actor}")
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
agents_to_delete = [agent]
|
||
sleeptime_group_to_delete = None
|
||
|
||
# Delete sleeptime agent and group (TODO this is flimsy pls fix)
|
||
if agent.multi_agent_group:
|
||
participant_agent_ids = agent.multi_agent_group.agent_ids
|
||
if agent.multi_agent_group.manager_type in {ManagerType.sleeptime, ManagerType.voice_sleeptime} and participant_agent_ids:
|
||
for participant_agent_id in participant_agent_ids:
|
||
try:
|
||
sleeptime_agent = AgentModel.read(db_session=session, identifier=participant_agent_id, actor=actor)
|
||
agents_to_delete.append(sleeptime_agent)
|
||
except NoResultFound:
|
||
pass # agent already deleted
|
||
sleeptime_agent_group = GroupModel.read(db_session=session, identifier=agent.multi_agent_group.id, actor=actor)
|
||
sleeptime_group_to_delete = sleeptime_agent_group
|
||
|
||
try:
|
||
if sleeptime_group_to_delete is not None:
|
||
session.delete(sleeptime_group_to_delete)
|
||
session.commit()
|
||
for agent in agents_to_delete:
|
||
session.delete(agent)
|
||
session.commit()
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.exception(f"Failed to hard delete Agent with ID {agent_id}")
|
||
raise ValueError(f"Failed to hard delete Agent with ID {agent_id}: {e}")
|
||
else:
|
||
logger.debug(f"Agent with ID {agent_id} successfully hard deleted")
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def delete_agent_async(self, agent_id: str, actor: PydanticUser) -> None:
|
||
"""
|
||
Deletes an agent and its associated relationships.
|
||
Ensures proper permission checks and cascades where applicable.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to be deleted.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If agent doesn't exist
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# Retrieve the agent
|
||
logger.debug(f"Hard deleting Agent with ID: {agent_id} with actor={actor}")
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
agents_to_delete = [agent]
|
||
sleeptime_group_to_delete = None
|
||
|
||
# Delete sleeptime agent and group (TODO this is flimsy pls fix)
|
||
if agent.multi_agent_group:
|
||
participant_agent_ids = agent.multi_agent_group.agent_ids
|
||
if agent.multi_agent_group.manager_type in {ManagerType.sleeptime, ManagerType.voice_sleeptime} and participant_agent_ids:
|
||
for participant_agent_id in participant_agent_ids:
|
||
try:
|
||
sleeptime_agent = await AgentModel.read_async(db_session=session, identifier=participant_agent_id, actor=actor)
|
||
agents_to_delete.append(sleeptime_agent)
|
||
except NoResultFound:
|
||
pass # agent already deleted
|
||
sleeptime_agent_group = await GroupModel.read_async(
|
||
db_session=session, identifier=agent.multi_agent_group.id, actor=actor
|
||
)
|
||
sleeptime_group_to_delete = sleeptime_agent_group
|
||
|
||
try:
|
||
if sleeptime_group_to_delete is not None:
|
||
await session.delete(sleeptime_group_to_delete)
|
||
await session.commit()
|
||
for agent in agents_to_delete:
|
||
await session.delete(agent)
|
||
await session.commit()
|
||
except Exception as e:
|
||
await session.rollback()
|
||
logger.exception(f"Failed to hard delete Agent with ID {agent_id}")
|
||
raise ValueError(f"Failed to hard delete Agent with ID {agent_id}: {e}")
|
||
else:
|
||
logger.debug(f"Agent with ID {agent_id} successfully hard deleted")
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def serialize(self, agent_id: str, actor: PydanticUser, max_steps: Optional[int] = None) -> AgentSchema:
|
||
with db_registry.session() as session:
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
schema = MarshmallowAgentSchema(session=session, actor=actor, max_steps=max_steps)
|
||
data = schema.dump(agent)
|
||
return AgentSchema(**data)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def deserialize(
|
||
self,
|
||
serialized_agent: AgentSchema,
|
||
actor: PydanticUser,
|
||
append_copy_suffix: bool = True,
|
||
override_existing_tools: bool = True,
|
||
project_id: Optional[str] = None,
|
||
strip_messages: Optional[bool] = False,
|
||
env_vars: Optional[dict[str, Any]] = None,
|
||
) -> PydanticAgentState:
|
||
serialized_agent_dict = serialized_agent.model_dump()
|
||
tool_data_list = serialized_agent_dict.pop("tools", [])
|
||
messages = serialized_agent_dict.pop(MarshmallowAgentSchema.FIELD_MESSAGES, [])
|
||
|
||
for msg in messages:
|
||
msg[MarshmallowAgentSchema.FIELD_ID] = SerializedMessageSchema.generate_id() # Generate new ID
|
||
|
||
message_ids = []
|
||
in_context_message_indices = serialized_agent_dict.pop(MarshmallowAgentSchema.FIELD_IN_CONTEXT_INDICES)
|
||
for idx in in_context_message_indices:
|
||
message_ids.append(messages[idx][MarshmallowAgentSchema.FIELD_ID])
|
||
|
||
serialized_agent_dict[MarshmallowAgentSchema.FIELD_MESSAGE_IDS] = message_ids
|
||
|
||
with db_registry.session() as session:
|
||
schema = MarshmallowAgentSchema(session=session, actor=actor)
|
||
agent = schema.load(serialized_agent_dict, session=session)
|
||
|
||
agent.organization_id = actor.organization_id
|
||
for block in agent.core_memory:
|
||
block.organization_id = actor.organization_id
|
||
if append_copy_suffix:
|
||
agent.name += "_copy"
|
||
if project_id:
|
||
agent.project_id = project_id
|
||
|
||
if strip_messages:
|
||
# we want to strip all but the first (system) message
|
||
agent.message_ids = [agent.message_ids[0]]
|
||
|
||
if env_vars:
|
||
for var in agent.tool_exec_environment_variables:
|
||
var.value = env_vars.get(var.key, "")
|
||
|
||
agent = agent.create(session, actor=actor)
|
||
|
||
pydantic_agent = agent.to_pydantic()
|
||
|
||
pyd_msgs = []
|
||
message_schema = SerializedMessageSchema(session=session, actor=actor)
|
||
|
||
for serialized_message in messages:
|
||
pydantic_message = message_schema.load(serialized_message, session=session).to_pydantic()
|
||
pydantic_message.agent_id = agent.id
|
||
pyd_msgs.append(pydantic_message)
|
||
self.message_manager.create_many_messages(pyd_msgs, actor=actor)
|
||
|
||
# Need to do this separately as there's some fancy upsert logic that SqlAlchemy cannot handle
|
||
for tool_data in tool_data_list:
|
||
pydantic_tool = SerializedToolSchema(actor=actor).load(tool_data, transient=True).to_pydantic()
|
||
|
||
existing_pydantic_tool = self.tool_manager.get_tool_by_name(pydantic_tool.name, actor=actor)
|
||
if existing_pydantic_tool and (
|
||
existing_pydantic_tool.tool_type in {ToolType.LETTA_CORE, ToolType.LETTA_MULTI_AGENT_CORE, ToolType.LETTA_MEMORY_CORE}
|
||
or not override_existing_tools
|
||
):
|
||
pydantic_tool = existing_pydantic_tool
|
||
else:
|
||
pydantic_tool = self.tool_manager.create_or_update_tool(pydantic_tool, actor=actor, bypass_name_check=True)
|
||
|
||
pydantic_agent = self.attach_tool(agent_id=pydantic_agent.id, tool_id=pydantic_tool.id, actor=actor)
|
||
|
||
return pydantic_agent
|
||
|
||
# ======================================================================================================================
|
||
# Per Agent Environment Variable Management
|
||
# ======================================================================================================================
|
||
@enforce_types
|
||
@trace_method
|
||
def _set_environment_variables(
|
||
self,
|
||
agent_id: str,
|
||
env_vars: Dict[str, str],
|
||
actor: PydanticUser,
|
||
) -> PydanticAgentState:
|
||
"""
|
||
Adds or replaces the environment variables for the specified agent.
|
||
|
||
Args:
|
||
agent_id: The agent id.
|
||
env_vars: A dictionary of environment variable key-value pairs.
|
||
actor: The user performing the action.
|
||
|
||
Returns:
|
||
PydanticAgentState: The updated agent as a Pydantic model.
|
||
"""
|
||
with db_registry.session() as session:
|
||
# Retrieve the agent
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
|
||
# Fetch existing environment variables as a dictionary
|
||
existing_vars = {var.key: var for var in agent.tool_exec_environment_variables}
|
||
|
||
# Update or create environment variables
|
||
updated_vars = []
|
||
for key, value in env_vars.items():
|
||
if key in existing_vars:
|
||
# Update existing variable
|
||
existing_vars[key].value = value
|
||
updated_vars.append(existing_vars[key])
|
||
else:
|
||
# Create new variable
|
||
updated_vars.append(
|
||
AgentEnvironmentVariableModel(
|
||
key=key,
|
||
value=value,
|
||
agent_id=agent_id,
|
||
organization_id=actor.organization_id,
|
||
created_by_id=actor.id,
|
||
last_updated_by_id=actor.id,
|
||
)
|
||
)
|
||
|
||
# Remove stale variables
|
||
stale_keys = set(existing_vars) - set(env_vars)
|
||
agent.tool_exec_environment_variables = [var for var in updated_vars if var.key not in stale_keys]
|
||
|
||
# Update the agent in the database
|
||
agent.update(session, actor=actor)
|
||
|
||
# Return the updated agent state
|
||
return agent.to_pydantic()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def list_groups(self, agent_id: str, actor: PydanticUser, manager_type: Optional[str] = None) -> List[PydanticGroup]:
|
||
with db_registry.session() as session:
|
||
query = (
|
||
select(GroupModel)
|
||
.join(GroupsAgents, GroupModel.id == GroupsAgents.group_id)
|
||
.where(GroupsAgents.agent_id == agent_id, GroupModel.organization_id == actor.organization_id)
|
||
)
|
||
|
||
if manager_type:
|
||
query = query.where(GroupModel.manager_type == manager_type)
|
||
|
||
result = session.execute(query)
|
||
return [group.to_pydantic() for group in result.scalars()]
|
||
|
||
# ======================================================================================================================
|
||
# In Context Messages Management
|
||
# ======================================================================================================================
|
||
# TODO: There are several assumptions here that are not explicitly checked
|
||
# TODO: 1) These message ids are valid
|
||
# TODO: 2) These messages are ordered from oldest to newest
|
||
# TODO: This can be fixed by having an actual relationship in the ORM for message_ids
|
||
# TODO: This can also be made more efficient, instead of getting, setting, we can do it all in one db session for one query.
|
||
@enforce_types
|
||
@trace_method
|
||
def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
|
||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||
return self.message_manager.get_messages_by_ids(message_ids=message_ids, actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def get_system_message(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
|
||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||
return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def get_system_message_async(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
|
||
agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor)
|
||
return await self.message_manager.get_message_by_id_async(message_id=agent.message_ids[0], actor=actor)
|
||
|
||
# TODO: This is duplicated below
|
||
# TODO: This is legacy code and should be cleaned up
|
||
# TODO: A lot of the memory "compilation" should be offset to a separate class
|
||
@enforce_types
|
||
@trace_method
|
||
def rebuild_system_prompt(self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True) -> PydanticAgentState:
|
||
"""Rebuilds the system message with the latest memory object and any shared memory block updates
|
||
|
||
Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object
|
||
|
||
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
|
||
"""
|
||
agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||
|
||
curr_system_message = self.get_system_message(
|
||
agent_id=agent_id, actor=actor
|
||
) # this is the system + memory bank, not just the system prompt
|
||
|
||
if curr_system_message is None:
|
||
logger.warning(f"No system message found for agent {agent_state.id} and user {actor}")
|
||
return agent_state
|
||
|
||
curr_system_message_openai = curr_system_message.to_openai_dict()
|
||
|
||
# note: we only update the system prompt if the core memory is changed
|
||
# this means that the archival/recall memory statistics may be someout out of date
|
||
curr_memory_str = agent_state.memory.compile(sources=agent_state.sources)
|
||
if curr_memory_str in curr_system_message_openai["content"] and not force:
|
||
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
||
logger.debug(
|
||
f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild"
|
||
)
|
||
return agent_state
|
||
|
||
# If the memory didn't update, we probably don't want to update the timestamp inside
|
||
# For example, if we're doing a system prompt swap, this should probably be False
|
||
if update_timestamp:
|
||
memory_edit_timestamp = get_utc_time()
|
||
else:
|
||
# NOTE: a bit of a hack - we pull the timestamp from the message created_by
|
||
memory_edit_timestamp = curr_system_message.created_at
|
||
|
||
num_messages = self.message_manager.size(actor=actor, agent_id=agent_id)
|
||
num_archival_memories = self.passage_manager.size(actor=actor, agent_id=agent_id)
|
||
|
||
# update memory (TODO: potentially update recall/archival stats separately)
|
||
new_system_message_str = compile_system_message(
|
||
system_prompt=agent_state.system,
|
||
in_context_memory=agent_state.memory,
|
||
in_context_memory_last_edit=memory_edit_timestamp,
|
||
timezone=agent_state.timezone,
|
||
previous_message_count=num_messages - len(agent_state.message_ids),
|
||
archival_memory_size=num_archival_memories,
|
||
sources=agent_state.sources,
|
||
max_files_open=agent_state.max_files_open,
|
||
)
|
||
|
||
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
|
||
if len(diff) > 0: # there was a diff
|
||
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||
|
||
# Swap the system message out (only if there is a diff)
|
||
message = PydanticMessage.dict_to_message(
|
||
agent_id=agent_id,
|
||
model=agent_state.llm_config.model,
|
||
openai_message_dict={"role": "system", "content": new_system_message_str},
|
||
)
|
||
message = self.message_manager.update_message_by_id(
|
||
message_id=curr_system_message.id,
|
||
message_update=MessageUpdate(**message.model_dump()),
|
||
actor=actor,
|
||
)
|
||
return self.set_in_context_messages(agent_id=agent_id, message_ids=agent_state.message_ids, actor=actor)
|
||
else:
|
||
return agent_state
|
||
|
||
# Do not remove comment. (cliandy)
|
||
# TODO: This is probably one of the worst pieces of code I've ever written please rip up as you see wish
|
||
@enforce_types
|
||
@trace_method
|
||
async def rebuild_system_prompt_async(
|
||
self,
|
||
agent_id: str,
|
||
actor: PydanticUser,
|
||
force=False,
|
||
update_timestamp=True,
|
||
dry_run: bool = False,
|
||
) -> Tuple[PydanticAgentState, Optional[PydanticMessage], int, int]:
|
||
"""Rebuilds the system message with the latest memory object and any shared memory block updates
|
||
|
||
Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object
|
||
|
||
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
|
||
"""
|
||
num_messages = await self.message_manager.size_async(actor=actor, agent_id=agent_id)
|
||
num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id)
|
||
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory", "sources", "tools"], actor=actor)
|
||
|
||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||
|
||
curr_system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
|
||
|
||
if curr_system_message is None:
|
||
logger.warning(f"No system message found for agent {agent_state.id} and user {actor}")
|
||
return agent_state, curr_system_message, num_messages, num_archival_memories
|
||
|
||
curr_system_message_openai = curr_system_message.to_openai_dict()
|
||
|
||
# note: we only update the system prompt if the core memory is changed
|
||
# this means that the archival/recall memory statistics may be someout out of date
|
||
curr_memory_str = await agent_state.memory.compile_in_thread_async(
|
||
sources=agent_state.sources,
|
||
tool_usage_rules=tool_rules_solver.compile_tool_rule_prompts(),
|
||
max_files_open=agent_state.max_files_open,
|
||
)
|
||
if curr_memory_str in curr_system_message_openai["content"] and not force:
|
||
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
||
logger.debug(
|
||
f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild"
|
||
)
|
||
return agent_state, curr_system_message, num_messages, num_archival_memories
|
||
|
||
# If the memory didn't update, we probably don't want to update the timestamp inside
|
||
# For example, if we're doing a system prompt swap, this should probably be False
|
||
if update_timestamp:
|
||
memory_edit_timestamp = get_utc_time()
|
||
else:
|
||
# NOTE: a bit of a hack - we pull the timestamp from the message created_by
|
||
memory_edit_timestamp = curr_system_message.created_at
|
||
|
||
# update memory (TODO: potentially update recall/archival stats separately)
|
||
|
||
new_system_message_str = PromptGenerator.get_system_message_from_compiled_memory(
|
||
system_prompt=agent_state.system,
|
||
memory_with_sources=curr_memory_str,
|
||
in_context_memory_last_edit=memory_edit_timestamp,
|
||
timezone=agent_state.timezone,
|
||
previous_message_count=num_messages - len(agent_state.message_ids),
|
||
archival_memory_size=num_archival_memories,
|
||
)
|
||
|
||
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
|
||
if len(diff) > 0: # there was a diff
|
||
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||
|
||
# Swap the system message out (only if there is a diff)
|
||
temp_message = PydanticMessage.dict_to_message(
|
||
agent_id=agent_id,
|
||
model=agent_state.llm_config.model,
|
||
openai_message_dict={"role": "system", "content": new_system_message_str},
|
||
)
|
||
temp_message.id = curr_system_message.id
|
||
|
||
if not dry_run:
|
||
await self.message_manager.update_message_by_id_async(
|
||
message_id=curr_system_message.id,
|
||
message_update=MessageUpdate(**temp_message.model_dump()),
|
||
actor=actor,
|
||
)
|
||
else:
|
||
curr_system_message = temp_message
|
||
|
||
return agent_state, curr_system_message, num_messages, num_archival_memories
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
|
||
return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def set_in_context_messages_async(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
|
||
return await self.update_agent_async(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||
new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message
|
||
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def trim_all_in_context_messages_except_system(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||
# TODO: How do we know this?
|
||
new_messages = [message_ids[0]] # 0 is system message
|
||
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||
new_messages = self.message_manager.create_many_messages(messages, actor=actor)
|
||
message_ids = [message_ids[0]] + [m.id for m in new_messages] + message_ids[1:]
|
||
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def append_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
messages = self.message_manager.create_many_messages(messages, actor=actor)
|
||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids or []
|
||
message_ids += [m.id for m in messages]
|
||
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def append_to_in_context_messages_async(
|
||
self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser
|
||
) -> PydanticAgentState:
|
||
agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||
messages = await self.message_manager.create_many_messages_async(messages, actor=actor, embedding_config=agent.embedding_config)
|
||
message_ids = agent.message_ids or []
|
||
message_ids += [m.id for m in messages]
|
||
return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def reset_messages_async(
|
||
self, agent_id: str, actor: PydanticUser, add_default_initial_messages: bool = False
|
||
) -> PydanticAgentState:
|
||
"""
|
||
Removes all in-context messages for the specified agent except the original system message by:
|
||
1) Preserving the first message ID (original system message).
|
||
2) Deleting all other messages for the agent.
|
||
3) Updating the agent's message_ids to only contain the system message.
|
||
4) Optionally adding default initial messages after the system message.
|
||
|
||
This action is destructive and cannot be undone once committed.
|
||
|
||
Args:
|
||
add_default_initial_messages: If true, adds the default initial messages after resetting.
|
||
agent_id (str): The ID of the agent whose messages will be reset.
|
||
actor (PydanticUser): The user performing this action.
|
||
|
||
Returns:
|
||
PydanticAgentState: The updated agent state with only the original system message preserved.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# Retrieve the existing agent (will raise NoResultFound if invalid)
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
|
||
# Ensure agent has message_ids with at least one message
|
||
if not agent.message_ids or len(agent.message_ids) == 0:
|
||
logger.error(
|
||
f"Agent {agent_id} has no message_ids. Agent details: "
|
||
f"name={agent.name}, created_at={agent.created_at}, "
|
||
f"message_ids={agent.message_ids}, organization_id={actor.organization_id}"
|
||
)
|
||
raise ValueError(f"Agent {agent_id} has no message_ids - cannot preserve system message")
|
||
|
||
# Get the system message ID (first message)
|
||
system_message_id = agent.message_ids[0]
|
||
|
||
# Delete all messages for the agent except the system message
|
||
await self.message_manager.delete_all_messages_for_agent_async(agent_id=agent_id, actor=actor, exclude_ids=[system_message_id])
|
||
|
||
# Update agent to only keep the system message
|
||
agent.message_ids = [system_message_id]
|
||
await agent.update_async(db_session=session, actor=actor)
|
||
agent_state = await agent.to_pydantic_async(include_relationships=["sources"])
|
||
|
||
# Optionally add default initial messages after the system message
|
||
if add_default_initial_messages:
|
||
init_messages = await initialize_message_sequence_async(
|
||
agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True
|
||
)
|
||
# Skip index 0 (system message) since we preserved the original
|
||
non_system_messages = [
|
||
PydanticMessage.dict_to_message(
|
||
agent_id=agent_state.id,
|
||
model=agent_state.llm_config.model,
|
||
openai_message_dict=msg,
|
||
)
|
||
for msg in init_messages[1:]
|
||
]
|
||
return await self.append_to_in_context_messages_async(non_system_messages, agent_id=agent_state.id, actor=actor)
|
||
else:
|
||
return agent_state
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def update_memory_if_changed_async(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState:
|
||
"""
|
||
Update internal memory object and system prompt if there have been modifications.
|
||
|
||
Args:
|
||
actor:
|
||
agent_id:
|
||
new_memory (Memory): the new memory object to compare to the current memory object
|
||
|
||
Returns:
|
||
modified (bool): whether the memory was updated
|
||
"""
|
||
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor, include_relationships=["memory", "sources"])
|
||
system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
|
||
temp_tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||
new_memory_str = await new_memory.compile_in_thread_async(
|
||
sources=agent_state.sources,
|
||
tool_usage_rules=temp_tool_rules_solver.compile_tool_rule_prompts(),
|
||
max_files_open=agent_state.max_files_open,
|
||
)
|
||
if new_memory_str not in system_message.content[0].text:
|
||
# update the blocks (LRW) in the DB
|
||
for label in agent_state.memory.list_block_labels():
|
||
updated_value = new_memory.get_block(label).value
|
||
if updated_value != agent_state.memory.get_block(label).value:
|
||
# update the block if it's changed
|
||
block_id = agent_state.memory.get_block(label).id
|
||
await self.block_manager.update_block_async(
|
||
block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=actor
|
||
)
|
||
|
||
# refresh memory from DB (using block ids)
|
||
blocks = await self.block_manager.get_all_blocks_by_ids_async(
|
||
block_ids=[b.id for b in agent_state.memory.get_blocks()], actor=actor
|
||
)
|
||
|
||
agent_state.memory = Memory(
|
||
blocks=blocks,
|
||
file_blocks=agent_state.memory.file_blocks,
|
||
prompt_template=get_prompt_template_for_agent_type(agent_state.agent_type),
|
||
)
|
||
|
||
# NOTE: don't do this since re-buildin the memory is handled at the start of the step
|
||
# rebuild memory - this records the last edited timestamp of the memory
|
||
# TODO: pass in update timestamp from block edit time
|
||
await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor)
|
||
|
||
return agent_state
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def refresh_memory_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
||
# TODO: This will NOT work for new blocks/file blocks added intra-step
|
||
block_ids = [b.id for b in agent_state.memory.blocks]
|
||
file_block_names = [b.label for b in agent_state.memory.file_blocks]
|
||
|
||
if block_ids:
|
||
blocks = await self.block_manager.get_all_blocks_by_ids_async(block_ids=[b.id for b in agent_state.memory.blocks], actor=actor)
|
||
agent_state.memory.blocks = [b for b in blocks if b is not None]
|
||
|
||
if file_block_names:
|
||
file_blocks = await self.file_agent_manager.get_all_file_blocks_by_name(
|
||
file_names=file_block_names,
|
||
agent_id=agent_state.id,
|
||
actor=actor,
|
||
per_file_view_window_char_limit=agent_state.per_file_view_window_char_limit,
|
||
)
|
||
agent_state.memory.file_blocks = [b for b in file_blocks if b is not None]
|
||
|
||
return agent_state
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def refresh_file_blocks(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
||
"""
|
||
Refresh the file blocks in an agent's memory with current file content.
|
||
|
||
This method synchronizes the agent's in-memory file blocks with the actual
|
||
file content from attached sources. It respects the per-file view window
|
||
limit to prevent excessive memory usage.
|
||
|
||
Args:
|
||
agent_state: The current agent state containing memory configuration
|
||
actor: The user performing this action (for permission checking)
|
||
|
||
Returns:
|
||
Updated agent state with refreshed file blocks
|
||
|
||
Important:
|
||
- File blocks are truncated based on per_file_view_window_char_limit
|
||
- None values are filtered out (files that couldn't be loaded)
|
||
- This does NOT persist changes to the database, only updates the state object
|
||
- Call this before agent interactions if files may have changed externally
|
||
"""
|
||
file_blocks = await self.file_agent_manager.list_files_for_agent(
|
||
agent_id=agent_state.id,
|
||
per_file_view_window_char_limit=agent_state.per_file_view_window_char_limit,
|
||
actor=actor,
|
||
return_as_blocks=True,
|
||
)
|
||
agent_state.memory.file_blocks = [b for b in file_blocks if b is not None]
|
||
return agent_state
|
||
|
||
# ======================================================================================================================
|
||
# Source Management
|
||
# ======================================================================================================================
|
||
@enforce_types
|
||
@trace_method
|
||
async def attach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
"""
|
||
Attaches a source to an agent.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to attach the source to
|
||
source_id: ID of the source to attach
|
||
actor: User performing the action
|
||
|
||
Raises:
|
||
ValueError: If either agent or source doesn't exist
|
||
IntegrityError: If the source is already attached to the agent
|
||
"""
|
||
|
||
async with db_registry.async_session() as session:
|
||
# Verify both agent and source exist and user has permission to access them
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
|
||
# The _process_relationship helper already handles duplicate checking via unique constraint
|
||
await _process_relationship_async(
|
||
session=session,
|
||
agent=agent,
|
||
relationship_name="sources",
|
||
model_class=SourceModel,
|
||
item_ids=[source_id],
|
||
replace=False,
|
||
)
|
||
|
||
# Commit the changes
|
||
agent = await agent.update_async(session, actor=actor)
|
||
return await agent.to_pydantic_async()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def append_system_message(self, agent_id: str, content: str, actor: PydanticUser):
|
||
"""
|
||
Append a system message to an agent's in-context message history.
|
||
|
||
This method is typically used during agent initialization to add system prompts,
|
||
instructions, or context that should be treated as system-level guidance.
|
||
Unlike user messages, system messages directly influence the agent's behavior
|
||
and understanding of its role.
|
||
|
||
Args:
|
||
agent_id: The ID of the agent to append the message to
|
||
content: The system message content (e.g., instructions, context, role definition)
|
||
actor: The user performing this action (for permission checking)
|
||
|
||
Side Effects:
|
||
- Creates a new Message object in the database
|
||
- Updates the agent's in_context_message_ids list
|
||
- The message becomes part of the agent's permanent context window
|
||
|
||
Note:
|
||
System messages consume tokens in the context window and cannot be
|
||
removed without rebuilding the agent's message history.
|
||
"""
|
||
|
||
# get the agent
|
||
agent = self.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||
message = PydanticMessage.dict_to_message(
|
||
agent_id=agent.id, model=agent.llm_config.model, openai_message_dict={"role": "system", "content": content}
|
||
)
|
||
|
||
# update agent in-context message IDs
|
||
self.append_to_in_context_messages(messages=[message], agent_id=agent_id, actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def append_system_message_async(self, agent_id: str, content: str, actor: PydanticUser):
|
||
"""
|
||
Async version of append_system_message.
|
||
|
||
Append a system message to an agent's in-context message history.
|
||
See append_system_message for detailed documentation.
|
||
|
||
This async version is preferred for high-throughput scenarios or when
|
||
called within other async operations to avoid blocking the event loop.
|
||
"""
|
||
|
||
# get the agent
|
||
agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||
message = PydanticMessage.dict_to_message(
|
||
agent_id=agent.id, model=agent.llm_config.model, openai_message_dict={"role": "system", "content": content}
|
||
)
|
||
|
||
# update agent in-context message IDs
|
||
await self.append_to_in_context_messages_async(messages=[message], agent_id=agent_id, actor=actor)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def list_attached_sources_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]:
|
||
"""
|
||
Lists all sources attached to an agent.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to list sources for
|
||
actor: User performing the action
|
||
|
||
Returns:
|
||
List[str]: List of source IDs attached to the agent
|
||
|
||
Raises:
|
||
NoResultFound: If agent doesn't exist or user doesn't have access
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# Validate agent exists and user has access
|
||
await validate_agent_exists_async(session, agent_id, actor)
|
||
|
||
# Use raw SQL to efficiently fetch sources - much faster than lazy loading
|
||
# Fast query without relationship loading
|
||
query = (
|
||
select(SourceModel)
|
||
.join(SourcesAgents, SourceModel.id == SourcesAgents.source_id)
|
||
.where(
|
||
SourcesAgents.agent_id == agent_id,
|
||
SourceModel.organization_id == actor.organization_id,
|
||
SourceModel.is_deleted == False,
|
||
)
|
||
.order_by(SourceModel.created_at.desc(), SourceModel.id)
|
||
)
|
||
|
||
result = await session.execute(query)
|
||
sources = result.scalars().all()
|
||
|
||
return [source.to_pydantic() for source in sources]
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def detach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
"""
|
||
Detaches a source from an agent.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to detach the source from
|
||
source_id: ID of the source to detach
|
||
actor: User performing the action
|
||
|
||
Raises:
|
||
NoResultFound: If agent doesn't exist or user doesn't have access
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# Validate agent exists and user has access
|
||
await validate_agent_exists_async(session, agent_id, actor)
|
||
|
||
# Check if the source is actually attached to this agent using junction table
|
||
attachment_check_query = select(SourcesAgents).where(SourcesAgents.agent_id == agent_id, SourcesAgents.source_id == source_id)
|
||
attachment_result = await session.execute(attachment_check_query)
|
||
attachment = attachment_result.scalar_one_or_none()
|
||
|
||
if not attachment:
|
||
logger.warning(f"Attempted to remove unattached source id={source_id} from agent id={agent_id} by actor={actor}")
|
||
else:
|
||
# Delete the association directly from the junction table
|
||
delete_query = delete(SourcesAgents).where(SourcesAgents.agent_id == agent_id, SourcesAgents.source_id == source_id)
|
||
await session.execute(delete_query)
|
||
await session.commit()
|
||
|
||
# Get agent without loading relationships for return value
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
return await agent.to_pydantic_async()
|
||
|
||
# ======================================================================================================================
|
||
# Block management
|
||
# ======================================================================================================================
|
||
@enforce_types
|
||
@trace_method
|
||
def get_block_with_label(
|
||
self,
|
||
agent_id: str,
|
||
block_label: str,
|
||
actor: PydanticUser,
|
||
) -> PydanticBlock:
|
||
"""Gets a block attached to an agent by its label."""
|
||
with db_registry.session() as session:
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
for block in agent.core_memory:
|
||
if block.label == block_label:
|
||
return block.to_pydantic()
|
||
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def get_block_with_label_async(
|
||
self,
|
||
agent_id: str,
|
||
block_label: str,
|
||
actor: PydanticUser,
|
||
) -> PydanticBlock:
|
||
"""Gets a block attached to an agent by its label."""
|
||
async with db_registry.async_session() as session:
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
for block in agent.core_memory:
|
||
if block.label == block_label:
|
||
return block.to_pydantic()
|
||
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def modify_block_by_label_async(
|
||
self,
|
||
agent_id: str,
|
||
block_label: str,
|
||
block_update: BlockUpdate,
|
||
actor: PydanticUser,
|
||
) -> PydanticBlock:
|
||
"""Gets a block attached to an agent by its label."""
|
||
async with db_registry.async_session() as session:
|
||
block = None
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
for block in agent.core_memory:
|
||
if block.label == block_label:
|
||
block = block
|
||
break
|
||
if not block:
|
||
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
|
||
|
||
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||
|
||
for key, value in update_data.items():
|
||
setattr(block, key, value)
|
||
|
||
await block.update_async(session, actor=actor)
|
||
return block.to_pydantic()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def update_block_with_label(
|
||
self,
|
||
agent_id: str,
|
||
block_label: str,
|
||
new_block_id: str,
|
||
actor: PydanticUser,
|
||
) -> PydanticAgentState:
|
||
"""Updates which block is assigned to a specific label for an agent."""
|
||
with db_registry.session() as session:
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
new_block = BlockModel.read(db_session=session, identifier=new_block_id, actor=actor)
|
||
|
||
if new_block.label != block_label:
|
||
raise ValueError(f"New block label '{new_block.label}' doesn't match required label '{block_label}'")
|
||
|
||
# Remove old block with this label if it exists
|
||
agent.core_memory = [b for b in agent.core_memory if b.label != block_label]
|
||
|
||
# Add new block
|
||
agent.core_memory.append(new_block)
|
||
agent.update(session, actor=actor)
|
||
return agent.to_pydantic()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def attach_block(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
"""Attaches a block to an agent. For sleeptime agents, also attaches to paired agents in the same group."""
|
||
with db_registry.session() as session:
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
||
|
||
# Attach block to the main agent
|
||
agent.core_memory.append(block)
|
||
agent.update(session, actor=actor, no_commit=True)
|
||
|
||
# If agent is part of a sleeptime group, attach block to the sleeptime_agent
|
||
if agent.multi_agent_group and agent.multi_agent_group.manager_type == ManagerType.sleeptime:
|
||
group = agent.multi_agent_group
|
||
# Find the sleeptime_agent in the group
|
||
for other_agent_id in group.agent_ids or []:
|
||
if other_agent_id != agent_id:
|
||
try:
|
||
other_agent = AgentModel.read(db_session=session, identifier=other_agent_id, actor=actor)
|
||
if other_agent.agent_type == AgentType.sleeptime_agent and block not in other_agent.core_memory:
|
||
other_agent.core_memory.append(block)
|
||
other_agent.update(session, actor=actor, no_commit=True)
|
||
except NoResultFound:
|
||
# Agent might not exist anymore, skip
|
||
continue
|
||
session.commit()
|
||
|
||
return agent.to_pydantic()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def attach_block_async(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
"""Attaches a block to an agent. For sleeptime agents, also attaches to paired agents in the same group."""
|
||
async with db_registry.async_session() as session:
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
|
||
|
||
# Attach block to the main agent
|
||
agent.core_memory.append(block)
|
||
# await agent.update_async(session, actor=actor, no_commit=True)
|
||
await agent.update_async(session)
|
||
|
||
# If agent is part of a sleeptime group, attach block to the sleeptime_agent
|
||
if agent.multi_agent_group and agent.multi_agent_group.manager_type == ManagerType.sleeptime:
|
||
group = agent.multi_agent_group
|
||
# Find the sleeptime_agent in the group
|
||
for other_agent_id in group.agent_ids or []:
|
||
if other_agent_id != agent_id:
|
||
try:
|
||
other_agent = await AgentModel.read_async(db_session=session, identifier=other_agent_id, actor=actor)
|
||
if other_agent.agent_type == AgentType.sleeptime_agent and block not in other_agent.core_memory:
|
||
other_agent.core_memory.append(block)
|
||
# await other_agent.update_async(session, actor=actor, no_commit=True)
|
||
await other_agent.update_async(session, actor=actor)
|
||
except NoResultFound:
|
||
# Agent might not exist anymore, skip
|
||
continue
|
||
|
||
# TODO: @andy/caren
|
||
# TODO: Ideally we do two no commits on the update_async calls, and then commit here - but that errors for some reason?
|
||
# TODO: I have too many things rn so lets look at this later
|
||
# await session.commit()
|
||
|
||
return await agent.to_pydantic_async()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def detach_block(
|
||
self,
|
||
agent_id: str,
|
||
block_id: str,
|
||
actor: PydanticUser,
|
||
) -> PydanticAgentState:
|
||
"""Detaches a block from an agent."""
|
||
with db_registry.session() as session:
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
original_length = len(agent.core_memory)
|
||
|
||
agent.core_memory = [b for b in agent.core_memory if b.id != block_id]
|
||
|
||
if len(agent.core_memory) == original_length:
|
||
raise NoResultFound(f"No block with id '{block_id}' found for agent '{agent_id}' with actor id: '{actor.id}'")
|
||
|
||
agent.update(session, actor=actor)
|
||
return agent.to_pydantic()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def detach_block_async(
|
||
self,
|
||
agent_id: str,
|
||
block_id: str,
|
||
actor: PydanticUser,
|
||
) -> PydanticAgentState:
|
||
"""Detaches a block from an agent."""
|
||
async with db_registry.async_session() as session:
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
original_length = len(agent.core_memory)
|
||
|
||
agent.core_memory = [b for b in agent.core_memory if b.id != block_id]
|
||
|
||
if len(agent.core_memory) == original_length:
|
||
raise NoResultFound(f"No block with id '{block_id}' found for agent '{agent_id}' with actor id: '{actor.id}'")
|
||
|
||
await agent.update_async(session, actor=actor)
|
||
return await agent.to_pydantic_async()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def detach_block_with_label(
|
||
self,
|
||
agent_id: str,
|
||
block_label: str,
|
||
actor: PydanticUser,
|
||
) -> PydanticAgentState:
|
||
"""Detaches a block with the specified label from an agent."""
|
||
with db_registry.session() as session:
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
original_length = len(agent.core_memory)
|
||
|
||
agent.core_memory = [b for b in agent.core_memory if b.label != block_label]
|
||
|
||
if len(agent.core_memory) == original_length:
|
||
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}' with actor id: '{actor.id}'")
|
||
|
||
agent.update(session, actor=actor)
|
||
return agent.to_pydantic()
|
||
|
||
# ======================================================================================================================
|
||
# Passage Management
|
||
# ======================================================================================================================
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def list_passages(
|
||
self,
|
||
actor: PydanticUser,
|
||
agent_id: Optional[str] = None,
|
||
file_id: Optional[str] = None,
|
||
limit: Optional[int] = 50,
|
||
query_text: Optional[str] = None,
|
||
start_date: Optional[datetime] = None,
|
||
end_date: Optional[datetime] = None,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
source_id: Optional[str] = None,
|
||
embed_query: bool = False,
|
||
ascending: bool = True,
|
||
embedding_config: Optional[EmbeddingConfig] = None,
|
||
agent_only: bool = False,
|
||
) -> List[PydanticPassage]:
|
||
"""Lists all passages attached to an agent."""
|
||
async with db_registry.async_session() as session:
|
||
main_query = await build_passage_query(
|
||
actor=actor,
|
||
agent_id=agent_id,
|
||
file_id=file_id,
|
||
query_text=query_text,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
before=before,
|
||
after=after,
|
||
source_id=source_id,
|
||
embed_query=embed_query,
|
||
ascending=ascending,
|
||
embedding_config=embedding_config,
|
||
agent_only=agent_only,
|
||
)
|
||
|
||
# Add limit
|
||
if limit:
|
||
main_query = main_query.limit(limit)
|
||
|
||
# Execute query
|
||
result = await session.execute(main_query)
|
||
|
||
passages = []
|
||
for row in result:
|
||
data = dict(row._mapping)
|
||
if data.get("archive_id", None):
|
||
# This is an ArchivalPassage - remove source fields
|
||
data.pop("source_id", None)
|
||
data.pop("file_id", None)
|
||
data.pop("file_name", None)
|
||
passage = ArchivalPassage(**data)
|
||
elif data.get("source_id", None):
|
||
# This is a SourcePassage - remove archive field
|
||
data.pop("archive_id", None)
|
||
data.pop("agent_id", None) # For backward compatibility
|
||
passage = SourcePassage(**data)
|
||
else:
|
||
raise ValueError(f"Passage data is malformed, is neither ArchivalPassage nor SourcePassage {data}")
|
||
passages.append(passage)
|
||
|
||
return [p.to_pydantic() for p in passages]
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def list_passages_async(
|
||
self,
|
||
actor: PydanticUser,
|
||
agent_id: Optional[str] = None,
|
||
file_id: Optional[str] = None,
|
||
limit: Optional[int] = 50,
|
||
query_text: Optional[str] = None,
|
||
start_date: Optional[datetime] = None,
|
||
end_date: Optional[datetime] = None,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
source_id: Optional[str] = None,
|
||
embed_query: bool = False,
|
||
ascending: bool = True,
|
||
embedding_config: Optional[EmbeddingConfig] = None,
|
||
agent_only: bool = False,
|
||
) -> List[PydanticPassage]:
|
||
"""
|
||
DEPRECATED: Use query_source_passages_async or query_agent_passages_async instead.
|
||
This method is kept only for test compatibility and will be removed in a future version.
|
||
|
||
Lists all passages attached to an agent (combines both source and agent passages).
|
||
"""
|
||
import warnings
|
||
|
||
warnings.warn(
|
||
"list_passages_async is deprecated. Use query_source_passages_async or query_agent_passages_async instead.",
|
||
DeprecationWarning,
|
||
stacklevel=2,
|
||
)
|
||
|
||
async with db_registry.async_session() as session:
|
||
main_query = await build_passage_query(
|
||
actor=actor,
|
||
agent_id=agent_id,
|
||
file_id=file_id,
|
||
query_text=query_text,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
before=before,
|
||
after=after,
|
||
source_id=source_id,
|
||
embed_query=embed_query,
|
||
ascending=ascending,
|
||
embedding_config=embedding_config,
|
||
agent_only=agent_only,
|
||
)
|
||
|
||
# Add limit
|
||
if limit:
|
||
main_query = main_query.limit(limit)
|
||
|
||
# Execute query
|
||
result = await session.execute(main_query)
|
||
|
||
passages = []
|
||
for row in result:
|
||
data = dict(row._mapping)
|
||
if data.get("archive_id", None):
|
||
# This is an ArchivalPassage - remove source fields
|
||
data.pop("source_id", None)
|
||
data.pop("file_id", None)
|
||
data.pop("file_name", None)
|
||
passage = ArchivalPassage(**data)
|
||
elif data.get("source_id", None):
|
||
# This is a SourcePassage - remove archive field
|
||
data.pop("archive_id", None)
|
||
data.pop("agent_id", None) # For backward compatibility
|
||
passage = SourcePassage(**data)
|
||
else:
|
||
raise ValueError(f"Passage data is malformed, is neither ArchivalPassage nor SourcePassage {data}")
|
||
passages.append(passage)
|
||
|
||
return [p.to_pydantic() for p in passages]
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def query_source_passages_async(
|
||
self,
|
||
actor: PydanticUser,
|
||
agent_id: Optional[str] = None,
|
||
file_id: Optional[str] = None,
|
||
limit: Optional[int] = 50,
|
||
query_text: Optional[str] = None,
|
||
start_date: Optional[datetime] = None,
|
||
end_date: Optional[datetime] = None,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
source_id: Optional[str] = None,
|
||
embed_query: bool = False,
|
||
ascending: bool = True,
|
||
embedding_config: Optional[EmbeddingConfig] = None,
|
||
) -> List[PydanticPassage]:
|
||
"""Lists all passages attached to an agent."""
|
||
async with db_registry.async_session() as session:
|
||
main_query = await build_source_passage_query(
|
||
actor=actor,
|
||
agent_id=agent_id,
|
||
file_id=file_id,
|
||
query_text=query_text,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
before=before,
|
||
after=after,
|
||
source_id=source_id,
|
||
embed_query=embed_query,
|
||
ascending=ascending,
|
||
embedding_config=embedding_config,
|
||
)
|
||
|
||
# Add limit
|
||
if limit:
|
||
main_query = main_query.limit(limit)
|
||
|
||
# Execute query
|
||
result = await session.execute(main_query)
|
||
|
||
# Get ORM objects directly using scalars()
|
||
passages = result.scalars().all()
|
||
|
||
# Convert to Pydantic models
|
||
return [p.to_pydantic() for p in passages]
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def query_agent_passages_async(
|
||
self,
|
||
actor: PydanticUser,
|
||
agent_id: Optional[str] = None,
|
||
limit: Optional[int] = 50,
|
||
query_text: Optional[str] = None,
|
||
start_date: Optional[datetime] = None,
|
||
end_date: Optional[datetime] = None,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
embed_query: bool = False,
|
||
ascending: bool = True,
|
||
embedding_config: Optional[EmbeddingConfig] = None,
|
||
tags: Optional[List[str]] = None,
|
||
tag_match_mode: Optional[TagMatchMode] = None,
|
||
) -> List[PydanticPassage]:
|
||
"""Lists all passages attached to an agent."""
|
||
# Check if we should use Turbopuffer for vector search
|
||
if embed_query and agent_id and query_text and embedding_config:
|
||
# Get archive IDs for the agent
|
||
archive_ids = await self.get_agent_archive_ids_async(agent_id=agent_id, actor=actor)
|
||
|
||
if archive_ids:
|
||
# TODO: Remove this restriction once we support multiple archives with mixed vector DB providers
|
||
if len(archive_ids) > 1:
|
||
raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported for vector search")
|
||
|
||
# Get archive to check vector_db_provider
|
||
archive = await self.archive_manager.get_archive_by_id_async(archive_id=archive_ids[0], actor=actor)
|
||
|
||
# Use Turbopuffer for vector search if archive is configured for TPUF
|
||
if archive.vector_db_provider == VectorDBProvider.TPUF:
|
||
from letta.helpers.tpuf_client import TurbopufferClient
|
||
from letta.llm_api.llm_client import LLMClient
|
||
|
||
# Generate embedding for query
|
||
embedding_client = LLMClient.create(
|
||
provider_type=embedding_config.embedding_endpoint_type,
|
||
actor=actor,
|
||
)
|
||
embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
|
||
query_embedding = embeddings[0]
|
||
|
||
# Query Turbopuffer - use hybrid search when text is available
|
||
tpuf_client = TurbopufferClient()
|
||
# use hybrid search to combine vector and full-text search
|
||
passages_with_scores = await tpuf_client.query_passages(
|
||
archive_id=archive_ids[0],
|
||
query_embedding=query_embedding,
|
||
query_text=query_text, # pass text for potential hybrid search
|
||
search_mode="hybrid", # use hybrid mode for better results
|
||
top_k=limit,
|
||
tags=tags,
|
||
tag_match_mode=tag_match_mode or TagMatchMode.ANY,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
)
|
||
|
||
# Return just the passages (without scores)
|
||
return [passage for passage, _ in passages_with_scores]
|
||
else:
|
||
return []
|
||
|
||
# Fall back to SQL-based search for non-vector queries or NATIVE archives
|
||
async with db_registry.async_session() as session:
|
||
main_query = await build_agent_passage_query(
|
||
actor=actor,
|
||
agent_id=agent_id,
|
||
query_text=query_text,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
before=before,
|
||
after=after,
|
||
embed_query=embed_query,
|
||
ascending=ascending,
|
||
embedding_config=embedding_config,
|
||
)
|
||
|
||
# Add limit
|
||
if limit:
|
||
main_query = main_query.limit(limit)
|
||
|
||
# Execute query
|
||
result = await session.execute(main_query)
|
||
|
||
# Get ORM objects directly using scalars()
|
||
passages = result.scalars().all()
|
||
|
||
# Convert to Pydantic models
|
||
pydantic_passages = [p.to_pydantic() for p in passages]
|
||
|
||
# TODO: Integrate tag filtering directly into the SQL query for better performance.
|
||
# Currently using post-filtering which is less efficient but simpler to implement.
|
||
# Future optimization: Add JOIN with passage_tags table and WHERE clause for tag filtering.
|
||
if tags:
|
||
filtered_passages = []
|
||
for passage in pydantic_passages:
|
||
if passage.tags:
|
||
passage_tags = set(passage.tags)
|
||
query_tags = set(tags)
|
||
|
||
if tag_match_mode == TagMatchMode.ALL:
|
||
# ALL mode: passage must have all query tags
|
||
if query_tags.issubset(passage_tags):
|
||
filtered_passages.append(passage)
|
||
else:
|
||
# ANY mode (default): passage must have at least one query tag
|
||
if query_tags.intersection(passage_tags):
|
||
filtered_passages.append(passage)
|
||
|
||
return filtered_passages
|
||
|
||
return pydantic_passages
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def search_agent_archival_memory_async(
|
||
self,
|
||
agent_id: str,
|
||
actor: PydanticUser,
|
||
query: str,
|
||
tags: Optional[List[str]] = None,
|
||
tag_match_mode: Literal["any", "all"] = "any",
|
||
top_k: Optional[int] = None,
|
||
start_datetime: Optional[str] = None,
|
||
end_datetime: Optional[str] = None,
|
||
) -> Tuple[List[Dict[str, Any]], int]:
|
||
"""
|
||
Search archival memory using semantic (embedding-based) search with optional temporal filtering.
|
||
|
||
This is a shared method used by both the agent tool and API endpoint to ensure consistent behavior.
|
||
|
||
Args:
|
||
agent_id: ID of the agent whose archival memory to search
|
||
actor: User performing the search
|
||
query: String to search for using semantic similarity
|
||
tags: Optional list of tags to filter search results
|
||
tag_match_mode: How to match tags - "any" or "all"
|
||
top_k: Maximum number of results to return
|
||
start_datetime: Filter results after this datetime (ISO 8601 format)
|
||
end_datetime: Filter results before this datetime (ISO 8601 format)
|
||
|
||
Returns:
|
||
Tuple of (formatted_results, count)
|
||
"""
|
||
# Handle empty or whitespace-only queries
|
||
if not query or not query.strip():
|
||
return [], 0
|
||
|
||
# Get the agent to access timezone and embedding config
|
||
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||
|
||
# Parse datetime parameters if provided
|
||
start_date = None
|
||
end_date = None
|
||
|
||
if start_datetime:
|
||
try:
|
||
# Try parsing as full datetime first (with time)
|
||
start_date = datetime.fromisoformat(start_datetime)
|
||
except ValueError:
|
||
try:
|
||
# Fall back to date-only format
|
||
start_date = datetime.strptime(start_datetime, "%Y-%m-%d")
|
||
# Set to beginning of day
|
||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||
except ValueError:
|
||
raise ValueError(
|
||
f"Invalid start_datetime format: {start_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)"
|
||
)
|
||
|
||
# Apply agent's timezone if datetime is naive
|
||
if start_date.tzinfo is None and agent_state.timezone:
|
||
tz = ZoneInfo(agent_state.timezone)
|
||
start_date = start_date.replace(tzinfo=tz)
|
||
|
||
if end_datetime:
|
||
try:
|
||
# Try parsing as full datetime first (with time)
|
||
end_date = datetime.fromisoformat(end_datetime)
|
||
except ValueError:
|
||
try:
|
||
# Fall back to date-only format
|
||
end_date = datetime.strptime(end_datetime, "%Y-%m-%d")
|
||
# Set to end of day for end dates
|
||
end_date = end_date.replace(hour=23, minute=59, second=59, microsecond=999999)
|
||
except ValueError:
|
||
raise ValueError(f"Invalid end_datetime format: {end_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)")
|
||
|
||
# Apply agent's timezone if datetime is naive
|
||
if end_date.tzinfo is None and agent_state.timezone:
|
||
tz = ZoneInfo(agent_state.timezone)
|
||
end_date = end_date.replace(tzinfo=tz)
|
||
|
||
# Convert string to TagMatchMode enum
|
||
tag_mode = TagMatchMode.ANY if tag_match_mode == "any" else TagMatchMode.ALL
|
||
|
||
# Get results using existing passage query method
|
||
limit = top_k if top_k is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||
all_results = await self.query_agent_passages_async(
|
||
actor=actor,
|
||
agent_id=agent_id,
|
||
query_text=query,
|
||
limit=limit,
|
||
embedding_config=agent_state.embedding_config,
|
||
embed_query=True,
|
||
tags=tags,
|
||
tag_match_mode=tag_mode,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
)
|
||
|
||
# Format results to include tags with friendly timestamps
|
||
formatted_results = []
|
||
for result in all_results:
|
||
# Format timestamp in agent's timezone if available
|
||
timestamp = result.created_at
|
||
if timestamp and agent_state.timezone:
|
||
try:
|
||
# Convert to agent's timezone
|
||
tz = ZoneInfo(agent_state.timezone)
|
||
local_time = timestamp.astimezone(tz)
|
||
# Format as ISO string with timezone
|
||
formatted_timestamp = local_time.isoformat()
|
||
except Exception:
|
||
# Fallback to ISO format if timezone conversion fails
|
||
formatted_timestamp = str(timestamp)
|
||
else:
|
||
# Use ISO format if no timezone is set
|
||
formatted_timestamp = str(timestamp) if timestamp else "Unknown"
|
||
|
||
formatted_results.append({"timestamp": formatted_timestamp, "content": result.text, "tags": result.tags or []})
|
||
|
||
return formatted_results, len(formatted_results)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def passage_size(
|
||
self,
|
||
actor: PydanticUser,
|
||
agent_id: Optional[str] = None,
|
||
file_id: Optional[str] = None,
|
||
query_text: Optional[str] = None,
|
||
start_date: Optional[datetime] = None,
|
||
end_date: Optional[datetime] = None,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
source_id: Optional[str] = None,
|
||
embed_query: bool = False,
|
||
ascending: bool = True,
|
||
embedding_config: Optional[EmbeddingConfig] = None,
|
||
agent_only: bool = False,
|
||
) -> int:
|
||
"""Returns the count of passages matching the given criteria."""
|
||
async with db_registry.async_session() as session:
|
||
main_query = await build_passage_query(
|
||
actor=actor,
|
||
agent_id=agent_id,
|
||
file_id=file_id,
|
||
query_text=query_text,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
before=before,
|
||
after=after,
|
||
source_id=source_id,
|
||
embed_query=embed_query,
|
||
ascending=ascending,
|
||
embedding_config=embedding_config,
|
||
agent_only=agent_only,
|
||
)
|
||
|
||
# Convert to count query
|
||
count_query = select(func.count()).select_from(main_query.subquery())
|
||
return (await session.scalar(count_query)) or 0
|
||
|
||
@enforce_types
|
||
async def passage_size_async(
|
||
self,
|
||
actor: PydanticUser,
|
||
agent_id: Optional[str] = None,
|
||
file_id: Optional[str] = None,
|
||
query_text: Optional[str] = None,
|
||
start_date: Optional[datetime] = None,
|
||
end_date: Optional[datetime] = None,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
source_id: Optional[str] = None,
|
||
embed_query: bool = False,
|
||
ascending: bool = True,
|
||
embedding_config: Optional[EmbeddingConfig] = None,
|
||
agent_only: bool = False,
|
||
) -> int:
|
||
async with db_registry.async_session() as session:
|
||
main_query = await build_passage_query(
|
||
actor=actor,
|
||
agent_id=agent_id,
|
||
file_id=file_id,
|
||
query_text=query_text,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
before=before,
|
||
after=after,
|
||
source_id=source_id,
|
||
embed_query=embed_query,
|
||
ascending=ascending,
|
||
embedding_config=embedding_config,
|
||
agent_only=agent_only,
|
||
)
|
||
|
||
# Convert to count query
|
||
count_query = select(func.count()).select_from(main_query.subquery())
|
||
return (await session.execute(count_query)).scalar() or 0
|
||
|
||
# ======================================================================================================================
|
||
# Tool Management
|
||
# ======================================================================================================================
|
||
@enforce_types
|
||
@trace_method
|
||
def attach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
"""
|
||
Attaches a tool to an agent.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to attach the tool to.
|
||
tool_id: ID of the tool to attach.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If the agent or tool is not found.
|
||
|
||
Returns:
|
||
PydanticAgentState: The updated agent state.
|
||
"""
|
||
with db_registry.session() as session:
|
||
# Verify the agent exists and user has permission to access it
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
|
||
# Use the _process_relationship helper to attach the tool
|
||
_process_relationship(
|
||
session=session,
|
||
agent=agent,
|
||
relationship_name="tools",
|
||
model_class=ToolModel,
|
||
item_ids=[tool_id],
|
||
allow_partial=False, # Ensure the tool exists
|
||
replace=False, # Extend the existing tools
|
||
)
|
||
|
||
# Commit and refresh the agent
|
||
agent.update(session, actor=actor)
|
||
return agent.to_pydantic()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def attach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None:
|
||
"""
|
||
Attaches a tool to an agent.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to attach the tool to.
|
||
tool_id: ID of the tool to attach.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If the agent or tool is not found.
|
||
|
||
Returns:
|
||
PydanticAgentState: The updated agent state.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# Verify the agent exists and user has permission to access it
|
||
await validate_agent_exists_async(session, agent_id, actor)
|
||
|
||
# verify tool exists and belongs to organization in a single query with the insert
|
||
# first, check if tool exists with correct organization
|
||
tool_check_query = select(ToolModel.name, ToolModel.default_requires_approval).where(
|
||
ToolModel.id == tool_id, ToolModel.organization_id == actor.organization_id
|
||
)
|
||
result = await session.execute(tool_check_query)
|
||
tool_rows = result.fetchall()
|
||
|
||
if len(tool_rows) == 0:
|
||
raise NoResultFound(f"Tool with id={tool_id} not found in organization={actor.organization_id}")
|
||
tool_name, default_requires_approval = tool_rows[0]
|
||
|
||
# use postgresql on conflict or mysql on duplicate key update for atomic operation
|
||
if settings.letta_pg_uri_no_default:
|
||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||
|
||
insert_stmt = pg_insert(ToolsAgents).values(agent_id=agent_id, tool_id=tool_id)
|
||
# on conflict do nothing - silently ignore if already exists
|
||
insert_stmt = insert_stmt.on_conflict_do_nothing(index_elements=["agent_id", "tool_id"])
|
||
result = await session.execute(insert_stmt)
|
||
if result.rowcount == 0:
|
||
logger.info(f"Tool id={tool_id} is already attached to agent id={agent_id}")
|
||
else:
|
||
# for sqlite/mysql, check then insert
|
||
existing_query = (
|
||
select(func.count()).select_from(ToolsAgents).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id == tool_id)
|
||
)
|
||
existing_result = await session.execute(existing_query)
|
||
if existing_result.scalar() == 0:
|
||
insert_stmt = insert(ToolsAgents).values(agent_id=agent_id, tool_id=tool_id)
|
||
await session.execute(insert_stmt)
|
||
else:
|
||
logger.info(f"Tool id={tool_id} is already attached to agent id={agent_id}")
|
||
|
||
if default_requires_approval:
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
existing_rules = [rule for rule in agent.tool_rules if rule.tool_name == tool_name and rule.type == "requires_approval"]
|
||
if len(existing_rules) == 0:
|
||
# Create a new list to ensure SQLAlchemy detects the change
|
||
# This is critical for JSON columns - modifying in place doesn't trigger change detection
|
||
tool_rules = list(agent.tool_rules) if agent.tool_rules else []
|
||
tool_rules.append(RequiresApprovalToolRule(tool_name=tool_name))
|
||
agent.tool_rules = tool_rules
|
||
session.add(agent)
|
||
|
||
await session.commit()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def bulk_attach_tools_async(self, agent_id: str, tool_ids: List[str], actor: PydanticUser) -> None:
|
||
"""
|
||
Efficiently attaches multiple tools to an agent in a single operation.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to attach the tools to.
|
||
tool_ids: List of tool IDs to attach.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If the agent or any tool is not found.
|
||
"""
|
||
if not tool_ids:
|
||
# no tools to attach, nothing to do
|
||
return
|
||
|
||
async with db_registry.async_session() as session:
|
||
# Verify the agent exists and user has permission to access it
|
||
await validate_agent_exists_async(session, agent_id, actor)
|
||
|
||
# verify all tools exist and belong to organization in a single query
|
||
tool_check_query = select(func.count(ToolModel.id)).where(
|
||
ToolModel.id.in_(tool_ids), ToolModel.organization_id == actor.organization_id
|
||
)
|
||
tool_result = await session.execute(tool_check_query)
|
||
found_count = tool_result.scalar()
|
||
|
||
if found_count != len(tool_ids):
|
||
# find which tools are missing for better error message
|
||
existing_query = select(ToolModel.id).where(ToolModel.id.in_(tool_ids), ToolModel.organization_id == actor.organization_id)
|
||
existing_result = await session.execute(existing_query)
|
||
existing_ids = {row[0] for row in existing_result}
|
||
missing_ids = set(tool_ids) - existing_ids
|
||
raise NoResultFound(f"Tools with ids={missing_ids} not found in organization={actor.organization_id}")
|
||
|
||
if settings.letta_pg_uri_no_default:
|
||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||
|
||
# prepare bulk values
|
||
values = [{"agent_id": agent_id, "tool_id": tool_id} for tool_id in tool_ids]
|
||
|
||
# bulk insert with on conflict do nothing
|
||
insert_stmt = pg_insert(ToolsAgents).values(values)
|
||
insert_stmt = insert_stmt.on_conflict_do_nothing(index_elements=["agent_id", "tool_id"])
|
||
result = await session.execute(insert_stmt)
|
||
logger.info(
|
||
f"Attached {result.rowcount} new tools to agent {agent_id} (skipped {len(tool_ids) - result.rowcount} already attached)"
|
||
)
|
||
else:
|
||
# for sqlite/mysql, first check which tools are already attached
|
||
existing_query = select(ToolsAgents.tool_id).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id.in_(tool_ids))
|
||
existing_result = await session.execute(existing_query)
|
||
already_attached = {row[0] for row in existing_result}
|
||
|
||
# only insert tools that aren't already attached
|
||
new_tool_ids = [tid for tid in tool_ids if tid not in already_attached]
|
||
|
||
if new_tool_ids:
|
||
# bulk insert new attachments
|
||
values = [{"agent_id": agent_id, "tool_id": tool_id} for tool_id in new_tool_ids]
|
||
insert_stmt = insert(ToolsAgents).values(values)
|
||
await session.execute(insert_stmt)
|
||
logger.info(
|
||
f"Attached {len(new_tool_ids)} new tools to agent {agent_id} (skipped {len(already_attached)} already attached)"
|
||
)
|
||
else:
|
||
logger.info(f"All {len(tool_ids)} tools already attached to agent {agent_id}")
|
||
|
||
await session.commit()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def attach_missing_files_tools_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
||
"""
|
||
Attaches missing core file tools to an agent.
|
||
|
||
Args:
|
||
agent_state: The current agent state with tools already loaded.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If the agent or tool is not found.
|
||
|
||
Returns:
|
||
PydanticAgentState: The updated agent state.
|
||
"""
|
||
# get current file tools attached to the agent
|
||
attached_file_tool_names = {tool.name for tool in agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE}
|
||
|
||
# determine which file tools are missing
|
||
missing_tool_names = set(FILES_TOOLS) - attached_file_tool_names
|
||
|
||
if not missing_tool_names:
|
||
# agent already has all file tools
|
||
return agent_state
|
||
|
||
# get full tool objects for all missing file tools in one query
|
||
async with db_registry.async_session() as session:
|
||
query = select(ToolModel).where(
|
||
ToolModel.name.in_(missing_tool_names),
|
||
ToolModel.organization_id == actor.organization_id,
|
||
ToolModel.tool_type == ToolType.LETTA_FILES_CORE,
|
||
)
|
||
result = await session.execute(query)
|
||
found_tool_models = result.scalars().all()
|
||
|
||
if not found_tool_models:
|
||
logger.warning(f"No file tools found for organization {actor.organization_id}. Expected tools: {missing_tool_names}")
|
||
return agent_state
|
||
|
||
# convert to pydantic tools
|
||
found_tools = [tool.to_pydantic() for tool in found_tool_models]
|
||
found_tool_names = {tool.name for tool in found_tools}
|
||
|
||
# log if any expected tools weren't found
|
||
still_missing = missing_tool_names - found_tool_names
|
||
if still_missing:
|
||
logger.warning(f"File tools {still_missing} not found in organization {actor.organization_id}")
|
||
|
||
# extract tool IDs for bulk attach
|
||
tool_ids_to_attach = [tool.id for tool in found_tools]
|
||
|
||
# bulk attach all found file tools
|
||
await self.bulk_attach_tools_async(agent_id=agent_state.id, tool_ids=tool_ids_to_attach, actor=actor)
|
||
|
||
# create a shallow copy with updated tools list to avoid modifying input
|
||
agent_state_dict = agent_state.model_dump()
|
||
agent_state_dict["tools"] = agent_state.tools + found_tools
|
||
|
||
return PydanticAgentState(**agent_state_dict)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def detach_all_files_tools_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
||
"""
|
||
Detach all core file tools from an agent.
|
||
|
||
Args:
|
||
agent_state: The current agent state with tools already loaded.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If the agent is not found.
|
||
|
||
Returns:
|
||
PydanticAgentState: The updated agent state.
|
||
"""
|
||
# extract file tool IDs directly from agent_state.tools
|
||
file_tool_ids = [tool.id for tool in agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE]
|
||
|
||
if not file_tool_ids:
|
||
# no file tools to detach
|
||
return agent_state
|
||
|
||
# bulk detach all file tools in one operation
|
||
await self.bulk_detach_tools_async(agent_id=agent_state.id, tool_ids=file_tool_ids, actor=actor)
|
||
|
||
# create a shallow copy with updated tools list to avoid modifying input
|
||
agent_state_dict = agent_state.model_dump()
|
||
agent_state_dict["tools"] = [tool for tool in agent_state.tools if tool.tool_type != ToolType.LETTA_FILES_CORE]
|
||
|
||
return PydanticAgentState(**agent_state_dict)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def detach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
"""
|
||
Detaches a tool from an agent.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to detach the tool from.
|
||
tool_id: ID of the tool to detach.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If the agent or tool is not found.
|
||
|
||
Returns:
|
||
PydanticAgentState: The updated agent state.
|
||
"""
|
||
with db_registry.session() as session:
|
||
# Verify the agent exists and user has permission to access it
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
|
||
# Filter out the tool to be detached
|
||
remaining_tools = [tool for tool in agent.tools if tool.id != tool_id]
|
||
|
||
if len(remaining_tools) == len(agent.tools): # Tool ID was not in the relationship
|
||
logger.warning(f"Attempted to remove unattached tool id={tool_id} from agent id={agent_id} by actor={actor}")
|
||
|
||
# Update the tools relationship
|
||
agent.tools = remaining_tools
|
||
|
||
# Commit and refresh the agent
|
||
agent.update(session, actor=actor)
|
||
return agent.to_pydantic()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def detach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None:
|
||
"""
|
||
Detaches a tool from an agent.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to detach the tool from.
|
||
tool_id: ID of the tool to detach.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If the agent is not found.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# Verify the agent exists and user has permission to access it
|
||
await validate_agent_exists_async(session, agent_id, actor)
|
||
|
||
# Delete the association directly - if it doesn't exist, rowcount will be 0
|
||
delete_query = delete(ToolsAgents).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id == tool_id)
|
||
result = await session.execute(delete_query)
|
||
|
||
if result.rowcount == 0:
|
||
logger.warning(f"Attempted to remove unattached tool id={tool_id} from agent id={agent_id} by actor={actor}")
|
||
else:
|
||
logger.debug(f"Detached tool id={tool_id} from agent id={agent_id}")
|
||
|
||
await session.commit()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def bulk_detach_tools_async(self, agent_id: str, tool_ids: List[str], actor: PydanticUser) -> None:
|
||
"""
|
||
Efficiently detaches multiple tools from an agent in a single operation.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to detach the tools from.
|
||
tool_ids: List of tool IDs to detach.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If the agent is not found.
|
||
"""
|
||
if not tool_ids:
|
||
# no tools to detach, nothing to do
|
||
return
|
||
|
||
async with db_registry.async_session() as session:
|
||
# Verify the agent exists and user has permission to access it
|
||
await validate_agent_exists_async(session, agent_id, actor)
|
||
|
||
# Delete all associations in a single query
|
||
delete_query = delete(ToolsAgents).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id.in_(tool_ids))
|
||
result = await session.execute(delete_query)
|
||
|
||
detached_count = result.rowcount
|
||
if detached_count == 0:
|
||
logger.warning(f"No tools from list {tool_ids} were attached to agent id={agent_id}")
|
||
elif detached_count < len(tool_ids):
|
||
logger.info(f"Detached {detached_count} tools from agent {agent_id} ({len(tool_ids) - detached_count} were not attached)")
|
||
else:
|
||
logger.info(f"Detached all {detached_count} tools from agent {agent_id}")
|
||
|
||
await session.commit()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def modify_approvals_async(self, agent_id: str, tool_name: str, requires_approval: bool, actor: PydanticUser) -> None:
|
||
def is_target_rule(rule):
|
||
return rule.tool_name == tool_name and rule.type == "requires_approval"
|
||
|
||
async with db_registry.async_session() as session:
|
||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||
existing_rules = [rule for rule in agent.tool_rules if is_target_rule(rule)]
|
||
|
||
if len(existing_rules) == 1 and not requires_approval:
|
||
tool_rules = [rule for rule in agent.tool_rules if not is_target_rule(rule)]
|
||
elif len(existing_rules) == 0 and requires_approval:
|
||
# Create a new list to ensure SQLAlchemy detects the change
|
||
# This is critical for JSON columns - modifying in place doesn't trigger change detection
|
||
tool_rules = list(agent.tool_rules) if agent.tool_rules else []
|
||
tool_rules.append(RequiresApprovalToolRule(tool_name=tool_name))
|
||
else:
|
||
tool_rules = None
|
||
|
||
if tool_rules is None:
|
||
return
|
||
|
||
agent.tool_rules = tool_rules
|
||
session.add(agent)
|
||
await session.commit()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def list_attached_tools(self, agent_id: str, actor: PydanticUser) -> List[PydanticTool]:
|
||
"""
|
||
List all tools attached to an agent.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to list tools for.
|
||
actor: User performing the action.
|
||
|
||
Returns:
|
||
List[PydanticTool]: List of tools attached to the agent.
|
||
"""
|
||
with db_registry.session() as session:
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
return [tool.to_pydantic() for tool in agent.tools]
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def list_attached_tools_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticTool]:
|
||
"""
|
||
List all tools attached to an agent (async version with optimized performance).
|
||
Uses direct SQL queries to avoid SqlAlchemyBase overhead.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to list tools for.
|
||
actor: User performing the action.
|
||
|
||
Returns:
|
||
List[PydanticTool]: List of tools attached to the agent.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# lightweight check for agent access
|
||
await validate_agent_exists_async(session, agent_id, actor)
|
||
|
||
# direct query for tools via join - much more performant
|
||
query = (
|
||
select(ToolModel)
|
||
.join(ToolsAgents, ToolModel.id == ToolsAgents.tool_id)
|
||
.where(ToolsAgents.agent_id == agent_id, ToolModel.organization_id == actor.organization_id)
|
||
)
|
||
|
||
result = await session.execute(query)
|
||
tools = result.scalars().all()
|
||
return [tool.to_pydantic() for tool in tools]
|
||
|
||
# ======================================================================================================================
|
||
# File Management
|
||
# ======================================================================================================================
|
||
async def insert_file_into_context_windows(
|
||
self,
|
||
source_id: str,
|
||
file_metadata_with_content: PydanticFileMetadata,
|
||
actor: PydanticUser,
|
||
agent_states: Optional[List[PydanticAgentState]] = None,
|
||
) -> List[PydanticAgentState]:
|
||
"""
|
||
Insert the uploaded document into the context window of all agents
|
||
attached to the given source.
|
||
"""
|
||
agent_states = agent_states or await self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
|
||
|
||
# Return early
|
||
if not agent_states:
|
||
return []
|
||
|
||
logger.info(f"Inserting document into context window for source: {source_id}")
|
||
logger.info(f"Attached agents: {[a.id for a in agent_states]}")
|
||
|
||
# Generate visible content for the file
|
||
line_chunker = LineChunker()
|
||
content_lines = line_chunker.chunk_text(file_metadata=file_metadata_with_content)
|
||
visible_content = "\n".join(content_lines)
|
||
visible_content_map = {file_metadata_with_content.file_name: visible_content}
|
||
|
||
# Attach file to each agent using bulk method (one file per agent, but atomic per agent)
|
||
all_closed_files = await asyncio.gather(
|
||
*(
|
||
self.file_agent_manager.attach_files_bulk(
|
||
agent_id=agent_state.id,
|
||
files_metadata=[file_metadata_with_content],
|
||
visible_content_map=visible_content_map,
|
||
actor=actor,
|
||
max_files_open=agent_state.max_files_open,
|
||
)
|
||
for agent_state in agent_states
|
||
)
|
||
)
|
||
# Flatten and log if any files were closed
|
||
closed_files = [file for closed_list in all_closed_files for file in closed_list]
|
||
if closed_files:
|
||
logger.info(f"LRU eviction closed {len(closed_files)} files during bulk attach: {closed_files}")
|
||
|
||
return agent_states
|
||
|
||
async def insert_files_into_context_window(
|
||
self, agent_state: PydanticAgentState, file_metadata_with_content: List[PydanticFileMetadata], actor: PydanticUser
|
||
) -> None:
|
||
"""
|
||
Insert the uploaded documents into the context window of an agent
|
||
attached to the given source.
|
||
"""
|
||
logger.info(f"Inserting {len(file_metadata_with_content)} documents into context window for agent_state: {agent_state.id}")
|
||
|
||
# Generate visible content for each file
|
||
line_chunker = LineChunker()
|
||
visible_content_map = {}
|
||
for file_metadata in file_metadata_with_content:
|
||
content_lines = line_chunker.chunk_text(file_metadata=file_metadata)
|
||
visible_content_map[file_metadata.file_name] = "\n".join(content_lines)
|
||
|
||
# Use bulk attach to avoid race conditions and duplicate LRU eviction decisions
|
||
closed_files = await self.file_agent_manager.attach_files_bulk(
|
||
agent_id=agent_state.id,
|
||
files_metadata=file_metadata_with_content,
|
||
visible_content_map=visible_content_map,
|
||
actor=actor,
|
||
max_files_open=agent_state.max_files_open,
|
||
)
|
||
|
||
if closed_files:
|
||
logger.info(f"LRU eviction closed {len(closed_files)} files during bulk insert: {closed_files}")
|
||
|
||
# ======================================================================================================================
|
||
# Tag Management
|
||
# ======================================================================================================================
|
||
@enforce_types
|
||
@trace_method
|
||
def list_tags(
|
||
self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, query_text: Optional[str] = None
|
||
) -> List[str]:
|
||
"""
|
||
Get all tags a user has created, ordered alphabetically.
|
||
|
||
Args:
|
||
actor: User performing the action.
|
||
after: Cursor for forward pagination.
|
||
limit: Maximum number of tags to return.
|
||
query_text: Query text to filter tags by.
|
||
|
||
Returns:
|
||
List[str]: List of all tags.
|
||
"""
|
||
with db_registry.session() as session:
|
||
query = (
|
||
session.query(AgentsTags.tag)
|
||
.join(AgentModel, AgentModel.id == AgentsTags.agent_id)
|
||
.filter(AgentModel.organization_id == actor.organization_id)
|
||
.distinct()
|
||
)
|
||
|
||
if query_text:
|
||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||
# PostgreSQL: Use ILIKE for case-insensitive search
|
||
query = query.filter(AgentsTags.tag.ilike(f"%{query_text}%"))
|
||
else:
|
||
# SQLite: Use LIKE with LOWER for case-insensitive search
|
||
query = query.filter(func.lower(AgentsTags.tag).like(func.lower(f"%{query_text}%")))
|
||
|
||
if after:
|
||
query = query.filter(AgentsTags.tag > after)
|
||
|
||
query = query.order_by(AgentsTags.tag).limit(limit)
|
||
results = [tag[0] for tag in query.all()]
|
||
return results
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def list_tags_async(
|
||
self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, query_text: Optional[str] = None
|
||
) -> List[str]:
|
||
"""
|
||
Get all tags a user has created, ordered alphabetically.
|
||
|
||
Args:
|
||
actor: User performing the action.
|
||
after: Cursor for forward pagination.
|
||
limit: Maximum number of tags to return.
|
||
query text to filter tags by.
|
||
|
||
Returns:
|
||
List[str]: List of all tags.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# Build the query using select() for async SQLAlchemy
|
||
query = (
|
||
select(AgentsTags.tag)
|
||
.join(AgentModel, AgentModel.id == AgentsTags.agent_id)
|
||
.where(AgentModel.organization_id == actor.organization_id)
|
||
.distinct()
|
||
)
|
||
|
||
if query_text:
|
||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||
# PostgreSQL: Use ILIKE for case-insensitive search
|
||
query = query.where(AgentsTags.tag.ilike(f"%{query_text}%"))
|
||
else:
|
||
# SQLite: Use LIKE with LOWER for case-insensitive search
|
||
query = query.where(func.lower(AgentsTags.tag).like(func.lower(f"%{query_text}%")))
|
||
|
||
if after:
|
||
query = query.where(AgentsTags.tag > after)
|
||
|
||
query = query.order_by(AgentsTags.tag).limit(limit)
|
||
|
||
# Execute the query asynchronously
|
||
result = await session.execute(query)
|
||
# Extract the tag values from the result
|
||
results = [row[0] for row in result.all()]
|
||
return results
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def get_agent_files_config_async(self, agent_id: str, actor: PydanticUser) -> Tuple[int, int]:
|
||
"""Get per_file_view_window_char_limit and max_files_open for an agent.
|
||
|
||
This is a performant query that only fetches the specific fields needed.
|
||
|
||
Args:
|
||
agent_id: The ID of the agent
|
||
actor: The user making the request
|
||
|
||
Returns:
|
||
Tuple of per_file_view_window_char_limit, max_files_open values
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
result = await session.execute(
|
||
select(AgentModel.per_file_view_window_char_limit, AgentModel.max_files_open)
|
||
.where(AgentModel.id == agent_id)
|
||
.where(AgentModel.organization_id == actor.organization_id)
|
||
.where(AgentModel.is_deleted == False)
|
||
)
|
||
row = result.one_or_none()
|
||
|
||
if row is None:
|
||
raise ValueError(f"Agent {agent_id} not found")
|
||
|
||
per_file_limit, max_files = row[0], row[1]
|
||
|
||
# Handle None values by calculating defaults based on context window
|
||
if per_file_limit is None or max_files is None:
|
||
# Get the agent's model context window to calculate appropriate defaults
|
||
model_result = await session.execute(
|
||
select(AgentModel.llm_config)
|
||
.where(AgentModel.id == agent_id)
|
||
.where(AgentModel.organization_id == actor.organization_id)
|
||
.where(AgentModel.is_deleted == False)
|
||
)
|
||
model_row = model_result.one_or_none()
|
||
context_window = model_row[0].context_window if model_row and model_row[0] else None
|
||
|
||
default_max_files, default_per_file_limit = calculate_file_defaults_based_on_context_window(context_window)
|
||
|
||
# Use calculated defaults for None values
|
||
if per_file_limit is None:
|
||
per_file_limit = default_per_file_limit
|
||
if max_files is None:
|
||
max_files = default_max_files
|
||
|
||
# FINAL fallback: ensure neither is None (should never happen, but just in case)
|
||
if per_file_limit is None:
|
||
per_file_limit = DEFAULT_CORE_MEMORY_SOURCE_CHAR_LIMIT
|
||
if max_files is None:
|
||
max_files = DEFAULT_MAX_FILES_OPEN
|
||
|
||
return per_file_limit, max_files
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def get_agent_max_files_open_async(self, agent_id: str, actor: PydanticUser) -> int:
|
||
"""Get max_files_open for an agent.
|
||
|
||
This is a performant query that only fetches the specific field needed.
|
||
|
||
Args:
|
||
agent_id: The ID of the agent
|
||
actor: The user making the request
|
||
|
||
Returns:
|
||
max_files_open value
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
result = await session.execute(
|
||
select(AgentModel.max_files_open)
|
||
.where(AgentModel.id == agent_id)
|
||
.where(AgentModel.organization_id == actor.organization_id)
|
||
.where(AgentModel.is_deleted == False)
|
||
)
|
||
row = result.scalar_one_or_none()
|
||
|
||
if row is None:
|
||
raise ValueError(f"Agent {agent_id} not found")
|
||
|
||
return row
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def get_agent_per_file_view_window_char_limit_async(self, agent_id: str, actor: PydanticUser) -> int:
|
||
"""Get per_file_view_window_char_limit for an agent.
|
||
|
||
This is a performant query that only fetches the specific field needed.
|
||
|
||
Args:
|
||
agent_id: The ID of the agent
|
||
actor: The user making the request
|
||
|
||
Returns:
|
||
per_file_view_window_char_limit value
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
result = await session.execute(
|
||
select(AgentModel.per_file_view_window_char_limit)
|
||
.where(AgentModel.id == agent_id)
|
||
.where(AgentModel.organization_id == actor.organization_id)
|
||
.where(AgentModel.is_deleted == False)
|
||
)
|
||
row = result.scalar_one_or_none()
|
||
|
||
if row is None:
|
||
raise ValueError(f"Agent {agent_id} not found")
|
||
|
||
return row
|
||
|
||
@trace_method
|
||
async def get_context_window(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview:
|
||
agent_state, system_message, num_messages, num_archival_memories = await self.rebuild_system_prompt_async(
|
||
agent_id=agent_id, actor=actor, force=True, dry_run=True
|
||
)
|
||
calculator = ContextWindowCalculator()
|
||
|
||
if settings.environment == "PRODUCTION" or agent_state.llm_config.model_endpoint_type == "anthropic":
|
||
anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor)
|
||
model = agent_state.llm_config.model if agent_state.llm_config.model_endpoint_type == "anthropic" else None
|
||
|
||
token_counter = AnthropicTokenCounter(anthropic_client, model) # noqa
|
||
else:
|
||
token_counter = TiktokenCounter(agent_state.llm_config.model)
|
||
|
||
return await calculator.calculate_context_window(
|
||
agent_state=agent_state,
|
||
actor=actor,
|
||
token_counter=token_counter,
|
||
message_manager=self.message_manager,
|
||
system_message_compiled=system_message,
|
||
num_archival_memories=num_archival_memories,
|
||
num_messages=num_messages,
|
||
)
|
||
|
||
async def get_or_set_vector_db_namespace_async(
|
||
self,
|
||
agent_id: str,
|
||
organization_id: str,
|
||
) -> str:
|
||
"""Get the vector database namespace for an agent, creating it if it doesn't exist.
|
||
|
||
Args:
|
||
agent_id: Agent ID to check/store namespace
|
||
organization_id: Organization ID for namespace generation
|
||
|
||
Returns:
|
||
The org-scoped namespace name
|
||
"""
|
||
from sqlalchemy import update
|
||
|
||
from letta.settings import settings
|
||
|
||
async with db_registry.async_session() as session:
|
||
# check if namespace already exists
|
||
result = await session.execute(select(AgentModel._vector_db_namespace).where(AgentModel.id == agent_id))
|
||
row = result.fetchone()
|
||
|
||
if row and row[0]:
|
||
return row[0]
|
||
|
||
# TODO: In the future, we might use agent_id for sharding the namespace
|
||
# For now, all messages in an org share the same namespace
|
||
|
||
# generate org-scoped namespace name
|
||
environment = settings.environment
|
||
if environment:
|
||
namespace_name = f"messages_{organization_id}_{environment.lower()}"
|
||
else:
|
||
namespace_name = f"messages_{organization_id}"
|
||
|
||
# update the agent with the namespace (keeps agent-level tracking for future sharding)
|
||
await session.execute(update(AgentModel).where(AgentModel.id == agent_id).values(_vector_db_namespace=namespace_name))
|
||
await session.commit()
|
||
|
||
return namespace_name
|