Files
letta-server/letta/services/agent_manager.py
2024-12-26 19:27:09 -08:00

876 lines
38 KiB
Python

from datetime import datetime
from typing import Dict, List, Optional
import numpy as np
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM
from letta.embeddings import embedding_model
from letta.log import get_logger
from letta.orm import Agent as AgentModel
from letta.orm import AgentPassage
from letta.orm import Block as BlockModel
from letta.orm import Source as SourceModel
from letta.orm import SourcePassage, SourcesAgents
from letta.orm import Tool as ToolModel
from letta.orm.errors import NoResultFound
from letta.orm.sqlite_functions import adapt_array
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.source import Source as PydanticSource
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
from letta.schemas.user import User as PydanticUser
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import (
_process_relationship,
_process_tags,
check_supports_structured_output,
compile_system_message,
derive_system_message,
initialize_message_sequence,
package_initial_message_sequence,
)
from letta.services.message_manager import MessageManager
from letta.services.source_manager import SourceManager
from letta.services.tool_manager import ToolManager
from letta.settings import settings
from letta.utils import enforce_types, get_utc_time, united_diff
from sqlalchemy import Select, func, literal, select, union_all
logger = get_logger(__name__)
# Agent Manager Class
class AgentManager:
"""Manager class to handle business logic related to Agents."""
def __init__(self):
from letta.server.server import db_context
self.session_maker = db_context
self.block_manager = BlockManager()
self.tool_manager = ToolManager()
self.source_manager = SourceManager()
self.message_manager = MessageManager()
# ======================================================================================================================
# Basic CRUD operations
# ======================================================================================================================
@enforce_types
def create_agent(
self,
agent_create: CreateAgent,
actor: PydanticUser,
) -> PydanticAgentState:
system = derive_system_message(agent_type=agent_create.agent_type, system=agent_create.system)
if not agent_create.llm_config or not agent_create.embedding_config:
raise ValueError("llm_config and embedding_config are required")
# Check tool rules are valid
if agent_create.tool_rules:
check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=agent_create.tool_rules)
# create blocks (note: cannot be linked into the agent_id is created)
block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original
for create_block in agent_create.memory_blocks:
block = self.block_manager.create_or_update_block(PydanticBlock(**create_block.model_dump()), actor=actor)
block_ids.append(block.id)
# TODO: Remove this block once we deprecate the legacy `tools` field
# create passed in `tools`
tool_names = []
if agent_create.include_base_tools:
tool_names.extend(BASE_TOOLS + BASE_MEMORY_TOOLS)
if agent_create.tools:
tool_names.extend(agent_create.tools)
# Remove duplicates
tool_names = list(set(tool_names))
tool_ids = agent_create.tool_ids or []
for tool_name in tool_names:
tool = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
if tool:
tool_ids.append(tool.id)
# Remove duplicates
tool_ids = list(set(tool_ids))
# Create the agent
agent_state = self._create_agent(
name=agent_create.name,
system=system,
agent_type=agent_create.agent_type,
llm_config=agent_create.llm_config,
embedding_config=agent_create.embedding_config,
block_ids=block_ids,
tool_ids=tool_ids,
source_ids=agent_create.source_ids or [],
tags=agent_create.tags or [],
description=agent_create.description,
metadata_=agent_create.metadata_,
tool_rules=agent_create.tool_rules,
actor=actor,
)
# TODO: See if we can merge this into the above SQL create call for performance reasons
# Generate a sequence of initial messages to put in the buffer
init_messages = initialize_message_sequence(
agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True
)
if agent_create.initial_message_sequence is not None:
# We always need the system prompt up front
system_message_obj = PydanticMessage.dict_to_message(
agent_id=agent_state.id,
user_id=agent_state.created_by_id,
model=agent_state.llm_config.model,
openai_message_dict=init_messages[0],
)
# Don't use anything else in the pregen sequence, instead use the provided sequence
init_messages = [system_message_obj]
init_messages.extend(
package_initial_message_sequence(agent_state.id, agent_create.initial_message_sequence, agent_state.llm_config.model, actor)
)
else:
init_messages = [
PydanticMessage.dict_to_message(
agent_id=agent_state.id, user_id=agent_state.created_by_id, model=agent_state.llm_config.model, openai_message_dict=msg
)
for msg in init_messages
]
return self.append_to_in_context_messages(init_messages, agent_id=agent_state.id, actor=actor)
@enforce_types
def _create_agent(
self,
actor: PydanticUser,
name: str,
system: str,
agent_type: AgentType,
llm_config: LLMConfig,
embedding_config: EmbeddingConfig,
block_ids: List[str],
tool_ids: List[str],
source_ids: List[str],
tags: List[str],
description: Optional[str] = None,
metadata_: Optional[Dict] = None,
tool_rules: Optional[List[PydanticToolRule]] = None,
) -> PydanticAgentState:
"""Create a new agent."""
with self.session_maker() as session:
# Prepare the agent data
data = {
"name": name,
"system": system,
"agent_type": agent_type,
"llm_config": llm_config,
"embedding_config": embedding_config,
"organization_id": actor.organization_id,
"description": description,
"metadata_": metadata_,
"tool_rules": tool_rules,
}
# Create the new agent using SqlalchemyBase.create
new_agent = AgentModel(**data)
_process_relationship(session, new_agent, "tools", ToolModel, tool_ids, replace=True)
_process_relationship(session, new_agent, "sources", SourceModel, source_ids, replace=True)
_process_relationship(session, new_agent, "core_memory", BlockModel, block_ids, replace=True)
_process_tags(new_agent, tags, replace=True)
new_agent.create(session, actor=actor)
# Convert to PydanticAgentState and return
return new_agent.to_pydantic()
@enforce_types
def update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState:
agent_state = self._update_agent(agent_id=agent_id, agent_update=agent_update, actor=actor)
# Rebuild the system prompt if it's different
if agent_update.system and agent_update.system != agent_state.system:
agent_state = self.rebuild_system_prompt(agent_id=agent_state.id, actor=actor, force=True, update_timestamp=False)
return agent_state
@enforce_types
def _update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState:
"""
Update an existing agent.
Args:
agent_id: The ID of the agent to update.
agent_update: UpdateAgent object containing the updated fields.
actor: User performing the action.
Returns:
PydanticAgentState: The updated agent as a Pydantic model.
"""
with self.session_maker() as session:
# Retrieve the existing agent
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# Update scalar fields directly
scalar_fields = {"name", "system", "llm_config", "embedding_config", "message_ids", "tool_rules", "description", "metadata_"}
for field in scalar_fields:
value = getattr(agent_update, field, None)
if value is not None:
setattr(agent, field, value)
# Update relationships using _process_relationship and _process_tags
if agent_update.tool_ids is not None:
_process_relationship(session, agent, "tools", ToolModel, agent_update.tool_ids, replace=True)
if agent_update.source_ids is not None:
_process_relationship(session, agent, "sources", SourceModel, agent_update.source_ids, replace=True)
if agent_update.block_ids is not None:
_process_relationship(session, agent, "core_memory", BlockModel, agent_update.block_ids, replace=True)
if agent_update.tags is not None:
_process_tags(agent, agent_update.tags, replace=True)
# Commit and refresh the agent
agent.update(session, actor=actor)
# Convert to PydanticAgentState and return
return agent.to_pydantic()
@enforce_types
def list_agents(
self,
actor: PydanticUser,
tags: Optional[List[str]] = None,
match_all_tags: bool = False,
cursor: Optional[str] = None,
limit: Optional[int] = 50,
**kwargs,
) -> List[PydanticAgentState]:
"""
List agents that have the specified tags.
"""
with self.session_maker() as session:
agents = AgentModel.list(
db_session=session,
tags=tags,
match_all_tags=match_all_tags,
cursor=cursor,
limit=limit,
organization_id=actor.organization_id if actor else None,
**kwargs,
)
return [agent.to_pydantic() for agent in agents]
@enforce_types
def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
"""Fetch an agent by its ID."""
with self.session_maker() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
return agent.to_pydantic()
@enforce_types
def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState:
"""Fetch an agent by its ID."""
with self.session_maker() as session:
agent = AgentModel.read(db_session=session, name=agent_name, actor=actor)
return agent.to_pydantic()
@enforce_types
def delete_agent(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Deletes an agent and its associated relationships.
Ensures proper permission checks and cascades where applicable.
Args:
agent_id: ID of the agent to be deleted.
actor: User performing the action.
Returns:
PydanticAgentState: The deleted agent state
"""
with self.session_maker() as session:
# Retrieve the agent
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
agent_state = agent.to_pydantic()
agent.hard_delete(session)
return agent_state
# ======================================================================================================================
# In Context Messages Management
# ======================================================================================================================
# TODO: There are several assumptions here that are not explicitly checked
# TODO: 1) These message ids are valid
# TODO: 2) These messages are ordered from oldest to newest
# TODO: This can be fixed by having an actual relationship in the ORM for message_ids
# TODO: This can also be made more efficient, instead of getting, setting, we can do it all in one db session for one query.
@enforce_types
def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
return self.message_manager.get_messages_by_ids(message_ids=message_ids, actor=actor)
@enforce_types
def get_system_message(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor)
@enforce_types
def rebuild_system_prompt(self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True) -> PydanticAgentState:
"""Rebuilds the system message with the latest memory object and any shared memory block updates
Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
"""
agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor)
curr_system_message = self.get_system_message(
agent_id=agent_id, actor=actor
) # this is the system + memory bank, not just the system prompt
curr_system_message_openai = curr_system_message.to_openai_dict()
# note: we only update the system prompt if the core memory is changed
# this means that the archival/recall memory statistics may be someout out of date
curr_memory_str = agent_state.memory.compile()
if curr_memory_str in curr_system_message_openai["content"] and not force:
# NOTE: could this cause issues if a block is removed? (substring match would still work)
logger.info(
f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild"
)
return agent_state
# If the memory didn't update, we probably don't want to update the timestamp inside
# For example, if we're doing a system prompt swap, this should probably be False
if update_timestamp:
memory_edit_timestamp = get_utc_time()
else:
# NOTE: a bit of a hack - we pull the timestamp from the message created_by
memory_edit_timestamp = curr_system_message.created_at
# update memory (TODO: potentially update recall/archival stats separately)
new_system_message_str = compile_system_message(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
)
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
if len(diff) > 0: # there was a diff
logger.info(f"Rebuilding system with new memory...\nDiff:\n{diff}")
# Swap the system message out (only if there is a diff)
message = PydanticMessage.dict_to_message(
agent_id=agent_id,
user_id=actor.id,
model=agent_state.llm_config.model,
openai_message_dict={"role": "system", "content": new_system_message_str},
)
message = self.message_manager.create_message(message, actor=actor)
message_ids = [message.id] + agent_state.message_ids[1:] # swap index 0 (system)
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
else:
return agent_state
@enforce_types
def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
@enforce_types
def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
@enforce_types
def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
new_messages = self.message_manager.create_many_messages(messages, actor=actor)
message_ids = [message_ids[0]] + [m.id for m in new_messages] + message_ids[1:]
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
@enforce_types
def append_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
messages = self.message_manager.create_many_messages(messages, actor=actor)
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids or []
message_ids += [m.id for m in messages]
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
# ======================================================================================================================
# Source Management
# ======================================================================================================================
@enforce_types
def attach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> None:
"""
Attaches a source to an agent.
Args:
agent_id: ID of the agent to attach the source to
source_id: ID of the source to attach
actor: User performing the action
Raises:
ValueError: If either agent or source doesn't exist
IntegrityError: If the source is already attached to the agent
"""
with self.session_maker() as session:
# Verify both agent and source exist and user has permission to access them
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# The _process_relationship helper already handles duplicate checking via unique constraint
_process_relationship(
session=session,
agent=agent,
relationship_name="sources",
model_class=SourceModel,
item_ids=[source_id],
allow_partial=False,
replace=False, # Extend existing sources rather than replace
)
# Commit the changes
agent.update(session, actor=actor)
@enforce_types
def list_attached_sources(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]:
"""
Lists all sources attached to an agent.
Args:
agent_id: ID of the agent to list sources for
actor: User performing the action
Returns:
List[str]: List of source IDs attached to the agent
"""
with self.session_maker() as session:
# Verify agent exists and user has permission to access it
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# Use the lazy-loaded relationship to get sources
return [source.to_pydantic() for source in agent.sources]
@enforce_types
def detach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> None:
"""
Detaches a source from an agent.
Args:
agent_id: ID of the agent to detach the source from
source_id: ID of the source to detach
actor: User performing the action
"""
with self.session_maker() as session:
# Verify agent exists and user has permission to access it
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# Remove the source from the relationship
agent.sources = [s for s in agent.sources if s.id != source_id]
# Commit the changes
agent.update(session, actor=actor)
# ======================================================================================================================
# Block management
# ======================================================================================================================
@enforce_types
def get_block_with_label(
self,
agent_id: str,
block_label: str,
actor: PydanticUser,
) -> PydanticBlock:
"""Gets a block attached to an agent by its label."""
with self.session_maker() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
for block in agent.core_memory:
if block.label == block_label:
return block.to_pydantic()
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
@enforce_types
def update_block_with_label(
self,
agent_id: str,
block_label: str,
new_block_id: str,
actor: PydanticUser,
) -> PydanticAgentState:
"""Updates which block is assigned to a specific label for an agent."""
with self.session_maker() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
new_block = BlockModel.read(db_session=session, identifier=new_block_id, actor=actor)
if new_block.label != block_label:
raise ValueError(f"New block label '{new_block.label}' doesn't match required label '{block_label}'")
# Remove old block with this label if it exists
agent.core_memory = [b for b in agent.core_memory if b.label != block_label]
# Add new block
agent.core_memory.append(new_block)
agent.update(session, actor=actor)
return agent.to_pydantic()
@enforce_types
def attach_block(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState:
"""Attaches a block to an agent."""
with self.session_maker() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
agent.core_memory.append(block)
agent.update(session, actor=actor)
return agent.to_pydantic()
@enforce_types
def detach_block(
self,
agent_id: str,
block_id: str,
actor: PydanticUser,
) -> PydanticAgentState:
"""Detaches a block from an agent."""
with self.session_maker() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
original_length = len(agent.core_memory)
agent.core_memory = [b for b in agent.core_memory if b.id != block_id]
if len(agent.core_memory) == original_length:
raise NoResultFound(f"No block with id '{block_id}' found for agent '{agent_id}' with actor id: '{actor.id}'")
agent.update(session, actor=actor)
return agent.to_pydantic()
@enforce_types
def detach_block_with_label(
self,
agent_id: str,
block_label: str,
actor: PydanticUser,
) -> PydanticAgentState:
"""Detaches a block with the specified label from an agent."""
with self.session_maker() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
original_length = len(agent.core_memory)
agent.core_memory = [b for b in agent.core_memory if b.label != block_label]
if len(agent.core_memory) == original_length:
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}' with actor id: '{actor.id}'")
agent.update(session, actor=actor)
return agent.to_pydantic()
# ======================================================================================================================
# Passage Management
# ======================================================================================================================
def _build_passage_query(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
query_text: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
cursor: Optional[str] = None,
source_id: Optional[str] = None,
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False,
) -> Select:
"""Helper function to build the base passage query with all filters applied.
Returns the query before any limit or count operations are applied.
"""
embedded_text = None
if embed_query:
assert embedding_config is not None, "embedding_config must be specified for vector search"
assert query_text is not None, "query_text must be specified for vector search"
embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
embedded_text = np.array(embedded_text)
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
with self.session_maker() as session:
# Start with base query for source passages
source_passages = None
if not agent_only: # Include source passages
if agent_id is not None:
source_passages = (
select(SourcePassage, literal(None).label("agent_id"))
.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
.where(SourcesAgents.agent_id == agent_id)
.where(SourcePassage.organization_id == actor.organization_id)
)
else:
source_passages = select(SourcePassage, literal(None).label("agent_id")).where(
SourcePassage.organization_id == actor.organization_id
)
if source_id:
source_passages = source_passages.where(SourcePassage.source_id == source_id)
if file_id:
source_passages = source_passages.where(SourcePassage.file_id == file_id)
# Add agent passages query
agent_passages = None
if agent_id is not None:
agent_passages = (
select(
AgentPassage.id,
AgentPassage.text,
AgentPassage.embedding_config,
AgentPassage.metadata_,
AgentPassage.embedding,
AgentPassage.created_at,
AgentPassage.updated_at,
AgentPassage.is_deleted,
AgentPassage._created_by_id,
AgentPassage._last_updated_by_id,
AgentPassage.organization_id,
literal(None).label("file_id"),
literal(None).label("source_id"),
AgentPassage.agent_id,
)
.where(AgentPassage.agent_id == agent_id)
.where(AgentPassage.organization_id == actor.organization_id)
)
# Combine queries
if source_passages is not None and agent_passages is not None:
combined_query = union_all(source_passages, agent_passages).cte("combined_passages")
elif agent_passages is not None:
combined_query = agent_passages.cte("combined_passages")
elif source_passages is not None:
combined_query = source_passages.cte("combined_passages")
else:
raise ValueError("No passages found")
# Build main query from combined CTE
main_query = select(combined_query)
# Apply filters
if start_date:
main_query = main_query.where(combined_query.c.created_at >= start_date)
if end_date:
main_query = main_query.where(combined_query.c.created_at <= end_date)
if source_id:
main_query = main_query.where(combined_query.c.source_id == source_id)
if file_id:
main_query = main_query.where(combined_query.c.file_id == file_id)
# Vector search
if embedded_text:
if settings.letta_pg_uri_no_default:
# PostgreSQL with pgvector
main_query = main_query.order_by(combined_query.c.embedding.cosine_distance(embedded_text).asc())
else:
# SQLite with custom vector type
query_embedding_binary = adapt_array(embedded_text)
if ascending:
main_query = main_query.order_by(
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
combined_query.c.created_at.asc(),
combined_query.c.id.asc(),
)
else:
main_query = main_query.order_by(
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
combined_query.c.created_at.desc(),
combined_query.c.id.asc(),
)
else:
if query_text:
main_query = main_query.where(func.lower(combined_query.c.text).contains(func.lower(query_text)))
# Handle cursor-based pagination
if cursor:
cursor_query = select(combined_query.c.created_at).where(combined_query.c.id == cursor).scalar_subquery()
if ascending:
main_query = main_query.where(combined_query.c.created_at > cursor_query)
else:
main_query = main_query.where(combined_query.c.created_at < cursor_query)
# Add ordering if not already ordered by similarity
if not embed_query:
if ascending:
main_query = main_query.order_by(
combined_query.c.created_at.asc(),
combined_query.c.id.asc(),
)
else:
main_query = main_query.order_by(
combined_query.c.created_at.desc(),
combined_query.c.id.asc(),
)
return main_query
@enforce_types
def list_passages(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
limit: Optional[int] = 50,
query_text: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
cursor: Optional[str] = None,
source_id: Optional[str] = None,
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False,
) -> List[PydanticPassage]:
"""Lists all passages attached to an agent."""
with self.session_maker() as session:
main_query = self._build_passage_query(
actor=actor,
agent_id=agent_id,
file_id=file_id,
query_text=query_text,
start_date=start_date,
end_date=end_date,
cursor=cursor,
source_id=source_id,
embed_query=embed_query,
ascending=ascending,
embedding_config=embedding_config,
agent_only=agent_only,
)
# Add limit
if limit:
main_query = main_query.limit(limit)
# Execute query
results = list(session.execute(main_query))
passages = []
for row in results:
data = dict(row._mapping)
if data["agent_id"] is not None:
# This is an AgentPassage - remove source fields
data.pop("source_id", None)
data.pop("file_id", None)
passage = AgentPassage(**data)
else:
# This is a SourcePassage - remove agent field
data.pop("agent_id", None)
passage = SourcePassage(**data)
passages.append(passage)
return [p.to_pydantic() for p in passages]
@enforce_types
def passage_size(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
query_text: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
cursor: Optional[str] = None,
source_id: Optional[str] = None,
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False,
) -> int:
"""Returns the count of passages matching the given criteria."""
with self.session_maker() as session:
main_query = self._build_passage_query(
actor=actor,
agent_id=agent_id,
file_id=file_id,
query_text=query_text,
start_date=start_date,
end_date=end_date,
cursor=cursor,
source_id=source_id,
embed_query=embed_query,
ascending=ascending,
embedding_config=embedding_config,
agent_only=agent_only,
)
# Convert to count query
count_query = select(func.count()).select_from(main_query.subquery())
return session.scalar(count_query) or 0
# ======================================================================================================================
# Tool Management
# ======================================================================================================================
@enforce_types
def attach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Attaches a tool to an agent.
Args:
agent_id: ID of the agent to attach the tool to.
tool_id: ID of the tool to attach.
actor: User performing the action.
Raises:
NoResultFound: If the agent or tool is not found.
Returns:
PydanticAgentState: The updated agent state.
"""
with self.session_maker() as session:
# Verify the agent exists and user has permission to access it
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# Use the _process_relationship helper to attach the tool
_process_relationship(
session=session,
agent=agent,
relationship_name="tools",
model_class=ToolModel,
item_ids=[tool_id],
allow_partial=False, # Ensure the tool exists
replace=False, # Extend the existing tools
)
# Commit and refresh the agent
agent.update(session, actor=actor)
return agent.to_pydantic()
@enforce_types
def detach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Detaches a tool from an agent.
Args:
agent_id: ID of the agent to detach the tool from.
tool_id: ID of the tool to detach.
actor: User performing the action.
Raises:
NoResultFound: If the agent or tool is not found.
Returns:
PydanticAgentState: The updated agent state.
"""
with self.session_maker() as session:
# Verify the agent exists and user has permission to access it
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# Filter out the tool to be detached
remaining_tools = [tool for tool in agent.tools if tool.id != tool_id]
if len(remaining_tools) == len(agent.tools): # Tool ID was not in the relationship
logger.warning(f"Attempted to remove unattached tool id={tool_id} from agent id={agent_id} by actor={actor}")
# Update the tools relationship
agent.tools = remaining_tools
# Commit and refresh the agent
agent.update(session, actor=actor)
return agent.to_pydantic()