444 lines
18 KiB
Python
444 lines
18 KiB
Python
import datetime
|
|
from typing import List, Literal, Optional
|
|
|
|
from sqlalchemy import and_, asc, desc, func, literal, or_, select
|
|
|
|
from letta import system
|
|
from letta.constants import IN_CONTEXT_MEMORY_KEYWORD, STRUCTURED_OUTPUT_MODELS
|
|
from letta.helpers import ToolRulesSolver
|
|
from letta.helpers.datetime_helpers import get_local_time, get_local_time_fast
|
|
from letta.orm.agent import Agent as AgentModel
|
|
from letta.orm.agents_tags import AgentsTags
|
|
from letta.orm.errors import NoResultFound
|
|
from letta.orm.identity import Identity
|
|
from letta.prompts import gpt_system
|
|
from letta.schemas.agent import AgentState, AgentType
|
|
from letta.schemas.enums import MessageRole
|
|
from letta.schemas.letta_message_content import TextContent
|
|
from letta.schemas.memory import Memory
|
|
from letta.schemas.message import Message, MessageCreate
|
|
from letta.schemas.passage import Passage as PydanticPassage
|
|
from letta.schemas.tool_rule import ToolRule
|
|
from letta.schemas.user import User
|
|
from letta.system import get_initial_boot_messages, get_login_event
|
|
from letta.tracing import trace_method
|
|
|
|
|
|
# Static methods
|
|
@trace_method
|
|
def _process_relationship(
|
|
session, agent: AgentModel, relationship_name: str, model_class, item_ids: List[str], allow_partial=False, replace=True
|
|
):
|
|
"""
|
|
Generalized function to handle relationships like tools, sources, and blocks using item IDs.
|
|
|
|
Args:
|
|
session: The database session.
|
|
agent: The AgentModel instance.
|
|
relationship_name: The name of the relationship attribute (e.g., 'tools', 'sources').
|
|
model_class: The ORM class corresponding to the related items.
|
|
item_ids: List of IDs to set or update.
|
|
allow_partial: If True, allows missing items without raising errors.
|
|
replace: If True, replaces the entire relationship; otherwise, extends it.
|
|
|
|
Raises:
|
|
ValueError: If `allow_partial` is False and some IDs are missing.
|
|
"""
|
|
current_relationship = getattr(agent, relationship_name, [])
|
|
if not item_ids:
|
|
if replace:
|
|
setattr(agent, relationship_name, [])
|
|
return
|
|
|
|
# Retrieve models for the provided IDs
|
|
found_items = session.query(model_class).filter(model_class.id.in_(item_ids)).all()
|
|
|
|
# Validate all items are found if allow_partial is False
|
|
if not allow_partial and len(found_items) != len(item_ids):
|
|
missing = set(item_ids) - {item.id for item in found_items}
|
|
raise NoResultFound(f"Items not found in {relationship_name}: {missing}")
|
|
|
|
if replace:
|
|
# Replace the relationship
|
|
setattr(agent, relationship_name, found_items)
|
|
else:
|
|
# Extend the relationship (only add new items)
|
|
current_ids = {item.id for item in current_relationship}
|
|
new_items = [item for item in found_items if item.id not in current_ids]
|
|
current_relationship.extend(new_items)
|
|
|
|
|
|
def _process_tags(agent: AgentModel, tags: List[str], replace=True):
|
|
"""
|
|
Handles tags for an agent.
|
|
|
|
Args:
|
|
agent: The AgentModel instance.
|
|
tags: List of tags to set or update.
|
|
replace: If True, replaces all tags; otherwise, extends them.
|
|
"""
|
|
if not tags:
|
|
if replace:
|
|
agent.tags = []
|
|
return
|
|
|
|
# Ensure tags are unique and prepare for replacement/extension
|
|
new_tags = {AgentsTags(agent_id=agent.id, tag=tag) for tag in set(tags)}
|
|
if replace:
|
|
agent.tags = list(new_tags)
|
|
else:
|
|
existing_tags = {t.tag for t in agent.tags}
|
|
agent.tags.extend([tag for tag in new_tags if tag.tag not in existing_tags])
|
|
|
|
|
|
def derive_system_message(agent_type: AgentType, enable_sleeptime: Optional[bool] = None, system: Optional[str] = None):
|
|
if system is None:
|
|
# TODO: don't hardcode
|
|
if agent_type == AgentType.voice_convo_agent:
|
|
system = gpt_system.get_system_text("voice_chat")
|
|
elif agent_type == AgentType.voice_sleeptime_agent:
|
|
system = gpt_system.get_system_text("voice_sleeptime")
|
|
elif agent_type == AgentType.memgpt_agent and not enable_sleeptime:
|
|
system = gpt_system.get_system_text("memgpt_chat")
|
|
elif agent_type == AgentType.memgpt_agent and enable_sleeptime:
|
|
system = gpt_system.get_system_text("memgpt_sleeptime_chat")
|
|
elif agent_type == AgentType.sleeptime_agent:
|
|
system = gpt_system.get_system_text("sleeptime")
|
|
else:
|
|
raise ValueError(f"Invalid agent type: {agent_type}")
|
|
|
|
return system
|
|
|
|
|
|
# TODO: This code is kind of wonky and deserves a rewrite
|
|
def compile_memory_metadata_block(
|
|
memory_edit_timestamp: datetime.datetime,
|
|
previous_message_count: int = 0,
|
|
archival_memory_size: int = 0,
|
|
recent_passages: List[PydanticPassage] = None,
|
|
) -> str:
|
|
# Put the timestamp in the local timezone (mimicking get_local_time())
|
|
timestamp_str = memory_edit_timestamp.astimezone().strftime("%Y-%m-%d %I:%M:%S %p %Z%z").strip()
|
|
|
|
# Create a metadata block of info so the agent knows about the metadata of out-of-context memories
|
|
memory_metadata_block = "\n".join(
|
|
[
|
|
f"### Current Time: {get_local_time_fast()}" f"### Memory [last modified: {timestamp_str}]",
|
|
f"{previous_message_count} previous messages between you and the user are stored in recall memory (use functions to access them)",
|
|
f"{archival_memory_size} total memories you created are stored in archival memory (use functions to access them)",
|
|
(
|
|
f"Most recent archival passages {len(recent_passages)} recent passages: {[passage.text for passage in recent_passages]}"
|
|
if recent_passages is not None
|
|
else ""
|
|
),
|
|
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
|
|
]
|
|
)
|
|
return memory_metadata_block
|
|
|
|
|
|
class PreserveMapping(dict):
|
|
"""Used to preserve (do not modify) undefined variables in the system prompt"""
|
|
|
|
def __missing__(self, key):
|
|
return "{" + key + "}"
|
|
|
|
|
|
def safe_format(template: str, variables: dict) -> str:
|
|
"""
|
|
Safely formats a template string, preserving empty {} and {unknown_vars}
|
|
while substituting known variables.
|
|
|
|
If we simply use {} in format_map, it'll be treated as a positional field
|
|
"""
|
|
# First escape any empty {} by doubling them
|
|
escaped = template.replace("{}", "{{}}")
|
|
|
|
# Now use format_map with our custom mapping
|
|
return escaped.format_map(PreserveMapping(variables))
|
|
|
|
|
|
def compile_system_message(
|
|
system_prompt: str,
|
|
in_context_memory: Memory,
|
|
in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory?
|
|
user_defined_variables: Optional[dict] = None,
|
|
append_icm_if_missing: bool = True,
|
|
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
|
previous_message_count: int = 0,
|
|
archival_memory_size: int = 0,
|
|
recent_passages: Optional[List[PydanticPassage]] = None,
|
|
) -> str:
|
|
"""Prepare the final/full system message that will be fed into the LLM API
|
|
|
|
The base system message may be templated, in which case we need to render the variables.
|
|
|
|
The following are reserved variables:
|
|
- CORE_MEMORY: the in-context memory of the LLM
|
|
"""
|
|
|
|
if user_defined_variables is not None:
|
|
# TODO eventually support the user defining their own variables to inject
|
|
raise NotImplementedError
|
|
else:
|
|
variables = {}
|
|
|
|
# Add the protected memory variable
|
|
if IN_CONTEXT_MEMORY_KEYWORD in variables:
|
|
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
|
|
else:
|
|
# TODO should this all put into the memory.__repr__ function?
|
|
memory_metadata_string = compile_memory_metadata_block(
|
|
memory_edit_timestamp=in_context_memory_last_edit,
|
|
previous_message_count=previous_message_count,
|
|
archival_memory_size=archival_memory_size,
|
|
recent_passages=recent_passages,
|
|
)
|
|
full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile()
|
|
|
|
# Add to the variables list to inject
|
|
variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string
|
|
|
|
if template_format == "f-string":
|
|
memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}"
|
|
# Catch the special case where the system prompt is unformatted
|
|
if append_icm_if_missing:
|
|
if memory_variable_string not in system_prompt:
|
|
# In this case, append it to the end to make sure memory is still injected
|
|
# warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead")
|
|
system_prompt += "\n" + memory_variable_string
|
|
|
|
# render the variables using the built-in templater
|
|
try:
|
|
if user_defined_variables:
|
|
formatted_prompt = safe_format(system_prompt, variables)
|
|
else:
|
|
formatted_prompt = system_prompt.replace(memory_variable_string, full_memory_string)
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}")
|
|
|
|
else:
|
|
# TODO support for mustache and jinja2
|
|
raise NotImplementedError(template_format)
|
|
|
|
return formatted_prompt
|
|
|
|
|
|
def initialize_message_sequence(
|
|
agent_state: AgentState,
|
|
memory_edit_timestamp: Optional[datetime.datetime] = None,
|
|
include_initial_boot_message: bool = True,
|
|
previous_message_count: int = 0,
|
|
archival_memory_size: int = 0,
|
|
) -> List[dict]:
|
|
if memory_edit_timestamp is None:
|
|
memory_edit_timestamp = get_local_time()
|
|
|
|
full_system_message = compile_system_message(
|
|
system_prompt=agent_state.system,
|
|
in_context_memory=agent_state.memory,
|
|
in_context_memory_last_edit=memory_edit_timestamp,
|
|
user_defined_variables=None,
|
|
append_icm_if_missing=True,
|
|
previous_message_count=previous_message_count,
|
|
archival_memory_size=archival_memory_size,
|
|
)
|
|
first_user_message = get_login_event() # event letting Letta know the user just logged in
|
|
|
|
if include_initial_boot_message:
|
|
if agent_state.agent_type == AgentType.sleeptime_agent:
|
|
initial_boot_messages = []
|
|
elif agent_state.llm_config.model is not None and "gpt-3.5" in agent_state.llm_config.model:
|
|
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35")
|
|
else:
|
|
initial_boot_messages = get_initial_boot_messages("startup_with_send_message")
|
|
messages = (
|
|
[
|
|
{"role": "system", "content": full_system_message},
|
|
]
|
|
+ initial_boot_messages
|
|
+ [
|
|
{"role": "user", "content": first_user_message},
|
|
]
|
|
)
|
|
|
|
else:
|
|
messages = [
|
|
{"role": "system", "content": full_system_message},
|
|
{"role": "user", "content": first_user_message},
|
|
]
|
|
|
|
return messages
|
|
|
|
|
|
def package_initial_message_sequence(
|
|
agent_id: str, initial_message_sequence: List[MessageCreate], model: str, actor: User
|
|
) -> List[Message]:
|
|
# create the agent object
|
|
init_messages = []
|
|
for message_create in initial_message_sequence:
|
|
|
|
if message_create.role == MessageRole.user:
|
|
packed_message = system.package_user_message(
|
|
user_message=message_create.content,
|
|
)
|
|
elif message_create.role == MessageRole.system:
|
|
packed_message = system.package_system_message(
|
|
system_message=message_create.content,
|
|
)
|
|
else:
|
|
raise ValueError(f"Invalid message role: {message_create.role}")
|
|
|
|
init_messages.append(
|
|
Message(
|
|
role=message_create.role,
|
|
content=[TextContent(text=packed_message)],
|
|
name=message_create.name,
|
|
organization_id=actor.organization_id,
|
|
agent_id=agent_id,
|
|
model=model,
|
|
)
|
|
)
|
|
return init_messages
|
|
|
|
|
|
def check_supports_structured_output(model: str, tool_rules: List[ToolRule]) -> bool:
|
|
if model not in STRUCTURED_OUTPUT_MODELS:
|
|
if len(ToolRulesSolver(tool_rules=tool_rules).init_tool_rules) > 1:
|
|
raise ValueError("Multiple initial tools are not supported for non-structured models. Please use only one initial tool rule.")
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
|
|
def _cursor_filter(created_at_col, id_col, ref_created_at, ref_id, forward: bool):
|
|
"""
|
|
Returns a SQLAlchemy filter expression for cursor-based pagination.
|
|
|
|
If `forward` is True, returns records after the reference.
|
|
If `forward` is False, returns records before the reference.
|
|
"""
|
|
if forward:
|
|
return or_(
|
|
created_at_col > ref_created_at,
|
|
and_(created_at_col == ref_created_at, id_col > ref_id),
|
|
)
|
|
else:
|
|
return or_(
|
|
created_at_col < ref_created_at,
|
|
and_(created_at_col == ref_created_at, id_col < ref_id),
|
|
)
|
|
|
|
|
|
def _apply_pagination(query, before: Optional[str], after: Optional[str], session, ascending: bool = True) -> any:
|
|
if after:
|
|
result = session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after)).first()
|
|
if result:
|
|
after_created_at, after_id = result
|
|
query = query.where(_cursor_filter(AgentModel.created_at, AgentModel.id, after_created_at, after_id, forward=ascending))
|
|
|
|
if before:
|
|
result = session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before)).first()
|
|
if result:
|
|
before_created_at, before_id = result
|
|
query = query.where(_cursor_filter(AgentModel.created_at, AgentModel.id, before_created_at, before_id, forward=not ascending))
|
|
|
|
# Apply ordering
|
|
order_fn = asc if ascending else desc
|
|
query = query.order_by(order_fn(AgentModel.created_at), order_fn(AgentModel.id))
|
|
return query
|
|
|
|
|
|
def _apply_tag_filter(query, tags: Optional[List[str]], match_all_tags: bool):
|
|
"""
|
|
Apply tag-based filtering to the agent query.
|
|
|
|
This helper function creates a subquery that groups agent IDs by their tags.
|
|
If `match_all_tags` is True, it filters agents that have all of the specified tags.
|
|
Otherwise, it filters agents that have any of the tags.
|
|
|
|
Args:
|
|
query: The SQLAlchemy query object to be modified.
|
|
tags (Optional[List[str]]): A list of tags to filter agents.
|
|
match_all_tags (bool): If True, only return agents that match all provided tags.
|
|
|
|
Returns:
|
|
The modified query with tag filters applied.
|
|
"""
|
|
if tags:
|
|
# Build a subquery to select agent IDs that have the specified tags.
|
|
subquery = select(AgentsTags.agent_id).where(AgentsTags.tag.in_(tags)).group_by(AgentsTags.agent_id)
|
|
# If all tags must match, add a HAVING clause to ensure the count of tags equals the number provided.
|
|
if match_all_tags:
|
|
subquery = subquery.having(func.count(AgentsTags.tag) == literal(len(tags)))
|
|
# Filter the main query to include only agents present in the subquery.
|
|
query = query.where(AgentModel.id.in_(subquery))
|
|
return query
|
|
|
|
|
|
def _apply_identity_filters(query, identity_id: Optional[str], identifier_keys: Optional[List[str]]):
|
|
"""
|
|
Apply identity-related filters to the agent query.
|
|
|
|
This helper function joins the identities relationship and filters the agents based on
|
|
a specific identity ID and/or a list of identifier keys.
|
|
|
|
Args:
|
|
query: The SQLAlchemy query object to be modified.
|
|
identity_id (Optional[str]): The identity ID to filter by.
|
|
identifier_keys (Optional[List[str]]): A list of identifier keys to filter agents.
|
|
|
|
Returns:
|
|
The modified query with identity filters applied.
|
|
"""
|
|
# Join the identities relationship and filter by a specific identity ID.
|
|
if identity_id:
|
|
query = query.join(AgentModel.identities).where(Identity.id == identity_id)
|
|
# Join the identities relationship and filter by a set of identifier keys.
|
|
if identifier_keys:
|
|
query = query.join(AgentModel.identities).where(Identity.identifier_key.in_(identifier_keys))
|
|
return query
|
|
|
|
|
|
def _apply_filters(
|
|
query,
|
|
name: Optional[str],
|
|
query_text: Optional[str],
|
|
project_id: Optional[str],
|
|
template_id: Optional[str],
|
|
base_template_id: Optional[str],
|
|
):
|
|
"""
|
|
Apply basic filtering criteria to the agent query.
|
|
|
|
This helper function adds WHERE clauses based on provided parameters such as
|
|
exact name, partial name match (using ILIKE), project ID, template ID, and base template ID.
|
|
|
|
Args:
|
|
query: The SQLAlchemy query object to be modified.
|
|
name (Optional[str]): Exact name to filter by.
|
|
query_text (Optional[str]): Partial text to search in the agent's name (case-insensitive).
|
|
project_id (Optional[str]): Filter for agents belonging to a specific project.
|
|
template_id (Optional[str]): Filter for agents using a specific template.
|
|
base_template_id (Optional[str]): Filter for agents using a specific base template.
|
|
|
|
Returns:
|
|
The modified query with the applied filters.
|
|
"""
|
|
# Filter by exact agent name if provided.
|
|
if name:
|
|
query = query.where(AgentModel.name == name)
|
|
# Apply a case-insensitive partial match for the agent's name.
|
|
if query_text:
|
|
query = query.where(AgentModel.name.ilike(f"%{query_text}%"))
|
|
# Filter agents by project ID.
|
|
if project_id:
|
|
query = query.where(AgentModel.project_id == project_id)
|
|
# Filter agents by template ID.
|
|
if template_id:
|
|
query = query.where(AgentModel.template_id == template_id)
|
|
# Filter agents by base template ID.
|
|
if base_template_id:
|
|
query = query.where(AgentModel.base_template_id == base_template_id)
|
|
return query
|