diff --git a/letta/errors.py b/letta/errors.py index 3528596a..30ca0ad2 100644 --- a/letta/errors.py +++ b/letta/errors.py @@ -377,3 +377,10 @@ class AgentExportProcessingError(AgentFileExportError): class AgentFileImportError(Exception): """Exception raised during agent file import operations""" + + +class RunCancelError(LettaError): + """Error raised when a run cannot be cancelled.""" + + def __init__(self, message: str): + super().__init__(message=message) diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index 91435638..17829366 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -168,7 +168,12 @@ class AnthropicClient(LLMClientBase): if hasattr(llm_config, "response_format") and isinstance(llm_config.response_format, JsonSchemaResponseFormat): betas.append("structured-outputs-2025-11-13") - return await client.beta.messages.create(**request_data, betas=betas) + # log failed requests + try: + return await client.beta.messages.create(**request_data, betas=betas) + except Exception as e: + logger.error(f"Error streaming Anthropic request: {e} with request data: {json.dumps(request_data)}") + raise e @trace_method async def send_llm_batch_request_async( diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index 42d042b5..554b2606 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -140,11 +140,16 @@ class GoogleVertexClient(LLMClientBase): @trace_method async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncIterator[GenerateContentResponse]: client = self._get_client() - response = await client.aio.models.generate_content_stream( - model=llm_config.model, - contents=request_data["contents"], - config=request_data["config"], - ) + + try: + response = await client.aio.models.generate_content_stream( + model=llm_config.model, + contents=request_data["contents"], + config=request_data["config"], + ) + except Exception as e: + logger.error(f"Error streaming Google Vertex request: {e} with request data: {json.dumps(request_data)}") + raise e # Direct yield - keeps response alive in generator's local scope throughout iteration # This is required because the SDK's connection lifecycle is tied to the response object async for chunk in response: diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index d4b989f9..77e80109 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -1,4 +1,5 @@ import asyncio +import json import os import time from typing import Any, List, Optional @@ -762,17 +763,25 @@ class OpenAIClient(LLMClientBase): # Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages' if "input" in request_data and "messages" not in request_data: - response_stream: AsyncStream[ResponseStreamEvent] = await client.responses.create( - **request_data, - stream=True, - # stream_options={"include_usage": True}, - ) + try: + response_stream: AsyncStream[ResponseStreamEvent] = await client.responses.create( + **request_data, + stream=True, + # stream_options={"include_usage": True}, + ) + except Exception as e: + logger.error(f"Error streaming OpenAI Responses request: {e} with request data: {json.dumps(request_data)}") + raise e else: - response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create( - **request_data, - stream=True, - stream_options={"include_usage": True}, - ) + try: + response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create( + **request_data, + stream=True, + stream_options={"include_usage": True}, + ) + except Exception as e: + logger.error(f"Error streaming OpenAI Chat Completions request: {e} with request data: {json.dumps(request_data)}") + raise e return response_stream @trace_method diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 08917d10..786bc604 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -25,6 +25,7 @@ from letta.errors import ( AgentFileImportError, AgentNotFoundForExportError, PendingApprovalError, + RunCancelError, ) from letta.groups.sleeptime_multi_agent_v4 import SleeptimeMultiAgentV4 from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns @@ -1682,6 +1683,7 @@ async def cancel_message( run_ids = [run_id] results = {} + failed_to_cancel = [] for run_id in run_ids: run = await server.run_manager.get_run_by_id(run_id=run_id, actor=actor) if run.metadata.get("lettuce"): @@ -1691,8 +1693,14 @@ async def cancel_message( run = await server.run_manager.cancel_run(actor=actor, agent_id=agent_id, run_id=run_id) except Exception as e: results[run_id] = "failed" + logger.error(f"Failed to cancel run {run_id}: {str(e)}") + failed_to_cancel.append(run_id) continue results[run_id] = "cancelled" + logger.info(f"Cancelled run {run_id}") + + if failed_to_cancel: + raise RunCancelError(f"Failed to cancel runs: {failed_to_cancel}") return results diff --git a/letta/services/run_manager.py b/letta/services/run_manager.py index 1ede2ea6..6140d24c 100644 --- a/letta/services/run_manager.py +++ b/letta/services/run_manager.py @@ -686,17 +686,28 @@ class RunManager: # Combine approval response and tool messages new_messages = approval_response_messages + [tool_message] - # Insert the approval response and tool messages into the database - persisted_messages = await self.message_manager.create_many_messages_async( - pydantic_msgs=new_messages, - actor=actor, - run_id=run_id, - ) - logger.debug(f"Persisted {len(persisted_messages)} messages (approval + tool returns)") + # Checkpoint the new messages + from letta.agents.agent_loop import AgentLoop - # Update the agent's message_ids to include the new messages (approval + tool message) - agent_state.message_ids = agent_state.message_ids + [m.id for m in persisted_messages] - await self.agent_manager.update_message_ids_async(agent_id=agent_state.id, message_ids=agent_state.message_ids, actor=actor) + agent_loop = AgentLoop.load(agent_state=agent_state, actor=actor) + new_in_context_messages = current_in_context_messages + new_messages + await agent_loop._checkpoint_messages( + run_id=run_id, + step_id=approval_request_message.step_id, + new_messages=new_messages, + in_context_messages=new_in_context_messages, + ) + + # persisted_messages = await self.message_manager.create_many_messages_async( + # pydantic_msgs=new_messages, + # actor=actor, + # run_id=run_id, + # ) + # logger.debug(f"Persisted {len(persisted_messages)} messages (approval + tool returns)") + + ## Update the agent's message_ids to include the new messages (approval + tool message) + # agent_state.message_ids = agent_state.message_ids + [m.id for m in persisted_messages] + # await self.agent_manager.update_message_ids_async(agent_id=agent_state.id, message_ids=agent_state.message_ids, actor=actor) logger.debug( f"Inserted approval response with {len(denials)} denials and tool return message for cancelled run {run_id}. " diff --git a/letta/utils.py b/letta/utils.py index 7525ca84..62548030 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -1261,77 +1261,6 @@ def safe_create_file_processing_task(coro, file_metadata, server, actor, logger: return task -class CancellationSignal: - """ - A signal that can be checked for cancellation during streaming operations. - - This provides a lightweight way to check if an operation should be cancelled - without having to pass job managers and other dependencies through every method. - """ - - def __init__(self, job_manager=None, job_id=None, actor=None): - from letta.log import get_logger - from letta.schemas.user import User - from letta.services.job_manager import JobManager - - self.job_manager: JobManager | None = job_manager - self.job_id: str | None = job_id - self.actor: User | None = actor - self._is_cancelled = False - self.logger = get_logger(__name__) - - async def is_cancelled(self) -> bool: - """ - Check if the operation has been cancelled. - - Returns: - True if cancelled, False otherwise - """ - from letta.schemas.enums import JobStatus - - if self._is_cancelled: - return True - - if not self.job_manager or not self.job_id or not self.actor: - return False - - try: - job = await self.job_manager.get_job_by_id_async(job_id=self.job_id, actor=self.actor) - self._is_cancelled = job.status == JobStatus.cancelled - return self._is_cancelled - except Exception as e: - self.logger.warning(f"Failed to check cancellation status for job {self.job_id}: {e}") - return False - - def cancel(self): - """Mark this signal as cancelled locally (for testing or direct cancellation).""" - self._is_cancelled = True - - async def check_and_raise_if_cancelled(self): - """ - Check for cancellation and raise CancelledError if cancelled. - - Raises: - asyncio.CancelledError: If the operation has been cancelled - """ - if await self.is_cancelled(): - self.logger.info(f"Operation cancelled for job {self.job_id}") - raise asyncio.CancelledError(f"Job {self.job_id} was cancelled") - - -class NullCancellationSignal(CancellationSignal): - """A null cancellation signal that is never cancelled.""" - - def __init__(self): - super().__init__() - - async def is_cancelled(self) -> bool: - return False - - async def check_and_raise_if_cancelled(self): - pass - - async def get_latest_alembic_revision() -> str: """Get the current alembic revision ID from the alembic_version table.""" from letta.server.db import db_registry