358 lines
14 KiB
Python
358 lines
14 KiB
Python
from typing import List, Optional
|
|
|
|
from sqlalchemy import func, select
|
|
|
|
from letta.orm.conversation import Conversation as ConversationModel
|
|
from letta.orm.conversation_messages import ConversationMessage as ConversationMessageModel
|
|
from letta.orm.errors import NoResultFound
|
|
from letta.orm.message import Message as MessageModel
|
|
from letta.otel.tracing import trace_method
|
|
from letta.schemas.conversation import Conversation as PydanticConversation, CreateConversation, UpdateConversation
|
|
from letta.schemas.letta_message import LettaMessage
|
|
from letta.schemas.message import Message as PydanticMessage
|
|
from letta.schemas.user import User as PydanticUser
|
|
from letta.server.db import db_registry
|
|
from letta.utils import enforce_types
|
|
|
|
|
|
class ConversationManager:
|
|
"""Manager class to handle business logic related to Conversations."""
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def create_conversation(
|
|
self,
|
|
agent_id: str,
|
|
conversation_create: CreateConversation,
|
|
actor: PydanticUser,
|
|
) -> PydanticConversation:
|
|
"""Create a new conversation for an agent."""
|
|
async with db_registry.async_session() as session:
|
|
conversation = ConversationModel(
|
|
agent_id=agent_id,
|
|
summary=conversation_create.summary,
|
|
organization_id=actor.organization_id,
|
|
)
|
|
await conversation.create_async(session, actor=actor)
|
|
return conversation.to_pydantic()
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_conversation_by_id(
|
|
self,
|
|
conversation_id: str,
|
|
actor: PydanticUser,
|
|
) -> PydanticConversation:
|
|
"""Retrieve a conversation by its ID, including in-context message IDs."""
|
|
async with db_registry.async_session() as session:
|
|
conversation = await ConversationModel.read_async(
|
|
db_session=session,
|
|
identifier=conversation_id,
|
|
actor=actor,
|
|
check_is_deleted=True,
|
|
)
|
|
|
|
# Get the in-context message IDs for this conversation
|
|
message_ids = await self.get_message_ids_for_conversation(
|
|
conversation_id=conversation_id,
|
|
actor=actor,
|
|
)
|
|
|
|
# Build the pydantic model with in_context_message_ids
|
|
pydantic_conversation = conversation.to_pydantic()
|
|
pydantic_conversation.in_context_message_ids = message_ids
|
|
return pydantic_conversation
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def list_conversations(
|
|
self,
|
|
agent_id: str,
|
|
actor: PydanticUser,
|
|
limit: int = 50,
|
|
after: Optional[str] = None,
|
|
) -> List[PydanticConversation]:
|
|
"""List conversations for an agent with cursor-based pagination."""
|
|
async with db_registry.async_session() as session:
|
|
conversations = await ConversationModel.list_async(
|
|
db_session=session,
|
|
actor=actor,
|
|
agent_id=agent_id,
|
|
limit=limit,
|
|
after=after,
|
|
ascending=False,
|
|
)
|
|
return [conv.to_pydantic() for conv in conversations]
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def update_conversation(
|
|
self,
|
|
conversation_id: str,
|
|
conversation_update: UpdateConversation,
|
|
actor: PydanticUser,
|
|
) -> PydanticConversation:
|
|
"""Update a conversation."""
|
|
async with db_registry.async_session() as session:
|
|
conversation = await ConversationModel.read_async(
|
|
db_session=session,
|
|
identifier=conversation_id,
|
|
actor=actor,
|
|
)
|
|
|
|
# Set attributes on the model
|
|
update_data = conversation_update.model_dump(exclude_none=True)
|
|
for key, value in update_data.items():
|
|
setattr(conversation, key, value)
|
|
|
|
# Commit the update
|
|
updated_conversation = await conversation.update_async(
|
|
db_session=session,
|
|
actor=actor,
|
|
)
|
|
return updated_conversation.to_pydantic()
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def delete_conversation(
|
|
self,
|
|
conversation_id: str,
|
|
actor: PydanticUser,
|
|
) -> None:
|
|
"""Soft delete a conversation."""
|
|
async with db_registry.async_session() as session:
|
|
conversation = await ConversationModel.read_async(
|
|
db_session=session,
|
|
identifier=conversation_id,
|
|
actor=actor,
|
|
)
|
|
# Soft delete by setting is_deleted flag
|
|
conversation.is_deleted = True
|
|
await conversation.update_async(db_session=session, actor=actor)
|
|
|
|
# ==================== Message Management Methods ====================
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_message_ids_for_conversation(
|
|
self,
|
|
conversation_id: str,
|
|
actor: PydanticUser,
|
|
) -> List[str]:
|
|
"""
|
|
Get ordered message IDs for a conversation.
|
|
|
|
Returns message IDs ordered by position in the conversation.
|
|
Only returns messages that are currently in_context.
|
|
"""
|
|
async with db_registry.async_session() as session:
|
|
query = (
|
|
select(ConversationMessageModel.message_id)
|
|
.where(
|
|
ConversationMessageModel.conversation_id == conversation_id,
|
|
ConversationMessageModel.organization_id == actor.organization_id,
|
|
ConversationMessageModel.in_context == True,
|
|
ConversationMessageModel.is_deleted == False,
|
|
)
|
|
.order_by(ConversationMessageModel.position)
|
|
)
|
|
result = await session.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_messages_for_conversation(
|
|
self,
|
|
conversation_id: str,
|
|
actor: PydanticUser,
|
|
) -> List[PydanticMessage]:
|
|
"""
|
|
Get ordered Message objects for a conversation.
|
|
|
|
Returns full Message objects ordered by position in the conversation.
|
|
Only returns messages that are currently in_context.
|
|
"""
|
|
async with db_registry.async_session() as session:
|
|
query = (
|
|
select(MessageModel)
|
|
.join(
|
|
ConversationMessageModel,
|
|
MessageModel.id == ConversationMessageModel.message_id,
|
|
)
|
|
.where(
|
|
ConversationMessageModel.conversation_id == conversation_id,
|
|
ConversationMessageModel.organization_id == actor.organization_id,
|
|
ConversationMessageModel.in_context == True,
|
|
ConversationMessageModel.is_deleted == False,
|
|
)
|
|
.order_by(ConversationMessageModel.position)
|
|
)
|
|
result = await session.execute(query)
|
|
return [msg.to_pydantic() for msg in result.scalars().all()]
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def add_messages_to_conversation(
|
|
self,
|
|
conversation_id: str,
|
|
agent_id: str,
|
|
message_ids: List[str],
|
|
actor: PydanticUser,
|
|
starting_position: Optional[int] = None,
|
|
) -> None:
|
|
"""
|
|
Add messages to a conversation's tracking table.
|
|
|
|
Creates ConversationMessage entries with auto-incrementing positions.
|
|
|
|
Args:
|
|
conversation_id: The conversation to add messages to
|
|
agent_id: The agent ID
|
|
message_ids: List of message IDs to add
|
|
actor: The user performing the action
|
|
starting_position: Optional starting position (defaults to next available)
|
|
"""
|
|
if not message_ids:
|
|
return
|
|
|
|
async with db_registry.async_session() as session:
|
|
# Get starting position if not provided
|
|
if starting_position is None:
|
|
query = select(func.coalesce(func.max(ConversationMessageModel.position), -1)).where(
|
|
ConversationMessageModel.conversation_id == conversation_id,
|
|
ConversationMessageModel.organization_id == actor.organization_id,
|
|
)
|
|
result = await session.execute(query)
|
|
max_position = result.scalar()
|
|
# Use explicit None check instead of `or` to handle position=0 correctly
|
|
if max_position is None:
|
|
max_position = -1
|
|
starting_position = max_position + 1
|
|
|
|
# Create ConversationMessage entries
|
|
for i, message_id in enumerate(message_ids):
|
|
conv_msg = ConversationMessageModel(
|
|
conversation_id=conversation_id,
|
|
agent_id=agent_id,
|
|
message_id=message_id,
|
|
position=starting_position + i,
|
|
in_context=True,
|
|
organization_id=actor.organization_id,
|
|
)
|
|
session.add(conv_msg)
|
|
|
|
await session.commit()
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def update_in_context_messages(
|
|
self,
|
|
conversation_id: str,
|
|
in_context_message_ids: List[str],
|
|
actor: PydanticUser,
|
|
) -> None:
|
|
"""
|
|
Update which messages are in context for a conversation.
|
|
|
|
Sets in_context=True for messages in the list, False for others.
|
|
|
|
Args:
|
|
conversation_id: The conversation to update
|
|
in_context_message_ids: List of message IDs that should be in context
|
|
actor: The user performing the action
|
|
"""
|
|
async with db_registry.async_session() as session:
|
|
# Get all conversation messages for this conversation
|
|
query = select(ConversationMessageModel).where(
|
|
ConversationMessageModel.conversation_id == conversation_id,
|
|
ConversationMessageModel.organization_id == actor.organization_id,
|
|
ConversationMessageModel.is_deleted == False,
|
|
)
|
|
result = await session.execute(query)
|
|
conv_messages = result.scalars().all()
|
|
|
|
# Update in_context status
|
|
in_context_set = set(in_context_message_ids)
|
|
for conv_msg in conv_messages:
|
|
conv_msg.in_context = conv_msg.message_id in in_context_set
|
|
|
|
await session.commit()
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def list_conversation_messages(
|
|
self,
|
|
conversation_id: str,
|
|
actor: PydanticUser,
|
|
limit: Optional[int] = 100,
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
) -> List[LettaMessage]:
|
|
"""
|
|
List all messages in a conversation with pagination support.
|
|
|
|
Unlike get_messages_for_conversation, this returns ALL messages
|
|
(not just in_context) and supports cursor-based pagination.
|
|
Messages are always ordered by position (oldest first).
|
|
|
|
Args:
|
|
conversation_id: The conversation to list messages for
|
|
actor: The user performing the action
|
|
limit: Maximum number of messages to return
|
|
before: Return messages before this message ID
|
|
after: Return messages after this message ID
|
|
|
|
Returns:
|
|
List of LettaMessage objects
|
|
"""
|
|
async with db_registry.async_session() as session:
|
|
# Build base query joining Message with ConversationMessage
|
|
query = (
|
|
select(MessageModel)
|
|
.join(
|
|
ConversationMessageModel,
|
|
MessageModel.id == ConversationMessageModel.message_id,
|
|
)
|
|
.where(
|
|
ConversationMessageModel.conversation_id == conversation_id,
|
|
ConversationMessageModel.organization_id == actor.organization_id,
|
|
ConversationMessageModel.is_deleted == False,
|
|
)
|
|
)
|
|
|
|
# Handle cursor-based pagination
|
|
if before:
|
|
# Get the position of the cursor message
|
|
cursor_query = select(ConversationMessageModel.position).where(
|
|
ConversationMessageModel.conversation_id == conversation_id,
|
|
ConversationMessageModel.message_id == before,
|
|
)
|
|
cursor_result = await session.execute(cursor_query)
|
|
cursor_position = cursor_result.scalar_one_or_none()
|
|
if cursor_position is not None:
|
|
query = query.where(ConversationMessageModel.position < cursor_position)
|
|
|
|
if after:
|
|
# Get the position of the cursor message
|
|
cursor_query = select(ConversationMessageModel.position).where(
|
|
ConversationMessageModel.conversation_id == conversation_id,
|
|
ConversationMessageModel.message_id == after,
|
|
)
|
|
cursor_result = await session.execute(cursor_query)
|
|
cursor_position = cursor_result.scalar_one_or_none()
|
|
if cursor_position is not None:
|
|
query = query.where(ConversationMessageModel.position > cursor_position)
|
|
|
|
# Order by position (oldest first)
|
|
query = query.order_by(ConversationMessageModel.position.asc())
|
|
|
|
# Apply limit
|
|
if limit is not None:
|
|
query = query.limit(limit)
|
|
|
|
result = await session.execute(query)
|
|
messages = [msg.to_pydantic() for msg in result.scalars().all()]
|
|
|
|
# Convert to LettaMessages
|
|
return PydanticMessage.to_letta_messages_from_list(messages, reverse=False, text_is_assistant_message=True)
|