diff --git a/letta/errors.py b/letta/errors.py index 22fea4f6..5a2eb849 100644 --- a/letta/errors.py +++ b/letta/errors.py @@ -56,11 +56,13 @@ class PendingApprovalError(LettaError): class NoActiveRunsToCancelError(LettaError): """Error raised when attempting to cancel but there are no active runs to cancel.""" - def __init__(self, agent_id: Optional[str] = None): + def __init__(self, agent_id: Optional[str] = None, conversation_id: Optional[str] = None): message = "No active runs to cancel" if agent_id: message = f"No active runs to cancel for agent {agent_id}" - details = {"error_code": "NO_ACTIVE_RUNS_TO_CANCEL", "agent_id": agent_id} + if conversation_id: + message = f"No active runs to cancel for conversation {conversation_id}" + details = {"error_code": "NO_ACTIVE_RUNS_TO_CANCEL", "agent_id": agent_id, "conversation_id": conversation_id} super().__init__(message=message, code=ErrorCode.CONFLICT, details=details) @@ -165,9 +167,7 @@ class LettaImageFetchError(LettaError): def __init__(self, url: str, reason: str): details = {"url": url, "reason": reason} super().__init__( - message=f"Failed to fetch image from {url}: {reason}", - code=ErrorCode.INVALID_ARGUMENT, - details=details, + message=f"Failed to fetch image from {url}: {reason}", code=ErrorCode.INVALID_ARGUMENT, details=details, ) @@ -308,9 +308,7 @@ class ContextWindowExceededError(LettaError): def __init__(self, message: str, details: dict = {}): error_message = f"{message} ({details})" super().__init__( - message=error_message, - code=ErrorCode.CONTEXT_WINDOW_EXCEEDED, - details=details, + message=error_message, code=ErrorCode.CONTEXT_WINDOW_EXCEEDED, details=details, ) @@ -330,9 +328,7 @@ class RateLimitExceededError(LettaError): def __init__(self, message: str, max_retries: int): error_message = f"{message} ({max_retries})" super().__init__( - message=error_message, - code=ErrorCode.RATE_LIMIT_EXCEEDED, - details={"max_retries": max_retries}, + message=error_message, code=ErrorCode.RATE_LIMIT_EXCEEDED, details={"max_retries": max_retries}, ) @@ -387,8 +383,7 @@ class HandleNotFoundError(LettaError): def __init__(self, handle: str, available_handles: List[str]): super().__init__( - message=f"Handle {handle} not found, must be one of {available_handles}", - code=ErrorCode.NOT_FOUND, + message=f"Handle {handle} not found, must be one of {available_handles}", code=ErrorCode.NOT_FOUND, ) diff --git a/letta/server/rest_api/routers/v1/conversations.py b/letta/server/rest_api/routers/v1/conversations.py index 58aad829..2102a803 100644 --- a/letta/server/rest_api/routers/v1/conversations.py +++ b/letta/server/rest_api/routers/v1/conversations.py @@ -6,7 +6,8 @@ from pydantic import Field from starlette.responses import StreamingResponse from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client -from letta.errors import LettaExpiredError, LettaInvalidArgumentError +from letta.errors import LettaExpiredError, LettaInvalidArgumentError, NoActiveRunsToCancelError +from letta.log import get_logger from letta.helpers.datetime_helpers import get_utc_time from letta.schemas.conversation import Conversation, CreateConversation from letta.schemas.enums import RunStatus @@ -22,6 +23,7 @@ from letta.server.rest_api.streaming_response import ( ) from letta.server.server import SyncServer from letta.services.conversation_manager import ConversationManager +from letta.services.lettuce import LettuceClient from letta.services.run_manager import RunManager from letta.services.streaming_service import StreamingService from letta.settings import settings @@ -29,6 +31,8 @@ from letta.validators import ConversationId router = APIRouter(prefix="/conversations", tags=["conversations"]) +logger = get_logger(__name__) + # Instantiate manager conversation_manager = ConversationManager() @@ -42,11 +46,7 @@ async def create_conversation( ): """Create a new conversation for an agent.""" actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - return await conversation_manager.create_conversation( - agent_id=agent_id, - conversation_create=conversation_create, - actor=actor, - ) + return await conversation_manager.create_conversation(agent_id=agent_id, conversation_create=conversation_create, actor=actor,) @router.get("/", response_model=List[Conversation], operation_id="list_conversations") @@ -59,26 +59,16 @@ async def list_conversations( ): """List all conversations for an agent.""" actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - return await conversation_manager.list_conversations( - agent_id=agent_id, - actor=actor, - limit=limit, - after=after, - ) + return await conversation_manager.list_conversations(agent_id=agent_id, actor=actor, limit=limit, after=after,) @router.get("/{conversation_id}", response_model=Conversation, operation_id="retrieve_conversation") async def retrieve_conversation( - conversation_id: ConversationId, - server: SyncServer = Depends(get_letta_server), - headers: HeaderParams = Depends(get_headers), + conversation_id: ConversationId, server: SyncServer = Depends(get_letta_server), headers: HeaderParams = Depends(get_headers), ): """Retrieve a specific conversation.""" actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - return await conversation_manager.get_conversation_by_id( - conversation_id=conversation_id, - actor=actor, - ) + return await conversation_manager.get_conversation_by_id(conversation_id=conversation_id, actor=actor,) ConversationMessagesResponse = Annotated[ @@ -87,9 +77,7 @@ ConversationMessagesResponse = Annotated[ @router.get( - "/{conversation_id}/messages", - response_model=ConversationMessagesResponse, - operation_id="list_conversation_messages", + "/{conversation_id}/messages", response_model=ConversationMessagesResponse, operation_id="list_conversation_messages", ) async def list_conversation_messages( conversation_id: ConversationId, @@ -135,12 +123,7 @@ async def list_conversation_messages( response_model=LettaStreamingResponse, operation_id="send_conversation_message", responses={ - 200: { - "description": "Successful response", - "content": { - "text/event-stream": {"description": "Server-Sent Events stream"}, - }, - } + 200: {"description": "Successful response", "content": {"text/event-stream": {"description": "Server-Sent Events stream"},},} }, ) async def send_conversation_message( @@ -158,10 +141,7 @@ async def send_conversation_message( actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) # Get the conversation to find the agent_id - conversation = await conversation_manager.get_conversation_by_id( - conversation_id=conversation_id, - actor=actor, - ) + conversation = await conversation_manager.get_conversation_by_id(conversation_id=conversation_id, actor=actor,) # Force streaming mode for this endpoint request.streaming = True @@ -169,11 +149,7 @@ async def send_conversation_message( # Use streaming service streaming_service = StreamingService(server) run, result = await streaming_service.create_agent_stream( - agent_id=conversation.agent_id, - actor=actor, - request=request, - run_type="send_conversation_message", - conversation_id=conversation_id, + agent_id=conversation.agent_id, actor=actor, request=request, run_type="send_conversation_message", conversation_id=conversation_id, ) return result @@ -228,11 +204,7 @@ async def retrieve_conversation_stream( # Find the most recent active run for this conversation active_runs = await runs_manager.list_runs( - actor=actor, - conversation_id=conversation_id, - statuses=[RunStatus.created, RunStatus.running], - limit=1, - ascending=False, + actor=actor, conversation_id=conversation_id, statuses=[RunStatus.created, RunStatus.running], limit=1, ascending=False, ) if not active_runs: @@ -267,17 +239,57 @@ async def retrieve_conversation_stream( ) if settings.enable_cancellation_aware_streaming: - stream = cancellation_aware_stream_wrapper( - stream_generator=stream, - run_manager=server.run_manager, - run_id=run.id, - actor=actor, - ) + stream = cancellation_aware_stream_wrapper(stream_generator=stream, run_manager=server.run_manager, run_id=run.id, actor=actor,) if request and request.include_pings and settings.enable_keepalive: stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval, run_id=run.id) - return StreamingResponseWithStatusCode( - stream, - media_type="text/event-stream", + return StreamingResponseWithStatusCode(stream, media_type="text/event-stream",) + + +@router.post("/{conversation_id}/messages/cancel", operation_id="cancel_conversation_message") +async def cancel_conversation_message( + conversation_id: ConversationId, server: SyncServer = Depends(get_letta_server), headers: HeaderParams = Depends(get_headers), +) -> dict: + """ + Cancel runs associated with a conversation. + + Note: To cancel active runs, Redis is required. + """ + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + + if not settings.track_agent_run: + raise HTTPException(status_code=400, detail="Agent run tracking is disabled") + + # Verify conversation exists and get agent_id + conversation = await conversation_manager.get_conversation_by_id(conversation_id=conversation_id, actor=actor,) + + # Find active runs for this conversation + runs = await server.run_manager.list_runs( + actor=actor, statuses=[RunStatus.created, RunStatus.running], ascending=False, conversation_id=conversation_id, limit=100, ) + run_ids = [run.id for run in runs] + + if not run_ids: + raise NoActiveRunsToCancelError(conversation_id=conversation_id) + + results = {} + for run_id in run_ids: + try: + run = await server.run_manager.get_run_by_id(run_id=run_id, actor=actor) + if run.metadata and run.metadata.get("lettuce"): + try: + lettuce_client = await LettuceClient.create() + await lettuce_client.cancel(run_id) + except Exception as e: + logger.error(f"Failed to cancel Lettuce run {run_id}: {e}") + + await server.run_manager.cancel_run(actor=actor, agent_id=conversation.agent_id, run_id=run_id) + except Exception as e: + results[run_id] = "failed" + logger.error(f"Failed to cancel run {run_id}: {str(e)}") + continue + results[run_id] = "cancelled" + logger.info(f"Cancelled run {run_id}") + + return results