diff --git a/.skills/db-migrations-schema-changes/SKILL.md b/.skills/db-migrations-schema-changes/SKILL.md index 7241ae60..fb74fe91 100644 --- a/.skills/db-migrations-schema-changes/SKILL.md +++ b/.skills/db-migrations-schema-changes/SKILL.md @@ -90,6 +90,36 @@ Workflow for Postgres-targeted migration: - `uv run alembic upgrade head` - `uv run alembic revision --autogenerate -m "..."` +### 5. Resetting local Postgres for clean migration generation + +If your local Postgres database has drifted from main (e.g., applied migrations +that no longer exist, or has stale schema), you can reset it to generate a clean +migration. + +From the repo root (`/Users/sarahwooders/repos/letta-cloud`): + +```bash +# 1. Remove postgres data directory +rm -rf ./data/postgres + +# 2. Stop the running postgres container +docker stop $(docker ps -q --filter ancestor=ankane/pgvector) + +# 3. Restart services (creates fresh postgres) +just start-services + +# 4. Wait a moment for postgres to be ready, then apply all migrations +cd apps/core +export LETTA_PG_URI=postgresql+pg8000://postgres:postgres@localhost:5432/letta-core +uv run alembic upgrade head + +# 5. Now generate your new migration +uv run alembic revision --autogenerate -m "your migration message" +``` + +This ensures the migration is generated against a clean database state matching +main, avoiding spurious diffs from local-only schema changes. + ## Troubleshooting - **"Target database is not up to date" when autogenerating** @@ -101,7 +131,7 @@ Workflow for Postgres-targeted migration: changed model is imported in Alembic env context. - **Autogenerated migration has unexpected drops/renames** - Review model changes; consider explicit operations instead of relying on - autogenerate. + autogenerate. Reset local Postgres (see workflow 5) to get a clean baseline. ## References diff --git a/alembic/versions/27de0f58e076_add_conversations_tables_and_run_.py b/alembic/versions/27de0f58e076_add_conversations_tables_and_run_.py new file mode 100644 index 00000000..edaa4692 --- /dev/null +++ b/alembic/versions/27de0f58e076_add_conversations_tables_and_run_.py @@ -0,0 +1,97 @@ +"""add conversations tables and run conversation_id + +Revision ID: 27de0f58e076 +Revises: ee2b43eea55e +Create Date: 2026-01-01 20:36:09.101274 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "27de0f58e076" +down_revision: Union[str, None] = "ee2b43eea55e" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "conversations", + sa.Column("id", sa.String(), nullable=False), + sa.Column("agent_id", sa.String(), nullable=False), + sa.Column("summary", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ix_conversations_agent_id", "conversations", ["agent_id"], unique=False) + op.create_index("ix_conversations_org_agent", "conversations", ["organization_id", "agent_id"], unique=False) + op.create_table( + "conversation_messages", + sa.Column("id", sa.String(), nullable=False), + sa.Column("conversation_id", sa.String(), nullable=True), + sa.Column("agent_id", sa.String(), nullable=False), + sa.Column("message_id", sa.String(), nullable=False), + sa.Column("position", sa.Integer(), nullable=False), + sa.Column("in_context", sa.Boolean(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["conversation_id"], ["conversations.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["message_id"], ["messages.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("conversation_id", "message_id", name="unique_conversation_message"), + ) + op.create_index("ix_conv_msg_agent_conversation", "conversation_messages", ["agent_id", "conversation_id"], unique=False) + op.create_index("ix_conv_msg_agent_id", "conversation_messages", ["agent_id"], unique=False) + op.create_index("ix_conv_msg_conversation_position", "conversation_messages", ["conversation_id", "position"], unique=False) + op.create_index("ix_conv_msg_message_id", "conversation_messages", ["message_id"], unique=False) + op.add_column("messages", sa.Column("conversation_id", sa.String(), nullable=True)) + op.create_index(op.f("ix_messages_conversation_id"), "messages", ["conversation_id"], unique=False) + op.create_foreign_key(None, "messages", "conversations", ["conversation_id"], ["id"], ondelete="SET NULL") + op.add_column("runs", sa.Column("conversation_id", sa.String(), nullable=True)) + op.create_index("ix_runs_conversation_id", "runs", ["conversation_id"], unique=False) + op.create_foreign_key(None, "runs", "conversations", ["conversation_id"], ["id"], ondelete="SET NULL") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, "runs", type_="foreignkey") + op.drop_index("ix_runs_conversation_id", table_name="runs") + op.drop_column("runs", "conversation_id") + op.drop_constraint(None, "messages", type_="foreignkey") + op.drop_index(op.f("ix_messages_conversation_id"), table_name="messages") + op.drop_column("messages", "conversation_id") + op.drop_index("ix_conv_msg_message_id", table_name="conversation_messages") + op.drop_index("ix_conv_msg_conversation_position", table_name="conversation_messages") + op.drop_index("ix_conv_msg_agent_id", table_name="conversation_messages") + op.drop_index("ix_conv_msg_agent_conversation", table_name="conversation_messages") + op.drop_table("conversation_messages") + op.drop_index("ix_conversations_org_agent", table_name="conversations") + op.drop_index("ix_conversations_agent_id", table_name="conversations") + op.drop_table("conversations") + # ### end Alembic commands ### diff --git a/fern/openapi.json b/fern/openapi.json index 42779f26..60bdeec6 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -8289,6 +8289,444 @@ } } }, + "/v1/conversations/": { + "post": { + "tags": ["conversations"], + "summary": "Create Conversation", + "description": "Create a new conversation for an agent.", + "operationId": "create_conversation", + "parameters": [ + { + "name": "agent_id", + "in": "query", + "required": true, + "schema": { + "type": "string", + "description": "The agent ID to create a conversation for", + "title": "Agent Id" + }, + "description": "The agent ID to create a conversation for" + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreateConversation" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Conversation" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "get": { + "tags": ["conversations"], + "summary": "List Conversations", + "description": "List all conversations for an agent.", + "operationId": "list_conversations", + "parameters": [ + { + "name": "agent_id", + "in": "query", + "required": true, + "schema": { + "type": "string", + "description": "The agent ID to list conversations for", + "title": "Agent Id" + }, + "description": "The agent ID to list conversations for" + }, + { + "name": "limit", + "in": "query", + "required": false, + "schema": { + "type": "integer", + "description": "Maximum number of conversations to return", + "default": 50, + "title": "Limit" + }, + "description": "Maximum number of conversations to return" + }, + { + "name": "after", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "description": "Cursor for pagination (conversation ID)", + "title": "After" + }, + "description": "Cursor for pagination (conversation ID)" + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Conversation" + }, + "title": "Response List Conversations" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/v1/conversations/{conversation_id}": { + "get": { + "tags": ["conversations"], + "summary": "Retrieve Conversation", + "description": "Retrieve a specific conversation.", + "operationId": "retrieve_conversation", + "parameters": [ + { + "name": "conversation_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "minLength": 41, + "maxLength": 41, + "pattern": "^conv-[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$", + "description": "The ID of the conv in the format 'conv-'", + "examples": ["conv-123e4567-e89b-42d3-8456-426614174000"], + "title": "Conversation Id" + }, + "description": "The ID of the conv in the format 'conv-'" + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Conversation" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/v1/conversations/{conversation_id}/messages": { + "get": { + "tags": ["conversations"], + "summary": "List Conversation Messages", + "description": "List all messages in a conversation.\n\nReturns LettaMessage objects (UserMessage, AssistantMessage, etc.) for all\nmessages in the conversation, ordered by position (oldest first),\nwith support for cursor-based pagination.", + "operationId": "list_conversation_messages", + "parameters": [ + { + "name": "conversation_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "minLength": 41, + "maxLength": 41, + "pattern": "^conv-[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$", + "description": "The ID of the conv in the format 'conv-'", + "examples": ["conv-123e4567-e89b-42d3-8456-426614174000"], + "title": "Conversation Id" + }, + "description": "The ID of the conv in the format 'conv-'" + }, + { + "name": "before", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "description": "Message ID cursor for pagination. Returns messages that come before this message ID in the conversation", + "title": "Before" + }, + "description": "Message ID cursor for pagination. Returns messages that come before this message ID in the conversation" + }, + { + "name": "after", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "description": "Message ID cursor for pagination. Returns messages that come after this message ID in the conversation", + "title": "After" + }, + "description": "Message ID cursor for pagination. Returns messages that come after this message ID in the conversation" + }, + { + "name": "limit", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "description": "Maximum number of messages to return", + "default": 100, + "title": "Limit" + }, + "description": "Maximum number of messages to return" + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/LettaMessageUnion" + }, + "title": "Response List Conversation Messages" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "post": { + "tags": ["conversations"], + "summary": "Send Conversation Message", + "description": "Send a message to a conversation and get a streaming response.\n\nThis endpoint sends a message to an existing conversation and streams\nthe agent's response back.", + "operationId": "send_conversation_message", + "parameters": [ + { + "name": "conversation_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "minLength": 41, + "maxLength": 41, + "pattern": "^conv-[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$", + "description": "The ID of the conv in the format 'conv-'", + "examples": ["conv-123e4567-e89b-42d3-8456-426614174000"], + "title": "Conversation Id" + }, + "description": "The ID of the conv in the format 'conv-'" + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/LettaStreamingRequest" + } + } + } + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/LettaStreamingResponse" + } + }, + "text/event-stream": { + "description": "Server-Sent Events stream" + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/v1/conversations/{conversation_id}/stream": { + "post": { + "tags": ["conversations"], + "summary": "Retrieve Conversation Stream", + "description": "Resume the stream for the most recent active run in a conversation.\n\nThis endpoint allows you to reconnect to an active background stream\nfor a conversation, enabling recovery from network interruptions.", + "operationId": "retrieve_conversation_stream", + "parameters": [ + { + "name": "conversation_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "minLength": 41, + "maxLength": 41, + "pattern": "^conv-[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$", + "description": "The ID of the conv in the format 'conv-'", + "examples": ["conv-123e4567-e89b-42d3-8456-426614174000"], + "title": "Conversation Id" + }, + "description": "The ID of the conv in the format 'conv-'" + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RetrieveStreamRequest" + } + } + } + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": {} + }, + "text/event-stream": { + "description": "Server-Sent Events stream", + "schema": { + "oneOf": [ + { + "$ref": "#/components/schemas/SystemMessage" + }, + { + "$ref": "#/components/schemas/UserMessage" + }, + { + "$ref": "#/components/schemas/ReasoningMessage" + }, + { + "$ref": "#/components/schemas/HiddenReasoningMessage" + }, + { + "$ref": "#/components/schemas/ToolCallMessage" + }, + { + "$ref": "#/components/schemas/ToolReturnMessage" + }, + { + "$ref": "#/components/schemas/AssistantMessage" + }, + { + "$ref": "#/components/schemas/ApprovalRequestMessage" + }, + { + "$ref": "#/components/schemas/ApprovalResponseMessage" + }, + { + "$ref": "#/components/schemas/LettaPing" + }, + { + "$ref": "#/components/schemas/LettaErrorMessage" + }, + { + "$ref": "#/components/schemas/LettaStopReason" + }, + { + "$ref": "#/components/schemas/LettaUsageStatistics" + } + ] + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, "/v1/chat/completions": { "post": { "tags": ["chat"], @@ -27246,6 +27684,95 @@ "title": "ContinueToolRule", "description": "Represents a tool rule configuration where if this tool gets called, it must continue the agent loop." }, + "Conversation": { + "properties": { + "created_by_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Created By Id", + "description": "The id of the user that made this object." + }, + "last_updated_by_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Last Updated By Id", + "description": "The id of the user that made this object." + }, + "created_at": { + "anyOf": [ + { + "type": "string", + "format": "date-time" + }, + { + "type": "null" + } + ], + "title": "Created At", + "description": "The timestamp when the object was created." + }, + "updated_at": { + "anyOf": [ + { + "type": "string", + "format": "date-time" + }, + { + "type": "null" + } + ], + "title": "Updated At", + "description": "The timestamp when the object was last updated." + }, + "id": { + "type": "string", + "title": "Id", + "description": "The unique identifier of the conversation." + }, + "agent_id": { + "type": "string", + "title": "Agent Id", + "description": "The ID of the agent this conversation belongs to." + }, + "summary": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Summary", + "description": "A summary of the conversation." + }, + "in_context_message_ids": { + "items": { + "type": "string" + }, + "type": "array", + "title": "In Context Message Ids", + "description": "The IDs of in-context messages for the conversation." + } + }, + "additionalProperties": false, + "type": "object", + "required": ["id", "agent_id"], + "title": "Conversation", + "description": "Represents a conversation on an agent for concurrent messaging." + }, "CoreMemoryBlockSchema": { "properties": { "created_at": { @@ -28255,6 +28782,25 @@ "title": "CreateBlock", "description": "Create a block" }, + "CreateConversation": { + "properties": { + "summary": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Summary", + "description": "A summary of the conversation." + } + }, + "type": "object", + "title": "CreateConversation", + "description": "Request model for creating a new conversation." + }, "CreateMCPServerRequest": { "properties": { "server_name": { @@ -37076,6 +37622,18 @@ "title": "Agent Id", "description": "The unique identifier of the agent associated with the run." }, + "conversation_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Conversation Id", + "description": "The unique identifier of the conversation associated with the run." + }, "base_template_id": { "anyOf": [ { diff --git a/letta/agents/base_agent_v2.py b/letta/agents/base_agent_v2.py index 847db90f..d66f6f04 100644 --- a/letta/agents/base_agent_v2.py +++ b/letta/agents/base_agent_v2.py @@ -66,6 +66,7 @@ class BaseAgentV2(ABC): use_assistant_message: bool = True, include_return_message_types: list[MessageType] | None = None, request_start_timestamp_ns: int | None = None, + conversation_id: str | None = None, client_tools: list["ClientToolSchema"] | None = None, ) -> AsyncGenerator[LettaMessage | LegacyLettaMessage | MessageStreamStatus, None]: """ diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 2490feb7..b17f4ae9 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -6,6 +6,7 @@ from uuid import UUID, uuid4 from letta.errors import PendingApprovalError from letta.helpers import ToolRulesSolver +from letta.helpers.datetime_helpers import get_utc_time from letta.log import get_logger from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState @@ -131,16 +132,21 @@ async def _prepare_in_context_messages_no_persist_async( message_manager: MessageManager, actor: User, run_id: Optional[str] = None, + conversation_id: Optional[str] = None, ) -> Tuple[List[Message], List[Message]]: """ Prepares in-context messages for an agent, based on the current state and a new user input. + When conversation_id is provided, messages are loaded from the conversation_messages + table instead of agent_state.message_ids. + Args: input_messages (List[MessageCreate]): The new user input messages to process. agent_state (AgentState): The current state of the agent, including message buffer config. message_manager (MessageManager): The manager used to retrieve and create messages. actor (User): The user performing the action, used for access control and attribution. run_id (str): The run ID associated with this message processing. + conversation_id (str): Optional conversation ID to load messages from. Returns: Tuple[List[Message], List[Message]]: A tuple containing: @@ -148,12 +154,74 @@ async def _prepare_in_context_messages_no_persist_async( - The new in-context messages (messages created from the new input). """ - if agent_state.message_buffer_autoclear: - # If autoclear is enabled, only include the most recent system message (usually at index 0) - current_in_context_messages = [await message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)] + if conversation_id: + # Conversation mode: load messages from conversation_messages table + from letta.services.conversation_manager import ConversationManager + + conversation_manager = ConversationManager() + message_ids = await conversation_manager.get_message_ids_for_conversation( + conversation_id=conversation_id, + actor=actor, + ) + + if agent_state.message_buffer_autoclear and message_ids: + # If autoclear is enabled, only include the system message + current_in_context_messages = [await message_manager.get_message_by_id_async(message_id=message_ids[0], actor=actor)] + elif message_ids: + # Otherwise, include the full list of messages from the conversation + current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=message_ids, actor=actor) + else: + # No messages in conversation yet - compile a new system message for this conversation + # Each conversation gets its own system message (captures memory state at conversation start) + from letta.prompts.prompt_generator import PromptGenerator + from letta.services.passage_manager import PassageManager + + num_messages = await message_manager.size_async(actor=actor, agent_id=agent_state.id) + passage_manager = PassageManager() + num_archival_memories = await passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_state.id) + + system_message_str = await PromptGenerator.compile_system_message_async( + system_prompt=agent_state.system, + in_context_memory=agent_state.memory, + in_context_memory_last_edit=get_utc_time(), + timezone=agent_state.timezone, + user_defined_variables=None, + append_icm_if_missing=True, + previous_message_count=num_messages, + archival_memory_size=num_archival_memories, + sources=agent_state.sources, + max_files_open=agent_state.max_files_open, + ) + system_message = Message.dict_to_message( + agent_id=agent_state.id, + model=agent_state.llm_config.model, + openai_message_dict={"role": "system", "content": system_message_str}, + ) + + # Persist the new system message + persisted_messages = await message_manager.create_many_messages_async([system_message], actor=actor) + system_message = persisted_messages[0] + + # Add it to the conversation tracking + await conversation_manager.add_messages_to_conversation( + conversation_id=conversation_id, + agent_id=agent_state.id, + message_ids=[system_message.id], + actor=actor, + starting_position=0, + ) + + current_in_context_messages = [system_message] else: - # Otherwise, include the full list of messages by ID for context - current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor) + # Default mode: load messages from agent_state.message_ids + if agent_state.message_buffer_autoclear: + # If autoclear is enabled, only include the most recent system message (usually at index 0) + current_in_context_messages = [ + await message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor) + ] + else: + # Otherwise, include the full list of messages by ID for context + current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor) # Check for approval-related message validation if input_messages[0].type == "approval": diff --git a/letta/agents/letta_agent_v2.py b/letta/agents/letta_agent_v2.py index 46974cec..e70d1189 100644 --- a/letta/agents/letta_agent_v2.py +++ b/letta/agents/letta_agent_v2.py @@ -254,6 +254,7 @@ class LettaAgentV2(BaseAgentV2): use_assistant_message: bool = True, include_return_message_types: list[MessageType] | None = None, request_start_timestamp_ns: int | None = None, + conversation_id: str | None = None, # Not used in V2, but accepted for API compatibility client_tools: list[ClientToolSchema] | None = None, ) -> AsyncGenerator[str, None]: """ diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 5fb6dcaa..e9c9c729 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -47,6 +47,7 @@ from letta.server.rest_api.utils import ( create_parallel_tool_messages_from_llm_response, create_tool_returns_for_denials, ) +from letta.services.conversation_manager import ConversationManager from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema from letta.services.summarizer.summarizer_all import summarize_all from letta.services.summarizer.summarizer_config import CompactionSettings @@ -80,6 +81,8 @@ class LettaAgentV3(LettaAgentV2): # affecting step-level telemetry. self.context_token_estimate: int | None = None self.in_context_messages: list[Message] = [] # in-memory tracker + # Conversation mode: when set, messages are tracked per-conversation + self.conversation_id: str | None = None # Client-side tools passed in the request (executed by client, not server) self.client_tools: list[ClientToolSchema] = [] @@ -104,6 +107,7 @@ class LettaAgentV3(LettaAgentV2): use_assistant_message: bool = True, # NOTE: not used include_return_message_types: list[MessageType] | None = None, request_start_timestamp_ns: int | None = None, + conversation_id: str | None = None, client_tools: list[ClientToolSchema] | None = None, ) -> LettaResponse: """ @@ -116,6 +120,7 @@ class LettaAgentV3(LettaAgentV2): use_assistant_message: Whether to use assistant message format include_return_message_types: Filter for which message types to return request_start_timestamp_ns: Start time for tracking request duration + conversation_id: Optional conversation ID for conversation-scoped messaging client_tools: Optional list of client-side tools. When called, execution pauses for client to provide tool returns. @@ -123,12 +128,19 @@ class LettaAgentV3(LettaAgentV2): LettaResponse: Complete response with all messages and metadata """ self._initialize_state() + self.conversation_id = conversation_id self.client_tools = client_tools or [] request_span = self._request_checkpoint_start(request_start_timestamp_ns=request_start_timestamp_ns) response_letta_messages = [] + # Prepare in-context messages (conversation mode if conversation_id provided) curr_in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async( - input_messages, self.agent_state, self.message_manager, self.actor, run_id + input_messages, + self.agent_state, + self.message_manager, + self.actor, + run_id, + conversation_id=conversation_id, ) follow_up_messages = [] if len(input_messages_to_persist) > 1 and input_messages_to_persist[0].role == "approval": @@ -241,6 +253,7 @@ class LettaAgentV3(LettaAgentV2): use_assistant_message: bool = True, # NOTE: not used include_return_message_types: list[MessageType] | None = None, request_start_timestamp_ns: int | None = None, + conversation_id: str | None = None, client_tools: list[ClientToolSchema] | None = None, ) -> AsyncGenerator[str, None]: """ @@ -259,6 +272,7 @@ class LettaAgentV3(LettaAgentV2): use_assistant_message: Whether to use assistant message format include_return_message_types: Filter for which message types to return request_start_timestamp_ns: Start time for tracking request duration + conversation_id: Optional conversation ID for conversation-scoped messaging client_tools: Optional list of client-side tools. When called, execution pauses for client to provide tool returns. @@ -266,6 +280,7 @@ class LettaAgentV3(LettaAgentV2): str: JSON-formatted SSE data chunks for each completed step """ self._initialize_state() + self.conversation_id = conversation_id self.client_tools = client_tools or [] request_span = self._request_checkpoint_start(request_start_timestamp_ns=request_start_timestamp_ns) response_letta_messages = [] @@ -284,8 +299,14 @@ class LettaAgentV3(LettaAgentV2): ) try: + # Prepare in-context messages (conversation mode if conversation_id provided) in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async( - input_messages, self.agent_state, self.message_manager, self.actor, run_id + input_messages, + self.agent_state, + self.message_manager, + self.actor, + run_id, + conversation_id=conversation_id, ) follow_up_messages = [] if len(input_messages_to_persist) > 1 and input_messages_to_persist[0].role == "approval": @@ -435,7 +456,7 @@ class LettaAgentV3(LettaAgentV2): This handles: - Persisting the new messages into the `messages` table - Updating the in-memory trackers for in-context messages (`self.in_context_messages`) and agent state (`self.agent_state.message_ids`) - - Updating the DB with the current in-context messages (`self.agent_state.message_ids`) + - Updating the DB with the current in-context messages (`self.agent_state.message_ids`) OR conversation_messages table Args: run_id: The run ID to associate with the messages @@ -457,14 +478,33 @@ class LettaAgentV3(LettaAgentV2): template_id=self.agent_state.template_id, ) - # persist the in-context messages - # TODO: somehow make sure all the message ids are already persisted - await self.agent_manager.update_message_ids_async( - agent_id=self.agent_state.id, - message_ids=[m.id for m in in_context_messages], - actor=self.actor, - ) - self.agent_state.message_ids = [m.id for m in in_context_messages] # update in-memory state + if self.conversation_id: + # Conversation mode: update conversation_messages table + # Add new messages to conversation tracking + new_message_ids = [m.id for m in new_messages] + if new_message_ids: + await ConversationManager().add_messages_to_conversation( + conversation_id=self.conversation_id, + agent_id=self.agent_state.id, + message_ids=new_message_ids, + actor=self.actor, + ) + + # Update which messages are in context + await ConversationManager().update_in_context_messages( + conversation_id=self.conversation_id, + in_context_message_ids=[m.id for m in in_context_messages], + actor=self.actor, + ) + else: + # Default mode: update agent.message_ids + await self.agent_manager.update_message_ids_async( + agent_id=self.agent_state.id, + message_ids=[m.id for m in in_context_messages], + actor=self.actor, + ) + self.agent_state.message_ids = [m.id for m in in_context_messages] # update in-memory state + self.in_context_messages = in_context_messages # update in-memory state @trace_method diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index b9fae451..c44ae3bd 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -6,6 +6,8 @@ from letta.orm.base import Base from letta.orm.block import Block from letta.orm.block_history import BlockHistory from letta.orm.blocks_agents import BlocksAgents +from letta.orm.conversation import Conversation +from letta.orm.conversation_messages import ConversationMessage from letta.orm.file import FileMetadata from letta.orm.files_agents import FileAgent from letta.orm.group import Group diff --git a/letta/orm/agent.py b/letta/orm/agent.py index 1c086cb5..cea425ef 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -27,6 +27,7 @@ from letta.utils import bounded_gather, calculate_file_defaults_based_on_context if TYPE_CHECKING: from letta.orm.agents_tags import AgentsTags from letta.orm.archives_agents import ArchivesAgents + from letta.orm.conversation import Conversation from letta.orm.files_agents import FileAgent from letta.orm.identity import Identity from letta.orm.organization import Organization @@ -186,6 +187,13 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin lazy="noload", doc="Archives accessible by this agent.", ) + conversations: Mapped[List["Conversation"]] = relationship( + "Conversation", + back_populates="agent", + cascade="all, delete-orphan", + lazy="raise", + doc="Conversations for concurrent messaging on this agent.", + ) def _get_per_file_view_window_char_limit(self) -> int: """Get the per_file_view_window_char_limit, calculating defaults if None.""" diff --git a/letta/orm/conversation.py b/letta/orm/conversation.py new file mode 100644 index 00000000..e425256a --- /dev/null +++ b/letta/orm/conversation.py @@ -0,0 +1,49 @@ +import uuid +from typing import TYPE_CHECKING, List, Optional + +from sqlalchemy import ForeignKey, Index, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.mixins import OrganizationMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.conversation import Conversation as PydanticConversation + +if TYPE_CHECKING: + from letta.orm.agent import Agent + from letta.orm.conversation_messages import ConversationMessage + + +class Conversation(SqlalchemyBase, OrganizationMixin): + """Conversations that can be created on an agent for concurrent messaging.""" + + __tablename__ = "conversations" + __pydantic_model__ = PydanticConversation + __table_args__ = ( + Index("ix_conversations_agent_id", "agent_id"), + Index("ix_conversations_org_agent", "organization_id", "agent_id"), + ) + + id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"conv-{uuid.uuid4()}") + agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), nullable=False) + summary: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="Summary of the conversation") + + # Relationships + agent: Mapped["Agent"] = relationship("Agent", back_populates="conversations", lazy="raise") + message_associations: Mapped[List["ConversationMessage"]] = relationship( + "ConversationMessage", + back_populates="conversation", + cascade="all, delete-orphan", + lazy="selectin", + ) + + def to_pydantic(self) -> PydanticConversation: + """Converts the SQLAlchemy model to its Pydantic counterpart.""" + return self.__pydantic_model__( + id=self.id, + agent_id=self.agent_id, + summary=self.summary, + created_at=self.created_at, + updated_at=self.updated_at, + created_by_id=self.created_by_id, + last_updated_by_id=self.last_updated_by_id, + ) diff --git a/letta/orm/conversation_messages.py b/letta/orm/conversation_messages.py new file mode 100644 index 00000000..c92ea500 --- /dev/null +++ b/letta/orm/conversation_messages.py @@ -0,0 +1,73 @@ +import uuid +from typing import TYPE_CHECKING, Optional + +from sqlalchemy import Boolean, ForeignKey, Index, Integer, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.mixins import OrganizationMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase + +if TYPE_CHECKING: + from letta.orm.conversation import Conversation + from letta.orm.message import Message + + +class ConversationMessage(SqlalchemyBase, OrganizationMixin): + """ + Track in-context messages for a conversation. + + This replaces the message_ids JSON list on agents with proper relational modeling. + - conversation_id=NULL represents the "default" conversation (backward compatible) + - conversation_id= represents a named conversation for concurrent messaging + """ + + __tablename__ = "conversation_messages" + __table_args__ = ( + Index("ix_conv_msg_conversation_position", "conversation_id", "position"), + Index("ix_conv_msg_message_id", "message_id"), + Index("ix_conv_msg_agent_id", "agent_id"), + Index("ix_conv_msg_agent_conversation", "agent_id", "conversation_id"), + UniqueConstraint("conversation_id", "message_id", name="unique_conversation_message"), + ) + + id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"conv_msg-{uuid.uuid4()}") + conversation_id: Mapped[Optional[str]] = mapped_column( + String, + ForeignKey("conversations.id", ondelete="CASCADE"), + nullable=True, + doc="NULL for default conversation, otherwise FK to conversation", + ) + agent_id: Mapped[str] = mapped_column( + String, + ForeignKey("agents.id", ondelete="CASCADE"), + nullable=False, + doc="The agent this message association belongs to", + ) + message_id: Mapped[str] = mapped_column( + String, + ForeignKey("messages.id", ondelete="CASCADE"), + nullable=False, + doc="The message being tracked", + ) + position: Mapped[int] = mapped_column( + Integer, + nullable=False, + doc="Position in conversation (for ordering)", + ) + in_context: Mapped[bool] = mapped_column( + Boolean, + default=True, + nullable=False, + doc="Whether message is currently in the agent's context window", + ) + + # Relationships + conversation: Mapped[Optional["Conversation"]] = relationship( + "Conversation", + back_populates="message_associations", + lazy="raise", + ) + message: Mapped["Message"] = relationship( + "Message", + lazy="selectin", + ) diff --git a/letta/orm/message.py b/letta/orm/message.py index dd800f32..578c7120 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -55,6 +55,12 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): nullable=True, doc="The id of the LLMBatchItem that this message is associated with", ) + conversation_id: Mapped[Optional[str]] = mapped_column( + ForeignKey("conversations.id", ondelete="SET NULL"), + nullable=True, + index=True, + doc="The conversation this message belongs to (NULL = default conversation)", + ) is_err: Mapped[Optional[bool]] = mapped_column( nullable=True, doc="Whether this message is part of an error step. Used only for debugging purposes." ) diff --git a/letta/orm/run.py b/letta/orm/run.py index 45d3d956..b2444e54 100644 --- a/letta/orm/run.py +++ b/letta/orm/run.py @@ -30,6 +30,7 @@ class Run(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateMixin): Index("ix_runs_created_at", "created_at", "id"), Index("ix_runs_agent_id", "agent_id"), Index("ix_runs_organization_id", "organization_id"), + Index("ix_runs_conversation_id", "conversation_id"), ) # Generate run ID with run- prefix @@ -50,6 +51,11 @@ class Run(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateMixin): # Agent relationship - A run belongs to one agent agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), nullable=False, doc="The agent that owns this run.") + # Conversation relationship - Optional, a run may be associated with a conversation + conversation_id: Mapped[Optional[str]] = mapped_column( + String, ForeignKey("conversations.id", ondelete="SET NULL"), nullable=True, doc="The conversation this run belongs to." + ) + # Callback related columns callback_url: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="When set, POST to this URL after run completion.") callback_sent_at: Mapped[Optional[datetime]] = mapped_column(nullable=True, doc="Timestamp when the callback was last attempted.") diff --git a/letta/schemas/conversation.py b/letta/schemas/conversation.py new file mode 100644 index 00000000..0027938f --- /dev/null +++ b/letta/schemas/conversation.py @@ -0,0 +1,28 @@ +from typing import List, Optional + +from pydantic import BaseModel, Field + +from letta.schemas.letta_base import OrmMetadataBase + + +class Conversation(OrmMetadataBase): + """Represents a conversation on an agent for concurrent messaging.""" + + __id_prefix__ = "conv" + + id: str = Field(..., description="The unique identifier of the conversation.") + agent_id: str = Field(..., description="The ID of the agent this conversation belongs to.") + summary: Optional[str] = Field(None, description="A summary of the conversation.") + in_context_message_ids: List[str] = Field(default_factory=list, description="The IDs of in-context messages for the conversation.") + + +class CreateConversation(BaseModel): + """Request model for creating a new conversation.""" + + summary: Optional[str] = Field(None, description="A summary of the conversation.") + + +class UpdateConversation(BaseModel): + """Request model for updating a conversation.""" + + summary: Optional[str] = Field(None, description="A summary of the conversation.") diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index e0a697b9..ecd66e18 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -26,6 +26,7 @@ class PrimitiveType(str, Enum): SANDBOX_CONFIG = "sandbox" # Note: sandbox_config IDs use "sandbox" prefix STEP = "step" IDENTITY = "identity" + CONVERSATION = "conv" # Infrastructure types MCP_SERVER = "mcp_server" diff --git a/letta/schemas/run.py b/letta/schemas/run.py index 53d43e46..d72e06ae 100644 --- a/letta/schemas/run.py +++ b/letta/schemas/run.py @@ -27,6 +27,9 @@ class Run(RunBase): # Agent relationship agent_id: str = Field(..., description="The unique identifier of the agent associated with the run.") + # Conversation relationship + conversation_id: Optional[str] = Field(None, description="The unique identifier of the conversation associated with the run.") + # Template fields base_template_id: Optional[str] = Field(None, description="The base template ID that the run belongs to.") diff --git a/letta/server/rest_api/routers/v1/__init__.py b/letta/server/rest_api/routers/v1/__init__.py index 919675f7..c75f715a 100644 --- a/letta/server/rest_api/routers/v1/__init__.py +++ b/letta/server/rest_api/routers/v1/__init__.py @@ -3,6 +3,7 @@ from letta.server.rest_api.routers.v1.anthropic import router as anthropic_route from letta.server.rest_api.routers.v1.archives import router as archives_router from letta.server.rest_api.routers.v1.blocks import router as blocks_router from letta.server.rest_api.routers.v1.chat_completions import router as chat_completions_router, router as openai_chat_completions_router +from letta.server.rest_api.routers.v1.conversations import router as conversations_router from letta.server.rest_api.routers.v1.embeddings import router as embeddings_router from letta.server.rest_api.routers.v1.folders import router as folders_router from letta.server.rest_api.routers.v1.groups import router as groups_router @@ -36,6 +37,7 @@ ROUTERS = [ sources_router, folders_router, agents_router, + conversations_router, chat_completions_router, groups_router, identities_router, diff --git a/letta/server/rest_api/routers/v1/conversations.py b/letta/server/rest_api/routers/v1/conversations.py new file mode 100644 index 00000000..c489f398 --- /dev/null +++ b/letta/server/rest_api/routers/v1/conversations.py @@ -0,0 +1,273 @@ +from datetime import timedelta +from typing import Annotated, List, Optional + +from fastapi import APIRouter, Body, Depends, HTTPException, Query +from pydantic import Field +from starlette.responses import StreamingResponse + +from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client +from letta.errors import LettaExpiredError, LettaInvalidArgumentError +from letta.helpers.datetime_helpers import get_utc_time +from letta.schemas.conversation import Conversation, CreateConversation +from letta.schemas.enums import RunStatus +from letta.schemas.letta_message import LettaMessageUnion +from letta.schemas.letta_request import LettaStreamingRequest, RetrieveStreamRequest +from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse +from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server +from letta.server.rest_api.redis_stream_manager import redis_sse_stream_generator +from letta.server.rest_api.streaming_response import ( + StreamingResponseWithStatusCode, + add_keepalive_to_stream, + cancellation_aware_stream_wrapper, +) +from letta.server.server import SyncServer +from letta.services.conversation_manager import ConversationManager +from letta.services.run_manager import RunManager +from letta.services.streaming_service import StreamingService +from letta.settings import settings +from letta.validators import ConversationId + +router = APIRouter(prefix="/conversations", tags=["conversations"]) + +# Instantiate manager +conversation_manager = ConversationManager() + + +@router.post("/", response_model=Conversation, operation_id="create_conversation") +async def create_conversation( + agent_id: str = Query(..., description="The agent ID to create a conversation for"), + conversation_create: CreateConversation = Body(default_factory=CreateConversation), + server: SyncServer = Depends(get_letta_server), + headers: HeaderParams = Depends(get_headers), +): + """Create a new conversation for an agent.""" + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + return await conversation_manager.create_conversation( + agent_id=agent_id, + conversation_create=conversation_create, + actor=actor, + ) + + +@router.get("/", response_model=List[Conversation], operation_id="list_conversations") +async def list_conversations( + agent_id: str = Query(..., description="The agent ID to list conversations for"), + limit: int = Query(50, description="Maximum number of conversations to return"), + after: Optional[str] = Query(None, description="Cursor for pagination (conversation ID)"), + server: SyncServer = Depends(get_letta_server), + headers: HeaderParams = Depends(get_headers), +): + """List all conversations for an agent.""" + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + return await conversation_manager.list_conversations( + agent_id=agent_id, + actor=actor, + limit=limit, + after=after, + ) + + +@router.get("/{conversation_id}", response_model=Conversation, operation_id="retrieve_conversation") +async def retrieve_conversation( + conversation_id: ConversationId, + server: SyncServer = Depends(get_letta_server), + headers: HeaderParams = Depends(get_headers), +): + """Retrieve a specific conversation.""" + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + return await conversation_manager.get_conversation_by_id( + conversation_id=conversation_id, + actor=actor, + ) + + +ConversationMessagesResponse = Annotated[ + List[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}}) +] + + +@router.get( + "/{conversation_id}/messages", + response_model=ConversationMessagesResponse, + operation_id="list_conversation_messages", +) +async def list_conversation_messages( + conversation_id: ConversationId, + server: SyncServer = Depends(get_letta_server), + headers: HeaderParams = Depends(get_headers), + before: Optional[str] = Query( + None, description="Message ID cursor for pagination. Returns messages that come before this message ID in the conversation" + ), + after: Optional[str] = Query( + None, description="Message ID cursor for pagination. Returns messages that come after this message ID in the conversation" + ), + limit: Optional[int] = Query(100, description="Maximum number of messages to return"), +): + """ + List all messages in a conversation. + + Returns LettaMessage objects (UserMessage, AssistantMessage, etc.) for all + messages in the conversation, ordered by position (oldest first), + with support for cursor-based pagination. + """ + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + return await conversation_manager.list_conversation_messages( + conversation_id=conversation_id, + actor=actor, + limit=limit, + before=before, + after=after, + ) + + +@router.post( + "/{conversation_id}/messages", + response_model=LettaStreamingResponse, + operation_id="send_conversation_message", + responses={ + 200: { + "description": "Successful response", + "content": { + "text/event-stream": {"description": "Server-Sent Events stream"}, + }, + } + }, +) +async def send_conversation_message( + conversation_id: ConversationId, + request: LettaStreamingRequest = Body(...), + server: SyncServer = Depends(get_letta_server), + headers: HeaderParams = Depends(get_headers), +) -> StreamingResponse | LettaResponse: + """ + Send a message to a conversation and get a streaming response. + + This endpoint sends a message to an existing conversation and streams + the agent's response back. + """ + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + + # Get the conversation to find the agent_id + conversation = await conversation_manager.get_conversation_by_id( + conversation_id=conversation_id, + actor=actor, + ) + + # Force streaming mode for this endpoint + request.streaming = True + + # Use streaming service + streaming_service = StreamingService(server) + run, result = await streaming_service.create_agent_stream( + agent_id=conversation.agent_id, + actor=actor, + request=request, + run_type="send_conversation_message", + conversation_id=conversation_id, + ) + + return result + + +@router.post( + "/{conversation_id}/stream", + response_model=None, + operation_id="retrieve_conversation_stream", + responses={ + 200: { + "description": "Successful response", + "content": { + "text/event-stream": { + "description": "Server-Sent Events stream", + "schema": { + "oneOf": [ + {"$ref": "#/components/schemas/SystemMessage"}, + {"$ref": "#/components/schemas/UserMessage"}, + {"$ref": "#/components/schemas/ReasoningMessage"}, + {"$ref": "#/components/schemas/HiddenReasoningMessage"}, + {"$ref": "#/components/schemas/ToolCallMessage"}, + {"$ref": "#/components/schemas/ToolReturnMessage"}, + {"$ref": "#/components/schemas/AssistantMessage"}, + {"$ref": "#/components/schemas/ApprovalRequestMessage"}, + {"$ref": "#/components/schemas/ApprovalResponseMessage"}, + {"$ref": "#/components/schemas/LettaPing"}, + {"$ref": "#/components/schemas/LettaErrorMessage"}, + {"$ref": "#/components/schemas/LettaStopReason"}, + {"$ref": "#/components/schemas/LettaUsageStatistics"}, + ] + }, + }, + }, + } + }, +) +async def retrieve_conversation_stream( + conversation_id: ConversationId, + request: RetrieveStreamRequest = Body(None), + headers: HeaderParams = Depends(get_headers), + server: SyncServer = Depends(get_letta_server), +): + """ + Resume the stream for the most recent active run in a conversation. + + This endpoint allows you to reconnect to an active background stream + for a conversation, enabling recovery from network interruptions. + """ + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + runs_manager = RunManager() + + # Find the most recent active run for this conversation + active_runs = await runs_manager.list_runs( + actor=actor, + conversation_id=conversation_id, + statuses=[RunStatus.created, RunStatus.running], + limit=1, + ascending=False, + ) + + if not active_runs: + raise LettaInvalidArgumentError("No active runs found for this conversation.") + + run = active_runs[0] + + if not run.background: + raise LettaInvalidArgumentError("Run was not created in background mode, so it cannot be retrieved.") + + if run.created_at < get_utc_time() - timedelta(hours=3): + raise LettaExpiredError("Run was created more than 3 hours ago, and is now expired.") + + redis_client = await get_redis_client() + + if isinstance(redis_client, NoopAsyncRedisClient): + raise HTTPException( + status_code=503, + detail=( + "Background streaming requires Redis to be running. " + "Please ensure Redis is properly configured. " + f"LETTA_REDIS_HOST: {settings.redis_host}, LETTA_REDIS_PORT: {settings.redis_port}" + ), + ) + + stream = redis_sse_stream_generator( + redis_client=redis_client, + run_id=run.id, + starting_after=request.starting_after if request else None, + poll_interval=request.poll_interval if request else None, + batch_size=request.batch_size if request else None, + ) + + if settings.enable_cancellation_aware_streaming: + stream = cancellation_aware_stream_wrapper( + stream_generator=stream, + run_manager=server.run_manager, + run_id=run.id, + actor=actor, + ) + + if request and request.include_pings and settings.enable_keepalive: + stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval, run_id=run.id) + + return StreamingResponseWithStatusCode( + stream, + media_type="text/event-stream", + ) diff --git a/letta/services/conversation_manager.py b/letta/services/conversation_manager.py new file mode 100644 index 00000000..dade99d4 --- /dev/null +++ b/letta/services/conversation_manager.py @@ -0,0 +1,357 @@ +from typing import List, Optional + +from sqlalchemy import func, select + +from letta.orm.conversation import Conversation as ConversationModel +from letta.orm.conversation_messages import ConversationMessage as ConversationMessageModel +from letta.orm.errors import NoResultFound +from letta.orm.message import Message as MessageModel +from letta.otel.tracing import trace_method +from letta.schemas.conversation import Conversation as PydanticConversation, CreateConversation, UpdateConversation +from letta.schemas.letta_message import LettaMessage +from letta.schemas.message import Message as PydanticMessage +from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry +from letta.utils import enforce_types + + +class ConversationManager: + """Manager class to handle business logic related to Conversations.""" + + @enforce_types + @trace_method + async def create_conversation( + self, + agent_id: str, + conversation_create: CreateConversation, + actor: PydanticUser, + ) -> PydanticConversation: + """Create a new conversation for an agent.""" + async with db_registry.async_session() as session: + conversation = ConversationModel( + agent_id=agent_id, + summary=conversation_create.summary, + organization_id=actor.organization_id, + ) + await conversation.create_async(session, actor=actor) + return conversation.to_pydantic() + + @enforce_types + @trace_method + async def get_conversation_by_id( + self, + conversation_id: str, + actor: PydanticUser, + ) -> PydanticConversation: + """Retrieve a conversation by its ID, including in-context message IDs.""" + async with db_registry.async_session() as session: + conversation = await ConversationModel.read_async( + db_session=session, + identifier=conversation_id, + actor=actor, + check_is_deleted=True, + ) + + # Get the in-context message IDs for this conversation + message_ids = await self.get_message_ids_for_conversation( + conversation_id=conversation_id, + actor=actor, + ) + + # Build the pydantic model with in_context_message_ids + pydantic_conversation = conversation.to_pydantic() + pydantic_conversation.in_context_message_ids = message_ids + return pydantic_conversation + + @enforce_types + @trace_method + async def list_conversations( + self, + agent_id: str, + actor: PydanticUser, + limit: int = 50, + after: Optional[str] = None, + ) -> List[PydanticConversation]: + """List conversations for an agent with cursor-based pagination.""" + async with db_registry.async_session() as session: + conversations = await ConversationModel.list_async( + db_session=session, + actor=actor, + agent_id=agent_id, + limit=limit, + after=after, + ascending=False, + ) + return [conv.to_pydantic() for conv in conversations] + + @enforce_types + @trace_method + async def update_conversation( + self, + conversation_id: str, + conversation_update: UpdateConversation, + actor: PydanticUser, + ) -> PydanticConversation: + """Update a conversation.""" + async with db_registry.async_session() as session: + conversation = await ConversationModel.read_async( + db_session=session, + identifier=conversation_id, + actor=actor, + ) + + # Set attributes on the model + update_data = conversation_update.model_dump(exclude_none=True) + for key, value in update_data.items(): + setattr(conversation, key, value) + + # Commit the update + updated_conversation = await conversation.update_async( + db_session=session, + actor=actor, + ) + return updated_conversation.to_pydantic() + + @enforce_types + @trace_method + async def delete_conversation( + self, + conversation_id: str, + actor: PydanticUser, + ) -> None: + """Soft delete a conversation.""" + async with db_registry.async_session() as session: + conversation = await ConversationModel.read_async( + db_session=session, + identifier=conversation_id, + actor=actor, + ) + # Soft delete by setting is_deleted flag + conversation.is_deleted = True + await conversation.update_async(db_session=session, actor=actor) + + # ==================== Message Management Methods ==================== + + @enforce_types + @trace_method + async def get_message_ids_for_conversation( + self, + conversation_id: str, + actor: PydanticUser, + ) -> List[str]: + """ + Get ordered message IDs for a conversation. + + Returns message IDs ordered by position in the conversation. + Only returns messages that are currently in_context. + """ + async with db_registry.async_session() as session: + query = ( + select(ConversationMessageModel.message_id) + .where( + ConversationMessageModel.conversation_id == conversation_id, + ConversationMessageModel.organization_id == actor.organization_id, + ConversationMessageModel.in_context == True, + ConversationMessageModel.is_deleted == False, + ) + .order_by(ConversationMessageModel.position) + ) + result = await session.execute(query) + return list(result.scalars().all()) + + @enforce_types + @trace_method + async def get_messages_for_conversation( + self, + conversation_id: str, + actor: PydanticUser, + ) -> List[PydanticMessage]: + """ + Get ordered Message objects for a conversation. + + Returns full Message objects ordered by position in the conversation. + Only returns messages that are currently in_context. + """ + async with db_registry.async_session() as session: + query = ( + select(MessageModel) + .join( + ConversationMessageModel, + MessageModel.id == ConversationMessageModel.message_id, + ) + .where( + ConversationMessageModel.conversation_id == conversation_id, + ConversationMessageModel.organization_id == actor.organization_id, + ConversationMessageModel.in_context == True, + ConversationMessageModel.is_deleted == False, + ) + .order_by(ConversationMessageModel.position) + ) + result = await session.execute(query) + return [msg.to_pydantic() for msg in result.scalars().all()] + + @enforce_types + @trace_method + async def add_messages_to_conversation( + self, + conversation_id: str, + agent_id: str, + message_ids: List[str], + actor: PydanticUser, + starting_position: Optional[int] = None, + ) -> None: + """ + Add messages to a conversation's tracking table. + + Creates ConversationMessage entries with auto-incrementing positions. + + Args: + conversation_id: The conversation to add messages to + agent_id: The agent ID + message_ids: List of message IDs to add + actor: The user performing the action + starting_position: Optional starting position (defaults to next available) + """ + if not message_ids: + return + + async with db_registry.async_session() as session: + # Get starting position if not provided + if starting_position is None: + query = select(func.coalesce(func.max(ConversationMessageModel.position), -1)).where( + ConversationMessageModel.conversation_id == conversation_id, + ConversationMessageModel.organization_id == actor.organization_id, + ) + result = await session.execute(query) + max_position = result.scalar() + # Use explicit None check instead of `or` to handle position=0 correctly + if max_position is None: + max_position = -1 + starting_position = max_position + 1 + + # Create ConversationMessage entries + for i, message_id in enumerate(message_ids): + conv_msg = ConversationMessageModel( + conversation_id=conversation_id, + agent_id=agent_id, + message_id=message_id, + position=starting_position + i, + in_context=True, + organization_id=actor.organization_id, + ) + session.add(conv_msg) + + await session.commit() + + @enforce_types + @trace_method + async def update_in_context_messages( + self, + conversation_id: str, + in_context_message_ids: List[str], + actor: PydanticUser, + ) -> None: + """ + Update which messages are in context for a conversation. + + Sets in_context=True for messages in the list, False for others. + + Args: + conversation_id: The conversation to update + in_context_message_ids: List of message IDs that should be in context + actor: The user performing the action + """ + async with db_registry.async_session() as session: + # Get all conversation messages for this conversation + query = select(ConversationMessageModel).where( + ConversationMessageModel.conversation_id == conversation_id, + ConversationMessageModel.organization_id == actor.organization_id, + ConversationMessageModel.is_deleted == False, + ) + result = await session.execute(query) + conv_messages = result.scalars().all() + + # Update in_context status + in_context_set = set(in_context_message_ids) + for conv_msg in conv_messages: + conv_msg.in_context = conv_msg.message_id in in_context_set + + await session.commit() + + @enforce_types + @trace_method + async def list_conversation_messages( + self, + conversation_id: str, + actor: PydanticUser, + limit: Optional[int] = 100, + before: Optional[str] = None, + after: Optional[str] = None, + ) -> List[LettaMessage]: + """ + List all messages in a conversation with pagination support. + + Unlike get_messages_for_conversation, this returns ALL messages + (not just in_context) and supports cursor-based pagination. + Messages are always ordered by position (oldest first). + + Args: + conversation_id: The conversation to list messages for + actor: The user performing the action + limit: Maximum number of messages to return + before: Return messages before this message ID + after: Return messages after this message ID + + Returns: + List of LettaMessage objects + """ + async with db_registry.async_session() as session: + # Build base query joining Message with ConversationMessage + query = ( + select(MessageModel) + .join( + ConversationMessageModel, + MessageModel.id == ConversationMessageModel.message_id, + ) + .where( + ConversationMessageModel.conversation_id == conversation_id, + ConversationMessageModel.organization_id == actor.organization_id, + ConversationMessageModel.is_deleted == False, + ) + ) + + # Handle cursor-based pagination + if before: + # Get the position of the cursor message + cursor_query = select(ConversationMessageModel.position).where( + ConversationMessageModel.conversation_id == conversation_id, + ConversationMessageModel.message_id == before, + ) + cursor_result = await session.execute(cursor_query) + cursor_position = cursor_result.scalar_one_or_none() + if cursor_position is not None: + query = query.where(ConversationMessageModel.position < cursor_position) + + if after: + # Get the position of the cursor message + cursor_query = select(ConversationMessageModel.position).where( + ConversationMessageModel.conversation_id == conversation_id, + ConversationMessageModel.message_id == after, + ) + cursor_result = await session.execute(cursor_query) + cursor_position = cursor_result.scalar_one_or_none() + if cursor_position is not None: + query = query.where(ConversationMessageModel.position > cursor_position) + + # Order by position (oldest first) + query = query.order_by(ConversationMessageModel.position.asc()) + + # Apply limit + if limit is not None: + query = query.limit(limit) + + result = await session.execute(query) + messages = [msg.to_pydantic() for msg in result.scalars().all()] + + # Convert to LettaMessages + return PydanticMessage.to_letta_messages_from_list(messages, reverse=False, text_is_assistant_message=True) diff --git a/letta/services/run_manager.py b/letta/services/run_manager.py index d76e4f2f..bea5ca3f 100644 --- a/letta/services/run_manager.py +++ b/letta/services/run_manager.py @@ -151,6 +151,7 @@ class RunManager: step_count_operator: ComparisonOperator = ComparisonOperator.EQ, tools_used: Optional[List[str]] = None, project_id: Optional[str] = None, + conversation_id: Optional[str] = None, order_by: Literal["created_at", "duration"] = "created_at", duration_percentile: Optional[int] = None, duration_filter: Optional[dict] = None, @@ -190,6 +191,10 @@ class RunManager: if background is not None: query = query.filter(RunModel.background == background) + # Filter by conversation_id + if conversation_id is not None: + query = query.filter(RunModel.conversation_id == conversation_id) + # Filter by template_family (base_template_id) if template_family: query = query.filter(RunModel.base_template_id == template_family) diff --git a/letta/services/streaming_service.py b/letta/services/streaming_service.py index f39f30a1..41593e3e 100644 --- a/letta/services/streaming_service.py +++ b/letta/services/streaming_service.py @@ -72,6 +72,7 @@ class StreamingService: actor: User, request: LettaStreamingRequest, run_type: str = "streaming", + conversation_id: Optional[str] = None, ) -> tuple[Optional[PydanticRun], Union[StreamingResponse, LettaResponse]]: """ Create a streaming response for an agent. @@ -81,6 +82,7 @@ class StreamingService: actor: The user making the request request: The LettaStreamingRequest containing all request parameters run_type: Type of run for tracking + conversation_id: Optional conversation ID for conversation-scoped messaging Returns: Tuple of (run object or None, streaming response) @@ -104,7 +106,7 @@ class StreamingService: run = None run_update_metadata = None if settings.track_agent_run: - run = await self._create_run(agent_id, request, run_type, actor) + run = await self._create_run(agent_id, request, run_type, actor, conversation_id=conversation_id) await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id if run else None) try: @@ -123,6 +125,7 @@ class StreamingService: request_start_timestamp_ns=request_start_timestamp_ns, include_return_message_types=request.include_return_message_types, actor=actor, + conversation_id=conversation_id, client_tools=request.client_tools, ) @@ -288,6 +291,7 @@ class StreamingService: request_start_timestamp_ns: int, include_return_message_types: Optional[list[MessageType]], actor: User, + conversation_id: Optional[str] = None, client_tools: Optional[list[ClientToolSchema]] = None, ) -> AsyncIterator: """ @@ -315,6 +319,7 @@ class StreamingService: use_assistant_message=use_assistant_message, request_start_timestamp_ns=request_start_timestamp_ns, include_return_message_types=include_return_message_types, + conversation_id=conversation_id, client_tools=client_tools, ) @@ -477,11 +482,14 @@ class StreamingService: ] return base_compatible or google_letta_v1 - async def _create_run(self, agent_id: str, request: LettaStreamingRequest, run_type: str, actor: User) -> PydanticRun: + async def _create_run( + self, agent_id: str, request: LettaStreamingRequest, run_type: str, actor: User, conversation_id: Optional[str] = None + ) -> PydanticRun: """Create a run for tracking execution.""" run = await self.runs_manager.create_run( pydantic_run=PydanticRun( agent_id=agent_id, + conversation_id=conversation_id, background=request.background or False, metadata={ "run_type": run_type, diff --git a/letta/validators.py b/letta/validators.py index f4e489ea..4b2f16ee 100644 --- a/letta/validators.py +++ b/letta/validators.py @@ -62,6 +62,7 @@ ProviderId = Annotated[str, PATH_VALIDATORS[PrimitiveType.PROVIDER.value]()] SandboxConfigId = Annotated[str, PATH_VALIDATORS[PrimitiveType.SANDBOX_CONFIG.value]()] StepId = Annotated[str, PATH_VALIDATORS[PrimitiveType.STEP.value]()] IdentityId = Annotated[str, PATH_VALIDATORS[PrimitiveType.IDENTITY.value]()] +ConversationId = Annotated[str, PATH_VALIDATORS[PrimitiveType.CONVERSATION.value]()] # Infrastructure types McpServerId = Annotated[str, PATH_VALIDATORS[PrimitiveType.MCP_SERVER.value]()] diff --git a/tests/integration_test_conversations_sdk.py b/tests/integration_test_conversations_sdk.py new file mode 100644 index 00000000..c512880f --- /dev/null +++ b/tests/integration_test_conversations_sdk.py @@ -0,0 +1,271 @@ +""" +Integration tests for the Conversations API using the SDK. +""" + +import uuid + +import pytest +from letta_client import Letta + + +@pytest.fixture +def client(server_url: str) -> Letta: + """Create a Letta client.""" + return Letta(base_url=server_url) + + +@pytest.fixture +def agent(client: Letta): + """Create a test agent.""" + agent_state = client.agents.create( + name=f"test_conversations_{uuid.uuid4().hex[:8]}", + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + memory_blocks=[ + {"label": "human", "value": "Test user"}, + {"label": "persona", "value": "You are a helpful assistant."}, + ], + ) + yield agent_state + # Cleanup + client.agents.delete(agent_id=agent_state.id) + + +class TestConversationsSDK: + """Test conversations using the SDK client.""" + + def test_create_conversation(self, client: Letta, agent): + """Test creating a conversation for an agent.""" + conversation = client.conversations.create(agent_id=agent.id) + + assert conversation.id is not None + assert conversation.id.startswith("conv-") + assert conversation.agent_id == agent.id + + def test_list_conversations(self, client: Letta, agent): + """Test listing conversations for an agent.""" + # Create multiple conversations + conv1 = client.conversations.create(agent_id=agent.id) + conv2 = client.conversations.create(agent_id=agent.id) + + # List conversations + conversations = client.conversations.list(agent_id=agent.id) + + assert len(conversations) >= 2 + conv_ids = [c.id for c in conversations] + assert conv1.id in conv_ids + assert conv2.id in conv_ids + + def test_retrieve_conversation(self, client: Letta, agent): + """Test retrieving a specific conversation.""" + # Create a conversation + created = client.conversations.create(agent_id=agent.id) + + # Retrieve it (should have empty in_context_message_ids initially) + retrieved = client.conversations.retrieve(conversation_id=created.id) + + assert retrieved.id == created.id + assert retrieved.agent_id == created.agent_id + assert retrieved.in_context_message_ids == [] + + # Send a message to the conversation + list( + client.conversations.messages.create( + conversation_id=created.id, + messages=[{"role": "user", "content": "Hello!"}], + ) + ) + + # Retrieve again and check in_context_message_ids is populated + retrieved_with_messages = client.conversations.retrieve(conversation_id=created.id) + + # System message + user + assistant messages should be in the conversation + assert len(retrieved_with_messages.in_context_message_ids) >= 3 # system + user + assistant + # All IDs should be strings starting with "message-" + for msg_id in retrieved_with_messages.in_context_message_ids: + assert isinstance(msg_id, str) + assert msg_id.startswith("message-") + + # Verify message ordering by listing messages + messages = client.conversations.messages.list(conversation_id=created.id) + assert len(messages) >= 3 # system + user + assistant + # First message should be system message (shared across conversations) + assert messages[0].message_type == "system_message", f"First message should be system_message, got {messages[0].message_type}" + # Second message should be user message + assert messages[1].message_type == "user_message", f"Second message should be user_message, got {messages[1].message_type}" + + def test_send_message_to_conversation(self, client: Letta, agent): + """Test sending a message to a conversation.""" + # Create a conversation + conversation = client.conversations.create(agent_id=agent.id) + + # Send a message (returns a stream) + stream = client.conversations.messages.create( + conversation_id=conversation.id, + messages=[{"role": "user", "content": "Hello, how are you?"}], + ) + + # Consume the stream to get messages + messages = list(stream) + + # Check response contains messages + assert len(messages) > 0 + # Should have at least an assistant message + message_types = [m.message_type for m in messages if hasattr(m, "message_type")] + assert "assistant_message" in message_types + + def test_list_conversation_messages(self, client: Letta, agent): + """Test listing messages from a conversation.""" + # Create a conversation + conversation = client.conversations.create(agent_id=agent.id) + + # Send a message to create some history (consume the stream) + stream = client.conversations.messages.create( + conversation_id=conversation.id, + messages=[{"role": "user", "content": "Say 'test response' back to me."}], + ) + list(stream) # Consume stream + + # List messages + messages = client.conversations.messages.list(conversation_id=conversation.id) + + assert len(messages) >= 2 # At least user + assistant + message_types = [m.message_type for m in messages] + assert "user_message" in message_types + assert "assistant_message" in message_types + + # Send another message and check that old and new messages are both listed + first_message_count = len(messages) + stream = client.conversations.messages.create( + conversation_id=conversation.id, + messages=[{"role": "user", "content": "This is a follow-up message."}], + ) + list(stream) # Consume stream + + # List messages again + updated_messages = client.conversations.messages.list(conversation_id=conversation.id) + + # Should have more messages now (at least 2 more: user + assistant) + assert len(updated_messages) >= first_message_count + 2 + + def test_conversation_isolation(self, client: Letta, agent): + """Test that conversations are isolated from each other.""" + # Create two conversations + conv1 = client.conversations.create(agent_id=agent.id) + conv2 = client.conversations.create(agent_id=agent.id) + + # Send different messages to each (consume streams) + list( + client.conversations.messages.create( + conversation_id=conv1.id, + messages=[{"role": "user", "content": "Remember the word: APPLE"}], + ) + ) + list( + client.conversations.messages.create( + conversation_id=conv2.id, + messages=[{"role": "user", "content": "Remember the word: BANANA"}], + ) + ) + + # List messages from each conversation + conv1_messages = client.conversations.messages.list(conversation_id=conv1.id) + conv2_messages = client.conversations.messages.list(conversation_id=conv2.id) + + # Check messages are separate + conv1_content = " ".join([m.content for m in conv1_messages if hasattr(m, "content") and m.content]) + conv2_content = " ".join([m.content for m in conv2_messages if hasattr(m, "content") and m.content]) + + assert "APPLE" in conv1_content + assert "BANANA" in conv2_content + # Each conversation should only have its own word + assert "BANANA" not in conv1_content or "APPLE" not in conv2_content + + # Ask what word was remembered and make sure it's different for each conversation + conv1_recall = list( + client.conversations.messages.create( + conversation_id=conv1.id, + messages=[{"role": "user", "content": "What word did I ask you to remember? Reply with just the word."}], + ) + ) + conv2_recall = list( + client.conversations.messages.create( + conversation_id=conv2.id, + messages=[{"role": "user", "content": "What word did I ask you to remember? Reply with just the word."}], + ) + ) + + # Get the assistant responses + conv1_response = " ".join([m.content for m in conv1_recall if hasattr(m, "message_type") and m.message_type == "assistant_message"]) + conv2_response = " ".join([m.content for m in conv2_recall if hasattr(m, "message_type") and m.message_type == "assistant_message"]) + + assert "APPLE" in conv1_response.upper(), f"Conv1 should remember APPLE, got: {conv1_response}" + assert "BANANA" in conv2_response.upper(), f"Conv2 should remember BANANA, got: {conv2_response}" + + # Each conversation has its own system message (created on first message) + conv1_system_id = conv1_messages[0].id + conv2_system_id = conv2_messages[0].id + assert conv1_system_id != conv2_system_id, "System messages should have different IDs for different conversations" + + def test_conversation_messages_pagination(self, client: Letta, agent): + """Test pagination when listing conversation messages.""" + # Create a conversation + conversation = client.conversations.create(agent_id=agent.id) + + # Send multiple messages to create history (consume streams) + for i in range(3): + list( + client.conversations.messages.create( + conversation_id=conversation.id, + messages=[{"role": "user", "content": f"Message number {i}"}], + ) + ) + + # List with limit + messages = client.conversations.messages.list( + conversation_id=conversation.id, + limit=2, + ) + + # Should respect the limit + assert len(messages) <= 2 + + def test_retrieve_conversation_stream_no_active_run(self, client: Letta, agent): + """Test that retrieve_conversation_stream returns error when no active run exists.""" + from letta_client import BadRequestError + + # Create a conversation + conversation = client.conversations.create(agent_id=agent.id) + + # Try to retrieve stream when no run exists (should fail) + with pytest.raises(BadRequestError) as exc_info: + # Use the SDK's stream method + stream = client.conversations.messages.stream(conversation_id=conversation.id) + list(stream) # Consume the stream to trigger the error + + # Should return 400 because no active run exists + assert "No active runs found" in str(exc_info.value) + + def test_retrieve_conversation_stream_after_completed_run(self, client: Letta, agent): + """Test that retrieve_conversation_stream returns error when run is completed.""" + from letta_client import BadRequestError + + # Create a conversation + conversation = client.conversations.create(agent_id=agent.id) + + # Send a message (this creates a run that completes) + list( + client.conversations.messages.create( + conversation_id=conversation.id, + messages=[{"role": "user", "content": "Hello"}], + ) + ) + + # Try to retrieve stream after the run has completed (should fail) + with pytest.raises(BadRequestError) as exc_info: + # Use the SDK's stream method + stream = client.conversations.messages.stream(conversation_id=conversation.id) + list(stream) # Consume the stream to trigger the error + + # Should return 400 because no active run exists (run is completed) + assert "No active runs found" in str(exc_info.value) diff --git a/tests/integration_test_send_message_v2.py b/tests/integration_test_send_message_v2.py index d5e47ece..249ecaf0 100644 --- a/tests/integration_test_send_message_v2.py +++ b/tests/integration_test_send_message_v2.py @@ -917,6 +917,92 @@ async def test_tool_call( assert run.status == ("cancelled" if cancellation == "with_cancellation" else "completed") +@pytest.mark.parametrize( + "model_config", + TESTED_MODEL_CONFIGS, + ids=[handle for handle, _ in TESTED_MODEL_CONFIGS], +) +@pytest.mark.asyncio(loop_scope="function") +async def test_conversation_streaming_raw_http( + disable_e2b_api_key: Any, + client: AsyncLetta, + server_url: str, + agent_state: AgentState, + model_config: Tuple[str, dict], +) -> None: + """ + Test conversation-based streaming functionality using raw HTTP requests. + + This test verifies that: + 1. A conversation can be created for an agent + 2. Messages can be sent to the conversation via streaming + 3. The streaming response contains the expected message types + 4. Messages are properly persisted in the conversation + + Uses raw HTTP requests instead of SDK until SDK is regenerated with conversations support. + """ + import httpx + + model_handle, model_settings = model_config + agent_state = await client.agents.update(agent_id=agent_state.id, model=model_handle, model_settings=model_settings) + + async with httpx.AsyncClient(base_url=server_url, timeout=60.0) as http_client: + # Create a conversation for the agent + create_response = await http_client.post( + "/v1/conversations/", + params={"agent_id": agent_state.id}, + json={}, + ) + assert create_response.status_code == 200, f"Failed to create conversation: {create_response.text}" + conversation = create_response.json() + assert conversation["id"] is not None + assert conversation["agent_id"] == agent_state.id + + # Send a message to the conversation using streaming + stream_response = await http_client.post( + f"/v1/conversations/{conversation['id']}/messages", + json={ + "messages": [{"role": "user", "content": f"Reply with the message '{USER_MESSAGE_RESPONSE}'."}], + "stream_tokens": True, + }, + ) + assert stream_response.status_code == 200, f"Failed to send message: {stream_response.text}" + + # Parse SSE response and accumulate messages + messages = await accumulate_chunks(stream_response.text) + print("MESSAGES:", messages) + + # Verify the response contains expected message types + assert_greeting_response(messages, model_handle, model_settings, streaming=True, token_streaming=True) + + # Verify the conversation can be retrieved + retrieve_response = await http_client.get(f"/v1/conversations/{conversation['id']}") + assert retrieve_response.status_code == 200, f"Failed to retrieve conversation: {retrieve_response.text}" + retrieved_conversation = retrieve_response.json() + assert retrieved_conversation["id"] == conversation["id"] + print("RETRIEVED CONVERSATION:", retrieved_conversation) + + # Verify conversations can be listed for the agent + list_response = await http_client.get("/v1/conversations/", params={"agent_id": agent_state.id}) + assert list_response.status_code == 200, f"Failed to list conversations: {list_response.text}" + conversations_list = list_response.json() + assert any(c["id"] == conversation["id"] for c in conversations_list) + + # Verify messages can be listed from the conversation + messages_response = await http_client.get(f"/v1/conversations/{conversation['id']}/messages") + assert messages_response.status_code == 200, f"Failed to list conversation messages: {messages_response.text}" + conversation_messages = messages_response.json() + print("CONVERSATION MESSAGES:", conversation_messages) + + # Verify we have at least the user message and assistant message + assert len(conversation_messages) >= 2, f"Expected at least 2 messages, got {len(conversation_messages)}" + + # Check message types are present + message_types = [msg.get("message_type") for msg in conversation_messages] + assert "user_message" in message_types, f"Expected user_message in {message_types}" + assert "assistant_message" in message_types, f"Expected assistant_message in {message_types}" + + @pytest.mark.parametrize( "model_handle,provider_type", [ diff --git a/tests/managers/test_conversation_manager.py b/tests/managers/test_conversation_manager.py new file mode 100644 index 00000000..4fc3d4e4 --- /dev/null +++ b/tests/managers/test_conversation_manager.py @@ -0,0 +1,506 @@ +""" +Tests for ConversationManager. +""" + +import pytest + +from letta.orm.errors import NoResultFound +from letta.schemas.conversation import CreateConversation, UpdateConversation +from letta.server.server import SyncServer +from letta.services.conversation_manager import ConversationManager + +# ====================================================================================================================== +# ConversationManager Tests +# ====================================================================================================================== + + +@pytest.fixture +def conversation_manager(): + """Create a ConversationManager instance.""" + return ConversationManager() + + +@pytest.mark.asyncio +async def test_create_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test creating a conversation.""" + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Test conversation"), + actor=default_user, + ) + + assert conversation.id is not None + assert conversation.agent_id == sarah_agent.id + assert conversation.summary == "Test conversation" + assert conversation.id.startswith("conv-") + + +@pytest.mark.asyncio +async def test_create_conversation_no_summary(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test creating a conversation without summary.""" + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(), + actor=default_user, + ) + + assert conversation.id is not None + assert conversation.agent_id == sarah_agent.id + assert conversation.summary is None + + +@pytest.mark.asyncio +async def test_get_conversation_by_id(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test retrieving a conversation by ID.""" + # Create a conversation + created = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Test"), + actor=default_user, + ) + + # Retrieve it + retrieved = await conversation_manager.get_conversation_by_id( + conversation_id=created.id, + actor=default_user, + ) + + assert retrieved.id == created.id + assert retrieved.agent_id == created.agent_id + assert retrieved.summary == created.summary + + +@pytest.mark.asyncio +async def test_get_conversation_not_found(conversation_manager, server: SyncServer, default_user): + """Test retrieving a non-existent conversation raises error.""" + with pytest.raises(NoResultFound): + await conversation_manager.get_conversation_by_id( + conversation_id="conv-nonexistent", + actor=default_user, + ) + + +@pytest.mark.asyncio +async def test_list_conversations(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test listing conversations for an agent.""" + # Create multiple conversations + for i in range(3): + await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary=f"Conversation {i}"), + actor=default_user, + ) + + # List them + conversations = await conversation_manager.list_conversations( + agent_id=sarah_agent.id, + actor=default_user, + ) + + assert len(conversations) == 3 + + +@pytest.mark.asyncio +async def test_list_conversations_with_limit(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test listing conversations with a limit.""" + # Create multiple conversations + for i in range(5): + await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary=f"Conversation {i}"), + actor=default_user, + ) + + # List with limit + conversations = await conversation_manager.list_conversations( + agent_id=sarah_agent.id, + actor=default_user, + limit=2, + ) + + assert len(conversations) == 2 + + +@pytest.mark.asyncio +async def test_update_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test updating a conversation.""" + # Create a conversation + created = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Original"), + actor=default_user, + ) + + # Update it + updated = await conversation_manager.update_conversation( + conversation_id=created.id, + conversation_update=UpdateConversation(summary="Updated summary"), + actor=default_user, + ) + + assert updated.id == created.id + assert updated.summary == "Updated summary" + + +@pytest.mark.asyncio +async def test_delete_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test soft deleting a conversation.""" + # Create a conversation + created = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="To delete"), + actor=default_user, + ) + + # Delete it + await conversation_manager.delete_conversation( + conversation_id=created.id, + actor=default_user, + ) + + # Verify it's no longer accessible + with pytest.raises(NoResultFound): + await conversation_manager.get_conversation_by_id( + conversation_id=created.id, + actor=default_user, + ) + + +@pytest.mark.asyncio +async def test_conversation_isolation_by_agent(conversation_manager, server: SyncServer, sarah_agent, charles_agent, default_user): + """Test that conversations are isolated by agent.""" + # Create conversation for sarah_agent + await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Sarah's conversation"), + actor=default_user, + ) + + # Create conversation for charles_agent + await conversation_manager.create_conversation( + agent_id=charles_agent.id, + conversation_create=CreateConversation(summary="Charles's conversation"), + actor=default_user, + ) + + # List for sarah_agent + sarah_convos = await conversation_manager.list_conversations( + agent_id=sarah_agent.id, + actor=default_user, + ) + assert len(sarah_convos) == 1 + assert sarah_convos[0].summary == "Sarah's conversation" + + # List for charles_agent + charles_convos = await conversation_manager.list_conversations( + agent_id=charles_agent.id, + actor=default_user, + ) + assert len(charles_convos) == 1 + assert charles_convos[0].summary == "Charles's conversation" + + +@pytest.mark.asyncio +async def test_conversation_isolation_by_organization( + conversation_manager, server: SyncServer, sarah_agent, default_user, other_user_different_org +): + """Test that conversations are isolated by organization.""" + # Create conversation + created = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Test"), + actor=default_user, + ) + + # Other org user should not be able to access it + with pytest.raises(NoResultFound): + await conversation_manager.get_conversation_by_id( + conversation_id=created.id, + actor=other_user_different_org, + ) + + +# ====================================================================================================================== +# Conversation Message Management Tests +# ====================================================================================================================== + + +@pytest.mark.asyncio +async def test_add_messages_to_conversation( + conversation_manager, server: SyncServer, sarah_agent, default_user, hello_world_message_fixture +): + """Test adding messages to a conversation.""" + # Create a conversation + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Test"), + actor=default_user, + ) + + # Add the message to the conversation + await conversation_manager.add_messages_to_conversation( + conversation_id=conversation.id, + agent_id=sarah_agent.id, + message_ids=[hello_world_message_fixture.id], + actor=default_user, + ) + + # Verify message is in conversation + message_ids = await conversation_manager.get_message_ids_for_conversation( + conversation_id=conversation.id, + actor=default_user, + ) + + assert len(message_ids) == 1 + assert message_ids[0] == hello_world_message_fixture.id + + +@pytest.mark.asyncio +async def test_get_messages_for_conversation( + conversation_manager, server: SyncServer, sarah_agent, default_user, hello_world_message_fixture +): + """Test getting full message objects from a conversation.""" + # Create a conversation + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Test"), + actor=default_user, + ) + + # Add the message + await conversation_manager.add_messages_to_conversation( + conversation_id=conversation.id, + agent_id=sarah_agent.id, + message_ids=[hello_world_message_fixture.id], + actor=default_user, + ) + + # Get full messages + messages = await conversation_manager.get_messages_for_conversation( + conversation_id=conversation.id, + actor=default_user, + ) + + assert len(messages) == 1 + assert messages[0].id == hello_world_message_fixture.id + + +@pytest.mark.asyncio +async def test_message_ordering_in_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test that messages maintain their order in a conversation.""" + from letta.schemas.letta_message_content import TextContent + from letta.schemas.message import Message as PydanticMessage + + # Create a conversation + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Test"), + actor=default_user, + ) + + # Create multiple messages + pydantic_messages = [ + PydanticMessage( + agent_id=sarah_agent.id, + role="user", + content=[TextContent(text=f"Message {i}")], + ) + for i in range(3) + ] + messages = await server.message_manager.create_many_messages_async( + pydantic_messages, + actor=default_user, + ) + + # Add messages in order + await conversation_manager.add_messages_to_conversation( + conversation_id=conversation.id, + agent_id=sarah_agent.id, + message_ids=[m.id for m in messages], + actor=default_user, + ) + + # Verify order is maintained + retrieved_ids = await conversation_manager.get_message_ids_for_conversation( + conversation_id=conversation.id, + actor=default_user, + ) + + assert retrieved_ids == [m.id for m in messages] + + +@pytest.mark.asyncio +async def test_update_in_context_messages(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test updating which messages are in context.""" + from letta.schemas.letta_message_content import TextContent + from letta.schemas.message import Message as PydanticMessage + + # Create a conversation + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Test"), + actor=default_user, + ) + + # Create messages + pydantic_messages = [ + PydanticMessage( + agent_id=sarah_agent.id, + role="user", + content=[TextContent(text=f"Message {i}")], + ) + for i in range(3) + ] + messages = await server.message_manager.create_many_messages_async( + pydantic_messages, + actor=default_user, + ) + + # Add all messages + await conversation_manager.add_messages_to_conversation( + conversation_id=conversation.id, + agent_id=sarah_agent.id, + message_ids=[m.id for m in messages], + actor=default_user, + ) + + # Update to only keep first and last in context + await conversation_manager.update_in_context_messages( + conversation_id=conversation.id, + in_context_message_ids=[messages[0].id, messages[2].id], + actor=default_user, + ) + + # Verify only the selected messages are in context + in_context_ids = await conversation_manager.get_message_ids_for_conversation( + conversation_id=conversation.id, + actor=default_user, + ) + + assert len(in_context_ids) == 2 + assert messages[0].id in in_context_ids + assert messages[2].id in in_context_ids + assert messages[1].id not in in_context_ids + + +@pytest.mark.asyncio +async def test_empty_conversation_message_ids(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test getting message IDs from an empty conversation.""" + # Create a conversation + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Empty"), + actor=default_user, + ) + + # Get message IDs (should be empty) + message_ids = await conversation_manager.get_message_ids_for_conversation( + conversation_id=conversation.id, + actor=default_user, + ) + + assert message_ids == [] + + +@pytest.mark.asyncio +async def test_list_conversation_messages(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test listing messages from a conversation as LettaMessages.""" + from letta.schemas.letta_message_content import TextContent + from letta.schemas.message import Message as PydanticMessage + + # Create a conversation + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Test"), + actor=default_user, + ) + + # Create messages with different roles + pydantic_messages = [ + PydanticMessage( + agent_id=sarah_agent.id, + role="user", + content=[TextContent(text="Hello!")], + ), + PydanticMessage( + agent_id=sarah_agent.id, + role="assistant", + content=[TextContent(text="Hi there!")], + ), + ] + messages = await server.message_manager.create_many_messages_async( + pydantic_messages, + actor=default_user, + ) + + # Add messages to conversation + await conversation_manager.add_messages_to_conversation( + conversation_id=conversation.id, + agent_id=sarah_agent.id, + message_ids=[m.id for m in messages], + actor=default_user, + ) + + # List conversation messages (returns LettaMessages) + letta_messages = await conversation_manager.list_conversation_messages( + conversation_id=conversation.id, + actor=default_user, + ) + + assert len(letta_messages) == 2 + # Check message types + message_types = [m.message_type for m in letta_messages] + assert "user_message" in message_types + assert "assistant_message" in message_types + + +@pytest.mark.asyncio +async def test_list_conversation_messages_pagination(conversation_manager, server: SyncServer, sarah_agent, default_user): + """Test pagination when listing conversation messages.""" + from letta.schemas.letta_message_content import TextContent + from letta.schemas.message import Message as PydanticMessage + + # Create a conversation + conversation = await conversation_manager.create_conversation( + agent_id=sarah_agent.id, + conversation_create=CreateConversation(summary="Test"), + actor=default_user, + ) + + # Create multiple messages + pydantic_messages = [ + PydanticMessage( + agent_id=sarah_agent.id, + role="user", + content=[TextContent(text=f"Message {i}")], + ) + for i in range(5) + ] + messages = await server.message_manager.create_many_messages_async( + pydantic_messages, + actor=default_user, + ) + + # Add messages to conversation + await conversation_manager.add_messages_to_conversation( + conversation_id=conversation.id, + agent_id=sarah_agent.id, + message_ids=[m.id for m in messages], + actor=default_user, + ) + + # List with limit + letta_messages = await conversation_manager.list_conversation_messages( + conversation_id=conversation.id, + actor=default_user, + limit=2, + ) + assert len(letta_messages) == 2 + + # List with after cursor (get messages after the first one) + letta_messages_after = await conversation_manager.list_conversation_messages( + conversation_id=conversation.id, + actor=default_user, + after=messages[0].id, + ) + assert len(letta_messages_after) == 4 # Should get messages 1-4