Files
letta-server/letta/services/message_manager.py
2024-12-19 11:33:00 -08:00

199 lines
7.5 KiB
Python

from datetime import datetime
from typing import Dict, List, Optional
from letta.orm.errors import NoResultFound
from letta.orm.message import Message as MessageModel
from letta.schemas.enums import MessageRole
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import MessageUpdate
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types
class MessageManager:
"""Manager class to handle business logic related to Messages."""
def __init__(self):
from letta.server.server import db_context
self.session_maker = db_context
@enforce_types
def get_message_by_id(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]:
"""Fetch a message by ID."""
with self.session_maker() as session:
try:
message = MessageModel.read(db_session=session, identifier=message_id, actor=actor)
return message.to_pydantic()
except NoResultFound:
return None
@enforce_types
def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage:
"""Create a new message."""
with self.session_maker() as session:
# Set the organization id of the Pydantic message
pydantic_msg.organization_id = actor.organization_id
msg_data = pydantic_msg.model_dump()
msg = MessageModel(**msg_data)
msg.create(session, actor=actor) # Persist to database
return msg.to_pydantic()
@enforce_types
def create_many_messages(self, pydantic_msgs: List[PydanticMessage], actor: PydanticUser) -> List[PydanticMessage]:
"""Create multiple messages."""
return [self.create_message(m, actor=actor) for m in pydantic_msgs]
@enforce_types
def update_message_by_id(self, message_id: str, message_update: MessageUpdate, actor: PydanticUser) -> PydanticMessage:
"""
Updates an existing record in the database with values from the provided record object.
"""
with self.session_maker() as session:
# Fetch existing message from database
message = MessageModel.read(
db_session=session,
identifier=message_id,
actor=actor,
)
# Some safety checks specific to messages
if message_update.tool_calls and message.role != MessageRole.assistant:
raise ValueError(
f"Tool calls {message_update.tool_calls} can only be added to assistant messages. Message {message_id} has role {message.role}."
)
if message_update.tool_call_id and message.role != MessageRole.tool:
raise ValueError(
f"Tool call IDs {message_update.tool_call_id} can only be added to tool messages. Message {message_id} has role {message.role}."
)
# get update dictionary
update_data = message_update.model_dump(exclude_unset=True, exclude_none=True)
# Remove redundant update fields
update_data = {key: value for key, value in update_data.items() if getattr(message, key) != value}
for key, value in update_data.items():
setattr(message, key, value)
message.update(db_session=session, actor=actor)
return message.to_pydantic()
@enforce_types
def delete_message_by_id(self, message_id: str, actor: PydanticUser) -> bool:
"""Delete a message."""
with self.session_maker() as session:
try:
msg = MessageModel.read(
db_session=session,
identifier=message_id,
actor=actor,
)
msg.hard_delete(session, actor=actor)
except NoResultFound:
raise ValueError(f"Message with id {message_id} not found.")
@enforce_types
def size(
self,
actor: PydanticUser,
role: Optional[MessageRole] = None,
agent_id: Optional[str] = None,
) -> int:
"""Get the total count of messages with optional filters.
Args:
actor: The user requesting the count
role: The role of the message
"""
with self.session_maker() as session:
return MessageModel.size(db_session=session, actor=actor, role=role, agent_id=agent_id)
@enforce_types
def list_user_messages_for_agent(
self,
agent_id: str,
actor: Optional[PydanticUser] = None,
cursor: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: Optional[int] = 50,
filters: Optional[Dict] = None,
query_text: Optional[str] = None,
ascending: bool = True,
) -> List[PydanticMessage]:
"""List user messages with flexible filtering and pagination options.
Args:
cursor: Cursor-based pagination - return records after this ID (exclusive)
start_date: Filter records created after this date
end_date: Filter records created before this date
limit: Maximum number of records to return
filters: Additional filters to apply
query_text: Optional text to search for in message content
Returns:
List[PydanticMessage] - List of messages matching the criteria
"""
message_filters = {"role": "user"}
if filters:
message_filters.update(filters)
return self.list_messages_for_agent(
agent_id=agent_id,
actor=actor,
cursor=cursor,
start_date=start_date,
end_date=end_date,
limit=limit,
filters=message_filters,
query_text=query_text,
ascending=ascending,
)
@enforce_types
def list_messages_for_agent(
self,
agent_id: str,
actor: Optional[PydanticUser] = None,
cursor: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: Optional[int] = 50,
filters: Optional[Dict] = None,
query_text: Optional[str] = None,
ascending: bool = True,
) -> List[PydanticMessage]:
"""List messages with flexible filtering and pagination options.
Args:
cursor: Cursor-based pagination - return records after this ID (exclusive)
start_date: Filter records created after this date
end_date: Filter records created before this date
limit: Maximum number of records to return
filters: Additional filters to apply
query_text: Optional text to search for in message content
Returns:
List[PydanticMessage] - List of messages matching the criteria
"""
with self.session_maker() as session:
# Start with base filters
message_filters = {"agent_id": agent_id}
if actor:
message_filters.update({"organization_id": actor.organization_id})
if filters:
message_filters.update(filters)
results = MessageModel.list(
db_session=session,
cursor=cursor,
start_date=start_date,
end_date=end_date,
limit=limit,
query_text=query_text,
ascending=ascending,
**message_filters,
)
return [msg.to_pydantic() for msg in results]