1001 lines
46 KiB
Python
1001 lines
46 KiB
Python
import json
|
||
import uuid
|
||
from datetime import datetime
|
||
from typing import List, Optional, Sequence
|
||
|
||
from sqlalchemy import delete, exists, func, select, text
|
||
|
||
from letta.log import get_logger
|
||
from letta.orm.agent import Agent as AgentModel
|
||
from letta.orm.errors import NoResultFound
|
||
from letta.orm.message import Message as MessageModel
|
||
from letta.otel.tracing import trace_method
|
||
from letta.schemas.embedding_config import EmbeddingConfig
|
||
from letta.schemas.enums import MessageRole
|
||
from letta.schemas.letta_message import LettaMessageUpdateUnion
|
||
from letta.schemas.letta_message_content import ImageSourceType, LettaImage, MessageContentType, TextContent
|
||
from letta.schemas.message import Message as PydanticMessage, MessageUpdate
|
||
from letta.schemas.user import User as PydanticUser
|
||
from letta.server.db import db_registry
|
||
from letta.services.file_manager import FileManager
|
||
from letta.services.helpers.agent_manager_helper import validate_agent_exists_async
|
||
from letta.settings import DatabaseChoice, settings
|
||
from letta.utils import enforce_types
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class MessageManager:
|
||
"""Manager class to handle business logic related to Messages."""
|
||
|
||
def __init__(self):
|
||
"""Initialize the MessageManager."""
|
||
self.file_manager = FileManager()
|
||
|
||
def _extract_message_text(self, message: PydanticMessage) -> str:
|
||
"""Extract text content from a message's complex content structure.
|
||
|
||
Only extracts text from searchable message roles (assistant, user, tool).
|
||
|
||
Args:
|
||
message: The message to extract text from
|
||
|
||
Returns:
|
||
Concatenated text content from the message, or empty string for non-searchable roles
|
||
"""
|
||
# only extract text from searchable roles
|
||
if message.role not in [MessageRole.assistant, MessageRole.user, MessageRole.tool]:
|
||
return ""
|
||
|
||
if not message.content:
|
||
return ""
|
||
|
||
# handle string content (legacy)
|
||
if isinstance(message.content, str):
|
||
return message.content
|
||
|
||
# handle list of content items using the to_text() method
|
||
text_parts = []
|
||
for content_item in message.content:
|
||
text = content_item.to_text()
|
||
if text: # only add non-None text
|
||
text_parts.append(text)
|
||
|
||
return " ".join(text_parts)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def get_message_by_id(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]:
|
||
"""Fetch a message by ID."""
|
||
with db_registry.session() as session:
|
||
try:
|
||
message = MessageModel.read(db_session=session, identifier=message_id, actor=actor)
|
||
return message.to_pydantic()
|
||
except NoResultFound:
|
||
return None
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def get_message_by_id_async(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]:
|
||
"""Fetch a message by ID."""
|
||
async with db_registry.async_session() as session:
|
||
try:
|
||
message = await MessageModel.read_async(db_session=session, identifier=message_id, actor=actor)
|
||
return message.to_pydantic()
|
||
except NoResultFound:
|
||
return None
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def get_messages_by_ids(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]:
|
||
"""Fetch messages by ID and return them in the requested order."""
|
||
with db_registry.session() as session:
|
||
results = MessageModel.read_multiple(db_session=session, identifiers=message_ids, actor=actor)
|
||
return self._get_messages_by_id_postprocess(results, message_ids)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def get_messages_by_ids_async(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]:
|
||
"""Fetch messages by ID and return them in the requested order. Async version of above function."""
|
||
async with db_registry.async_session() as session:
|
||
results = await MessageModel.read_multiple_async(db_session=session, identifiers=message_ids, actor=actor)
|
||
return self._get_messages_by_id_postprocess(results, message_ids)
|
||
|
||
def _get_messages_by_id_postprocess(
|
||
self,
|
||
results: List[MessageModel],
|
||
message_ids: List[str],
|
||
) -> List[PydanticMessage]:
|
||
if len(results) != len(message_ids):
|
||
logger.warning(
|
||
f"Expected {len(message_ids)} messages, but found {len(results)}. Missing ids={set(message_ids) - set([r.id for r in results])}"
|
||
)
|
||
# Sort results directly based on message_ids
|
||
result_dict = {msg.id: msg.to_pydantic() for msg in results}
|
||
return list(filter(lambda x: x is not None, [result_dict.get(msg_id, None) for msg_id in message_ids]))
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage:
|
||
"""Create a new message."""
|
||
with db_registry.session() as session:
|
||
# Set the organization id of the Pydantic message
|
||
msg_data = pydantic_msg.model_dump(to_orm=True)
|
||
msg_data["organization_id"] = actor.organization_id
|
||
msg = MessageModel(**msg_data)
|
||
msg.create(session, actor=actor) # Persist to database
|
||
return msg.to_pydantic()
|
||
|
||
def _create_many_preprocess(self, pydantic_msgs: List[PydanticMessage], actor: PydanticUser) -> List[MessageModel]:
|
||
# Create ORM model instances for all messages
|
||
orm_messages = []
|
||
for pydantic_msg in pydantic_msgs:
|
||
# Set the organization id of the Pydantic message
|
||
msg_data = pydantic_msg.model_dump(to_orm=True)
|
||
msg_data["organization_id"] = actor.organization_id
|
||
orm_messages.append(MessageModel(**msg_data))
|
||
return orm_messages
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def create_many_messages(self, pydantic_msgs: List[PydanticMessage], actor: PydanticUser) -> List[PydanticMessage]:
|
||
"""
|
||
Create multiple messages in a single database transaction.
|
||
Args:
|
||
pydantic_msgs: List of Pydantic message models to create
|
||
actor: User performing the action
|
||
|
||
Returns:
|
||
List of created Pydantic message models
|
||
"""
|
||
if not pydantic_msgs:
|
||
return []
|
||
|
||
orm_messages = self._create_many_preprocess(pydantic_msgs, actor)
|
||
with db_registry.session() as session:
|
||
created_messages = MessageModel.batch_create(orm_messages, session, actor=actor)
|
||
return [msg.to_pydantic() for msg in created_messages]
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def create_many_messages_async(
|
||
self,
|
||
pydantic_msgs: List[PydanticMessage],
|
||
actor: PydanticUser,
|
||
embedding_config: Optional[EmbeddingConfig] = None,
|
||
strict_mode: bool = False,
|
||
) -> List[PydanticMessage]:
|
||
"""
|
||
Create multiple messages in a single database transaction asynchronously.
|
||
|
||
Args:
|
||
pydantic_msgs: List of Pydantic message models to create
|
||
actor: User performing the action
|
||
embedding_config: Optional embedding configuration to enable message embedding in Turbopuffer
|
||
|
||
Returns:
|
||
List of created Pydantic message models
|
||
"""
|
||
if not pydantic_msgs:
|
||
return []
|
||
|
||
for message in pydantic_msgs:
|
||
if isinstance(message.content, list):
|
||
for content in message.content:
|
||
if content.type == MessageContentType.image and content.source.type == ImageSourceType.base64:
|
||
# TODO: actually persist image files in db
|
||
# file = await self.file_manager.create_file( # TODO: use batch create to prevent multiple db round trips
|
||
# db_session=session,
|
||
# image_create=FileMetadata(
|
||
# user_id=actor.id, # TODO: add field
|
||
# source_id= '' # TODO: make optional
|
||
# organization_id=actor.organization_id,
|
||
# file_type=content.source.media_type,
|
||
# processing_status=FileProcessingStatus.COMPLETED,
|
||
# content= '' # TODO: should content be added here or in top level text field?
|
||
# ),
|
||
# actor=actor,
|
||
# text=content.source.data,
|
||
# )
|
||
file_id_placeholder = "file-" + str(uuid.uuid4())
|
||
content.source = LettaImage(
|
||
file_id=file_id_placeholder,
|
||
data=content.source.data,
|
||
media_type=content.source.media_type,
|
||
detail=content.source.detail,
|
||
)
|
||
orm_messages = self._create_many_preprocess(pydantic_msgs, actor)
|
||
async with db_registry.async_session() as session:
|
||
created_messages = await MessageModel.batch_create_async(orm_messages, session, actor=actor, no_commit=True, no_refresh=True)
|
||
result = [msg.to_pydantic() for msg in created_messages]
|
||
await session.commit()
|
||
|
||
# embed messages in turbopuffer if enabled and embedding_config provided
|
||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
|
||
|
||
if should_use_tpuf_for_messages() and embedding_config and result:
|
||
try:
|
||
# extract agent_id from the first message (all should have same agent_id)
|
||
agent_id = result[0].agent_id
|
||
if agent_id:
|
||
# extract text content from each message
|
||
message_texts = []
|
||
message_ids = []
|
||
roles = []
|
||
created_ats = []
|
||
|
||
for msg in result:
|
||
text = self._extract_message_text(msg)
|
||
if text: # only embed messages with text content (role filtering is handled in _extract_message_text)
|
||
message_texts.append(text)
|
||
message_ids.append(msg.id)
|
||
roles.append(msg.role)
|
||
created_ats.append(msg.created_at)
|
||
|
||
if message_texts:
|
||
# generate embeddings using provided config
|
||
from letta.llm_api.llm_client import LLMClient
|
||
|
||
embedding_client = LLMClient.create(
|
||
provider_type=embedding_config.embedding_endpoint_type,
|
||
actor=actor,
|
||
)
|
||
embeddings = await embedding_client.request_embeddings(message_texts, embedding_config)
|
||
|
||
# insert to turbopuffer
|
||
tpuf_client = TurbopufferClient()
|
||
await tpuf_client.insert_messages(
|
||
agent_id=agent_id,
|
||
message_texts=message_texts,
|
||
embeddings=embeddings,
|
||
message_ids=message_ids,
|
||
organization_id=actor.organization_id,
|
||
roles=roles,
|
||
created_ats=created_ats,
|
||
)
|
||
logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to embed messages in Turbopuffer: {e}")
|
||
if strict_mode:
|
||
raise # Re-raise the exception in strict mode
|
||
|
||
return result
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def update_message_by_letta_message(
|
||
self, message_id: str, letta_message_update: LettaMessageUpdateUnion, actor: PydanticUser
|
||
) -> PydanticMessage:
|
||
"""
|
||
Updated the underlying messages table giving an update specified to the user-facing LettaMessage
|
||
"""
|
||
message = self.get_message_by_id(message_id=message_id, actor=actor)
|
||
if letta_message_update.message_type == "assistant_message":
|
||
# modify the tool call for send_message
|
||
# TODO: fix this if we add parallel tool calls
|
||
# TODO: note this only works if the AssistantMessage is generated by the standard send_message
|
||
assert message.tool_calls[0].function.name == "send_message", (
|
||
f"Expected the first tool call to be send_message, but got {message.tool_calls[0].function.name}"
|
||
)
|
||
original_args = json.loads(message.tool_calls[0].function.arguments)
|
||
original_args["message"] = letta_message_update.content # override the assistant message
|
||
update_tool_call = message.tool_calls[0].__deepcopy__()
|
||
update_tool_call.function.arguments = json.dumps(original_args)
|
||
|
||
update_message = MessageUpdate(tool_calls=[update_tool_call])
|
||
elif letta_message_update.message_type == "reasoning_message":
|
||
update_message = MessageUpdate(content=letta_message_update.reasoning)
|
||
elif letta_message_update.message_type == "user_message" or letta_message_update.message_type == "system_message":
|
||
update_message = MessageUpdate(content=letta_message_update.content)
|
||
else:
|
||
raise ValueError(f"Unsupported message type for modification: {letta_message_update.message_type}")
|
||
|
||
message = self.update_message_by_id(message_id=message_id, message_update=update_message, actor=actor)
|
||
|
||
# convert back to LettaMessage
|
||
for letta_msg in message.to_letta_messages(use_assistant_message=True):
|
||
if letta_msg.message_type == letta_message_update.message_type:
|
||
return letta_msg
|
||
|
||
# raise error if message type got modified
|
||
raise ValueError(f"Message type got modified: {letta_message_update.message_type}")
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def update_message_by_letta_message(
|
||
self, message_id: str, letta_message_update: LettaMessageUpdateUnion, actor: PydanticUser
|
||
) -> PydanticMessage:
|
||
"""
|
||
Updated the underlying messages table giving an update specified to the user-facing LettaMessage
|
||
"""
|
||
message = self.get_message_by_id(message_id=message_id, actor=actor)
|
||
if letta_message_update.message_type == "assistant_message":
|
||
# modify the tool call for send_message
|
||
# TODO: fix this if we add parallel tool calls
|
||
# TODO: note this only works if the AssistantMessage is generated by the standard send_message
|
||
assert message.tool_calls[0].function.name == "send_message", (
|
||
f"Expected the first tool call to be send_message, but got {message.tool_calls[0].function.name}"
|
||
)
|
||
original_args = json.loads(message.tool_calls[0].function.arguments)
|
||
original_args["message"] = letta_message_update.content # override the assistant message
|
||
update_tool_call = message.tool_calls[0].__deepcopy__()
|
||
update_tool_call.function.arguments = json.dumps(original_args)
|
||
|
||
update_message = MessageUpdate(tool_calls=[update_tool_call])
|
||
elif letta_message_update.message_type == "reasoning_message":
|
||
update_message = MessageUpdate(content=letta_message_update.reasoning)
|
||
elif letta_message_update.message_type == "user_message" or letta_message_update.message_type == "system_message":
|
||
update_message = MessageUpdate(content=letta_message_update.content)
|
||
else:
|
||
raise ValueError(f"Unsupported message type for modification: {letta_message_update.message_type}")
|
||
|
||
message = self.update_message_by_id(message_id=message_id, message_update=update_message, actor=actor)
|
||
|
||
# convert back to LettaMessage
|
||
for letta_msg in message.to_letta_messages(use_assistant_message=True):
|
||
if letta_msg.message_type == letta_message_update.message_type:
|
||
return letta_msg
|
||
|
||
# raise error if message type got modified
|
||
raise ValueError(f"Message type got modified: {letta_message_update.message_type}")
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def update_message_by_id(self, message_id: str, message_update: MessageUpdate, actor: PydanticUser) -> PydanticMessage:
|
||
"""
|
||
Updates an existing record in the database with values from the provided record object.
|
||
"""
|
||
with db_registry.session() as session:
|
||
# Fetch existing message from database
|
||
message = MessageModel.read(
|
||
db_session=session,
|
||
identifier=message_id,
|
||
actor=actor,
|
||
)
|
||
|
||
message = self._update_message_by_id_impl(message_id, message_update, actor, message)
|
||
message.update(db_session=session, actor=actor)
|
||
return message.to_pydantic()
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def update_message_by_id_async(
|
||
self,
|
||
message_id: str,
|
||
message_update: MessageUpdate,
|
||
actor: PydanticUser,
|
||
embedding_config: Optional[EmbeddingConfig] = None,
|
||
strict_mode: bool = False,
|
||
) -> PydanticMessage:
|
||
"""
|
||
Updates an existing record in the database with values from the provided record object.
|
||
Async version of the function above.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# Fetch existing message from database
|
||
message = await MessageModel.read_async(
|
||
db_session=session,
|
||
identifier=message_id,
|
||
actor=actor,
|
||
)
|
||
|
||
message = self._update_message_by_id_impl(message_id, message_update, actor, message)
|
||
await message.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||
pydantic_message = message.to_pydantic()
|
||
await session.commit()
|
||
|
||
# update message in turbopuffer if enabled (delete and re-insert)
|
||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
|
||
|
||
if should_use_tpuf_for_messages() and embedding_config and pydantic_message.agent_id:
|
||
try:
|
||
# extract text content from updated message
|
||
text = self._extract_message_text(pydantic_message)
|
||
|
||
# only update in turbopuffer if there's text content (role filtering is handled in _extract_message_text)
|
||
if text:
|
||
tpuf_client = TurbopufferClient()
|
||
|
||
# delete old message from turbopuffer
|
||
await tpuf_client.delete_messages(agent_id=pydantic_message.agent_id, message_ids=[message_id])
|
||
|
||
# generate new embedding
|
||
from letta.llm_api.llm_client import LLMClient
|
||
|
||
embedding_client = LLMClient.create(
|
||
provider_type=embedding_config.embedding_endpoint_type,
|
||
actor=actor,
|
||
)
|
||
embeddings = await embedding_client.request_embeddings([text], embedding_config)
|
||
|
||
# re-insert with updated content
|
||
await tpuf_client.insert_messages(
|
||
agent_id=pydantic_message.agent_id,
|
||
message_texts=[text],
|
||
embeddings=embeddings,
|
||
message_ids=[message_id],
|
||
organization_id=actor.organization_id,
|
||
roles=[pydantic_message.role],
|
||
created_ats=[pydantic_message.created_at],
|
||
)
|
||
logger.info(f"Successfully updated message {message_id} in Turbopuffer")
|
||
except Exception as e:
|
||
logger.error(f"Failed to update message in Turbopuffer: {e}")
|
||
if strict_mode:
|
||
raise # Re-raise the exception in strict mode
|
||
|
||
return pydantic_message
|
||
|
||
def _update_message_by_id_impl(
|
||
self, message_id: str, message_update: MessageUpdate, actor: PydanticUser, message: MessageModel
|
||
) -> MessageModel:
|
||
"""
|
||
Modifies the existing message object to update the database in the sync/async functions.
|
||
"""
|
||
# Some safety checks specific to messages
|
||
if message_update.tool_calls and message.role != MessageRole.assistant:
|
||
raise ValueError(
|
||
f"Tool calls {message_update.tool_calls} can only be added to assistant messages. Message {message_id} has role {message.role}."
|
||
)
|
||
if message_update.tool_call_id and message.role != MessageRole.tool:
|
||
raise ValueError(
|
||
f"Tool call IDs {message_update.tool_call_id} can only be added to tool messages. Message {message_id} has role {message.role}."
|
||
)
|
||
|
||
# get update dictionary
|
||
update_data = message_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||
# Remove redundant update fields
|
||
update_data = {key: value for key, value in update_data.items() if getattr(message, key) != value}
|
||
|
||
for key, value in update_data.items():
|
||
setattr(message, key, value)
|
||
return message
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def delete_message_by_id(self, message_id: str, actor: PydanticUser) -> bool:
|
||
"""Delete a message."""
|
||
with db_registry.session() as session:
|
||
try:
|
||
msg = MessageModel.read(
|
||
db_session=session,
|
||
identifier=message_id,
|
||
actor=actor,
|
||
)
|
||
msg.hard_delete(session, actor=actor)
|
||
# Note: Turbopuffer deletion requires async, use delete_message_by_id_async for full deletion
|
||
except NoResultFound:
|
||
raise ValueError(f"Message with id {message_id} not found.")
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def delete_message_by_id_async(self, message_id: str, actor: PydanticUser, strict_mode: bool = False) -> bool:
|
||
"""Delete a message (async version with turbopuffer support)."""
|
||
async with db_registry.async_session() as session:
|
||
try:
|
||
msg = await MessageModel.read_async(
|
||
db_session=session,
|
||
identifier=message_id,
|
||
actor=actor,
|
||
)
|
||
agent_id = msg.agent_id
|
||
await msg.hard_delete_async(session, actor=actor)
|
||
|
||
# delete from turbopuffer if enabled
|
||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
|
||
|
||
if should_use_tpuf_for_messages() and agent_id:
|
||
try:
|
||
tpuf_client = TurbopufferClient()
|
||
await tpuf_client.delete_messages(agent_id=agent_id, message_ids=[message_id])
|
||
logger.info(f"Successfully deleted message {message_id} from Turbopuffer")
|
||
except Exception as e:
|
||
logger.error(f"Failed to delete message from Turbopuffer: {e}")
|
||
if strict_mode:
|
||
raise # Re-raise the exception in strict mode
|
||
|
||
return True
|
||
|
||
except NoResultFound:
|
||
raise ValueError(f"Message with id {message_id} not found.")
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def size(
|
||
self,
|
||
actor: PydanticUser,
|
||
role: Optional[MessageRole] = None,
|
||
agent_id: Optional[str] = None,
|
||
) -> int:
|
||
"""Get the total count of messages with optional filters.
|
||
|
||
Args:
|
||
actor: The user requesting the count
|
||
role: The role of the message
|
||
"""
|
||
with db_registry.session() as session:
|
||
return MessageModel.size(db_session=session, actor=actor, role=role, agent_id=agent_id)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def size_async(
|
||
self,
|
||
actor: PydanticUser,
|
||
role: Optional[MessageRole] = None,
|
||
agent_id: Optional[str] = None,
|
||
) -> int:
|
||
"""Get the total count of messages with optional filters.
|
||
Args:
|
||
actor: The user requesting the count
|
||
role: The role of the message
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
return await MessageModel.size_async(db_session=session, actor=actor, role=role, agent_id=agent_id)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def list_user_messages_for_agent(
|
||
self,
|
||
agent_id: str,
|
||
actor: PydanticUser,
|
||
after: Optional[str] = None,
|
||
before: Optional[str] = None,
|
||
query_text: Optional[str] = None,
|
||
limit: Optional[int] = 50,
|
||
ascending: bool = True,
|
||
) -> List[PydanticMessage]:
|
||
return self.list_messages_for_agent(
|
||
agent_id=agent_id,
|
||
actor=actor,
|
||
after=after,
|
||
before=before,
|
||
query_text=query_text,
|
||
roles=[MessageRole.user],
|
||
limit=limit,
|
||
ascending=ascending,
|
||
)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def list_user_messages_for_agent_async(
|
||
self,
|
||
agent_id: str,
|
||
actor: PydanticUser,
|
||
after: Optional[str] = None,
|
||
before: Optional[str] = None,
|
||
query_text: Optional[str] = None,
|
||
limit: Optional[int] = 50,
|
||
ascending: bool = True,
|
||
) -> List[PydanticMessage]:
|
||
return await self.list_messages_for_agent_async(
|
||
agent_id=agent_id,
|
||
actor=actor,
|
||
after=after,
|
||
before=before,
|
||
query_text=query_text,
|
||
roles=[MessageRole.user],
|
||
limit=limit,
|
||
ascending=ascending,
|
||
)
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
def list_messages_for_agent(
|
||
self,
|
||
agent_id: str,
|
||
actor: PydanticUser,
|
||
after: Optional[str] = None,
|
||
before: Optional[str] = None,
|
||
query_text: Optional[str] = None,
|
||
roles: Optional[Sequence[MessageRole]] = None,
|
||
limit: Optional[int] = 50,
|
||
ascending: bool = True,
|
||
group_id: Optional[str] = None,
|
||
) -> List[PydanticMessage]:
|
||
"""
|
||
Most performant query to list messages for an agent by directly querying the Message table.
|
||
|
||
This function filters by the agent_id (leveraging the index on messages.agent_id)
|
||
and applies pagination using sequence_id as the cursor.
|
||
If query_text is provided, it will filter messages whose text content partially matches the query.
|
||
If role is provided, it will filter messages by the specified role.
|
||
|
||
Args:
|
||
agent_id: The ID of the agent whose messages are queried.
|
||
actor: The user performing the action (used for permission checks).
|
||
after: A message ID; if provided, only messages *after* this message (by sequence_id) are returned.
|
||
before: A message ID; if provided, only messages *before* this message (by sequence_id) are returned.
|
||
query_text: Optional string to partially match the message text content.
|
||
roles: Optional MessageRole to filter messages by role.
|
||
limit: Maximum number of messages to return.
|
||
ascending: If True, sort by sequence_id ascending; if False, sort descending.
|
||
group_id: Optional group ID to filter messages by group_id.
|
||
|
||
Returns:
|
||
List[PydanticMessage]: A list of messages (converted via .to_pydantic()).
|
||
|
||
Raises:
|
||
NoResultFound: If the provided after/before message IDs do not exist.
|
||
"""
|
||
|
||
with db_registry.session() as session:
|
||
# Permission check: raise if the agent doesn't exist or actor is not allowed.
|
||
AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
|
||
# Build a query that directly filters the Message table by agent_id.
|
||
query = session.query(MessageModel).filter(MessageModel.agent_id == agent_id)
|
||
|
||
# If group_id is provided, filter messages by group_id.
|
||
if group_id:
|
||
query = query.filter(MessageModel.group_id == group_id)
|
||
|
||
# If query_text is provided, filter messages using database-specific JSON search.
|
||
if query_text:
|
||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||
# PostgreSQL: Use json_array_elements and ILIKE
|
||
content_element = func.json_array_elements(MessageModel.content).alias("content_element")
|
||
query = query.filter(
|
||
exists(
|
||
select(1)
|
||
.select_from(content_element)
|
||
.where(text("content_element->>'type' = 'text' AND content_element->>'text' ILIKE :query_text"))
|
||
.params(query_text=f"%{query_text}%")
|
||
)
|
||
)
|
||
else:
|
||
# SQLite: Use JSON_EXTRACT with individual array indices for case-insensitive search
|
||
# Since SQLite doesn't support $[*] syntax, we'll use a different approach
|
||
query = query.filter(text("JSON_EXTRACT(content, '$') LIKE :query_text")).params(query_text=f"%{query_text}%")
|
||
|
||
# If role(s) are provided, filter messages by those roles.
|
||
if roles:
|
||
role_values = [r.value for r in roles]
|
||
query = query.filter(MessageModel.role.in_(role_values))
|
||
|
||
# Apply 'after' pagination if specified.
|
||
if after:
|
||
after_ref = session.query(MessageModel.sequence_id).filter(MessageModel.id == after).one_or_none()
|
||
if not after_ref:
|
||
raise NoResultFound(f"No message found with id '{after}' for agent '{agent_id}'.")
|
||
# Filter out any messages with a sequence_id <= after_ref.sequence_id
|
||
query = query.filter(MessageModel.sequence_id > after_ref.sequence_id)
|
||
|
||
# Apply 'before' pagination if specified.
|
||
if before:
|
||
before_ref = session.query(MessageModel.sequence_id).filter(MessageModel.id == before).one_or_none()
|
||
if not before_ref:
|
||
raise NoResultFound(f"No message found with id '{before}' for agent '{agent_id}'.")
|
||
# Filter out any messages with a sequence_id >= before_ref.sequence_id
|
||
query = query.filter(MessageModel.sequence_id < before_ref.sequence_id)
|
||
|
||
# Apply ordering based on the ascending flag.
|
||
if ascending:
|
||
query = query.order_by(MessageModel.sequence_id.asc())
|
||
else:
|
||
query = query.order_by(MessageModel.sequence_id.desc())
|
||
|
||
# Limit the number of results.
|
||
query = query.limit(limit)
|
||
|
||
# Execute and convert each Message to its Pydantic representation.
|
||
results = query.all()
|
||
return [msg.to_pydantic() for msg in results]
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def list_messages_for_agent_async(
|
||
self,
|
||
agent_id: str,
|
||
actor: PydanticUser,
|
||
after: Optional[str] = None,
|
||
before: Optional[str] = None,
|
||
query_text: Optional[str] = None,
|
||
roles: Optional[Sequence[MessageRole]] = None,
|
||
limit: Optional[int] = 50,
|
||
ascending: bool = True,
|
||
group_id: Optional[str] = None,
|
||
include_err: Optional[bool] = None,
|
||
) -> List[PydanticMessage]:
|
||
"""
|
||
Most performant query to list messages for an agent by directly querying the Message table.
|
||
|
||
This function filters by the agent_id (leveraging the index on messages.agent_id)
|
||
and applies pagination using sequence_id as the cursor.
|
||
If query_text is provided, it will filter messages whose text content partially matches the query.
|
||
If role is provided, it will filter messages by the specified role.
|
||
|
||
Args:
|
||
agent_id: The ID of the agent whose messages are queried.
|
||
actor: The user performing the action (used for permission checks).
|
||
after: A message ID; if provided, only messages *after* this message (by sequence_id) are returned.
|
||
before: A message ID; if provided, only messages *before* this message (by sequence_id) are returned.
|
||
query_text: Optional string to partially match the message text content.
|
||
roles: Optional MessageRole to filter messages by role.
|
||
limit: Maximum number of messages to return.
|
||
ascending: If True, sort by sequence_id ascending; if False, sort descending.
|
||
group_id: Optional group ID to filter messages by group_id.
|
||
include_err: Optional boolean to include errors and error statuses. Used for debugging only.
|
||
|
||
Returns:
|
||
List[PydanticMessage]: A list of messages (converted via .to_pydantic()).
|
||
|
||
Raises:
|
||
NoResultFound: If the provided after/before message IDs do not exist.
|
||
"""
|
||
|
||
async with db_registry.async_session() as session:
|
||
# Permission check: raise if the agent doesn't exist or actor is not allowed.
|
||
await validate_agent_exists_async(session, agent_id, actor)
|
||
|
||
# Build a query that directly filters the Message table by agent_id.
|
||
query = select(MessageModel).where(MessageModel.agent_id == agent_id)
|
||
|
||
# If group_id is provided, filter messages by group_id.
|
||
if group_id:
|
||
query = query.where(MessageModel.group_id == group_id)
|
||
|
||
if not include_err:
|
||
query = query.where((MessageModel.is_err == False) | (MessageModel.is_err.is_(None)))
|
||
|
||
# If query_text is provided, filter messages using database-specific JSON search.
|
||
if query_text:
|
||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||
# PostgreSQL: Use json_array_elements and ILIKE
|
||
content_element = func.json_array_elements(MessageModel.content).alias("content_element")
|
||
query = query.where(
|
||
exists(
|
||
select(1)
|
||
.select_from(content_element)
|
||
.where(text("content_element->>'type' = 'text' AND content_element->>'text' ILIKE :query_text"))
|
||
.params(query_text=f"%{query_text}%")
|
||
)
|
||
)
|
||
else:
|
||
# SQLite: Use JSON_EXTRACT with individual array indices for case-insensitive search
|
||
# Since SQLite doesn't support $[*] syntax, we'll use a different approach
|
||
query = query.where(text("JSON_EXTRACT(content, '$') LIKE :query_text")).params(query_text=f"%{query_text}%")
|
||
|
||
# If role(s) are provided, filter messages by those roles.
|
||
if roles:
|
||
role_values = [r.value for r in roles]
|
||
query = query.where(MessageModel.role.in_(role_values))
|
||
|
||
# Apply 'after' pagination if specified.
|
||
if after:
|
||
after_query = select(MessageModel.sequence_id).where(MessageModel.id == after)
|
||
after_result = await session.execute(after_query)
|
||
after_ref = after_result.one_or_none()
|
||
if not after_ref:
|
||
raise NoResultFound(f"No message found with id '{after}' for agent '{agent_id}'.")
|
||
# Filter out any messages with a sequence_id <= after_ref.sequence_id
|
||
query = query.where(MessageModel.sequence_id > after_ref.sequence_id)
|
||
|
||
# Apply 'before' pagination if specified.
|
||
if before:
|
||
before_query = select(MessageModel.sequence_id).where(MessageModel.id == before)
|
||
before_result = await session.execute(before_query)
|
||
before_ref = before_result.one_or_none()
|
||
if not before_ref:
|
||
raise NoResultFound(f"No message found with id '{before}' for agent '{agent_id}'.")
|
||
# Filter out any messages with a sequence_id >= before_ref.sequence_id
|
||
query = query.where(MessageModel.sequence_id < before_ref.sequence_id)
|
||
|
||
# Apply ordering based on the ascending flag.
|
||
if ascending:
|
||
query = query.order_by(MessageModel.sequence_id.asc())
|
||
else:
|
||
query = query.order_by(MessageModel.sequence_id.desc())
|
||
|
||
# Limit the number of results.
|
||
query = query.limit(limit)
|
||
|
||
# Execute and convert each Message to its Pydantic representation.
|
||
result = await session.execute(query)
|
||
results = result.scalars().all()
|
||
return [msg.to_pydantic() for msg in results]
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def delete_all_messages_for_agent_async(
|
||
self, agent_id: str, actor: PydanticUser, exclude_ids: Optional[List[str]] = None, strict_mode: bool = False
|
||
) -> int:
|
||
"""
|
||
Efficiently deletes all messages associated with a given agent_id,
|
||
while enforcing permission checks and avoiding any ORM‑level loads.
|
||
Optionally excludes specific message IDs from deletion.
|
||
"""
|
||
async with db_registry.async_session() as session:
|
||
# 1) verify the agent exists and the actor has access
|
||
await validate_agent_exists_async(session, agent_id, actor)
|
||
|
||
# 2) issue a CORE DELETE against the mapped class
|
||
stmt = (
|
||
delete(MessageModel).where(MessageModel.agent_id == agent_id).where(MessageModel.organization_id == actor.organization_id)
|
||
)
|
||
|
||
# 3) exclude specific message IDs if provided
|
||
if exclude_ids:
|
||
stmt = stmt.where(~MessageModel.id.in_(exclude_ids))
|
||
|
||
result = await session.execute(stmt)
|
||
|
||
# 4) commit once
|
||
await session.commit()
|
||
|
||
# 5) delete from turbopuffer if enabled
|
||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
|
||
|
||
if should_use_tpuf_for_messages():
|
||
try:
|
||
tpuf_client = TurbopufferClient()
|
||
if exclude_ids:
|
||
# if we're excluding some IDs, we can't use delete_all
|
||
# would need to query all messages first then delete specific ones
|
||
# for now, log a warning
|
||
logger.warning(f"Turbopuffer deletion with exclude_ids not fully supported, using delete_all for agent {agent_id}")
|
||
# delete all messages for the agent from turbopuffer
|
||
await tpuf_client.delete_all_messages(agent_id)
|
||
logger.info(f"Successfully deleted all messages for agent {agent_id} from Turbopuffer")
|
||
except Exception as e:
|
||
logger.error(f"Failed to delete messages from Turbopuffer: {e}")
|
||
if strict_mode:
|
||
raise # Re-raise the exception in strict mode
|
||
|
||
# 6) return the number of rows deleted
|
||
return result.rowcount
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def delete_messages_by_ids_async(self, message_ids: List[str], actor: PydanticUser, strict_mode: bool = False) -> int:
|
||
"""
|
||
Efficiently deletes messages by their specific IDs,
|
||
while enforcing permission checks.
|
||
"""
|
||
if not message_ids:
|
||
return 0
|
||
|
||
async with db_registry.async_session() as session:
|
||
# get agent_ids BEFORE deleting (for turbopuffer)
|
||
agent_ids = []
|
||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
|
||
|
||
if should_use_tpuf_for_messages():
|
||
agent_query = (
|
||
select(MessageModel.agent_id)
|
||
.where(MessageModel.id.in_(message_ids))
|
||
.where(MessageModel.organization_id == actor.organization_id)
|
||
.distinct()
|
||
)
|
||
agent_result = await session.execute(agent_query)
|
||
agent_ids = [row[0] for row in agent_result.fetchall() if row[0]]
|
||
|
||
# issue a CORE DELETE against the mapped class for specific message IDs
|
||
stmt = delete(MessageModel).where(MessageModel.id.in_(message_ids)).where(MessageModel.organization_id == actor.organization_id)
|
||
result = await session.execute(stmt)
|
||
|
||
# commit once
|
||
await session.commit()
|
||
|
||
# delete from turbopuffer if enabled
|
||
if should_use_tpuf_for_messages() and agent_ids:
|
||
try:
|
||
tpuf_client = TurbopufferClient()
|
||
# delete from each affected agent's namespace
|
||
for agent_id in agent_ids:
|
||
await tpuf_client.delete_messages(agent_id=agent_id, message_ids=message_ids)
|
||
logger.info(f"Successfully deleted {len(message_ids)} messages from Turbopuffer")
|
||
except Exception as e:
|
||
logger.error(f"Failed to delete messages from Turbopuffer: {e}")
|
||
if strict_mode:
|
||
raise # Re-raise the exception in strict mode
|
||
|
||
# return the number of rows deleted
|
||
return result.rowcount
|
||
|
||
@enforce_types
|
||
@trace_method
|
||
async def search_messages_async(
|
||
self,
|
||
agent_id: str,
|
||
actor: PydanticUser,
|
||
query_text: Optional[str] = None,
|
||
query_embedding: Optional[List[float]] = None,
|
||
search_mode: str = "hybrid",
|
||
roles: Optional[List[MessageRole]] = None,
|
||
limit: int = 50,
|
||
start_date: Optional[datetime] = None,
|
||
end_date: Optional[datetime] = None,
|
||
embedding_config: Optional[EmbeddingConfig] = None,
|
||
) -> List[PydanticMessage]:
|
||
"""
|
||
Search messages using Turbopuffer if enabled, otherwise fall back to SQL search.
|
||
|
||
Args:
|
||
agent_id: ID of the agent whose messages to search
|
||
actor: User performing the search
|
||
query_text: Text query for full-text search
|
||
query_embedding: Optional pre-computed embedding for vector search
|
||
search_mode: "vector", "fts", "hybrid", or "timestamp" (default: "hybrid")
|
||
roles: Optional list of message roles to filter by
|
||
limit: Maximum number of results to return
|
||
start_date: Optional filter for messages created after this date
|
||
end_date: Optional filter for messages created before this date
|
||
embedding_config: Optional embedding configuration for generating query embedding
|
||
|
||
Returns:
|
||
List of matching messages
|
||
"""
|
||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
|
||
|
||
# check if we should use turbopuffer
|
||
if should_use_tpuf_for_messages():
|
||
try:
|
||
# generate embedding if needed and not provided
|
||
if search_mode in ["vector", "hybrid"] and query_embedding is None and query_text:
|
||
if not embedding_config:
|
||
# fall back to SQL search if no embedding config
|
||
logger.warning("No embedding config provided for vector search, falling back to SQL")
|
||
return await self.list_messages_for_agent_async(
|
||
agent_id=agent_id,
|
||
actor=actor,
|
||
query_text=query_text,
|
||
roles=roles,
|
||
limit=limit,
|
||
ascending=False,
|
||
)
|
||
|
||
# generate embedding from query text
|
||
from letta.llm_api.llm_client import LLMClient
|
||
|
||
embedding_client = LLMClient.create(
|
||
provider_type=embedding_config.embedding_endpoint_type,
|
||
actor=actor,
|
||
)
|
||
embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
|
||
query_embedding = embeddings[0]
|
||
|
||
# use turbopuffer for search
|
||
tpuf_client = TurbopufferClient()
|
||
results = await tpuf_client.query_messages(
|
||
agent_id=agent_id,
|
||
query_embedding=query_embedding,
|
||
query_text=query_text,
|
||
search_mode=search_mode,
|
||
top_k=limit,
|
||
roles=roles,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
)
|
||
|
||
# fetch full message objects from database using the IDs
|
||
message_ids = [msg_dict["id"] for msg_dict, _ in results]
|
||
if message_ids:
|
||
messages = await self.get_messages_by_ids_async(message_ids, actor)
|
||
# maintain the order from turbopuffer results
|
||
message_dict = {msg.id: msg for msg in messages}
|
||
return [message_dict[msg_id] for msg_id in message_ids if msg_id in message_dict]
|
||
else:
|
||
return []
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to search messages with Turbopuffer, falling back to SQL: {e}")
|
||
# fall back to SQL search
|
||
return await self.list_messages_for_agent_async(
|
||
agent_id=agent_id,
|
||
actor=actor,
|
||
query_text=query_text,
|
||
roles=roles,
|
||
limit=limit,
|
||
ascending=False,
|
||
)
|
||
else:
|
||
# use sql-based search
|
||
return await self.list_messages_for_agent_async(
|
||
agent_id=agent_id,
|
||
actor=actor,
|
||
query_text=query_text,
|
||
roles=roles,
|
||
limit=limit,
|
||
ascending=False,
|
||
)
|