feat: add conversation cancellation endpoint (#8729)
This commit is contained in:
@@ -56,11 +56,13 @@ class PendingApprovalError(LettaError):
|
||||
class NoActiveRunsToCancelError(LettaError):
|
||||
"""Error raised when attempting to cancel but there are no active runs to cancel."""
|
||||
|
||||
def __init__(self, agent_id: Optional[str] = None):
|
||||
def __init__(self, agent_id: Optional[str] = None, conversation_id: Optional[str] = None):
|
||||
message = "No active runs to cancel"
|
||||
if agent_id:
|
||||
message = f"No active runs to cancel for agent {agent_id}"
|
||||
details = {"error_code": "NO_ACTIVE_RUNS_TO_CANCEL", "agent_id": agent_id}
|
||||
if conversation_id:
|
||||
message = f"No active runs to cancel for conversation {conversation_id}"
|
||||
details = {"error_code": "NO_ACTIVE_RUNS_TO_CANCEL", "agent_id": agent_id, "conversation_id": conversation_id}
|
||||
super().__init__(message=message, code=ErrorCode.CONFLICT, details=details)
|
||||
|
||||
|
||||
@@ -165,9 +167,7 @@ class LettaImageFetchError(LettaError):
|
||||
def __init__(self, url: str, reason: str):
|
||||
details = {"url": url, "reason": reason}
|
||||
super().__init__(
|
||||
message=f"Failed to fetch image from {url}: {reason}",
|
||||
code=ErrorCode.INVALID_ARGUMENT,
|
||||
details=details,
|
||||
message=f"Failed to fetch image from {url}: {reason}", code=ErrorCode.INVALID_ARGUMENT, details=details,
|
||||
)
|
||||
|
||||
|
||||
@@ -308,9 +308,7 @@ class ContextWindowExceededError(LettaError):
|
||||
def __init__(self, message: str, details: dict = {}):
|
||||
error_message = f"{message} ({details})"
|
||||
super().__init__(
|
||||
message=error_message,
|
||||
code=ErrorCode.CONTEXT_WINDOW_EXCEEDED,
|
||||
details=details,
|
||||
message=error_message, code=ErrorCode.CONTEXT_WINDOW_EXCEEDED, details=details,
|
||||
)
|
||||
|
||||
|
||||
@@ -330,9 +328,7 @@ class RateLimitExceededError(LettaError):
|
||||
def __init__(self, message: str, max_retries: int):
|
||||
error_message = f"{message} ({max_retries})"
|
||||
super().__init__(
|
||||
message=error_message,
|
||||
code=ErrorCode.RATE_LIMIT_EXCEEDED,
|
||||
details={"max_retries": max_retries},
|
||||
message=error_message, code=ErrorCode.RATE_LIMIT_EXCEEDED, details={"max_retries": max_retries},
|
||||
)
|
||||
|
||||
|
||||
@@ -387,8 +383,7 @@ class HandleNotFoundError(LettaError):
|
||||
|
||||
def __init__(self, handle: str, available_handles: List[str]):
|
||||
super().__init__(
|
||||
message=f"Handle {handle} not found, must be one of {available_handles}",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message=f"Handle {handle} not found, must be one of {available_handles}", code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,8 @@ from pydantic import Field
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
||||
from letta.errors import LettaExpiredError, LettaInvalidArgumentError
|
||||
from letta.errors import LettaExpiredError, LettaInvalidArgumentError, NoActiveRunsToCancelError
|
||||
from letta.log import get_logger
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.schemas.conversation import Conversation, CreateConversation
|
||||
from letta.schemas.enums import RunStatus
|
||||
@@ -22,6 +23,7 @@ from letta.server.rest_api.streaming_response import (
|
||||
)
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.conversation_manager import ConversationManager
|
||||
from letta.services.lettuce import LettuceClient
|
||||
from letta.services.run_manager import RunManager
|
||||
from letta.services.streaming_service import StreamingService
|
||||
from letta.settings import settings
|
||||
@@ -29,6 +31,8 @@ from letta.validators import ConversationId
|
||||
|
||||
router = APIRouter(prefix="/conversations", tags=["conversations"])
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Instantiate manager
|
||||
conversation_manager = ConversationManager()
|
||||
|
||||
@@ -42,11 +46,7 @@ async def create_conversation(
|
||||
):
|
||||
"""Create a new conversation for an agent."""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await conversation_manager.create_conversation(
|
||||
agent_id=agent_id,
|
||||
conversation_create=conversation_create,
|
||||
actor=actor,
|
||||
)
|
||||
return await conversation_manager.create_conversation(agent_id=agent_id, conversation_create=conversation_create, actor=actor,)
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Conversation], operation_id="list_conversations")
|
||||
@@ -59,26 +59,16 @@ async def list_conversations(
|
||||
):
|
||||
"""List all conversations for an agent."""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await conversation_manager.list_conversations(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
limit=limit,
|
||||
after=after,
|
||||
)
|
||||
return await conversation_manager.list_conversations(agent_id=agent_id, actor=actor, limit=limit, after=after,)
|
||||
|
||||
|
||||
@router.get("/{conversation_id}", response_model=Conversation, operation_id="retrieve_conversation")
|
||||
async def retrieve_conversation(
|
||||
conversation_id: ConversationId,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
conversation_id: ConversationId, server: SyncServer = Depends(get_letta_server), headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
"""Retrieve a specific conversation."""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await conversation_manager.get_conversation_by_id(
|
||||
conversation_id=conversation_id,
|
||||
actor=actor,
|
||||
)
|
||||
return await conversation_manager.get_conversation_by_id(conversation_id=conversation_id, actor=actor,)
|
||||
|
||||
|
||||
ConversationMessagesResponse = Annotated[
|
||||
@@ -87,9 +77,7 @@ ConversationMessagesResponse = Annotated[
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{conversation_id}/messages",
|
||||
response_model=ConversationMessagesResponse,
|
||||
operation_id="list_conversation_messages",
|
||||
"/{conversation_id}/messages", response_model=ConversationMessagesResponse, operation_id="list_conversation_messages",
|
||||
)
|
||||
async def list_conversation_messages(
|
||||
conversation_id: ConversationId,
|
||||
@@ -135,12 +123,7 @@ async def list_conversation_messages(
|
||||
response_model=LettaStreamingResponse,
|
||||
operation_id="send_conversation_message",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"text/event-stream": {"description": "Server-Sent Events stream"},
|
||||
},
|
||||
}
|
||||
200: {"description": "Successful response", "content": {"text/event-stream": {"description": "Server-Sent Events stream"},},}
|
||||
},
|
||||
)
|
||||
async def send_conversation_message(
|
||||
@@ -158,10 +141,7 @@ async def send_conversation_message(
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
# Get the conversation to find the agent_id
|
||||
conversation = await conversation_manager.get_conversation_by_id(
|
||||
conversation_id=conversation_id,
|
||||
actor=actor,
|
||||
)
|
||||
conversation = await conversation_manager.get_conversation_by_id(conversation_id=conversation_id, actor=actor,)
|
||||
|
||||
# Force streaming mode for this endpoint
|
||||
request.streaming = True
|
||||
@@ -169,11 +149,7 @@ async def send_conversation_message(
|
||||
# Use streaming service
|
||||
streaming_service = StreamingService(server)
|
||||
run, result = await streaming_service.create_agent_stream(
|
||||
agent_id=conversation.agent_id,
|
||||
actor=actor,
|
||||
request=request,
|
||||
run_type="send_conversation_message",
|
||||
conversation_id=conversation_id,
|
||||
agent_id=conversation.agent_id, actor=actor, request=request, run_type="send_conversation_message", conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -228,11 +204,7 @@ async def retrieve_conversation_stream(
|
||||
|
||||
# Find the most recent active run for this conversation
|
||||
active_runs = await runs_manager.list_runs(
|
||||
actor=actor,
|
||||
conversation_id=conversation_id,
|
||||
statuses=[RunStatus.created, RunStatus.running],
|
||||
limit=1,
|
||||
ascending=False,
|
||||
actor=actor, conversation_id=conversation_id, statuses=[RunStatus.created, RunStatus.running], limit=1, ascending=False,
|
||||
)
|
||||
|
||||
if not active_runs:
|
||||
@@ -267,17 +239,57 @@ async def retrieve_conversation_stream(
|
||||
)
|
||||
|
||||
if settings.enable_cancellation_aware_streaming:
|
||||
stream = cancellation_aware_stream_wrapper(
|
||||
stream_generator=stream,
|
||||
run_manager=server.run_manager,
|
||||
run_id=run.id,
|
||||
actor=actor,
|
||||
)
|
||||
stream = cancellation_aware_stream_wrapper(stream_generator=stream, run_manager=server.run_manager, run_id=run.id, actor=actor,)
|
||||
|
||||
if request and request.include_pings and settings.enable_keepalive:
|
||||
stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval, run_id=run.id)
|
||||
|
||||
return StreamingResponseWithStatusCode(
|
||||
stream,
|
||||
media_type="text/event-stream",
|
||||
return StreamingResponseWithStatusCode(stream, media_type="text/event-stream",)
|
||||
|
||||
|
||||
@router.post("/{conversation_id}/messages/cancel", operation_id="cancel_conversation_message")
|
||||
async def cancel_conversation_message(
|
||||
conversation_id: ConversationId, server: SyncServer = Depends(get_letta_server), headers: HeaderParams = Depends(get_headers),
|
||||
) -> dict:
|
||||
"""
|
||||
Cancel runs associated with a conversation.
|
||||
|
||||
Note: To cancel active runs, Redis is required.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
if not settings.track_agent_run:
|
||||
raise HTTPException(status_code=400, detail="Agent run tracking is disabled")
|
||||
|
||||
# Verify conversation exists and get agent_id
|
||||
conversation = await conversation_manager.get_conversation_by_id(conversation_id=conversation_id, actor=actor,)
|
||||
|
||||
# Find active runs for this conversation
|
||||
runs = await server.run_manager.list_runs(
|
||||
actor=actor, statuses=[RunStatus.created, RunStatus.running], ascending=False, conversation_id=conversation_id, limit=100,
|
||||
)
|
||||
run_ids = [run.id for run in runs]
|
||||
|
||||
if not run_ids:
|
||||
raise NoActiveRunsToCancelError(conversation_id=conversation_id)
|
||||
|
||||
results = {}
|
||||
for run_id in run_ids:
|
||||
try:
|
||||
run = await server.run_manager.get_run_by_id(run_id=run_id, actor=actor)
|
||||
if run.metadata and run.metadata.get("lettuce"):
|
||||
try:
|
||||
lettuce_client = await LettuceClient.create()
|
||||
await lettuce_client.cancel(run_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cancel Lettuce run {run_id}: {e}")
|
||||
|
||||
await server.run_manager.cancel_run(actor=actor, agent_id=conversation.agent_id, run_id=run_id)
|
||||
except Exception as e:
|
||||
results[run_id] = "failed"
|
||||
logger.error(f"Failed to cancel run {run_id}: {str(e)}")
|
||||
continue
|
||||
results[run_id] = "cancelled"
|
||||
logger.info(f"Cancelled run {run_id}")
|
||||
|
||||
return results
|
||||
|
||||
Reference in New Issue
Block a user