feat: add conversation and conversation_messages tables for concurrent messaging (#8182)
This commit is contained in:
committed by
Caren Thomas
parent
c66b852978
commit
87d920782f
@@ -90,6 +90,36 @@ Workflow for Postgres-targeted migration:
|
||||
- `uv run alembic upgrade head`
|
||||
- `uv run alembic revision --autogenerate -m "..."`
|
||||
|
||||
### 5. Resetting local Postgres for clean migration generation
|
||||
|
||||
If your local Postgres database has drifted from main (e.g., applied migrations
|
||||
that no longer exist, or has stale schema), you can reset it to generate a clean
|
||||
migration.
|
||||
|
||||
From the repo root (`/Users/sarahwooders/repos/letta-cloud`):
|
||||
|
||||
```bash
|
||||
# 1. Remove postgres data directory
|
||||
rm -rf ./data/postgres
|
||||
|
||||
# 2. Stop the running postgres container
|
||||
docker stop $(docker ps -q --filter ancestor=ankane/pgvector)
|
||||
|
||||
# 3. Restart services (creates fresh postgres)
|
||||
just start-services
|
||||
|
||||
# 4. Wait a moment for postgres to be ready, then apply all migrations
|
||||
cd apps/core
|
||||
export LETTA_PG_URI=postgresql+pg8000://postgres:postgres@localhost:5432/letta-core
|
||||
uv run alembic upgrade head
|
||||
|
||||
# 5. Now generate your new migration
|
||||
uv run alembic revision --autogenerate -m "your migration message"
|
||||
```
|
||||
|
||||
This ensures the migration is generated against a clean database state matching
|
||||
main, avoiding spurious diffs from local-only schema changes.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **"Target database is not up to date" when autogenerating**
|
||||
@@ -101,7 +131,7 @@ Workflow for Postgres-targeted migration:
|
||||
changed model is imported in Alembic env context.
|
||||
- **Autogenerated migration has unexpected drops/renames**
|
||||
- Review model changes; consider explicit operations instead of relying on
|
||||
autogenerate.
|
||||
autogenerate. Reset local Postgres (see workflow 5) to get a clean baseline.
|
||||
|
||||
## References
|
||||
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
"""add conversations tables and run conversation_id
|
||||
|
||||
Revision ID: 27de0f58e076
|
||||
Revises: ee2b43eea55e
|
||||
Create Date: 2026-01-01 20:36:09.101274
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "27de0f58e076"
|
||||
down_revision: Union[str, None] = "ee2b43eea55e"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"conversations",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("agent_id", sa.String(), nullable=False),
|
||||
sa.Column("summary", sa.String(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("ix_conversations_agent_id", "conversations", ["agent_id"], unique=False)
|
||||
op.create_index("ix_conversations_org_agent", "conversations", ["organization_id", "agent_id"], unique=False)
|
||||
op.create_table(
|
||||
"conversation_messages",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("conversation_id", sa.String(), nullable=True),
|
||||
sa.Column("agent_id", sa.String(), nullable=False),
|
||||
sa.Column("message_id", sa.String(), nullable=False),
|
||||
sa.Column("position", sa.Integer(), nullable=False),
|
||||
sa.Column("in_context", sa.Boolean(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["conversation_id"], ["conversations.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["message_id"], ["messages.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("conversation_id", "message_id", name="unique_conversation_message"),
|
||||
)
|
||||
op.create_index("ix_conv_msg_agent_conversation", "conversation_messages", ["agent_id", "conversation_id"], unique=False)
|
||||
op.create_index("ix_conv_msg_agent_id", "conversation_messages", ["agent_id"], unique=False)
|
||||
op.create_index("ix_conv_msg_conversation_position", "conversation_messages", ["conversation_id", "position"], unique=False)
|
||||
op.create_index("ix_conv_msg_message_id", "conversation_messages", ["message_id"], unique=False)
|
||||
op.add_column("messages", sa.Column("conversation_id", sa.String(), nullable=True))
|
||||
op.create_index(op.f("ix_messages_conversation_id"), "messages", ["conversation_id"], unique=False)
|
||||
op.create_foreign_key(None, "messages", "conversations", ["conversation_id"], ["id"], ondelete="SET NULL")
|
||||
op.add_column("runs", sa.Column("conversation_id", sa.String(), nullable=True))
|
||||
op.create_index("ix_runs_conversation_id", "runs", ["conversation_id"], unique=False)
|
||||
op.create_foreign_key(None, "runs", "conversations", ["conversation_id"], ["id"], ondelete="SET NULL")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint(None, "runs", type_="foreignkey")
|
||||
op.drop_index("ix_runs_conversation_id", table_name="runs")
|
||||
op.drop_column("runs", "conversation_id")
|
||||
op.drop_constraint(None, "messages", type_="foreignkey")
|
||||
op.drop_index(op.f("ix_messages_conversation_id"), table_name="messages")
|
||||
op.drop_column("messages", "conversation_id")
|
||||
op.drop_index("ix_conv_msg_message_id", table_name="conversation_messages")
|
||||
op.drop_index("ix_conv_msg_conversation_position", table_name="conversation_messages")
|
||||
op.drop_index("ix_conv_msg_agent_id", table_name="conversation_messages")
|
||||
op.drop_index("ix_conv_msg_agent_conversation", table_name="conversation_messages")
|
||||
op.drop_table("conversation_messages")
|
||||
op.drop_index("ix_conversations_org_agent", table_name="conversations")
|
||||
op.drop_index("ix_conversations_agent_id", table_name="conversations")
|
||||
op.drop_table("conversations")
|
||||
# ### end Alembic commands ###
|
||||
@@ -8289,6 +8289,444 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/conversations/": {
|
||||
"post": {
|
||||
"tags": ["conversations"],
|
||||
"summary": "Create Conversation",
|
||||
"description": "Create a new conversation for an agent.",
|
||||
"operationId": "create_conversation",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "agent_id",
|
||||
"in": "query",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"description": "The agent ID to create a conversation for",
|
||||
"title": "Agent Id"
|
||||
},
|
||||
"description": "The agent ID to create a conversation for"
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/CreateConversation"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Conversation"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"get": {
|
||||
"tags": ["conversations"],
|
||||
"summary": "List Conversations",
|
||||
"description": "List all conversations for an agent.",
|
||||
"operationId": "list_conversations",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "agent_id",
|
||||
"in": "query",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"description": "The agent ID to list conversations for",
|
||||
"title": "Agent Id"
|
||||
},
|
||||
"description": "The agent ID to list conversations for"
|
||||
},
|
||||
{
|
||||
"name": "limit",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of conversations to return",
|
||||
"default": 50,
|
||||
"title": "Limit"
|
||||
},
|
||||
"description": "Maximum number of conversations to return"
|
||||
},
|
||||
{
|
||||
"name": "after",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "Cursor for pagination (conversation ID)",
|
||||
"title": "After"
|
||||
},
|
||||
"description": "Cursor for pagination (conversation ID)"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/Conversation"
|
||||
},
|
||||
"title": "Response List Conversations"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/conversations/{conversation_id}": {
|
||||
"get": {
|
||||
"tags": ["conversations"],
|
||||
"summary": "Retrieve Conversation",
|
||||
"description": "Retrieve a specific conversation.",
|
||||
"operationId": "retrieve_conversation",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "conversation_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"minLength": 41,
|
||||
"maxLength": 41,
|
||||
"pattern": "^conv-[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$",
|
||||
"description": "The ID of the conv in the format 'conv-<uuid4>'",
|
||||
"examples": ["conv-123e4567-e89b-42d3-8456-426614174000"],
|
||||
"title": "Conversation Id"
|
||||
},
|
||||
"description": "The ID of the conv in the format 'conv-<uuid4>'"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Conversation"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/conversations/{conversation_id}/messages": {
|
||||
"get": {
|
||||
"tags": ["conversations"],
|
||||
"summary": "List Conversation Messages",
|
||||
"description": "List all messages in a conversation.\n\nReturns LettaMessage objects (UserMessage, AssistantMessage, etc.) for all\nmessages in the conversation, ordered by position (oldest first),\nwith support for cursor-based pagination.",
|
||||
"operationId": "list_conversation_messages",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "conversation_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"minLength": 41,
|
||||
"maxLength": 41,
|
||||
"pattern": "^conv-[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$",
|
||||
"description": "The ID of the conv in the format 'conv-<uuid4>'",
|
||||
"examples": ["conv-123e4567-e89b-42d3-8456-426614174000"],
|
||||
"title": "Conversation Id"
|
||||
},
|
||||
"description": "The ID of the conv in the format 'conv-<uuid4>'"
|
||||
},
|
||||
{
|
||||
"name": "before",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "Message ID cursor for pagination. Returns messages that come before this message ID in the conversation",
|
||||
"title": "Before"
|
||||
},
|
||||
"description": "Message ID cursor for pagination. Returns messages that come before this message ID in the conversation"
|
||||
},
|
||||
{
|
||||
"name": "after",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "Message ID cursor for pagination. Returns messages that come after this message ID in the conversation",
|
||||
"title": "After"
|
||||
},
|
||||
"description": "Message ID cursor for pagination. Returns messages that come after this message ID in the conversation"
|
||||
},
|
||||
{
|
||||
"name": "limit",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "integer"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "Maximum number of messages to return",
|
||||
"default": 100,
|
||||
"title": "Limit"
|
||||
},
|
||||
"description": "Maximum number of messages to return"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/LettaMessageUnion"
|
||||
},
|
||||
"title": "Response List Conversation Messages"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"post": {
|
||||
"tags": ["conversations"],
|
||||
"summary": "Send Conversation Message",
|
||||
"description": "Send a message to a conversation and get a streaming response.\n\nThis endpoint sends a message to an existing conversation and streams\nthe agent's response back.",
|
||||
"operationId": "send_conversation_message",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "conversation_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"minLength": 41,
|
||||
"maxLength": 41,
|
||||
"pattern": "^conv-[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$",
|
||||
"description": "The ID of the conv in the format 'conv-<uuid4>'",
|
||||
"examples": ["conv-123e4567-e89b-42d3-8456-426614174000"],
|
||||
"title": "Conversation Id"
|
||||
},
|
||||
"description": "The ID of the conv in the format 'conv-<uuid4>'"
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/LettaStreamingRequest"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/LettaStreamingResponse"
|
||||
}
|
||||
},
|
||||
"text/event-stream": {
|
||||
"description": "Server-Sent Events stream"
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/conversations/{conversation_id}/stream": {
|
||||
"post": {
|
||||
"tags": ["conversations"],
|
||||
"summary": "Retrieve Conversation Stream",
|
||||
"description": "Resume the stream for the most recent active run in a conversation.\n\nThis endpoint allows you to reconnect to an active background stream\nfor a conversation, enabling recovery from network interruptions.",
|
||||
"operationId": "retrieve_conversation_stream",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "conversation_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"minLength": 41,
|
||||
"maxLength": 41,
|
||||
"pattern": "^conv-[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$",
|
||||
"description": "The ID of the conv in the format 'conv-<uuid4>'",
|
||||
"examples": ["conv-123e4567-e89b-42d3-8456-426614174000"],
|
||||
"title": "Conversation Id"
|
||||
},
|
||||
"description": "The ID of the conv in the format 'conv-<uuid4>'"
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/RetrieveStreamRequest"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {}
|
||||
},
|
||||
"text/event-stream": {
|
||||
"description": "Server-Sent Events stream",
|
||||
"schema": {
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/SystemMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/UserMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ReasoningMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/HiddenReasoningMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ToolCallMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ToolReturnMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/AssistantMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ApprovalRequestMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ApprovalResponseMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/LettaPing"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/LettaErrorMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/LettaStopReason"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/LettaUsageStatistics"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/chat/completions": {
|
||||
"post": {
|
||||
"tags": ["chat"],
|
||||
@@ -27246,6 +27684,95 @@
|
||||
"title": "ContinueToolRule",
|
||||
"description": "Represents a tool rule configuration where if this tool gets called, it must continue the agent loop."
|
||||
},
|
||||
"Conversation": {
|
||||
"properties": {
|
||||
"created_by_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Created By Id",
|
||||
"description": "The id of the user that made this object."
|
||||
},
|
||||
"last_updated_by_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Last Updated By Id",
|
||||
"description": "The id of the user that made this object."
|
||||
},
|
||||
"created_at": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Created At",
|
||||
"description": "The timestamp when the object was created."
|
||||
},
|
||||
"updated_at": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Updated At",
|
||||
"description": "The timestamp when the object was last updated."
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"title": "Id",
|
||||
"description": "The unique identifier of the conversation."
|
||||
},
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"title": "Agent Id",
|
||||
"description": "The ID of the agent this conversation belongs to."
|
||||
},
|
||||
"summary": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Summary",
|
||||
"description": "A summary of the conversation."
|
||||
},
|
||||
"in_context_message_ids": {
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": "array",
|
||||
"title": "In Context Message Ids",
|
||||
"description": "The IDs of in-context messages for the conversation."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"type": "object",
|
||||
"required": ["id", "agent_id"],
|
||||
"title": "Conversation",
|
||||
"description": "Represents a conversation on an agent for concurrent messaging."
|
||||
},
|
||||
"CoreMemoryBlockSchema": {
|
||||
"properties": {
|
||||
"created_at": {
|
||||
@@ -28255,6 +28782,25 @@
|
||||
"title": "CreateBlock",
|
||||
"description": "Create a block"
|
||||
},
|
||||
"CreateConversation": {
|
||||
"properties": {
|
||||
"summary": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Summary",
|
||||
"description": "A summary of the conversation."
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"title": "CreateConversation",
|
||||
"description": "Request model for creating a new conversation."
|
||||
},
|
||||
"CreateMCPServerRequest": {
|
||||
"properties": {
|
||||
"server_name": {
|
||||
@@ -37076,6 +37622,18 @@
|
||||
"title": "Agent Id",
|
||||
"description": "The unique identifier of the agent associated with the run."
|
||||
},
|
||||
"conversation_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Conversation Id",
|
||||
"description": "The unique identifier of the conversation associated with the run."
|
||||
},
|
||||
"base_template_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
|
||||
@@ -66,6 +66,7 @@ class BaseAgentV2(ABC):
|
||||
use_assistant_message: bool = True,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
conversation_id: str | None = None,
|
||||
client_tools: list["ClientToolSchema"] | None = None,
|
||||
) -> AsyncGenerator[LettaMessage | LegacyLettaMessage | MessageStreamStatus, None]:
|
||||
"""
|
||||
|
||||
@@ -6,6 +6,7 @@ from uuid import UUID, uuid4
|
||||
|
||||
from letta.errors import PendingApprovalError
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
@@ -131,16 +132,21 @@ async def _prepare_in_context_messages_no_persist_async(
|
||||
message_manager: MessageManager,
|
||||
actor: User,
|
||||
run_id: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> Tuple[List[Message], List[Message]]:
|
||||
"""
|
||||
Prepares in-context messages for an agent, based on the current state and a new user input.
|
||||
|
||||
When conversation_id is provided, messages are loaded from the conversation_messages
|
||||
table instead of agent_state.message_ids.
|
||||
|
||||
Args:
|
||||
input_messages (List[MessageCreate]): The new user input messages to process.
|
||||
agent_state (AgentState): The current state of the agent, including message buffer config.
|
||||
message_manager (MessageManager): The manager used to retrieve and create messages.
|
||||
actor (User): The user performing the action, used for access control and attribution.
|
||||
run_id (str): The run ID associated with this message processing.
|
||||
conversation_id (str): Optional conversation ID to load messages from.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Message], List[Message]]: A tuple containing:
|
||||
@@ -148,12 +154,74 @@ async def _prepare_in_context_messages_no_persist_async(
|
||||
- The new in-context messages (messages created from the new input).
|
||||
"""
|
||||
|
||||
if agent_state.message_buffer_autoclear:
|
||||
# If autoclear is enabled, only include the most recent system message (usually at index 0)
|
||||
current_in_context_messages = [await message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)]
|
||||
if conversation_id:
|
||||
# Conversation mode: load messages from conversation_messages table
|
||||
from letta.services.conversation_manager import ConversationManager
|
||||
|
||||
conversation_manager = ConversationManager()
|
||||
message_ids = await conversation_manager.get_message_ids_for_conversation(
|
||||
conversation_id=conversation_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
if agent_state.message_buffer_autoclear and message_ids:
|
||||
# If autoclear is enabled, only include the system message
|
||||
current_in_context_messages = [await message_manager.get_message_by_id_async(message_id=message_ids[0], actor=actor)]
|
||||
elif message_ids:
|
||||
# Otherwise, include the full list of messages from the conversation
|
||||
current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=message_ids, actor=actor)
|
||||
else:
|
||||
# No messages in conversation yet - compile a new system message for this conversation
|
||||
# Each conversation gets its own system message (captures memory state at conversation start)
|
||||
from letta.prompts.prompt_generator import PromptGenerator
|
||||
from letta.services.passage_manager import PassageManager
|
||||
|
||||
num_messages = await message_manager.size_async(actor=actor, agent_id=agent_state.id)
|
||||
passage_manager = PassageManager()
|
||||
num_archival_memories = await passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_state.id)
|
||||
|
||||
system_message_str = await PromptGenerator.compile_system_message_async(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=get_utc_time(),
|
||||
timezone=agent_state.timezone,
|
||||
user_defined_variables=None,
|
||||
append_icm_if_missing=True,
|
||||
previous_message_count=num_messages,
|
||||
archival_memory_size=num_archival_memories,
|
||||
sources=agent_state.sources,
|
||||
max_files_open=agent_state.max_files_open,
|
||||
)
|
||||
system_message = Message.dict_to_message(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
openai_message_dict={"role": "system", "content": system_message_str},
|
||||
)
|
||||
|
||||
# Persist the new system message
|
||||
persisted_messages = await message_manager.create_many_messages_async([system_message], actor=actor)
|
||||
system_message = persisted_messages[0]
|
||||
|
||||
# Add it to the conversation tracking
|
||||
await conversation_manager.add_messages_to_conversation(
|
||||
conversation_id=conversation_id,
|
||||
agent_id=agent_state.id,
|
||||
message_ids=[system_message.id],
|
||||
actor=actor,
|
||||
starting_position=0,
|
||||
)
|
||||
|
||||
current_in_context_messages = [system_message]
|
||||
else:
|
||||
# Otherwise, include the full list of messages by ID for context
|
||||
current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
|
||||
# Default mode: load messages from agent_state.message_ids
|
||||
if agent_state.message_buffer_autoclear:
|
||||
# If autoclear is enabled, only include the most recent system message (usually at index 0)
|
||||
current_in_context_messages = [
|
||||
await message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
|
||||
]
|
||||
else:
|
||||
# Otherwise, include the full list of messages by ID for context
|
||||
current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
|
||||
|
||||
# Check for approval-related message validation
|
||||
if input_messages[0].type == "approval":
|
||||
|
||||
@@ -254,6 +254,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
use_assistant_message: bool = True,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
conversation_id: str | None = None, # Not used in V2, but accepted for API compatibility
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
|
||||
@@ -47,6 +47,7 @@ from letta.server.rest_api.utils import (
|
||||
create_parallel_tool_messages_from_llm_response,
|
||||
create_tool_returns_for_denials,
|
||||
)
|
||||
from letta.services.conversation_manager import ConversationManager
|
||||
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
|
||||
from letta.services.summarizer.summarizer_all import summarize_all
|
||||
from letta.services.summarizer.summarizer_config import CompactionSettings
|
||||
@@ -80,6 +81,8 @@ class LettaAgentV3(LettaAgentV2):
|
||||
# affecting step-level telemetry.
|
||||
self.context_token_estimate: int | None = None
|
||||
self.in_context_messages: list[Message] = [] # in-memory tracker
|
||||
# Conversation mode: when set, messages are tracked per-conversation
|
||||
self.conversation_id: str | None = None
|
||||
# Client-side tools passed in the request (executed by client, not server)
|
||||
self.client_tools: list[ClientToolSchema] = []
|
||||
|
||||
@@ -104,6 +107,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
use_assistant_message: bool = True, # NOTE: not used
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
conversation_id: str | None = None,
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
) -> LettaResponse:
|
||||
"""
|
||||
@@ -116,6 +120,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
use_assistant_message: Whether to use assistant message format
|
||||
include_return_message_types: Filter for which message types to return
|
||||
request_start_timestamp_ns: Start time for tracking request duration
|
||||
conversation_id: Optional conversation ID for conversation-scoped messaging
|
||||
client_tools: Optional list of client-side tools. When called, execution pauses
|
||||
for client to provide tool returns.
|
||||
|
||||
@@ -123,12 +128,19 @@ class LettaAgentV3(LettaAgentV2):
|
||||
LettaResponse: Complete response with all messages and metadata
|
||||
"""
|
||||
self._initialize_state()
|
||||
self.conversation_id = conversation_id
|
||||
self.client_tools = client_tools or []
|
||||
request_span = self._request_checkpoint_start(request_start_timestamp_ns=request_start_timestamp_ns)
|
||||
response_letta_messages = []
|
||||
|
||||
# Prepare in-context messages (conversation mode if conversation_id provided)
|
||||
curr_in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
|
||||
input_messages, self.agent_state, self.message_manager, self.actor, run_id
|
||||
input_messages,
|
||||
self.agent_state,
|
||||
self.message_manager,
|
||||
self.actor,
|
||||
run_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
follow_up_messages = []
|
||||
if len(input_messages_to_persist) > 1 and input_messages_to_persist[0].role == "approval":
|
||||
@@ -241,6 +253,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
use_assistant_message: bool = True, # NOTE: not used
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
conversation_id: str | None = None,
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
@@ -259,6 +272,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
use_assistant_message: Whether to use assistant message format
|
||||
include_return_message_types: Filter for which message types to return
|
||||
request_start_timestamp_ns: Start time for tracking request duration
|
||||
conversation_id: Optional conversation ID for conversation-scoped messaging
|
||||
client_tools: Optional list of client-side tools. When called, execution pauses
|
||||
for client to provide tool returns.
|
||||
|
||||
@@ -266,6 +280,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
str: JSON-formatted SSE data chunks for each completed step
|
||||
"""
|
||||
self._initialize_state()
|
||||
self.conversation_id = conversation_id
|
||||
self.client_tools = client_tools or []
|
||||
request_span = self._request_checkpoint_start(request_start_timestamp_ns=request_start_timestamp_ns)
|
||||
response_letta_messages = []
|
||||
@@ -284,8 +299,14 @@ class LettaAgentV3(LettaAgentV2):
|
||||
)
|
||||
|
||||
try:
|
||||
# Prepare in-context messages (conversation mode if conversation_id provided)
|
||||
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
|
||||
input_messages, self.agent_state, self.message_manager, self.actor, run_id
|
||||
input_messages,
|
||||
self.agent_state,
|
||||
self.message_manager,
|
||||
self.actor,
|
||||
run_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
follow_up_messages = []
|
||||
if len(input_messages_to_persist) > 1 and input_messages_to_persist[0].role == "approval":
|
||||
@@ -435,7 +456,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
This handles:
|
||||
- Persisting the new messages into the `messages` table
|
||||
- Updating the in-memory trackers for in-context messages (`self.in_context_messages`) and agent state (`self.agent_state.message_ids`)
|
||||
- Updating the DB with the current in-context messages (`self.agent_state.message_ids`)
|
||||
- Updating the DB with the current in-context messages (`self.agent_state.message_ids`) OR conversation_messages table
|
||||
|
||||
Args:
|
||||
run_id: The run ID to associate with the messages
|
||||
@@ -457,14 +478,33 @@ class LettaAgentV3(LettaAgentV2):
|
||||
template_id=self.agent_state.template_id,
|
||||
)
|
||||
|
||||
# persist the in-context messages
|
||||
# TODO: somehow make sure all the message ids are already persisted
|
||||
await self.agent_manager.update_message_ids_async(
|
||||
agent_id=self.agent_state.id,
|
||||
message_ids=[m.id for m in in_context_messages],
|
||||
actor=self.actor,
|
||||
)
|
||||
self.agent_state.message_ids = [m.id for m in in_context_messages] # update in-memory state
|
||||
if self.conversation_id:
|
||||
# Conversation mode: update conversation_messages table
|
||||
# Add new messages to conversation tracking
|
||||
new_message_ids = [m.id for m in new_messages]
|
||||
if new_message_ids:
|
||||
await ConversationManager().add_messages_to_conversation(
|
||||
conversation_id=self.conversation_id,
|
||||
agent_id=self.agent_state.id,
|
||||
message_ids=new_message_ids,
|
||||
actor=self.actor,
|
||||
)
|
||||
|
||||
# Update which messages are in context
|
||||
await ConversationManager().update_in_context_messages(
|
||||
conversation_id=self.conversation_id,
|
||||
in_context_message_ids=[m.id for m in in_context_messages],
|
||||
actor=self.actor,
|
||||
)
|
||||
else:
|
||||
# Default mode: update agent.message_ids
|
||||
await self.agent_manager.update_message_ids_async(
|
||||
agent_id=self.agent_state.id,
|
||||
message_ids=[m.id for m in in_context_messages],
|
||||
actor=self.actor,
|
||||
)
|
||||
self.agent_state.message_ids = [m.id for m in in_context_messages] # update in-memory state
|
||||
|
||||
self.in_context_messages = in_context_messages # update in-memory state
|
||||
|
||||
@trace_method
|
||||
|
||||
@@ -6,6 +6,8 @@ from letta.orm.base import Base
|
||||
from letta.orm.block import Block
|
||||
from letta.orm.block_history import BlockHistory
|
||||
from letta.orm.blocks_agents import BlocksAgents
|
||||
from letta.orm.conversation import Conversation
|
||||
from letta.orm.conversation_messages import ConversationMessage
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.files_agents import FileAgent
|
||||
from letta.orm.group import Group
|
||||
|
||||
@@ -27,6 +27,7 @@ from letta.utils import bounded_gather, calculate_file_defaults_based_on_context
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
from letta.orm.archives_agents import ArchivesAgents
|
||||
from letta.orm.conversation import Conversation
|
||||
from letta.orm.files_agents import FileAgent
|
||||
from letta.orm.identity import Identity
|
||||
from letta.orm.organization import Organization
|
||||
@@ -186,6 +187,13 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
|
||||
lazy="noload",
|
||||
doc="Archives accessible by this agent.",
|
||||
)
|
||||
conversations: Mapped[List["Conversation"]] = relationship(
|
||||
"Conversation",
|
||||
back_populates="agent",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="raise",
|
||||
doc="Conversations for concurrent messaging on this agent.",
|
||||
)
|
||||
|
||||
def _get_per_file_view_window_char_limit(self) -> int:
|
||||
"""Get the per_file_view_window_char_limit, calculating defaults if None."""
|
||||
|
||||
49
letta/orm/conversation.py
Normal file
49
letta/orm/conversation.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from sqlalchemy import ForeignKey, Index, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.conversation import Conversation as PydanticConversation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.conversation_messages import ConversationMessage
|
||||
|
||||
|
||||
class Conversation(SqlalchemyBase, OrganizationMixin):
|
||||
"""Conversations that can be created on an agent for concurrent messaging."""
|
||||
|
||||
__tablename__ = "conversations"
|
||||
__pydantic_model__ = PydanticConversation
|
||||
__table_args__ = (
|
||||
Index("ix_conversations_agent_id", "agent_id"),
|
||||
Index("ix_conversations_org_agent", "organization_id", "agent_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"conv-{uuid.uuid4()}")
|
||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), nullable=False)
|
||||
summary: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="Summary of the conversation")
|
||||
|
||||
# Relationships
|
||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="conversations", lazy="raise")
|
||||
message_associations: Mapped[List["ConversationMessage"]] = relationship(
|
||||
"ConversationMessage",
|
||||
back_populates="conversation",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="selectin",
|
||||
)
|
||||
|
||||
def to_pydantic(self) -> PydanticConversation:
|
||||
"""Converts the SQLAlchemy model to its Pydantic counterpart."""
|
||||
return self.__pydantic_model__(
|
||||
id=self.id,
|
||||
agent_id=self.agent_id,
|
||||
summary=self.summary,
|
||||
created_at=self.created_at,
|
||||
updated_at=self.updated_at,
|
||||
created_by_id=self.created_by_id,
|
||||
last_updated_by_id=self.last_updated_by_id,
|
||||
)
|
||||
73
letta/orm/conversation_messages.py
Normal file
73
letta/orm/conversation_messages.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from sqlalchemy import Boolean, ForeignKey, Index, Integer, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.conversation import Conversation
|
||||
from letta.orm.message import Message
|
||||
|
||||
|
||||
class ConversationMessage(SqlalchemyBase, OrganizationMixin):
|
||||
"""
|
||||
Track in-context messages for a conversation.
|
||||
|
||||
This replaces the message_ids JSON list on agents with proper relational modeling.
|
||||
- conversation_id=NULL represents the "default" conversation (backward compatible)
|
||||
- conversation_id=<id> represents a named conversation for concurrent messaging
|
||||
"""
|
||||
|
||||
__tablename__ = "conversation_messages"
|
||||
__table_args__ = (
|
||||
Index("ix_conv_msg_conversation_position", "conversation_id", "position"),
|
||||
Index("ix_conv_msg_message_id", "message_id"),
|
||||
Index("ix_conv_msg_agent_id", "agent_id"),
|
||||
Index("ix_conv_msg_agent_conversation", "agent_id", "conversation_id"),
|
||||
UniqueConstraint("conversation_id", "message_id", name="unique_conversation_message"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"conv_msg-{uuid.uuid4()}")
|
||||
conversation_id: Mapped[Optional[str]] = mapped_column(
|
||||
String,
|
||||
ForeignKey("conversations.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
doc="NULL for default conversation, otherwise FK to conversation",
|
||||
)
|
||||
agent_id: Mapped[str] = mapped_column(
|
||||
String,
|
||||
ForeignKey("agents.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
doc="The agent this message association belongs to",
|
||||
)
|
||||
message_id: Mapped[str] = mapped_column(
|
||||
String,
|
||||
ForeignKey("messages.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
doc="The message being tracked",
|
||||
)
|
||||
position: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
doc="Position in conversation (for ordering)",
|
||||
)
|
||||
in_context: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
default=True,
|
||||
nullable=False,
|
||||
doc="Whether message is currently in the agent's context window",
|
||||
)
|
||||
|
||||
# Relationships
|
||||
conversation: Mapped[Optional["Conversation"]] = relationship(
|
||||
"Conversation",
|
||||
back_populates="message_associations",
|
||||
lazy="raise",
|
||||
)
|
||||
message: Mapped["Message"] = relationship(
|
||||
"Message",
|
||||
lazy="selectin",
|
||||
)
|
||||
@@ -55,6 +55,12 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
nullable=True,
|
||||
doc="The id of the LLMBatchItem that this message is associated with",
|
||||
)
|
||||
conversation_id: Mapped[Optional[str]] = mapped_column(
|
||||
ForeignKey("conversations.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
doc="The conversation this message belongs to (NULL = default conversation)",
|
||||
)
|
||||
is_err: Mapped[Optional[bool]] = mapped_column(
|
||||
nullable=True, doc="Whether this message is part of an error step. Used only for debugging purposes."
|
||||
)
|
||||
|
||||
@@ -30,6 +30,7 @@ class Run(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateMixin):
|
||||
Index("ix_runs_created_at", "created_at", "id"),
|
||||
Index("ix_runs_agent_id", "agent_id"),
|
||||
Index("ix_runs_organization_id", "organization_id"),
|
||||
Index("ix_runs_conversation_id", "conversation_id"),
|
||||
)
|
||||
|
||||
# Generate run ID with run- prefix
|
||||
@@ -50,6 +51,11 @@ class Run(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateMixin):
|
||||
# Agent relationship - A run belongs to one agent
|
||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), nullable=False, doc="The agent that owns this run.")
|
||||
|
||||
# Conversation relationship - Optional, a run may be associated with a conversation
|
||||
conversation_id: Mapped[Optional[str]] = mapped_column(
|
||||
String, ForeignKey("conversations.id", ondelete="SET NULL"), nullable=True, doc="The conversation this run belongs to."
|
||||
)
|
||||
|
||||
# Callback related columns
|
||||
callback_url: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="When set, POST to this URL after run completion.")
|
||||
callback_sent_at: Mapped[Optional[datetime]] = mapped_column(nullable=True, doc="Timestamp when the callback was last attempted.")
|
||||
|
||||
28
letta/schemas/conversation.py
Normal file
28
letta/schemas/conversation.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
|
||||
|
||||
class Conversation(OrmMetadataBase):
|
||||
"""Represents a conversation on an agent for concurrent messaging."""
|
||||
|
||||
__id_prefix__ = "conv"
|
||||
|
||||
id: str = Field(..., description="The unique identifier of the conversation.")
|
||||
agent_id: str = Field(..., description="The ID of the agent this conversation belongs to.")
|
||||
summary: Optional[str] = Field(None, description="A summary of the conversation.")
|
||||
in_context_message_ids: List[str] = Field(default_factory=list, description="The IDs of in-context messages for the conversation.")
|
||||
|
||||
|
||||
class CreateConversation(BaseModel):
|
||||
"""Request model for creating a new conversation."""
|
||||
|
||||
summary: Optional[str] = Field(None, description="A summary of the conversation.")
|
||||
|
||||
|
||||
class UpdateConversation(BaseModel):
|
||||
"""Request model for updating a conversation."""
|
||||
|
||||
summary: Optional[str] = Field(None, description="A summary of the conversation.")
|
||||
@@ -26,6 +26,7 @@ class PrimitiveType(str, Enum):
|
||||
SANDBOX_CONFIG = "sandbox" # Note: sandbox_config IDs use "sandbox" prefix
|
||||
STEP = "step"
|
||||
IDENTITY = "identity"
|
||||
CONVERSATION = "conv"
|
||||
|
||||
# Infrastructure types
|
||||
MCP_SERVER = "mcp_server"
|
||||
|
||||
@@ -27,6 +27,9 @@ class Run(RunBase):
|
||||
# Agent relationship
|
||||
agent_id: str = Field(..., description="The unique identifier of the agent associated with the run.")
|
||||
|
||||
# Conversation relationship
|
||||
conversation_id: Optional[str] = Field(None, description="The unique identifier of the conversation associated with the run.")
|
||||
|
||||
# Template fields
|
||||
base_template_id: Optional[str] = Field(None, description="The base template ID that the run belongs to.")
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from letta.server.rest_api.routers.v1.anthropic import router as anthropic_route
|
||||
from letta.server.rest_api.routers.v1.archives import router as archives_router
|
||||
from letta.server.rest_api.routers.v1.blocks import router as blocks_router
|
||||
from letta.server.rest_api.routers.v1.chat_completions import router as chat_completions_router, router as openai_chat_completions_router
|
||||
from letta.server.rest_api.routers.v1.conversations import router as conversations_router
|
||||
from letta.server.rest_api.routers.v1.embeddings import router as embeddings_router
|
||||
from letta.server.rest_api.routers.v1.folders import router as folders_router
|
||||
from letta.server.rest_api.routers.v1.groups import router as groups_router
|
||||
@@ -36,6 +37,7 @@ ROUTERS = [
|
||||
sources_router,
|
||||
folders_router,
|
||||
agents_router,
|
||||
conversations_router,
|
||||
chat_completions_router,
|
||||
groups_router,
|
||||
identities_router,
|
||||
|
||||
273
letta/server/rest_api/routers/v1/conversations.py
Normal file
273
letta/server/rest_api/routers/v1/conversations.py
Normal file
@@ -0,0 +1,273 @@
|
||||
from datetime import timedelta
|
||||
from typing import Annotated, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
from pydantic import Field
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
||||
from letta.errors import LettaExpiredError, LettaInvalidArgumentError
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.schemas.conversation import Conversation, CreateConversation
|
||||
from letta.schemas.enums import RunStatus
|
||||
from letta.schemas.letta_message import LettaMessageUnion
|
||||
from letta.schemas.letta_request import LettaStreamingRequest, RetrieveStreamRequest
|
||||
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
|
||||
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
|
||||
from letta.server.rest_api.redis_stream_manager import redis_sse_stream_generator
|
||||
from letta.server.rest_api.streaming_response import (
|
||||
StreamingResponseWithStatusCode,
|
||||
add_keepalive_to_stream,
|
||||
cancellation_aware_stream_wrapper,
|
||||
)
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.conversation_manager import ConversationManager
|
||||
from letta.services.run_manager import RunManager
|
||||
from letta.services.streaming_service import StreamingService
|
||||
from letta.settings import settings
|
||||
from letta.validators import ConversationId
|
||||
|
||||
router = APIRouter(prefix="/conversations", tags=["conversations"])
|
||||
|
||||
# Instantiate manager
|
||||
conversation_manager = ConversationManager()
|
||||
|
||||
|
||||
@router.post("/", response_model=Conversation, operation_id="create_conversation")
|
||||
async def create_conversation(
|
||||
agent_id: str = Query(..., description="The agent ID to create a conversation for"),
|
||||
conversation_create: CreateConversation = Body(default_factory=CreateConversation),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
"""Create a new conversation for an agent."""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await conversation_manager.create_conversation(
|
||||
agent_id=agent_id,
|
||||
conversation_create=conversation_create,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Conversation], operation_id="list_conversations")
|
||||
async def list_conversations(
|
||||
agent_id: str = Query(..., description="The agent ID to list conversations for"),
|
||||
limit: int = Query(50, description="Maximum number of conversations to return"),
|
||||
after: Optional[str] = Query(None, description="Cursor for pagination (conversation ID)"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
"""List all conversations for an agent."""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await conversation_manager.list_conversations(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
limit=limit,
|
||||
after=after,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{conversation_id}", response_model=Conversation, operation_id="retrieve_conversation")
|
||||
async def retrieve_conversation(
|
||||
conversation_id: ConversationId,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
"""Retrieve a specific conversation."""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await conversation_manager.get_conversation_by_id(
|
||||
conversation_id=conversation_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
ConversationMessagesResponse = Annotated[
|
||||
List[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}})
|
||||
]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{conversation_id}/messages",
|
||||
response_model=ConversationMessagesResponse,
|
||||
operation_id="list_conversation_messages",
|
||||
)
|
||||
async def list_conversation_messages(
|
||||
conversation_id: ConversationId,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
before: Optional[str] = Query(
|
||||
None, description="Message ID cursor for pagination. Returns messages that come before this message ID in the conversation"
|
||||
),
|
||||
after: Optional[str] = Query(
|
||||
None, description="Message ID cursor for pagination. Returns messages that come after this message ID in the conversation"
|
||||
),
|
||||
limit: Optional[int] = Query(100, description="Maximum number of messages to return"),
|
||||
):
|
||||
"""
|
||||
List all messages in a conversation.
|
||||
|
||||
Returns LettaMessage objects (UserMessage, AssistantMessage, etc.) for all
|
||||
messages in the conversation, ordered by position (oldest first),
|
||||
with support for cursor-based pagination.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await conversation_manager.list_conversation_messages(
|
||||
conversation_id=conversation_id,
|
||||
actor=actor,
|
||||
limit=limit,
|
||||
before=before,
|
||||
after=after,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{conversation_id}/messages",
|
||||
response_model=LettaStreamingResponse,
|
||||
operation_id="send_conversation_message",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"text/event-stream": {"description": "Server-Sent Events stream"},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
async def send_conversation_message(
|
||||
conversation_id: ConversationId,
|
||||
request: LettaStreamingRequest = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
) -> StreamingResponse | LettaResponse:
|
||||
"""
|
||||
Send a message to a conversation and get a streaming response.
|
||||
|
||||
This endpoint sends a message to an existing conversation and streams
|
||||
the agent's response back.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
# Get the conversation to find the agent_id
|
||||
conversation = await conversation_manager.get_conversation_by_id(
|
||||
conversation_id=conversation_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Force streaming mode for this endpoint
|
||||
request.streaming = True
|
||||
|
||||
# Use streaming service
|
||||
streaming_service = StreamingService(server)
|
||||
run, result = await streaming_service.create_agent_stream(
|
||||
agent_id=conversation.agent_id,
|
||||
actor=actor,
|
||||
request=request,
|
||||
run_type="send_conversation_message",
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{conversation_id}/stream",
|
||||
response_model=None,
|
||||
operation_id="retrieve_conversation_stream",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"text/event-stream": {
|
||||
"description": "Server-Sent Events stream",
|
||||
"schema": {
|
||||
"oneOf": [
|
||||
{"$ref": "#/components/schemas/SystemMessage"},
|
||||
{"$ref": "#/components/schemas/UserMessage"},
|
||||
{"$ref": "#/components/schemas/ReasoningMessage"},
|
||||
{"$ref": "#/components/schemas/HiddenReasoningMessage"},
|
||||
{"$ref": "#/components/schemas/ToolCallMessage"},
|
||||
{"$ref": "#/components/schemas/ToolReturnMessage"},
|
||||
{"$ref": "#/components/schemas/AssistantMessage"},
|
||||
{"$ref": "#/components/schemas/ApprovalRequestMessage"},
|
||||
{"$ref": "#/components/schemas/ApprovalResponseMessage"},
|
||||
{"$ref": "#/components/schemas/LettaPing"},
|
||||
{"$ref": "#/components/schemas/LettaErrorMessage"},
|
||||
{"$ref": "#/components/schemas/LettaStopReason"},
|
||||
{"$ref": "#/components/schemas/LettaUsageStatistics"},
|
||||
]
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
async def retrieve_conversation_stream(
|
||||
conversation_id: ConversationId,
|
||||
request: RetrieveStreamRequest = Body(None),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Resume the stream for the most recent active run in a conversation.
|
||||
|
||||
This endpoint allows you to reconnect to an active background stream
|
||||
for a conversation, enabling recovery from network interruptions.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
runs_manager = RunManager()
|
||||
|
||||
# Find the most recent active run for this conversation
|
||||
active_runs = await runs_manager.list_runs(
|
||||
actor=actor,
|
||||
conversation_id=conversation_id,
|
||||
statuses=[RunStatus.created, RunStatus.running],
|
||||
limit=1,
|
||||
ascending=False,
|
||||
)
|
||||
|
||||
if not active_runs:
|
||||
raise LettaInvalidArgumentError("No active runs found for this conversation.")
|
||||
|
||||
run = active_runs[0]
|
||||
|
||||
if not run.background:
|
||||
raise LettaInvalidArgumentError("Run was not created in background mode, so it cannot be retrieved.")
|
||||
|
||||
if run.created_at < get_utc_time() - timedelta(hours=3):
|
||||
raise LettaExpiredError("Run was created more than 3 hours ago, and is now expired.")
|
||||
|
||||
redis_client = await get_redis_client()
|
||||
|
||||
if isinstance(redis_client, NoopAsyncRedisClient):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"Background streaming requires Redis to be running. "
|
||||
"Please ensure Redis is properly configured. "
|
||||
f"LETTA_REDIS_HOST: {settings.redis_host}, LETTA_REDIS_PORT: {settings.redis_port}"
|
||||
),
|
||||
)
|
||||
|
||||
stream = redis_sse_stream_generator(
|
||||
redis_client=redis_client,
|
||||
run_id=run.id,
|
||||
starting_after=request.starting_after if request else None,
|
||||
poll_interval=request.poll_interval if request else None,
|
||||
batch_size=request.batch_size if request else None,
|
||||
)
|
||||
|
||||
if settings.enable_cancellation_aware_streaming:
|
||||
stream = cancellation_aware_stream_wrapper(
|
||||
stream_generator=stream,
|
||||
run_manager=server.run_manager,
|
||||
run_id=run.id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
if request and request.include_pings and settings.enable_keepalive:
|
||||
stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval, run_id=run.id)
|
||||
|
||||
return StreamingResponseWithStatusCode(
|
||||
stream,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
357
letta/services/conversation_manager.py
Normal file
357
letta/services/conversation_manager.py
Normal file
@@ -0,0 +1,357 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from letta.orm.conversation import Conversation as ConversationModel
|
||||
from letta.orm.conversation_messages import ConversationMessage as ConversationMessageModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.conversation import Conversation as PydanticConversation, CreateConversation, UpdateConversation
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class ConversationManager:
|
||||
"""Manager class to handle business logic related to Conversations."""
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_conversation(
|
||||
self,
|
||||
agent_id: str,
|
||||
conversation_create: CreateConversation,
|
||||
actor: PydanticUser,
|
||||
) -> PydanticConversation:
|
||||
"""Create a new conversation for an agent."""
|
||||
async with db_registry.async_session() as session:
|
||||
conversation = ConversationModel(
|
||||
agent_id=agent_id,
|
||||
summary=conversation_create.summary,
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
await conversation.create_async(session, actor=actor)
|
||||
return conversation.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_conversation_by_id(
|
||||
self,
|
||||
conversation_id: str,
|
||||
actor: PydanticUser,
|
||||
) -> PydanticConversation:
|
||||
"""Retrieve a conversation by its ID, including in-context message IDs."""
|
||||
async with db_registry.async_session() as session:
|
||||
conversation = await ConversationModel.read_async(
|
||||
db_session=session,
|
||||
identifier=conversation_id,
|
||||
actor=actor,
|
||||
check_is_deleted=True,
|
||||
)
|
||||
|
||||
# Get the in-context message IDs for this conversation
|
||||
message_ids = await self.get_message_ids_for_conversation(
|
||||
conversation_id=conversation_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Build the pydantic model with in_context_message_ids
|
||||
pydantic_conversation = conversation.to_pydantic()
|
||||
pydantic_conversation.in_context_message_ids = message_ids
|
||||
return pydantic_conversation
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_conversations(
|
||||
self,
|
||||
agent_id: str,
|
||||
actor: PydanticUser,
|
||||
limit: int = 50,
|
||||
after: Optional[str] = None,
|
||||
) -> List[PydanticConversation]:
|
||||
"""List conversations for an agent with cursor-based pagination."""
|
||||
async with db_registry.async_session() as session:
|
||||
conversations = await ConversationModel.list_async(
|
||||
db_session=session,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
limit=limit,
|
||||
after=after,
|
||||
ascending=False,
|
||||
)
|
||||
return [conv.to_pydantic() for conv in conversations]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def update_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
conversation_update: UpdateConversation,
|
||||
actor: PydanticUser,
|
||||
) -> PydanticConversation:
|
||||
"""Update a conversation."""
|
||||
async with db_registry.async_session() as session:
|
||||
conversation = await ConversationModel.read_async(
|
||||
db_session=session,
|
||||
identifier=conversation_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Set attributes on the model
|
||||
update_data = conversation_update.model_dump(exclude_none=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(conversation, key, value)
|
||||
|
||||
# Commit the update
|
||||
updated_conversation = await conversation.update_async(
|
||||
db_session=session,
|
||||
actor=actor,
|
||||
)
|
||||
return updated_conversation.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def delete_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
actor: PydanticUser,
|
||||
) -> None:
|
||||
"""Soft delete a conversation."""
|
||||
async with db_registry.async_session() as session:
|
||||
conversation = await ConversationModel.read_async(
|
||||
db_session=session,
|
||||
identifier=conversation_id,
|
||||
actor=actor,
|
||||
)
|
||||
# Soft delete by setting is_deleted flag
|
||||
conversation.is_deleted = True
|
||||
await conversation.update_async(db_session=session, actor=actor)
|
||||
|
||||
# ==================== Message Management Methods ====================
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_message_ids_for_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
actor: PydanticUser,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get ordered message IDs for a conversation.
|
||||
|
||||
Returns message IDs ordered by position in the conversation.
|
||||
Only returns messages that are currently in_context.
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
query = (
|
||||
select(ConversationMessageModel.message_id)
|
||||
.where(
|
||||
ConversationMessageModel.conversation_id == conversation_id,
|
||||
ConversationMessageModel.organization_id == actor.organization_id,
|
||||
ConversationMessageModel.in_context == True,
|
||||
ConversationMessageModel.is_deleted == False,
|
||||
)
|
||||
.order_by(ConversationMessageModel.position)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_messages_for_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
actor: PydanticUser,
|
||||
) -> List[PydanticMessage]:
|
||||
"""
|
||||
Get ordered Message objects for a conversation.
|
||||
|
||||
Returns full Message objects ordered by position in the conversation.
|
||||
Only returns messages that are currently in_context.
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
query = (
|
||||
select(MessageModel)
|
||||
.join(
|
||||
ConversationMessageModel,
|
||||
MessageModel.id == ConversationMessageModel.message_id,
|
||||
)
|
||||
.where(
|
||||
ConversationMessageModel.conversation_id == conversation_id,
|
||||
ConversationMessageModel.organization_id == actor.organization_id,
|
||||
ConversationMessageModel.in_context == True,
|
||||
ConversationMessageModel.is_deleted == False,
|
||||
)
|
||||
.order_by(ConversationMessageModel.position)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return [msg.to_pydantic() for msg in result.scalars().all()]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def add_messages_to_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
agent_id: str,
|
||||
message_ids: List[str],
|
||||
actor: PydanticUser,
|
||||
starting_position: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add messages to a conversation's tracking table.
|
||||
|
||||
Creates ConversationMessage entries with auto-incrementing positions.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation to add messages to
|
||||
agent_id: The agent ID
|
||||
message_ids: List of message IDs to add
|
||||
actor: The user performing the action
|
||||
starting_position: Optional starting position (defaults to next available)
|
||||
"""
|
||||
if not message_ids:
|
||||
return
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Get starting position if not provided
|
||||
if starting_position is None:
|
||||
query = select(func.coalesce(func.max(ConversationMessageModel.position), -1)).where(
|
||||
ConversationMessageModel.conversation_id == conversation_id,
|
||||
ConversationMessageModel.organization_id == actor.organization_id,
|
||||
)
|
||||
result = await session.execute(query)
|
||||
max_position = result.scalar()
|
||||
# Use explicit None check instead of `or` to handle position=0 correctly
|
||||
if max_position is None:
|
||||
max_position = -1
|
||||
starting_position = max_position + 1
|
||||
|
||||
# Create ConversationMessage entries
|
||||
for i, message_id in enumerate(message_ids):
|
||||
conv_msg = ConversationMessageModel(
|
||||
conversation_id=conversation_id,
|
||||
agent_id=agent_id,
|
||||
message_id=message_id,
|
||||
position=starting_position + i,
|
||||
in_context=True,
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
session.add(conv_msg)
|
||||
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def update_in_context_messages(
|
||||
self,
|
||||
conversation_id: str,
|
||||
in_context_message_ids: List[str],
|
||||
actor: PydanticUser,
|
||||
) -> None:
|
||||
"""
|
||||
Update which messages are in context for a conversation.
|
||||
|
||||
Sets in_context=True for messages in the list, False for others.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation to update
|
||||
in_context_message_ids: List of message IDs that should be in context
|
||||
actor: The user performing the action
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Get all conversation messages for this conversation
|
||||
query = select(ConversationMessageModel).where(
|
||||
ConversationMessageModel.conversation_id == conversation_id,
|
||||
ConversationMessageModel.organization_id == actor.organization_id,
|
||||
ConversationMessageModel.is_deleted == False,
|
||||
)
|
||||
result = await session.execute(query)
|
||||
conv_messages = result.scalars().all()
|
||||
|
||||
# Update in_context status
|
||||
in_context_set = set(in_context_message_ids)
|
||||
for conv_msg in conv_messages:
|
||||
conv_msg.in_context = conv_msg.message_id in in_context_set
|
||||
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_conversation_messages(
|
||||
self,
|
||||
conversation_id: str,
|
||||
actor: PydanticUser,
|
||||
limit: Optional[int] = 100,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
) -> List[LettaMessage]:
|
||||
"""
|
||||
List all messages in a conversation with pagination support.
|
||||
|
||||
Unlike get_messages_for_conversation, this returns ALL messages
|
||||
(not just in_context) and supports cursor-based pagination.
|
||||
Messages are always ordered by position (oldest first).
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation to list messages for
|
||||
actor: The user performing the action
|
||||
limit: Maximum number of messages to return
|
||||
before: Return messages before this message ID
|
||||
after: Return messages after this message ID
|
||||
|
||||
Returns:
|
||||
List of LettaMessage objects
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Build base query joining Message with ConversationMessage
|
||||
query = (
|
||||
select(MessageModel)
|
||||
.join(
|
||||
ConversationMessageModel,
|
||||
MessageModel.id == ConversationMessageModel.message_id,
|
||||
)
|
||||
.where(
|
||||
ConversationMessageModel.conversation_id == conversation_id,
|
||||
ConversationMessageModel.organization_id == actor.organization_id,
|
||||
ConversationMessageModel.is_deleted == False,
|
||||
)
|
||||
)
|
||||
|
||||
# Handle cursor-based pagination
|
||||
if before:
|
||||
# Get the position of the cursor message
|
||||
cursor_query = select(ConversationMessageModel.position).where(
|
||||
ConversationMessageModel.conversation_id == conversation_id,
|
||||
ConversationMessageModel.message_id == before,
|
||||
)
|
||||
cursor_result = await session.execute(cursor_query)
|
||||
cursor_position = cursor_result.scalar_one_or_none()
|
||||
if cursor_position is not None:
|
||||
query = query.where(ConversationMessageModel.position < cursor_position)
|
||||
|
||||
if after:
|
||||
# Get the position of the cursor message
|
||||
cursor_query = select(ConversationMessageModel.position).where(
|
||||
ConversationMessageModel.conversation_id == conversation_id,
|
||||
ConversationMessageModel.message_id == after,
|
||||
)
|
||||
cursor_result = await session.execute(cursor_query)
|
||||
cursor_position = cursor_result.scalar_one_or_none()
|
||||
if cursor_position is not None:
|
||||
query = query.where(ConversationMessageModel.position > cursor_position)
|
||||
|
||||
# Order by position (oldest first)
|
||||
query = query.order_by(ConversationMessageModel.position.asc())
|
||||
|
||||
# Apply limit
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
result = await session.execute(query)
|
||||
messages = [msg.to_pydantic() for msg in result.scalars().all()]
|
||||
|
||||
# Convert to LettaMessages
|
||||
return PydanticMessage.to_letta_messages_from_list(messages, reverse=False, text_is_assistant_message=True)
|
||||
@@ -151,6 +151,7 @@ class RunManager:
|
||||
step_count_operator: ComparisonOperator = ComparisonOperator.EQ,
|
||||
tools_used: Optional[List[str]] = None,
|
||||
project_id: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
order_by: Literal["created_at", "duration"] = "created_at",
|
||||
duration_percentile: Optional[int] = None,
|
||||
duration_filter: Optional[dict] = None,
|
||||
@@ -190,6 +191,10 @@ class RunManager:
|
||||
if background is not None:
|
||||
query = query.filter(RunModel.background == background)
|
||||
|
||||
# Filter by conversation_id
|
||||
if conversation_id is not None:
|
||||
query = query.filter(RunModel.conversation_id == conversation_id)
|
||||
|
||||
# Filter by template_family (base_template_id)
|
||||
if template_family:
|
||||
query = query.filter(RunModel.base_template_id == template_family)
|
||||
|
||||
@@ -72,6 +72,7 @@ class StreamingService:
|
||||
actor: User,
|
||||
request: LettaStreamingRequest,
|
||||
run_type: str = "streaming",
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> tuple[Optional[PydanticRun], Union[StreamingResponse, LettaResponse]]:
|
||||
"""
|
||||
Create a streaming response for an agent.
|
||||
@@ -81,6 +82,7 @@ class StreamingService:
|
||||
actor: The user making the request
|
||||
request: The LettaStreamingRequest containing all request parameters
|
||||
run_type: Type of run for tracking
|
||||
conversation_id: Optional conversation ID for conversation-scoped messaging
|
||||
|
||||
Returns:
|
||||
Tuple of (run object or None, streaming response)
|
||||
@@ -104,7 +106,7 @@ class StreamingService:
|
||||
run = None
|
||||
run_update_metadata = None
|
||||
if settings.track_agent_run:
|
||||
run = await self._create_run(agent_id, request, run_type, actor)
|
||||
run = await self._create_run(agent_id, request, run_type, actor, conversation_id=conversation_id)
|
||||
await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id if run else None)
|
||||
|
||||
try:
|
||||
@@ -123,6 +125,7 @@ class StreamingService:
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
actor=actor,
|
||||
conversation_id=conversation_id,
|
||||
client_tools=request.client_tools,
|
||||
)
|
||||
|
||||
@@ -288,6 +291,7 @@ class StreamingService:
|
||||
request_start_timestamp_ns: int,
|
||||
include_return_message_types: Optional[list[MessageType]],
|
||||
actor: User,
|
||||
conversation_id: Optional[str] = None,
|
||||
client_tools: Optional[list[ClientToolSchema]] = None,
|
||||
) -> AsyncIterator:
|
||||
"""
|
||||
@@ -315,6 +319,7 @@ class StreamingService:
|
||||
use_assistant_message=use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=include_return_message_types,
|
||||
conversation_id=conversation_id,
|
||||
client_tools=client_tools,
|
||||
)
|
||||
|
||||
@@ -477,11 +482,14 @@ class StreamingService:
|
||||
]
|
||||
return base_compatible or google_letta_v1
|
||||
|
||||
async def _create_run(self, agent_id: str, request: LettaStreamingRequest, run_type: str, actor: User) -> PydanticRun:
|
||||
async def _create_run(
|
||||
self, agent_id: str, request: LettaStreamingRequest, run_type: str, actor: User, conversation_id: Optional[str] = None
|
||||
) -> PydanticRun:
|
||||
"""Create a run for tracking execution."""
|
||||
run = await self.runs_manager.create_run(
|
||||
pydantic_run=PydanticRun(
|
||||
agent_id=agent_id,
|
||||
conversation_id=conversation_id,
|
||||
background=request.background or False,
|
||||
metadata={
|
||||
"run_type": run_type,
|
||||
|
||||
@@ -62,6 +62,7 @@ ProviderId = Annotated[str, PATH_VALIDATORS[PrimitiveType.PROVIDER.value]()]
|
||||
SandboxConfigId = Annotated[str, PATH_VALIDATORS[PrimitiveType.SANDBOX_CONFIG.value]()]
|
||||
StepId = Annotated[str, PATH_VALIDATORS[PrimitiveType.STEP.value]()]
|
||||
IdentityId = Annotated[str, PATH_VALIDATORS[PrimitiveType.IDENTITY.value]()]
|
||||
ConversationId = Annotated[str, PATH_VALIDATORS[PrimitiveType.CONVERSATION.value]()]
|
||||
|
||||
# Infrastructure types
|
||||
McpServerId = Annotated[str, PATH_VALIDATORS[PrimitiveType.MCP_SERVER.value]()]
|
||||
|
||||
271
tests/integration_test_conversations_sdk.py
Normal file
271
tests/integration_test_conversations_sdk.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""
|
||||
Integration tests for the Conversations API using the SDK.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from letta_client import Letta
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(server_url: str) -> Letta:
|
||||
"""Create a Letta client."""
|
||||
return Letta(base_url=server_url)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent(client: Letta):
|
||||
"""Create a test agent."""
|
||||
agent_state = client.agents.create(
|
||||
name=f"test_conversations_{uuid.uuid4().hex[:8]}",
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
memory_blocks=[
|
||||
{"label": "human", "value": "Test user"},
|
||||
{"label": "persona", "value": "You are a helpful assistant."},
|
||||
],
|
||||
)
|
||||
yield agent_state
|
||||
# Cleanup
|
||||
client.agents.delete(agent_id=agent_state.id)
|
||||
|
||||
|
||||
class TestConversationsSDK:
|
||||
"""Test conversations using the SDK client."""
|
||||
|
||||
def test_create_conversation(self, client: Letta, agent):
|
||||
"""Test creating a conversation for an agent."""
|
||||
conversation = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
assert conversation.id is not None
|
||||
assert conversation.id.startswith("conv-")
|
||||
assert conversation.agent_id == agent.id
|
||||
|
||||
def test_list_conversations(self, client: Letta, agent):
|
||||
"""Test listing conversations for an agent."""
|
||||
# Create multiple conversations
|
||||
conv1 = client.conversations.create(agent_id=agent.id)
|
||||
conv2 = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# List conversations
|
||||
conversations = client.conversations.list(agent_id=agent.id)
|
||||
|
||||
assert len(conversations) >= 2
|
||||
conv_ids = [c.id for c in conversations]
|
||||
assert conv1.id in conv_ids
|
||||
assert conv2.id in conv_ids
|
||||
|
||||
def test_retrieve_conversation(self, client: Letta, agent):
|
||||
"""Test retrieving a specific conversation."""
|
||||
# Create a conversation
|
||||
created = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# Retrieve it (should have empty in_context_message_ids initially)
|
||||
retrieved = client.conversations.retrieve(conversation_id=created.id)
|
||||
|
||||
assert retrieved.id == created.id
|
||||
assert retrieved.agent_id == created.agent_id
|
||||
assert retrieved.in_context_message_ids == []
|
||||
|
||||
# Send a message to the conversation
|
||||
list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=created.id,
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
)
|
||||
)
|
||||
|
||||
# Retrieve again and check in_context_message_ids is populated
|
||||
retrieved_with_messages = client.conversations.retrieve(conversation_id=created.id)
|
||||
|
||||
# System message + user + assistant messages should be in the conversation
|
||||
assert len(retrieved_with_messages.in_context_message_ids) >= 3 # system + user + assistant
|
||||
# All IDs should be strings starting with "message-"
|
||||
for msg_id in retrieved_with_messages.in_context_message_ids:
|
||||
assert isinstance(msg_id, str)
|
||||
assert msg_id.startswith("message-")
|
||||
|
||||
# Verify message ordering by listing messages
|
||||
messages = client.conversations.messages.list(conversation_id=created.id)
|
||||
assert len(messages) >= 3 # system + user + assistant
|
||||
# First message should be system message (shared across conversations)
|
||||
assert messages[0].message_type == "system_message", f"First message should be system_message, got {messages[0].message_type}"
|
||||
# Second message should be user message
|
||||
assert messages[1].message_type == "user_message", f"Second message should be user_message, got {messages[1].message_type}"
|
||||
|
||||
def test_send_message_to_conversation(self, client: Letta, agent):
|
||||
"""Test sending a message to a conversation."""
|
||||
# Create a conversation
|
||||
conversation = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# Send a message (returns a stream)
|
||||
stream = client.conversations.messages.create(
|
||||
conversation_id=conversation.id,
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
)
|
||||
|
||||
# Consume the stream to get messages
|
||||
messages = list(stream)
|
||||
|
||||
# Check response contains messages
|
||||
assert len(messages) > 0
|
||||
# Should have at least an assistant message
|
||||
message_types = [m.message_type for m in messages if hasattr(m, "message_type")]
|
||||
assert "assistant_message" in message_types
|
||||
|
||||
def test_list_conversation_messages(self, client: Letta, agent):
|
||||
"""Test listing messages from a conversation."""
|
||||
# Create a conversation
|
||||
conversation = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# Send a message to create some history (consume the stream)
|
||||
stream = client.conversations.messages.create(
|
||||
conversation_id=conversation.id,
|
||||
messages=[{"role": "user", "content": "Say 'test response' back to me."}],
|
||||
)
|
||||
list(stream) # Consume stream
|
||||
|
||||
# List messages
|
||||
messages = client.conversations.messages.list(conversation_id=conversation.id)
|
||||
|
||||
assert len(messages) >= 2 # At least user + assistant
|
||||
message_types = [m.message_type for m in messages]
|
||||
assert "user_message" in message_types
|
||||
assert "assistant_message" in message_types
|
||||
|
||||
# Send another message and check that old and new messages are both listed
|
||||
first_message_count = len(messages)
|
||||
stream = client.conversations.messages.create(
|
||||
conversation_id=conversation.id,
|
||||
messages=[{"role": "user", "content": "This is a follow-up message."}],
|
||||
)
|
||||
list(stream) # Consume stream
|
||||
|
||||
# List messages again
|
||||
updated_messages = client.conversations.messages.list(conversation_id=conversation.id)
|
||||
|
||||
# Should have more messages now (at least 2 more: user + assistant)
|
||||
assert len(updated_messages) >= first_message_count + 2
|
||||
|
||||
def test_conversation_isolation(self, client: Letta, agent):
|
||||
"""Test that conversations are isolated from each other."""
|
||||
# Create two conversations
|
||||
conv1 = client.conversations.create(agent_id=agent.id)
|
||||
conv2 = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# Send different messages to each (consume streams)
|
||||
list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conv1.id,
|
||||
messages=[{"role": "user", "content": "Remember the word: APPLE"}],
|
||||
)
|
||||
)
|
||||
list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conv2.id,
|
||||
messages=[{"role": "user", "content": "Remember the word: BANANA"}],
|
||||
)
|
||||
)
|
||||
|
||||
# List messages from each conversation
|
||||
conv1_messages = client.conversations.messages.list(conversation_id=conv1.id)
|
||||
conv2_messages = client.conversations.messages.list(conversation_id=conv2.id)
|
||||
|
||||
# Check messages are separate
|
||||
conv1_content = " ".join([m.content for m in conv1_messages if hasattr(m, "content") and m.content])
|
||||
conv2_content = " ".join([m.content for m in conv2_messages if hasattr(m, "content") and m.content])
|
||||
|
||||
assert "APPLE" in conv1_content
|
||||
assert "BANANA" in conv2_content
|
||||
# Each conversation should only have its own word
|
||||
assert "BANANA" not in conv1_content or "APPLE" not in conv2_content
|
||||
|
||||
# Ask what word was remembered and make sure it's different for each conversation
|
||||
conv1_recall = list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conv1.id,
|
||||
messages=[{"role": "user", "content": "What word did I ask you to remember? Reply with just the word."}],
|
||||
)
|
||||
)
|
||||
conv2_recall = list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conv2.id,
|
||||
messages=[{"role": "user", "content": "What word did I ask you to remember? Reply with just the word."}],
|
||||
)
|
||||
)
|
||||
|
||||
# Get the assistant responses
|
||||
conv1_response = " ".join([m.content for m in conv1_recall if hasattr(m, "message_type") and m.message_type == "assistant_message"])
|
||||
conv2_response = " ".join([m.content for m in conv2_recall if hasattr(m, "message_type") and m.message_type == "assistant_message"])
|
||||
|
||||
assert "APPLE" in conv1_response.upper(), f"Conv1 should remember APPLE, got: {conv1_response}"
|
||||
assert "BANANA" in conv2_response.upper(), f"Conv2 should remember BANANA, got: {conv2_response}"
|
||||
|
||||
# Each conversation has its own system message (created on first message)
|
||||
conv1_system_id = conv1_messages[0].id
|
||||
conv2_system_id = conv2_messages[0].id
|
||||
assert conv1_system_id != conv2_system_id, "System messages should have different IDs for different conversations"
|
||||
|
||||
def test_conversation_messages_pagination(self, client: Letta, agent):
|
||||
"""Test pagination when listing conversation messages."""
|
||||
# Create a conversation
|
||||
conversation = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# Send multiple messages to create history (consume streams)
|
||||
for i in range(3):
|
||||
list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conversation.id,
|
||||
messages=[{"role": "user", "content": f"Message number {i}"}],
|
||||
)
|
||||
)
|
||||
|
||||
# List with limit
|
||||
messages = client.conversations.messages.list(
|
||||
conversation_id=conversation.id,
|
||||
limit=2,
|
||||
)
|
||||
|
||||
# Should respect the limit
|
||||
assert len(messages) <= 2
|
||||
|
||||
def test_retrieve_conversation_stream_no_active_run(self, client: Letta, agent):
|
||||
"""Test that retrieve_conversation_stream returns error when no active run exists."""
|
||||
from letta_client import BadRequestError
|
||||
|
||||
# Create a conversation
|
||||
conversation = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# Try to retrieve stream when no run exists (should fail)
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
# Use the SDK's stream method
|
||||
stream = client.conversations.messages.stream(conversation_id=conversation.id)
|
||||
list(stream) # Consume the stream to trigger the error
|
||||
|
||||
# Should return 400 because no active run exists
|
||||
assert "No active runs found" in str(exc_info.value)
|
||||
|
||||
def test_retrieve_conversation_stream_after_completed_run(self, client: Letta, agent):
|
||||
"""Test that retrieve_conversation_stream returns error when run is completed."""
|
||||
from letta_client import BadRequestError
|
||||
|
||||
# Create a conversation
|
||||
conversation = client.conversations.create(agent_id=agent.id)
|
||||
|
||||
# Send a message (this creates a run that completes)
|
||||
list(
|
||||
client.conversations.messages.create(
|
||||
conversation_id=conversation.id,
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
)
|
||||
)
|
||||
|
||||
# Try to retrieve stream after the run has completed (should fail)
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
# Use the SDK's stream method
|
||||
stream = client.conversations.messages.stream(conversation_id=conversation.id)
|
||||
list(stream) # Consume the stream to trigger the error
|
||||
|
||||
# Should return 400 because no active run exists (run is completed)
|
||||
assert "No active runs found" in str(exc_info.value)
|
||||
@@ -917,6 +917,92 @@ async def test_tool_call(
|
||||
assert run.status == ("cancelled" if cancellation == "with_cancellation" else "completed")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_config",
|
||||
TESTED_MODEL_CONFIGS,
|
||||
ids=[handle for handle, _ in TESTED_MODEL_CONFIGS],
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_conversation_streaming_raw_http(
|
||||
disable_e2b_api_key: Any,
|
||||
client: AsyncLetta,
|
||||
server_url: str,
|
||||
agent_state: AgentState,
|
||||
model_config: Tuple[str, dict],
|
||||
) -> None:
|
||||
"""
|
||||
Test conversation-based streaming functionality using raw HTTP requests.
|
||||
|
||||
This test verifies that:
|
||||
1. A conversation can be created for an agent
|
||||
2. Messages can be sent to the conversation via streaming
|
||||
3. The streaming response contains the expected message types
|
||||
4. Messages are properly persisted in the conversation
|
||||
|
||||
Uses raw HTTP requests instead of SDK until SDK is regenerated with conversations support.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
model_handle, model_settings = model_config
|
||||
agent_state = await client.agents.update(agent_id=agent_state.id, model=model_handle, model_settings=model_settings)
|
||||
|
||||
async with httpx.AsyncClient(base_url=server_url, timeout=60.0) as http_client:
|
||||
# Create a conversation for the agent
|
||||
create_response = await http_client.post(
|
||||
"/v1/conversations/",
|
||||
params={"agent_id": agent_state.id},
|
||||
json={},
|
||||
)
|
||||
assert create_response.status_code == 200, f"Failed to create conversation: {create_response.text}"
|
||||
conversation = create_response.json()
|
||||
assert conversation["id"] is not None
|
||||
assert conversation["agent_id"] == agent_state.id
|
||||
|
||||
# Send a message to the conversation using streaming
|
||||
stream_response = await http_client.post(
|
||||
f"/v1/conversations/{conversation['id']}/messages",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": f"Reply with the message '{USER_MESSAGE_RESPONSE}'."}],
|
||||
"stream_tokens": True,
|
||||
},
|
||||
)
|
||||
assert stream_response.status_code == 200, f"Failed to send message: {stream_response.text}"
|
||||
|
||||
# Parse SSE response and accumulate messages
|
||||
messages = await accumulate_chunks(stream_response.text)
|
||||
print("MESSAGES:", messages)
|
||||
|
||||
# Verify the response contains expected message types
|
||||
assert_greeting_response(messages, model_handle, model_settings, streaming=True, token_streaming=True)
|
||||
|
||||
# Verify the conversation can be retrieved
|
||||
retrieve_response = await http_client.get(f"/v1/conversations/{conversation['id']}")
|
||||
assert retrieve_response.status_code == 200, f"Failed to retrieve conversation: {retrieve_response.text}"
|
||||
retrieved_conversation = retrieve_response.json()
|
||||
assert retrieved_conversation["id"] == conversation["id"]
|
||||
print("RETRIEVED CONVERSATION:", retrieved_conversation)
|
||||
|
||||
# Verify conversations can be listed for the agent
|
||||
list_response = await http_client.get("/v1/conversations/", params={"agent_id": agent_state.id})
|
||||
assert list_response.status_code == 200, f"Failed to list conversations: {list_response.text}"
|
||||
conversations_list = list_response.json()
|
||||
assert any(c["id"] == conversation["id"] for c in conversations_list)
|
||||
|
||||
# Verify messages can be listed from the conversation
|
||||
messages_response = await http_client.get(f"/v1/conversations/{conversation['id']}/messages")
|
||||
assert messages_response.status_code == 200, f"Failed to list conversation messages: {messages_response.text}"
|
||||
conversation_messages = messages_response.json()
|
||||
print("CONVERSATION MESSAGES:", conversation_messages)
|
||||
|
||||
# Verify we have at least the user message and assistant message
|
||||
assert len(conversation_messages) >= 2, f"Expected at least 2 messages, got {len(conversation_messages)}"
|
||||
|
||||
# Check message types are present
|
||||
message_types = [msg.get("message_type") for msg in conversation_messages]
|
||||
assert "user_message" in message_types, f"Expected user_message in {message_types}"
|
||||
assert "assistant_message" in message_types, f"Expected assistant_message in {message_types}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_handle,provider_type",
|
||||
[
|
||||
|
||||
506
tests/managers/test_conversation_manager.py
Normal file
506
tests/managers/test_conversation_manager.py
Normal file
@@ -0,0 +1,506 @@
|
||||
"""
|
||||
Tests for ConversationManager.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.conversation import CreateConversation, UpdateConversation
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.conversation_manager import ConversationManager
|
||||
|
||||
# ======================================================================================================================
|
||||
# ConversationManager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conversation_manager():
|
||||
"""Create a ConversationManager instance."""
|
||||
return ConversationManager()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test creating a conversation."""
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test conversation"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert conversation.id is not None
|
||||
assert conversation.agent_id == sarah_agent.id
|
||||
assert conversation.summary == "Test conversation"
|
||||
assert conversation.id.startswith("conv-")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_conversation_no_summary(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test creating a conversation without summary."""
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert conversation.id is not None
|
||||
assert conversation.agent_id == sarah_agent.id
|
||||
assert conversation.summary is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_by_id(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test retrieving a conversation by ID."""
|
||||
# Create a conversation
|
||||
created = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Retrieve it
|
||||
retrieved = await conversation_manager.get_conversation_by_id(
|
||||
conversation_id=created.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert retrieved.id == created.id
|
||||
assert retrieved.agent_id == created.agent_id
|
||||
assert retrieved.summary == created.summary
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_not_found(conversation_manager, server: SyncServer, default_user):
|
||||
"""Test retrieving a non-existent conversation raises error."""
|
||||
with pytest.raises(NoResultFound):
|
||||
await conversation_manager.get_conversation_by_id(
|
||||
conversation_id="conv-nonexistent",
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversations(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test listing conversations for an agent."""
|
||||
# Create multiple conversations
|
||||
for i in range(3):
|
||||
await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary=f"Conversation {i}"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# List them
|
||||
conversations = await conversation_manager.list_conversations(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert len(conversations) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversations_with_limit(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test listing conversations with a limit."""
|
||||
# Create multiple conversations
|
||||
for i in range(5):
|
||||
await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary=f"Conversation {i}"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# List with limit
|
||||
conversations = await conversation_manager.list_conversations(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
limit=2,
|
||||
)
|
||||
|
||||
assert len(conversations) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test updating a conversation."""
|
||||
# Create a conversation
|
||||
created = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Original"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Update it
|
||||
updated = await conversation_manager.update_conversation(
|
||||
conversation_id=created.id,
|
||||
conversation_update=UpdateConversation(summary="Updated summary"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert updated.id == created.id
|
||||
assert updated.summary == "Updated summary"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test soft deleting a conversation."""
|
||||
# Create a conversation
|
||||
created = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="To delete"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Delete it
|
||||
await conversation_manager.delete_conversation(
|
||||
conversation_id=created.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Verify it's no longer accessible
|
||||
with pytest.raises(NoResultFound):
|
||||
await conversation_manager.get_conversation_by_id(
|
||||
conversation_id=created.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_isolation_by_agent(conversation_manager, server: SyncServer, sarah_agent, charles_agent, default_user):
|
||||
"""Test that conversations are isolated by agent."""
|
||||
# Create conversation for sarah_agent
|
||||
await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Sarah's conversation"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create conversation for charles_agent
|
||||
await conversation_manager.create_conversation(
|
||||
agent_id=charles_agent.id,
|
||||
conversation_create=CreateConversation(summary="Charles's conversation"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# List for sarah_agent
|
||||
sarah_convos = await conversation_manager.list_conversations(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
)
|
||||
assert len(sarah_convos) == 1
|
||||
assert sarah_convos[0].summary == "Sarah's conversation"
|
||||
|
||||
# List for charles_agent
|
||||
charles_convos = await conversation_manager.list_conversations(
|
||||
agent_id=charles_agent.id,
|
||||
actor=default_user,
|
||||
)
|
||||
assert len(charles_convos) == 1
|
||||
assert charles_convos[0].summary == "Charles's conversation"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_isolation_by_organization(
|
||||
conversation_manager, server: SyncServer, sarah_agent, default_user, other_user_different_org
|
||||
):
|
||||
"""Test that conversations are isolated by organization."""
|
||||
# Create conversation
|
||||
created = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Other org user should not be able to access it
|
||||
with pytest.raises(NoResultFound):
|
||||
await conversation_manager.get_conversation_by_id(
|
||||
conversation_id=created.id,
|
||||
actor=other_user_different_org,
|
||||
)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Conversation Message Management Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_messages_to_conversation(
|
||||
conversation_manager, server: SyncServer, sarah_agent, default_user, hello_world_message_fixture
|
||||
):
|
||||
"""Test adding messages to a conversation."""
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Add the message to the conversation
|
||||
await conversation_manager.add_messages_to_conversation(
|
||||
conversation_id=conversation.id,
|
||||
agent_id=sarah_agent.id,
|
||||
message_ids=[hello_world_message_fixture.id],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Verify message is in conversation
|
||||
message_ids = await conversation_manager.get_message_ids_for_conversation(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert len(message_ids) == 1
|
||||
assert message_ids[0] == hello_world_message_fixture.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_messages_for_conversation(
|
||||
conversation_manager, server: SyncServer, sarah_agent, default_user, hello_world_message_fixture
|
||||
):
|
||||
"""Test getting full message objects from a conversation."""
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Add the message
|
||||
await conversation_manager.add_messages_to_conversation(
|
||||
conversation_id=conversation.id,
|
||||
agent_id=sarah_agent.id,
|
||||
message_ids=[hello_world_message_fixture.id],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Get full messages
|
||||
messages = await conversation_manager.get_messages_for_conversation(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].id == hello_world_message_fixture.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_ordering_in_conversation(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that messages maintain their order in a conversation."""
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create multiple messages
|
||||
pydantic_messages = [
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
content=[TextContent(text=f"Message {i}")],
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
pydantic_messages,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Add messages in order
|
||||
await conversation_manager.add_messages_to_conversation(
|
||||
conversation_id=conversation.id,
|
||||
agent_id=sarah_agent.id,
|
||||
message_ids=[m.id for m in messages],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Verify order is maintained
|
||||
retrieved_ids = await conversation_manager.get_message_ids_for_conversation(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert retrieved_ids == [m.id for m in messages]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_in_context_messages(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test updating which messages are in context."""
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create messages
|
||||
pydantic_messages = [
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
content=[TextContent(text=f"Message {i}")],
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
pydantic_messages,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Add all messages
|
||||
await conversation_manager.add_messages_to_conversation(
|
||||
conversation_id=conversation.id,
|
||||
agent_id=sarah_agent.id,
|
||||
message_ids=[m.id for m in messages],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Update to only keep first and last in context
|
||||
await conversation_manager.update_in_context_messages(
|
||||
conversation_id=conversation.id,
|
||||
in_context_message_ids=[messages[0].id, messages[2].id],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Verify only the selected messages are in context
|
||||
in_context_ids = await conversation_manager.get_message_ids_for_conversation(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert len(in_context_ids) == 2
|
||||
assert messages[0].id in in_context_ids
|
||||
assert messages[2].id in in_context_ids
|
||||
assert messages[1].id not in in_context_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_conversation_message_ids(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test getting message IDs from an empty conversation."""
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Empty"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Get message IDs (should be empty)
|
||||
message_ids = await conversation_manager.get_message_ids_for_conversation(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert message_ids == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversation_messages(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test listing messages from a conversation as LettaMessages."""
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create messages with different roles
|
||||
pydantic_messages = [
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
content=[TextContent(text="Hello!")],
|
||||
),
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role="assistant",
|
||||
content=[TextContent(text="Hi there!")],
|
||||
),
|
||||
]
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
pydantic_messages,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Add messages to conversation
|
||||
await conversation_manager.add_messages_to_conversation(
|
||||
conversation_id=conversation.id,
|
||||
agent_id=sarah_agent.id,
|
||||
message_ids=[m.id for m in messages],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# List conversation messages (returns LettaMessages)
|
||||
letta_messages = await conversation_manager.list_conversation_messages(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert len(letta_messages) == 2
|
||||
# Check message types
|
||||
message_types = [m.message_type for m in letta_messages]
|
||||
assert "user_message" in message_types
|
||||
assert "assistant_message" in message_types
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversation_messages_pagination(conversation_manager, server: SyncServer, sarah_agent, default_user):
|
||||
"""Test pagination when listing conversation messages."""
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
|
||||
# Create a conversation
|
||||
conversation = await conversation_manager.create_conversation(
|
||||
agent_id=sarah_agent.id,
|
||||
conversation_create=CreateConversation(summary="Test"),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create multiple messages
|
||||
pydantic_messages = [
|
||||
PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
content=[TextContent(text=f"Message {i}")],
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
pydantic_messages,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Add messages to conversation
|
||||
await conversation_manager.add_messages_to_conversation(
|
||||
conversation_id=conversation.id,
|
||||
agent_id=sarah_agent.id,
|
||||
message_ids=[m.id for m in messages],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# List with limit
|
||||
letta_messages = await conversation_manager.list_conversation_messages(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
limit=2,
|
||||
)
|
||||
assert len(letta_messages) == 2
|
||||
|
||||
# List with after cursor (get messages after the first one)
|
||||
letta_messages_after = await conversation_manager.list_conversation_messages(
|
||||
conversation_id=conversation.id,
|
||||
actor=default_user,
|
||||
after=messages[0].id,
|
||||
)
|
||||
assert len(letta_messages_after) == 4 # Should get messages 1-4
|
||||
Reference in New Issue
Block a user