feat: add conversation cancellation endpoint (#8729)
This commit is contained in:
@@ -56,11 +56,13 @@ class PendingApprovalError(LettaError):
|
|||||||
class NoActiveRunsToCancelError(LettaError):
|
class NoActiveRunsToCancelError(LettaError):
|
||||||
"""Error raised when attempting to cancel but there are no active runs to cancel."""
|
"""Error raised when attempting to cancel but there are no active runs to cancel."""
|
||||||
|
|
||||||
def __init__(self, agent_id: Optional[str] = None):
|
def __init__(self, agent_id: Optional[str] = None, conversation_id: Optional[str] = None):
|
||||||
message = "No active runs to cancel"
|
message = "No active runs to cancel"
|
||||||
if agent_id:
|
if agent_id:
|
||||||
message = f"No active runs to cancel for agent {agent_id}"
|
message = f"No active runs to cancel for agent {agent_id}"
|
||||||
details = {"error_code": "NO_ACTIVE_RUNS_TO_CANCEL", "agent_id": agent_id}
|
if conversation_id:
|
||||||
|
message = f"No active runs to cancel for conversation {conversation_id}"
|
||||||
|
details = {"error_code": "NO_ACTIVE_RUNS_TO_CANCEL", "agent_id": agent_id, "conversation_id": conversation_id}
|
||||||
super().__init__(message=message, code=ErrorCode.CONFLICT, details=details)
|
super().__init__(message=message, code=ErrorCode.CONFLICT, details=details)
|
||||||
|
|
||||||
|
|
||||||
@@ -165,9 +167,7 @@ class LettaImageFetchError(LettaError):
|
|||||||
def __init__(self, url: str, reason: str):
|
def __init__(self, url: str, reason: str):
|
||||||
details = {"url": url, "reason": reason}
|
details = {"url": url, "reason": reason}
|
||||||
super().__init__(
|
super().__init__(
|
||||||
message=f"Failed to fetch image from {url}: {reason}",
|
message=f"Failed to fetch image from {url}: {reason}", code=ErrorCode.INVALID_ARGUMENT, details=details,
|
||||||
code=ErrorCode.INVALID_ARGUMENT,
|
|
||||||
details=details,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -308,9 +308,7 @@ class ContextWindowExceededError(LettaError):
|
|||||||
def __init__(self, message: str, details: dict = {}):
|
def __init__(self, message: str, details: dict = {}):
|
||||||
error_message = f"{message} ({details})"
|
error_message = f"{message} ({details})"
|
||||||
super().__init__(
|
super().__init__(
|
||||||
message=error_message,
|
message=error_message, code=ErrorCode.CONTEXT_WINDOW_EXCEEDED, details=details,
|
||||||
code=ErrorCode.CONTEXT_WINDOW_EXCEEDED,
|
|
||||||
details=details,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -330,9 +328,7 @@ class RateLimitExceededError(LettaError):
|
|||||||
def __init__(self, message: str, max_retries: int):
|
def __init__(self, message: str, max_retries: int):
|
||||||
error_message = f"{message} ({max_retries})"
|
error_message = f"{message} ({max_retries})"
|
||||||
super().__init__(
|
super().__init__(
|
||||||
message=error_message,
|
message=error_message, code=ErrorCode.RATE_LIMIT_EXCEEDED, details={"max_retries": max_retries},
|
||||||
code=ErrorCode.RATE_LIMIT_EXCEEDED,
|
|
||||||
details={"max_retries": max_retries},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -387,8 +383,7 @@ class HandleNotFoundError(LettaError):
|
|||||||
|
|
||||||
def __init__(self, handle: str, available_handles: List[str]):
|
def __init__(self, handle: str, available_handles: List[str]):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
message=f"Handle {handle} not found, must be one of {available_handles}",
|
message=f"Handle {handle} not found, must be one of {available_handles}", code=ErrorCode.NOT_FOUND,
|
||||||
code=ErrorCode.NOT_FOUND,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ from pydantic import Field
|
|||||||
from starlette.responses import StreamingResponse
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
||||||
from letta.errors import LettaExpiredError, LettaInvalidArgumentError
|
from letta.errors import LettaExpiredError, LettaInvalidArgumentError, NoActiveRunsToCancelError
|
||||||
|
from letta.log import get_logger
|
||||||
from letta.helpers.datetime_helpers import get_utc_time
|
from letta.helpers.datetime_helpers import get_utc_time
|
||||||
from letta.schemas.conversation import Conversation, CreateConversation
|
from letta.schemas.conversation import Conversation, CreateConversation
|
||||||
from letta.schemas.enums import RunStatus
|
from letta.schemas.enums import RunStatus
|
||||||
@@ -22,6 +23,7 @@ from letta.server.rest_api.streaming_response import (
|
|||||||
)
|
)
|
||||||
from letta.server.server import SyncServer
|
from letta.server.server import SyncServer
|
||||||
from letta.services.conversation_manager import ConversationManager
|
from letta.services.conversation_manager import ConversationManager
|
||||||
|
from letta.services.lettuce import LettuceClient
|
||||||
from letta.services.run_manager import RunManager
|
from letta.services.run_manager import RunManager
|
||||||
from letta.services.streaming_service import StreamingService
|
from letta.services.streaming_service import StreamingService
|
||||||
from letta.settings import settings
|
from letta.settings import settings
|
||||||
@@ -29,6 +31,8 @@ from letta.validators import ConversationId
|
|||||||
|
|
||||||
router = APIRouter(prefix="/conversations", tags=["conversations"])
|
router = APIRouter(prefix="/conversations", tags=["conversations"])
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
# Instantiate manager
|
# Instantiate manager
|
||||||
conversation_manager = ConversationManager()
|
conversation_manager = ConversationManager()
|
||||||
|
|
||||||
@@ -42,11 +46,7 @@ async def create_conversation(
|
|||||||
):
|
):
|
||||||
"""Create a new conversation for an agent."""
|
"""Create a new conversation for an agent."""
|
||||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||||
return await conversation_manager.create_conversation(
|
return await conversation_manager.create_conversation(agent_id=agent_id, conversation_create=conversation_create, actor=actor,)
|
||||||
agent_id=agent_id,
|
|
||||||
conversation_create=conversation_create,
|
|
||||||
actor=actor,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=List[Conversation], operation_id="list_conversations")
|
@router.get("/", response_model=List[Conversation], operation_id="list_conversations")
|
||||||
@@ -59,26 +59,16 @@ async def list_conversations(
|
|||||||
):
|
):
|
||||||
"""List all conversations for an agent."""
|
"""List all conversations for an agent."""
|
||||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||||
return await conversation_manager.list_conversations(
|
return await conversation_manager.list_conversations(agent_id=agent_id, actor=actor, limit=limit, after=after,)
|
||||||
agent_id=agent_id,
|
|
||||||
actor=actor,
|
|
||||||
limit=limit,
|
|
||||||
after=after,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{conversation_id}", response_model=Conversation, operation_id="retrieve_conversation")
|
@router.get("/{conversation_id}", response_model=Conversation, operation_id="retrieve_conversation")
|
||||||
async def retrieve_conversation(
|
async def retrieve_conversation(
|
||||||
conversation_id: ConversationId,
|
conversation_id: ConversationId, server: SyncServer = Depends(get_letta_server), headers: HeaderParams = Depends(get_headers),
|
||||||
server: SyncServer = Depends(get_letta_server),
|
|
||||||
headers: HeaderParams = Depends(get_headers),
|
|
||||||
):
|
):
|
||||||
"""Retrieve a specific conversation."""
|
"""Retrieve a specific conversation."""
|
||||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||||
return await conversation_manager.get_conversation_by_id(
|
return await conversation_manager.get_conversation_by_id(conversation_id=conversation_id, actor=actor,)
|
||||||
conversation_id=conversation_id,
|
|
||||||
actor=actor,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
ConversationMessagesResponse = Annotated[
|
ConversationMessagesResponse = Annotated[
|
||||||
@@ -87,9 +77,7 @@ ConversationMessagesResponse = Annotated[
|
|||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/{conversation_id}/messages",
|
"/{conversation_id}/messages", response_model=ConversationMessagesResponse, operation_id="list_conversation_messages",
|
||||||
response_model=ConversationMessagesResponse,
|
|
||||||
operation_id="list_conversation_messages",
|
|
||||||
)
|
)
|
||||||
async def list_conversation_messages(
|
async def list_conversation_messages(
|
||||||
conversation_id: ConversationId,
|
conversation_id: ConversationId,
|
||||||
@@ -135,12 +123,7 @@ async def list_conversation_messages(
|
|||||||
response_model=LettaStreamingResponse,
|
response_model=LettaStreamingResponse,
|
||||||
operation_id="send_conversation_message",
|
operation_id="send_conversation_message",
|
||||||
responses={
|
responses={
|
||||||
200: {
|
200: {"description": "Successful response", "content": {"text/event-stream": {"description": "Server-Sent Events stream"},},}
|
||||||
"description": "Successful response",
|
|
||||||
"content": {
|
|
||||||
"text/event-stream": {"description": "Server-Sent Events stream"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def send_conversation_message(
|
async def send_conversation_message(
|
||||||
@@ -158,10 +141,7 @@ async def send_conversation_message(
|
|||||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||||
|
|
||||||
# Get the conversation to find the agent_id
|
# Get the conversation to find the agent_id
|
||||||
conversation = await conversation_manager.get_conversation_by_id(
|
conversation = await conversation_manager.get_conversation_by_id(conversation_id=conversation_id, actor=actor,)
|
||||||
conversation_id=conversation_id,
|
|
||||||
actor=actor,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Force streaming mode for this endpoint
|
# Force streaming mode for this endpoint
|
||||||
request.streaming = True
|
request.streaming = True
|
||||||
@@ -169,11 +149,7 @@ async def send_conversation_message(
|
|||||||
# Use streaming service
|
# Use streaming service
|
||||||
streaming_service = StreamingService(server)
|
streaming_service = StreamingService(server)
|
||||||
run, result = await streaming_service.create_agent_stream(
|
run, result = await streaming_service.create_agent_stream(
|
||||||
agent_id=conversation.agent_id,
|
agent_id=conversation.agent_id, actor=actor, request=request, run_type="send_conversation_message", conversation_id=conversation_id,
|
||||||
actor=actor,
|
|
||||||
request=request,
|
|
||||||
run_type="send_conversation_message",
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -228,11 +204,7 @@ async def retrieve_conversation_stream(
|
|||||||
|
|
||||||
# Find the most recent active run for this conversation
|
# Find the most recent active run for this conversation
|
||||||
active_runs = await runs_manager.list_runs(
|
active_runs = await runs_manager.list_runs(
|
||||||
actor=actor,
|
actor=actor, conversation_id=conversation_id, statuses=[RunStatus.created, RunStatus.running], limit=1, ascending=False,
|
||||||
conversation_id=conversation_id,
|
|
||||||
statuses=[RunStatus.created, RunStatus.running],
|
|
||||||
limit=1,
|
|
||||||
ascending=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not active_runs:
|
if not active_runs:
|
||||||
@@ -267,17 +239,57 @@ async def retrieve_conversation_stream(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if settings.enable_cancellation_aware_streaming:
|
if settings.enable_cancellation_aware_streaming:
|
||||||
stream = cancellation_aware_stream_wrapper(
|
stream = cancellation_aware_stream_wrapper(stream_generator=stream, run_manager=server.run_manager, run_id=run.id, actor=actor,)
|
||||||
stream_generator=stream,
|
|
||||||
run_manager=server.run_manager,
|
|
||||||
run_id=run.id,
|
|
||||||
actor=actor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if request and request.include_pings and settings.enable_keepalive:
|
if request and request.include_pings and settings.enable_keepalive:
|
||||||
stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval, run_id=run.id)
|
stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval, run_id=run.id)
|
||||||
|
|
||||||
return StreamingResponseWithStatusCode(
|
return StreamingResponseWithStatusCode(stream, media_type="text/event-stream",)
|
||||||
stream,
|
|
||||||
media_type="text/event-stream",
|
|
||||||
|
@router.post("/{conversation_id}/messages/cancel", operation_id="cancel_conversation_message")
|
||||||
|
async def cancel_conversation_message(
|
||||||
|
conversation_id: ConversationId, server: SyncServer = Depends(get_letta_server), headers: HeaderParams = Depends(get_headers),
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Cancel runs associated with a conversation.
|
||||||
|
|
||||||
|
Note: To cancel active runs, Redis is required.
|
||||||
|
"""
|
||||||
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||||
|
|
||||||
|
if not settings.track_agent_run:
|
||||||
|
raise HTTPException(status_code=400, detail="Agent run tracking is disabled")
|
||||||
|
|
||||||
|
# Verify conversation exists and get agent_id
|
||||||
|
conversation = await conversation_manager.get_conversation_by_id(conversation_id=conversation_id, actor=actor,)
|
||||||
|
|
||||||
|
# Find active runs for this conversation
|
||||||
|
runs = await server.run_manager.list_runs(
|
||||||
|
actor=actor, statuses=[RunStatus.created, RunStatus.running], ascending=False, conversation_id=conversation_id, limit=100,
|
||||||
)
|
)
|
||||||
|
run_ids = [run.id for run in runs]
|
||||||
|
|
||||||
|
if not run_ids:
|
||||||
|
raise NoActiveRunsToCancelError(conversation_id=conversation_id)
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
for run_id in run_ids:
|
||||||
|
try:
|
||||||
|
run = await server.run_manager.get_run_by_id(run_id=run_id, actor=actor)
|
||||||
|
if run.metadata and run.metadata.get("lettuce"):
|
||||||
|
try:
|
||||||
|
lettuce_client = await LettuceClient.create()
|
||||||
|
await lettuce_client.cancel(run_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to cancel Lettuce run {run_id}: {e}")
|
||||||
|
|
||||||
|
await server.run_manager.cancel_run(actor=actor, agent_id=conversation.agent_id, run_id=run_id)
|
||||||
|
except Exception as e:
|
||||||
|
results[run_id] = "failed"
|
||||||
|
logger.error(f"Failed to cancel run {run_id}: {str(e)}")
|
||||||
|
continue
|
||||||
|
results[run_id] = "cancelled"
|
||||||
|
logger.info(f"Cancelled run {run_id}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|||||||
Reference in New Issue
Block a user