From f9bb757a985f925eff2bc4350c665ef7a5104fbc Mon Sep 17 00:00:00 2001 From: Andy Li <55300002+cliandy@users.noreply.github.com> Date: Wed, 2 Jul 2025 14:31:16 -0700 Subject: [PATCH] feat: support for agent loop job cancelation (#2837) --- letta/agents/letta_agent.py | 132 +++-- letta/constants.py | 1 + letta/groups/sleeptime_multi_agent_v2.py | 35 +- letta/interface.py | 2 +- .../anthropic_streaming_interface.py | 23 +- ...ai_chat_completions_streaming_interface.py | 15 +- .../interfaces/openai_streaming_interface.py | 28 +- letta/schemas/enums.py | 4 + letta/schemas/letta_stop_reason.py | 18 + letta/server/rest_api/routers/v1/agents.py | 489 +++++++++++------- letta/server/rest_api/routers/v1/jobs.py | 50 +- letta/server/rest_api/streaming_response.py | 88 ++++ letta/services/job_manager.py | 46 +- letta/utils.py | 83 ++- tests/integration_test_send_message.py | 204 ++++++++ tests/test_managers.py | 2 +- tests/test_provider_trace.py | 1 + 17 files changed, 940 insertions(+), 281 deletions(-) diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index d13546ca..750a42b1 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -1,8 +1,9 @@ import asyncio import json import uuid +from collections.abc import AsyncGenerator from datetime import datetime -from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union +from typing import Optional from openai import AsyncStream from openai.types.chat import ChatCompletionChunk @@ -34,7 +35,7 @@ from letta.otel.context import get_ctx_attributes from letta.otel.metric_registry import MetricRegistry from letta.otel.tracing import log_event, trace_method, tracer from letta.schemas.agent import AgentState, UpdateAgent -from letta.schemas.enums import MessageRole, ProviderType +from letta.schemas.enums import JobStatus, MessageRole, ProviderType from letta.schemas.letta_message import MessageType from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.letta_response import LettaResponse @@ -69,7 +70,6 @@ DEFAULT_SUMMARY_BLOCK_LABEL = "conversation_summary" class LettaAgent(BaseAgent): - def __init__( self, agent_id: str, @@ -81,6 +81,7 @@ class LettaAgent(BaseAgent): actor: User, step_manager: StepManager = NoopStepManager(), telemetry_manager: TelemetryManager = NoopTelemetryManager(), + current_run_id: str | None = None, summary_block_label: str = DEFAULT_SUMMARY_BLOCK_LABEL, message_buffer_limit: int = summarizer_settings.message_buffer_limit, message_buffer_min: int = summarizer_settings.message_buffer_min, @@ -96,7 +97,9 @@ class LettaAgent(BaseAgent): self.passage_manager = passage_manager self.step_manager = step_manager self.telemetry_manager = telemetry_manager - self.response_messages: List[Message] = [] + self.job_manager = job_manager + self.current_run_id = current_run_id + self.response_messages: list[Message] = [] self.last_function_response = None @@ -128,16 +131,35 @@ class LettaAgent(BaseAgent): message_buffer_min=message_buffer_min, ) + async def _check_run_cancellation(self) -> bool: + """ + Check if the current run associated with this agent execution has been cancelled. + + Returns: + True if the run is cancelled, False otherwise (or if no run is associated) + """ + if not self.job_manager or not self.current_run_id: + return False + + try: + job = await self.job_manager.get_job_by_id_async(job_id=self.current_run_id, actor=self.actor) + return job.status == JobStatus.cancelled + except Exception as e: + # Log the error but don't fail the execution + logger.warning(f"Failed to check job cancellation status for job {self.current_run_id}: {e}") + return False + @trace_method async def step( self, - input_messages: List[MessageCreate], + input_messages: list[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, - run_id: Optional[str] = None, + run_id: str | None = None, use_assistant_message: bool = True, - request_start_timestamp_ns: Optional[int] = None, - include_return_message_types: Optional[List[MessageType]] = None, + request_start_timestamp_ns: int | None = None, + include_return_message_types: list[MessageType] | None = None, ) -> LettaResponse: + # TODO (cliandy): pass in run_id and use at send_message endpoints for all step functions agent_state = await self.agent_manager.get_agent_by_id_async( agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor ) @@ -159,11 +181,11 @@ class LettaAgent(BaseAgent): @trace_method async def step_stream_no_tokens( self, - input_messages: List[MessageCreate], + input_messages: list[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, use_assistant_message: bool = True, - request_start_timestamp_ns: Optional[int] = None, - include_return_message_types: Optional[List[MessageType]] = None, + request_start_timestamp_ns: int | None = None, + include_return_message_types: list[MessageType] | None = None, ): agent_state = await self.agent_manager.get_agent_by_id_async( agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor @@ -186,6 +208,13 @@ class LettaAgent(BaseAgent): request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None}) for i in range(max_steps): + # Check for job cancellation at the start of each step + if await self._check_run_cancellation(): + stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value) + logger.info(f"Agent execution cancelled for run {self.current_run_id}") + yield f"data: {stop_reason.model_dump_json()}\n\n" + break + step_id = generate_step_id() step_start = get_utc_timestamp_ns() agent_step_span = tracer.start_span("agent_step", start_time=step_start) @@ -317,11 +346,11 @@ class LettaAgent(BaseAgent): async def _step( self, agent_state: AgentState, - input_messages: List[MessageCreate], + input_messages: list[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, - run_id: Optional[str] = None, - request_start_timestamp_ns: Optional[int] = None, - ) -> Tuple[List[Message], List[Message], Optional[LettaStopReason], LettaUsageStatistics]: + run_id: str | None = None, + request_start_timestamp_ns: int | None = None, + ) -> tuple[list[Message], list[Message], LettaStopReason | None, LettaUsageStatistics]: """ Carries out an invocation of the agent loop. In each step, the agent 1. Rebuilds its memory @@ -347,6 +376,12 @@ class LettaAgent(BaseAgent): stop_reason = None usage = LettaUsageStatistics() for i in range(max_steps): + # Check for job cancellation at the start of each step + if await self._check_run_cancellation(): + stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value) + logger.info(f"Agent execution cancelled for run {self.current_run_id}") + break + step_id = generate_step_id() step_start = get_utc_timestamp_ns() agent_step_span = tracer.start_span("agent_step", start_time=step_start) @@ -471,11 +506,11 @@ class LettaAgent(BaseAgent): @trace_method async def step_stream( self, - input_messages: List[MessageCreate], + input_messages: list[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, use_assistant_message: bool = True, - request_start_timestamp_ns: Optional[int] = None, - include_return_message_types: Optional[List[MessageType]] = None, + request_start_timestamp_ns: int | None = None, + include_return_message_types: list[MessageType] | None = None, ) -> AsyncGenerator[str, None]: """ Carries out an invocation of the agent loop in a streaming fashion that yields partial tokens. @@ -507,6 +542,13 @@ class LettaAgent(BaseAgent): request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None}) for i in range(max_steps): + # Check for job cancellation at the start of each step + if await self._check_run_cancellation(): + stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value) + logger.info(f"Agent execution cancelled for run {self.current_run_id}") + yield f"data: {stop_reason.model_dump_json()}\n\n" + break + step_id = generate_step_id() step_start = get_utc_timestamp_ns() agent_step_span = tracer.start_span("agent_step", start_time=step_start) @@ -547,7 +589,9 @@ class LettaAgent(BaseAgent): raise ValueError(f"Streaming not supported for {agent_state.llm_config}") async for chunk in interface.process( - stream, ttft_span=request_span, provider_request_start_timestamp_ns=provider_request_start_timestamp_ns + stream, + ttft_span=request_span, + provider_request_start_timestamp_ns=provider_request_start_timestamp_ns, ): # Measure time to first token if first_chunk and request_span is not None: @@ -690,13 +734,13 @@ class LettaAgent(BaseAgent): # noinspection PyInconsistentReturns async def _build_and_request_from_llm( self, - current_in_context_messages: List[Message], - new_in_context_messages: List[Message], + current_in_context_messages: list[Message], + new_in_context_messages: list[Message], agent_state: AgentState, llm_client: LLMClientBase, tool_rules_solver: ToolRulesSolver, agent_step_span: "Span", - ) -> Tuple[Dict, Dict, List[Message], List[Message], List[str]] | None: + ) -> tuple[dict, dict, list[Message], list[Message], list[str]] | None: for attempt in range(self.max_summarization_retries + 1): try: log_event("agent.stream_no_tokens.messages.refreshed") @@ -742,12 +786,12 @@ class LettaAgent(BaseAgent): first_chunk: bool, ttft_span: "Span", request_start_timestamp_ns: int, - current_in_context_messages: List[Message], - new_in_context_messages: List[Message], + current_in_context_messages: list[Message], + new_in_context_messages: list[Message], agent_state: AgentState, llm_client: LLMClientBase, tool_rules_solver: ToolRulesSolver, - ) -> Tuple[Dict, AsyncStream[ChatCompletionChunk], List[Message], List[Message], List[str], int] | None: + ) -> tuple[dict, AsyncStream[ChatCompletionChunk], list[Message], list[Message], list[str], int] | None: for attempt in range(self.max_summarization_retries + 1): try: log_event("agent.stream_no_tokens.messages.refreshed") @@ -799,11 +843,11 @@ class LettaAgent(BaseAgent): self, e: Exception, llm_client: LLMClientBase, - in_context_messages: List[Message], - new_letta_messages: List[Message], + in_context_messages: list[Message], + new_letta_messages: list[Message], llm_config: LLMConfig, force: bool, - ) -> List[Message]: + ) -> list[Message]: if isinstance(e, ContextWindowExceededError): return await self._rebuild_context_window( in_context_messages=in_context_messages, new_letta_messages=new_letta_messages, llm_config=llm_config, force=force @@ -814,12 +858,12 @@ class LettaAgent(BaseAgent): @trace_method async def _rebuild_context_window( self, - in_context_messages: List[Message], - new_letta_messages: List[Message], + in_context_messages: list[Message], + new_letta_messages: list[Message], llm_config: LLMConfig, - total_tokens: Optional[int] = None, + total_tokens: int | None = None, force: bool = False, - ) -> List[Message]: + ) -> list[Message]: # If total tokens is reached, we truncate down # TODO: This can be broken by bad configs, e.g. lower bound too high, initial messages too fat, etc. if force or (total_tokens and total_tokens > llm_config.context_window): @@ -855,10 +899,10 @@ class LettaAgent(BaseAgent): async def _create_llm_request_data_async( self, llm_client: LLMClientBase, - in_context_messages: List[Message], + in_context_messages: list[Message], agent_state: AgentState, tool_rules_solver: ToolRulesSolver, - ) -> Tuple[dict, List[str]]: + ) -> tuple[dict, list[str]]: self.num_messages, self.num_archival_memories = await asyncio.gather( ( self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id) @@ -929,18 +973,18 @@ class LettaAgent(BaseAgent): async def _handle_ai_response( self, tool_call: ToolCall, - valid_tool_names: List[str], + valid_tool_names: list[str], agent_state: AgentState, tool_rules_solver: ToolRulesSolver, usage: UsageStatistics, - reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None, - pre_computed_assistant_message_id: Optional[str] = None, + reasoning_content: list[TextContent | ReasoningContent | RedactedReasoningContent | OmittedReasoningContent] | None = None, + pre_computed_assistant_message_id: str | None = None, step_id: str | None = None, - initial_messages: Optional[List[Message]] = None, + initial_messages: list[Message] | None = None, agent_step_span: Optional["Span"] = None, - is_final_step: Optional[bool] = None, - run_id: Optional[str] = None, - ) -> Tuple[List[Message], bool, Optional[LettaStopReason]]: + is_final_step: bool | None = None, + run_id: str | None = None, + ) -> tuple[list[Message], bool, LettaStopReason | None]: """ Handle the final AI response once streaming completes, execute / validate the tool call, decide whether we should keep stepping, and persist state. @@ -1016,7 +1060,7 @@ class LettaAgent(BaseAgent): context_window_limit=agent_state.llm_config.context_window, usage=usage, provider_id=None, - job_id=run_id, + job_id=run_id if run_id else self.current_run_id, step_id=step_id, project_id=agent_state.project_id, ) @@ -1155,7 +1199,7 @@ class LettaAgent(BaseAgent): name="tool_execution_completed", attributes={ "tool_name": target_tool.name, - "duration_ms": ns_to_ms((end_time - start_time)), + "duration_ms": ns_to_ms(end_time - start_time), "success": tool_execution_result.success_flag, "tool_type": target_tool.tool_type, "tool_id": target_tool.id, @@ -1165,7 +1209,7 @@ class LettaAgent(BaseAgent): return tool_execution_result @trace_method - def _load_last_function_response(self, in_context_messages: List[Message]): + def _load_last_function_response(self, in_context_messages: list[Message]): """Load the last function response from message history""" for msg in reversed(in_context_messages): if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent): diff --git a/letta/constants.py b/letta/constants.py index 35232f6e..c83eefe4 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -358,6 +358,7 @@ REDIS_INCLUDE = "include" REDIS_EXCLUDE = "exclude" REDIS_SET_DEFAULT_VAL = "None" REDIS_DEFAULT_CACHE_PREFIX = "letta_cache" +REDIS_RUN_ID_PREFIX = "agent:send_message:run_id" # TODO: This is temporary, eventually use token-based eviction MAX_FILES_OPEN = 5 diff --git a/letta/groups/sleeptime_multi_agent_v2.py b/letta/groups/sleeptime_multi_agent_v2.py index a314ac2d..275fe3bf 100644 --- a/letta/groups/sleeptime_multi_agent_v2.py +++ b/letta/groups/sleeptime_multi_agent_v2.py @@ -1,6 +1,6 @@ import asyncio +from collections.abc import AsyncGenerator from datetime import datetime, timezone -from typing import AsyncGenerator, List, Optional from letta.agents.base_agent import BaseAgent from letta.agents.letta_agent import LettaAgent @@ -39,7 +39,8 @@ class SleeptimeMultiAgentV2(BaseAgent): actor: User, step_manager: StepManager = NoopStepManager(), telemetry_manager: TelemetryManager = NoopTelemetryManager(), - group: Optional[Group] = None, + group: Group | None = None, + current_run_id: str | None = None, ): super().__init__( agent_id=agent_id, @@ -54,6 +55,7 @@ class SleeptimeMultiAgentV2(BaseAgent): self.job_manager = job_manager self.step_manager = step_manager self.telemetry_manager = telemetry_manager + self.current_run_id = current_run_id # Group settings assert group.manager_type == ManagerType.sleeptime, f"Expected group manager type to be 'sleeptime', got {group.manager_type}" self.group = group @@ -61,12 +63,12 @@ class SleeptimeMultiAgentV2(BaseAgent): @trace_method async def step( self, - input_messages: List[MessageCreate], + input_messages: list[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, - run_id: Optional[str] = None, + run_id: str | None = None, use_assistant_message: bool = True, - request_start_timestamp_ns: Optional[int] = None, - include_return_message_types: Optional[List[MessageType]] = None, + request_start_timestamp_ns: int | None = None, + include_return_message_types: list[MessageType] | None = None, ) -> LettaResponse: run_ids = [] @@ -89,6 +91,7 @@ class SleeptimeMultiAgentV2(BaseAgent): actor=self.actor, step_manager=self.step_manager, telemetry_manager=self.telemetry_manager, + current_run_id=self.current_run_id, ) # Perform foreground agent step response = await foreground_agent.step( @@ -125,7 +128,7 @@ class SleeptimeMultiAgentV2(BaseAgent): except Exception as e: # Individual task failures - print(f"Agent processing failed: {str(e)}") + print(f"Agent processing failed: {e!s}") raise e response.usage.run_ids = run_ids @@ -134,11 +137,11 @@ class SleeptimeMultiAgentV2(BaseAgent): @trace_method async def step_stream_no_tokens( self, - input_messages: List[MessageCreate], + input_messages: list[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, use_assistant_message: bool = True, - request_start_timestamp_ns: Optional[int] = None, - include_return_message_types: Optional[List[MessageType]] = None, + request_start_timestamp_ns: int | None = None, + include_return_message_types: list[MessageType] | None = None, ): response = await self.step( input_messages=input_messages, @@ -157,11 +160,11 @@ class SleeptimeMultiAgentV2(BaseAgent): @trace_method async def step_stream( self, - input_messages: List[MessageCreate], + input_messages: list[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, use_assistant_message: bool = True, - request_start_timestamp_ns: Optional[int] = None, - include_return_message_types: Optional[List[MessageType]] = None, + request_start_timestamp_ns: int | None = None, + include_return_message_types: list[MessageType] | None = None, ) -> AsyncGenerator[str, None]: # Prepare new messages new_messages = [] @@ -182,6 +185,7 @@ class SleeptimeMultiAgentV2(BaseAgent): actor=self.actor, step_manager=self.step_manager, telemetry_manager=self.telemetry_manager, + current_run_id=self.current_run_id, ) # Perform foreground agent step async for chunk in foreground_agent.step_stream( @@ -218,7 +222,7 @@ class SleeptimeMultiAgentV2(BaseAgent): async def _issue_background_task( self, sleeptime_agent_id: str, - response_messages: List[Message], + response_messages: list[Message], last_processed_message_id: str, use_assistant_message: bool = True, ) -> str: @@ -248,7 +252,7 @@ class SleeptimeMultiAgentV2(BaseAgent): self, foreground_agent_id: str, sleeptime_agent_id: str, - response_messages: List[Message], + response_messages: list[Message], last_processed_message_id: str, run_id: str, use_assistant_message: bool = True, @@ -296,6 +300,7 @@ class SleeptimeMultiAgentV2(BaseAgent): actor=self.actor, step_manager=self.step_manager, telemetry_manager=self.telemetry_manager, + current_run_id=self.current_run_id, message_buffer_limit=20, # TODO: Make this configurable message_buffer_min=8, # TODO: Make this configurable enable_summarization=False, # TODO: Make this configurable diff --git a/letta/interface.py b/letta/interface.py index 281274de..ca9eea12 100644 --- a/letta/interface.py +++ b/letta/interface.py @@ -81,7 +81,7 @@ class CLIInterface(AgentInterface): @staticmethod def internal_monologue(msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None): # ANSI escape code for italic is '\x1B[3m' - fstr = f"\x1B[3m{Fore.LIGHTBLACK_EX}{INNER_THOUGHTS_CLI_SYMBOL} {{msg}}{Style.RESET_ALL}" + fstr = f"\x1b[3m{Fore.LIGHTBLACK_EX}{INNER_THOUGHTS_CLI_SYMBOL} {{msg}}{Style.RESET_ALL}" if STRIP_UI: fstr = "{msg}" print(fstr.format(msg=msg)) diff --git a/letta/interfaces/anthropic_streaming_interface.py b/letta/interfaces/anthropic_streaming_interface.py index 1bd63bbf..9e1328fa 100644 --- a/letta/interfaces/anthropic_streaming_interface.py +++ b/letta/interfaces/anthropic_streaming_interface.py @@ -1,7 +1,9 @@ +import asyncio import json +from collections.abc import AsyncGenerator from datetime import datetime, timezone from enum import Enum -from typing import AsyncGenerator, List, Optional, Union +from typing import Optional from anthropic import AsyncStream from anthropic.types.beta import ( @@ -131,14 +133,16 @@ class AnthropicStreamingInterface: self, stream: AsyncStream[BetaRawMessageStreamEvent], ttft_span: Optional["Span"] = None, - provider_request_start_timestamp_ns: Optional[int] = None, - ) -> AsyncGenerator[LettaMessage, None]: + provider_request_start_timestamp_ns: int | None = None, + ) -> AsyncGenerator[LettaMessage | LettaStopReason, None]: prev_message_type = None message_index = 0 first_chunk = True try: async with stream: async for event in stream: + # TODO (cliandy): reconsider in stream cancellations + # await cancellation_token.check_and_raise_if_cancelled() if first_chunk and ttft_span is not None and provider_request_start_timestamp_ns is not None: now = get_utc_timestamp_ns() ttft_ns = now - provider_request_start_timestamp_ns @@ -384,18 +388,21 @@ class AnthropicStreamingInterface: self.tool_call_buffer = [] self.anthropic_mode = None + except asyncio.CancelledError as e: + logger.info("Cancelled stream %s", e) + yield LettaStopReason(stop_reason=StopReasonType.cancelled) + raise except Exception as e: logger.error("Error processing stream: %s", e) - stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) - yield stop_reason + yield LettaStopReason(stop_reason=StopReasonType.error) raise finally: logger.info("AnthropicStreamingInterface: Stream processing complete.") - def get_reasoning_content(self) -> List[Union[TextContent, ReasoningContent, RedactedReasoningContent]]: + def get_reasoning_content(self) -> list[TextContent | ReasoningContent | RedactedReasoningContent]: def _process_group( - group: List[Union[ReasoningMessage, HiddenReasoningMessage]], group_type: str - ) -> Union[TextContent, ReasoningContent, RedactedReasoningContent]: + group: list[ReasoningMessage | HiddenReasoningMessage], group_type: str + ) -> TextContent | ReasoningContent | RedactedReasoningContent: if group_type == "reasoning": reasoning_text = "".join(chunk.reasoning for chunk in group).strip() is_native = any(chunk.source == "reasoner_model" for chunk in group) diff --git a/letta/interfaces/openai_chat_completions_streaming_interface.py b/letta/interfaces/openai_chat_completions_streaming_interface.py index a58ee554..b0a06d39 100644 --- a/letta/interfaces/openai_chat_completions_streaming_interface.py +++ b/letta/interfaces/openai_chat_completions_streaming_interface.py @@ -1,4 +1,5 @@ -from typing import Any, AsyncGenerator, Dict, List, Optional +from collections.abc import AsyncGenerator +from typing import Any from openai import AsyncStream from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta @@ -19,14 +20,14 @@ class OpenAIChatCompletionsStreamingInterface: self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser() self.stream_pre_execution_message: bool = stream_pre_execution_message - self.current_parsed_json_result: Dict[str, Any] = {} - self.content_buffer: List[str] = [] + self.current_parsed_json_result: dict[str, Any] = {} + self.content_buffer: list[str] = [] self.tool_call_happened: bool = False self.finish_reason_stop: bool = False - self.tool_call_name: Optional[str] = None + self.tool_call_name: str | None = None self.tool_call_args_str: str = "" - self.tool_call_id: Optional[str] = None + self.tool_call_id: str | None = None async def process(self, stream: AsyncStream[ChatCompletionChunk]) -> AsyncGenerator[str, None]: """ @@ -35,6 +36,8 @@ class OpenAIChatCompletionsStreamingInterface: """ async with stream: async for chunk in stream: + # TODO (cliandy): reconsider in stream cancellations + # await cancellation_token.check_and_raise_if_cancelled() if chunk.choices: choice = chunk.choices[0] delta = choice.delta @@ -103,7 +106,7 @@ class OpenAIChatCompletionsStreamingInterface: ) ) - def _handle_finish_reason(self, finish_reason: Optional[str]) -> bool: + def _handle_finish_reason(self, finish_reason: str | None) -> bool: """Handles the finish reason and determines if streaming should stop.""" if finish_reason == "tool_calls": self.tool_call_happened = True diff --git a/letta/interfaces/openai_streaming_interface.py b/letta/interfaces/openai_streaming_interface.py index 7d5de339..f03e81f5 100644 --- a/letta/interfaces/openai_streaming_interface.py +++ b/letta/interfaces/openai_streaming_interface.py @@ -1,5 +1,7 @@ +import asyncio +from collections.abc import AsyncGenerator from datetime import datetime, timezone -from typing import AsyncGenerator, List, Optional +from typing import Optional from openai import AsyncStream from openai.types.chat.chat_completion_chunk import ChatCompletionChunk @@ -55,12 +57,12 @@ class OpenAIStreamingInterface: self.input_tokens = 0 self.output_tokens = 0 - self.content_buffer: List[str] = [] - self.tool_call_name: Optional[str] = None - self.tool_call_id: Optional[str] = None + self.content_buffer: list[str] = [] + self.tool_call_name: str | None = None + self.tool_call_id: str | None = None self.reasoning_messages = [] - def get_reasoning_content(self) -> List[TextContent]: + def get_reasoning_content(self) -> list[TextContent | OmittedReasoningContent]: content = "".join(self.reasoning_messages).strip() # Right now we assume that all models omit reasoning content for OAI, @@ -87,8 +89,8 @@ class OpenAIStreamingInterface: self, stream: AsyncStream[ChatCompletionChunk], ttft_span: Optional["Span"] = None, - provider_request_start_timestamp_ns: Optional[int] = None, - ) -> AsyncGenerator[LettaMessage, None]: + provider_request_start_timestamp_ns: int | None = None, + ) -> AsyncGenerator[LettaMessage | LettaStopReason, None]: """ Iterates over the OpenAI stream, yielding SSE events. It also collects tokens and detects if a tool call is triggered. @@ -99,6 +101,8 @@ class OpenAIStreamingInterface: prev_message_type = None message_index = 0 async for chunk in stream: + # TODO (cliandy): reconsider in stream cancellations + # await cancellation_token.check_and_raise_if_cancelled() if first_chunk and ttft_span is not None and provider_request_start_timestamp_ns is not None: now = get_utc_timestamp_ns() ttft_ns = now - provider_request_start_timestamp_ns @@ -224,8 +228,7 @@ class OpenAIStreamingInterface: # If there was nothing in the name buffer, we can proceed to # output the arguments chunk as a ToolCallMessage else: - - # use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..." + # use_assistant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..." if self.use_assistant_message and ( self.last_flushed_function_name is not None and self.last_flushed_function_name == self.assistant_message_tool_name @@ -349,10 +352,13 @@ class OpenAIStreamingInterface: prev_message_type = tool_call_msg.message_type yield tool_call_msg self.function_id_buffer = None + except asyncio.CancelledError as e: + logger.info("Cancelled stream %s", e) + yield LettaStopReason(stop_reason=StopReasonType.cancelled) + raise except Exception as e: logger.error("Error processing stream: %s", e) - stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) - yield stop_reason + yield LettaStopReason(stop_reason=StopReasonType.error) raise finally: logger.info("OpenAIStreamingInterface: Stream processing complete.") diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index d4c4714e..f4cd35f8 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -54,6 +54,10 @@ class JobStatus(str, Enum): cancelled = "cancelled" expired = "expired" + @property + def is_terminal(self): + return self in (JobStatus.completed, JobStatus.failed, JobStatus.cancelled, JobStatus.expired) + class AgentStepStatus(str, Enum): """ diff --git a/letta/schemas/letta_stop_reason.py b/letta/schemas/letta_stop_reason.py index 66761222..ab37ef19 100644 --- a/letta/schemas/letta_stop_reason.py +++ b/letta/schemas/letta_stop_reason.py @@ -3,6 +3,8 @@ from typing import Literal from pydantic import BaseModel, Field +from letta.schemas.enums import JobStatus + class StopReasonType(str, Enum): end_turn = "end_turn" @@ -11,6 +13,22 @@ class StopReasonType(str, Enum): max_steps = "max_steps" no_tool_call = "no_tool_call" tool_rule = "tool_rule" + cancelled = "cancelled" + + @property + def run_status(self) -> JobStatus: + if self in ( + StopReasonType.end_turn, + StopReasonType.max_steps, + StopReasonType.tool_rule, + ): + return JobStatus.completed + elif self in (StopReasonType.error, StopReasonType.invalid_tool_call, StopReasonType.no_tool_call): + return JobStatus.failed + elif self == StopReasonType.cancelled: + return JobStatus.cancelled + else: + raise ValueError("Unknown StopReasonType") class LettaStopReason(BaseModel): diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index d281dc4b..deabe63e 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -2,7 +2,7 @@ import asyncio import json import traceback from datetime import datetime, timezone -from typing import Annotated, Any, List, Optional +from typing import Annotated, Any from fastapi import APIRouter, Body, Depends, File, Header, HTTPException, Query, Request, UploadFile, status from fastapi.responses import JSONResponse @@ -13,7 +13,8 @@ from sqlalchemy.exc import IntegrityError, OperationalError from starlette.responses import Response, StreamingResponse from letta.agents.letta_agent import LettaAgent -from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, LETTA_MODEL_ENDPOINT +from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, LETTA_MODEL_ENDPOINT, REDIS_RUN_ID_PREFIX +from letta.data_sources.redis_client import get_redis_client from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2 from letta.helpers.datetime_helpers import get_utc_timestamp_ns from letta.log import get_logger @@ -49,26 +50,26 @@ router = APIRouter(prefix="/agents", tags=["agents"]) logger = get_logger(__name__) -@router.get("/", response_model=List[AgentState], operation_id="list_agents") +@router.get("/", response_model=list[AgentState], operation_id="list_agents") async def list_agents( - name: Optional[str] = Query(None, description="Name of the agent"), - tags: Optional[List[str]] = Query(None, description="List of tags to filter agents by"), + name: str | None = Query(None, description="Name of the agent"), + tags: list[str] | None = Query(None, description="List of tags to filter agents by"), match_all_tags: bool = Query( False, description="If True, only returns agents that match ALL given tags. Otherwise, return agents that have ANY of the passed-in tags.", ), server: SyncServer = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), - before: Optional[str] = Query(None, description="Cursor for pagination"), - after: Optional[str] = Query(None, description="Cursor for pagination"), - limit: Optional[int] = Query(50, description="Limit for pagination"), - query_text: Optional[str] = Query(None, description="Search agents by name"), - project_id: Optional[str] = Query(None, description="Search agents by project ID"), - template_id: Optional[str] = Query(None, description="Search agents by template ID"), - base_template_id: Optional[str] = Query(None, description="Search agents by base template ID"), - identity_id: Optional[str] = Query(None, description="Search agents by identity ID"), - identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"), - include_relationships: Optional[List[str]] = Query( + actor_id: str | None = Header(None, alias="user_id"), + before: str | None = Query(None, description="Cursor for pagination"), + after: str | None = Query(None, description="Cursor for pagination"), + limit: int | None = Query(50, description="Limit for pagination"), + query_text: str | None = Query(None, description="Search agents by name"), + project_id: str | None = Query(None, description="Search agents by project ID"), + template_id: str | None = Query(None, description="Search agents by template ID"), + base_template_id: str | None = Query(None, description="Search agents by base template ID"), + identity_id: str | None = Query(None, description="Search agents by identity ID"), + identifier_keys: list[str] | None = Query(None, description="Search agents by identifier keys"), + include_relationships: list[str] | None = Query( None, description=( "Specify which relational fields (e.g., 'tools', 'sources', 'memory') to include in the response. " @@ -80,7 +81,7 @@ async def list_agents( False, description="Whether to sort agents oldest to newest (True) or newest to oldest (False, default)", ), - sort_by: Optional[str] = Query( + sort_by: str | None = Query( "created_at", description="Field to sort by. Options: 'created_at' (default), 'last_run_completion'", ), @@ -119,7 +120,7 @@ async def list_agents( @router.get("/count", response_model=int, operation_id="count_agents") async def count_agents( server: SyncServer = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), + actor_id: str | None = Header(None, alias="user_id"), ): """ Get the count of all agents associated with a given user. @@ -139,10 +140,10 @@ class IndentedORJSONResponse(Response): def export_agent_serialized( agent_id: str, server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), + actor_id: str | None = Header(None, alias="user_id"), # do not remove, used to autogeneration of spec # TODO: Think of a better way to export AgentSchema - spec: Optional[AgentSchema] = None, + spec: AgentSchema | None = None, ) -> JSONResponse: """ Export the serialized JSON representation of an agent, formatted with indentation. @@ -160,13 +161,13 @@ def export_agent_serialized( def import_agent_serialized( file: UploadFile = File(...), server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), + actor_id: str | None = Header(None, alias="user_id"), append_copy_suffix: bool = Query(True, description='If set to True, appends "_copy" to the end of the agent name.'), override_existing_tools: bool = Query( True, description="If set to True, existing tools can get their source code overwritten by the uploaded tool definitions. Note that Letta core tools can never be updated externally.", ), - project_id: Optional[str] = Query(None, description="The project ID to associate the uploaded agent with."), + project_id: str | None = Query(None, description="The project ID to associate the uploaded agent with."), strip_messages: bool = Query( False, description="If set to True, strips all messages from the agent before importing.", @@ -198,24 +199,24 @@ def import_agent_serialized( raise HTTPException(status_code=400, detail="Corrupted agent file format.") except ValidationError as e: - raise HTTPException(status_code=422, detail=f"Invalid agent schema: {str(e)}") + raise HTTPException(status_code=422, detail=f"Invalid agent schema: {e!s}") except IntegrityError as e: - raise HTTPException(status_code=409, detail=f"Database integrity error: {str(e)}") + raise HTTPException(status_code=409, detail=f"Database integrity error: {e!s}") except OperationalError as e: - raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {str(e)}") + raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {e!s}") except Exception as e: traceback.print_exc() - raise HTTPException(status_code=500, detail=f"An unexpected error occurred while uploading the agent: {str(e)}") + raise HTTPException(status_code=500, detail=f"An unexpected error occurred while uploading the agent: {e!s}") @router.get("/{agent_id}/context", response_model=ContextWindowOverview, operation_id="retrieve_agent_context_window") async def retrieve_agent_context_window( agent_id: str, server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Retrieve the context window of a specific agent. @@ -234,15 +235,15 @@ class CreateAgentRequest(CreateAgent): """ # Override the user_id field to exclude it from the request body validation - actor_id: Optional[str] = Field(None, exclude=True) + actor_id: str | None = Field(None, exclude=True) @router.post("/", response_model=AgentState, operation_id="create_agent") async def create_agent( agent: CreateAgentRequest = Body(...), server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present - x_project: Optional[str] = Header( + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + x_project: str | None = Header( None, alias="X-Project", description="The project slug to associate with the agent (cloud only)." ), # Only handled by next js middleware ): @@ -262,18 +263,18 @@ async def modify_agent( agent_id: str, update_agent: UpdateAgent = Body(...), server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """Update an existing agent""" actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) return await server.update_agent_async(agent_id=agent_id, request=update_agent, actor=actor) -@router.get("/{agent_id}/tools", response_model=List[Tool], operation_id="list_agent_tools") +@router.get("/{agent_id}/tools", response_model=list[Tool], operation_id="list_agent_tools") def list_agent_tools( agent_id: str, server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """Get tools from an existing agent""" actor = server.user_manager.get_user_or_default(user_id=actor_id) @@ -285,7 +286,7 @@ async def attach_tool( agent_id: str, tool_id: str, server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), + actor_id: str | None = Header(None, alias="user_id"), ): """ Attach a tool to an agent. @@ -299,7 +300,7 @@ async def detach_tool( agent_id: str, tool_id: str, server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), + actor_id: str | None = Header(None, alias="user_id"), ): """ Detach a tool from an agent. @@ -313,7 +314,7 @@ async def attach_source( agent_id: str, source_id: str, server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), + actor_id: str | None = Header(None, alias="user_id"), ): """ Attach a source to an agent. @@ -341,7 +342,7 @@ async def detach_source( agent_id: str, source_id: str, server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), + actor_id: str | None = Header(None, alias="user_id"), ): """ Detach a source from an agent. @@ -386,7 +387,7 @@ async def close_all_open_files( @router.get("/{agent_id}", response_model=AgentState, operation_id="retrieve_agent") async def retrieve_agent( agent_id: str, - include_relationships: Optional[List[str]] = Query( + include_relationships: list[str] | None = Query( None, description=( "Specify which relational fields (e.g., 'tools', 'sources', 'memory') to include in the response. " @@ -395,7 +396,7 @@ async def retrieve_agent( ), ), server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Get the state of the agent. @@ -412,7 +413,7 @@ async def retrieve_agent( async def delete_agent( agent_id: str, server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Delete an agent. @@ -425,11 +426,11 @@ async def delete_agent( raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found for user_id={actor.id}.") -@router.get("/{agent_id}/sources", response_model=List[Source], operation_id="list_agent_sources") +@router.get("/{agent_id}/sources", response_model=list[Source], operation_id="list_agent_sources") async def list_agent_sources( agent_id: str, server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Get the sources associated with an agent. @@ -443,7 +444,7 @@ async def list_agent_sources( async def retrieve_agent_memory( agent_id: str, server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Retrieve the memory state of a specific agent. @@ -459,7 +460,7 @@ async def retrieve_block( agent_id: str, block_label: str, server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Retrieve a core memory block from an agent. @@ -472,11 +473,11 @@ async def retrieve_block( raise HTTPException(status_code=404, detail=str(e)) -@router.get("/{agent_id}/core-memory/blocks", response_model=List[Block], operation_id="list_core_memory_blocks") +@router.get("/{agent_id}/core-memory/blocks", response_model=list[Block], operation_id="list_core_memory_blocks") async def list_blocks( agent_id: str, server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Retrieve the core memory blocks of a specific agent. @@ -495,7 +496,7 @@ async def modify_block( block_label: str, block_update: BlockUpdate = Body(...), server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Updates a core memory block of an agent. @@ -517,7 +518,7 @@ async def attach_block( agent_id: str, block_id: str, server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), + actor_id: str | None = Header(None, alias="user_id"), ): """ Attach a core memoryblock to an agent. @@ -531,7 +532,7 @@ async def detach_block( agent_id: str, block_id: str, server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), + actor_id: str | None = Header(None, alias="user_id"), ): """ Detach a core memory block from an agent. @@ -540,18 +541,18 @@ async def detach_block( return await server.agent_manager.detach_block_async(agent_id=agent_id, block_id=block_id, actor=actor) -@router.get("/{agent_id}/archival-memory", response_model=List[Passage], operation_id="list_passages") +@router.get("/{agent_id}/archival-memory", response_model=list[Passage], operation_id="list_passages") async def list_passages( agent_id: str, server: "SyncServer" = Depends(get_letta_server), - after: Optional[str] = Query(None, description="Unique ID of the memory to start the query range at."), - before: Optional[str] = Query(None, description="Unique ID of the memory to end the query range at."), - limit: Optional[int] = Query(None, description="How many results to include in the response."), - search: Optional[str] = Query(None, description="Search passages by text"), - ascending: Optional[bool] = Query( + after: str | None = Query(None, description="Unique ID of the memory to start the query range at."), + before: str | None = Query(None, description="Unique ID of the memory to end the query range at."), + limit: int | None = Query(None, description="How many results to include in the response."), + search: str | None = Query(None, description="Search passages by text"), + ascending: bool | None = Query( True, description="Whether to sort passages oldest to newest (True, default) or newest to oldest (False)" ), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Retrieve the memories in an agent's archival memory store (paginated query). @@ -569,12 +570,12 @@ async def list_passages( ) -@router.post("/{agent_id}/archival-memory", response_model=List[Passage], operation_id="create_passage") +@router.post("/{agent_id}/archival-memory", response_model=list[Passage], operation_id="create_passage") async def create_passage( agent_id: str, request: CreateArchivalMemory = Body(...), server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), + actor_id: str | None = Header(None, alias="user_id"), ): """ Insert a memory into an agent's archival memory store. @@ -584,13 +585,13 @@ async def create_passage( return await server.insert_archival_memory_async(agent_id=agent_id, memory_contents=request.text, actor=actor) -@router.patch("/{agent_id}/archival-memory/{memory_id}", response_model=List[Passage], operation_id="modify_passage") +@router.patch("/{agent_id}/archival-memory/{memory_id}", response_model=list[Passage], operation_id="modify_passage") def modify_passage( agent_id: str, memory_id: str, passage: PassageUpdate = Body(...), server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Modify a memory in the agent's archival memory store. @@ -607,7 +608,7 @@ async def delete_passage( memory_id: str, # memory_id: str = Query(..., description="Unique ID of the memory to be deleted."), server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Delete a memory from an agent's archival memory store. @@ -619,7 +620,7 @@ async def delete_passage( AgentMessagesResponse = Annotated[ - List[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}}) + list[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}}) ] @@ -627,14 +628,14 @@ AgentMessagesResponse = Annotated[ async def list_messages( agent_id: str, server: "SyncServer" = Depends(get_letta_server), - after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."), - before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."), + after: str | None = Query(None, description="Message after which to retrieve the returned messages."), + before: str | None = Query(None, description="Message before which to retrieve the returned messages."), limit: int = Query(10, description="Maximum number of messages to retrieve."), - group_id: Optional[str] = Query(None, description="Group ID to filter messages by."), + group_id: str | None = Query(None, description="Group ID to filter messages by."), use_assistant_message: bool = Query(True, description="Whether to use assistant messages"), assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."), assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Retrieve message history for an agent. @@ -662,7 +663,7 @@ def modify_message( message_id: str, request: LettaMessageUpdateUnion = Body(...), server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Update the details of a message associated with an agent. @@ -672,6 +673,7 @@ def modify_message( return server.message_manager.update_message_by_letta_message(message_id=message_id, letta_message_update=request, actor=actor) +# noinspection PyInconsistentReturns @router.post( "/{agent_id}/messages", response_model=LettaResponse, @@ -682,7 +684,7 @@ async def send_message( request_obj: Request, # FastAPI Request server: SyncServer = Depends(get_letta_server), request: LettaRequest = Body(...), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Process a user message and return the agent's response. @@ -697,55 +699,95 @@ async def send_message( agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"] model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"] - if agent_eligible and model_compatible: - if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent: - agent_loop = SleeptimeMultiAgentV2( - agent_id=agent_id, - message_manager=server.message_manager, - agent_manager=server.agent_manager, - block_manager=server.block_manager, - passage_manager=server.passage_manager, - group_manager=server.group_manager, - job_manager=server.job_manager, - actor=actor, - group=agent.multi_agent_group, + # Create a new run for execution tracking + job_status = JobStatus.created + run = await server.job_manager.create_job_async( + pydantic_job=Run( + user_id=actor.id, + status=job_status, + metadata={ + "job_type": "send_message", + "agent_id": agent_id, + }, + request_config=LettaRequestConfig( + use_assistant_message=request.use_assistant_message, + assistant_message_tool_name=request.assistant_message_tool_name, + assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, + include_return_message_types=request.include_return_message_types, + ), + ), + actor=actor, + ) + job_update_metadata = None + # TODO (cliandy): clean this up + redis_client = await get_redis_client() + await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id) + + try: + if agent_eligible and model_compatible: + if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent: + agent_loop = SleeptimeMultiAgentV2( + agent_id=agent_id, + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + passage_manager=server.passage_manager, + group_manager=server.group_manager, + job_manager=server.job_manager, + actor=actor, + group=agent.multi_agent_group, + current_run_id=run.id, + ) + else: + agent_loop = LettaAgent( + agent_id=agent_id, + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + job_manager=server.job_manager, + passage_manager=server.passage_manager, + actor=actor, + step_manager=server.step_manager, + telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), + current_run_id=run.id, + ) + + result = await agent_loop.step( + request.messages, + max_steps=request.max_steps, + use_assistant_message=request.use_assistant_message, + request_start_timestamp_ns=request_start_timestamp_ns, + include_return_message_types=request.include_return_message_types, ) else: - agent_loop = LettaAgent( + result = await server.send_message_to_agent( agent_id=agent_id, - message_manager=server.message_manager, - agent_manager=server.agent_manager, - block_manager=server.block_manager, - job_manager=server.job_manager, - passage_manager=server.passage_manager, actor=actor, - step_manager=server.step_manager, - telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), + input_messages=request.messages, + stream_steps=False, + stream_tokens=False, + # Support for AssistantMessage + use_assistant_message=request.use_assistant_message, + assistant_message_tool_name=request.assistant_message_tool_name, + assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, + include_return_message_types=request.include_return_message_types, ) - - result = await agent_loop.step( - request.messages, - max_steps=request.max_steps, - use_assistant_message=request.use_assistant_message, - request_start_timestamp_ns=request_start_timestamp_ns, - include_return_message_types=request.include_return_message_types, - ) - else: - result = await server.send_message_to_agent( - agent_id=agent_id, + job_status = result.stop_reason.stop_reason.run_status + return result + except Exception as e: + job_update_metadata = {"error": str(e)} + job_status = JobStatus.failed + raise + finally: + await server.job_manager.safe_update_job_status_async( + job_id=run.id, + new_status=job_status, actor=actor, - input_messages=request.messages, - stream_steps=False, - stream_tokens=False, - # Support for AssistantMessage - use_assistant_message=request.use_assistant_message, - assistant_message_tool_name=request.assistant_message_tool_name, - assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, - include_return_message_types=request.include_return_message_types, + metadata=job_update_metadata, ) - return result +# noinspection PyInconsistentReturns @router.post( "/{agent_id}/messages/stream", response_model=None, @@ -764,7 +806,7 @@ async def send_message_streaming( request_obj: Request, # FastAPI Request server: SyncServer = Depends(get_letta_server), request: LettaStreamingRequest = Body(...), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ) -> StreamingResponse | LettaResponse: """ Process a user message and return the agent's response. @@ -780,88 +822,160 @@ async def send_message_streaming( agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"] model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"] model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"] - not_letta_endpoint = LETTA_MODEL_ENDPOINT != agent.llm_config.model_endpoint + not_letta_endpoint = agent.llm_config.model_endpoint != LETTA_MODEL_ENDPOINT - if agent_eligible and model_compatible: - if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent: - agent_loop = SleeptimeMultiAgentV2( - agent_id=agent_id, - message_manager=server.message_manager, - agent_manager=server.agent_manager, - block_manager=server.block_manager, - passage_manager=server.passage_manager, - group_manager=server.group_manager, - job_manager=server.job_manager, - actor=actor, - step_manager=server.step_manager, - telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), - group=agent.multi_agent_group, - ) - else: - agent_loop = LettaAgent( - agent_id=agent_id, - message_manager=server.message_manager, - agent_manager=server.agent_manager, - block_manager=server.block_manager, - job_manager=server.job_manager, - passage_manager=server.passage_manager, - actor=actor, - step_manager=server.step_manager, - telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), - ) - from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode + # Create a new job for execution tracking + job_status = JobStatus.created + run = await server.job_manager.create_job_async( + pydantic_job=Run( + user_id=actor.id, + status=job_status, + metadata={ + "job_type": "send_message_streaming", + "agent_id": agent_id, + }, + request_config=LettaRequestConfig( + use_assistant_message=request.use_assistant_message, + assistant_message_tool_name=request.assistant_message_tool_name, + assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, + include_return_message_types=request.include_return_message_types, + ), + ), + actor=actor, + ) - if request.stream_tokens and model_compatible_token_streaming and not_letta_endpoint: - result = StreamingResponseWithStatusCode( - agent_loop.step_stream( - input_messages=request.messages, - max_steps=request.max_steps, - use_assistant_message=request.use_assistant_message, - request_start_timestamp_ns=request_start_timestamp_ns, - include_return_message_types=request.include_return_message_types, - ), - media_type="text/event-stream", - ) + job_update_metadata = None + # TODO (cliandy): clean this up + redis_client = await get_redis_client() + await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id) + + try: + if agent_eligible and model_compatible: + if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent: + agent_loop = SleeptimeMultiAgentV2( + agent_id=agent_id, + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + passage_manager=server.passage_manager, + group_manager=server.group_manager, + job_manager=server.job_manager, + actor=actor, + step_manager=server.step_manager, + telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), + group=agent.multi_agent_group, + current_run_id=run.id, + ) + else: + agent_loop = LettaAgent( + agent_id=agent_id, + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + job_manager=server.job_manager, + passage_manager=server.passage_manager, + actor=actor, + step_manager=server.step_manager, + telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), + current_run_id=run.id, + ) + from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode + + if request.stream_tokens and model_compatible_token_streaming and not_letta_endpoint: + result = StreamingResponseWithStatusCode( + agent_loop.step_stream( + input_messages=request.messages, + max_steps=request.max_steps, + use_assistant_message=request.use_assistant_message, + request_start_timestamp_ns=request_start_timestamp_ns, + include_return_message_types=request.include_return_message_types, + ), + media_type="text/event-stream", + ) + else: + result = StreamingResponseWithStatusCode( + agent_loop.step_stream_no_tokens( + request.messages, + max_steps=request.max_steps, + use_assistant_message=request.use_assistant_message, + request_start_timestamp_ns=request_start_timestamp_ns, + include_return_message_types=request.include_return_message_types, + ), + media_type="text/event-stream", + ) else: - result = StreamingResponseWithStatusCode( - agent_loop.step_stream_no_tokens( - request.messages, - max_steps=request.max_steps, - use_assistant_message=request.use_assistant_message, - request_start_timestamp_ns=request_start_timestamp_ns, - include_return_message_types=request.include_return_message_types, - ), - media_type="text/event-stream", + result = await server.send_message_to_agent( + agent_id=agent_id, + actor=actor, + input_messages=request.messages, + stream_steps=True, + stream_tokens=request.stream_tokens, + # Support for AssistantMessage + use_assistant_message=request.use_assistant_message, + assistant_message_tool_name=request.assistant_message_tool_name, + assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, + request_start_timestamp_ns=request_start_timestamp_ns, + include_return_message_types=request.include_return_message_types, ) - else: - result = await server.send_message_to_agent( - agent_id=agent_id, + job_status = JobStatus.running + return result + except Exception as e: + job_update_metadata = {"error": str(e)} + job_status = JobStatus.failed + raise + finally: + await server.job_manager.safe_update_job_status_async( + job_id=run.id, + new_status=job_status, actor=actor, - input_messages=request.messages, - stream_steps=True, - stream_tokens=request.stream_tokens, - # Support for AssistantMessage - use_assistant_message=request.use_assistant_message, - assistant_message_tool_name=request.assistant_message_tool_name, - assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, - request_start_timestamp_ns=request_start_timestamp_ns, - include_return_message_types=request.include_return_message_types, + metadata=job_update_metadata, ) - return result + +@router.post("/{agent_id}/messages/cancel", operation_id="cancel_agent_run") +async def cancel_agent_run( + agent_id: str, + run_ids: list[str] | None = None, + server: SyncServer = Depends(get_letta_server), + actor_id: str | None = Header(None, alias="user_id"), +) -> dict: + """ + Cancel runs associated with an agent. If run_ids are passed in, cancel those in particular. + + Note to cancel active runs associated with an agent, redis is required. + """ + + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + if not run_ids: + redis_client = await get_redis_client() + run_id = await redis_client.get(f"{REDIS_RUN_ID_PREFIX}:{agent_id}") + if run_id is None: + logger.warning("Cannot find run associated with agent to cancel.") + return {} + run_ids = [run_id] + + results = {} + for run_id in run_ids: + success = await server.job_manager.safe_update_job_status_async( + job_id=run_id, + new_status=JobStatus.cancelled, + actor=actor, + ) + results[run_id] = "cancelled" if success else "failed" + return results -async def process_message_background( - job_id: str, +async def _process_message_background( + run_id: str, server: SyncServer, actor: User, agent_id: str, - messages: List[MessageCreate], + messages: list[MessageCreate], use_assistant_message: bool, assistant_message_tool_name: str, assistant_message_tool_kwarg: str, max_steps: int = DEFAULT_MAX_STEPS, - include_return_message_types: Optional[List[MessageType]] = None, + include_return_message_types: list[MessageType] | None = None, ) -> None: """Background task to process the message and update job status.""" request_start_timestamp_ns = get_utc_timestamp_ns() @@ -905,7 +1019,7 @@ async def process_message_background( result = await agent_loop.step( messages, max_steps=max_steps, - run_id=job_id, + run_id=run_id, use_assistant_message=use_assistant_message, request_start_timestamp_ns=request_start_timestamp_ns, include_return_message_types=include_return_message_types, @@ -917,7 +1031,7 @@ async def process_message_background( input_messages=messages, stream_steps=False, stream_tokens=False, - metadata={"job_id": job_id}, + metadata={"job_id": run_id}, # Support for AssistantMessage use_assistant_message=use_assistant_message, assistant_message_tool_name=assistant_message_tool_name, @@ -930,7 +1044,7 @@ async def process_message_background( completed_at=datetime.now(timezone.utc), metadata={"result": result.model_dump(mode="json")}, ) - await server.job_manager.update_job_by_id_async(job_id=job_id, job_update=job_update, actor=actor) + await server.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=actor) except Exception as e: # Update job status to failed @@ -951,11 +1065,14 @@ async def send_message_async( agent_id: str, server: SyncServer = Depends(get_letta_server), request: LettaAsyncRequest = Body(...), - actor_id: Optional[str] = Header(None, alias="user_id"), + actor_id: str | None = Header(None, alias="user_id"), ): """ Asynchronously process a user message and return a run object. The actual processing happens in the background, and the status can be checked using the run ID. + + This is "asynchronous" in the sense that it's a background job and explicitly must be fetched by the run ID. + This is more like `send_message_job` """ MetricRegistry().user_message_counter.add(1, get_ctx_attributes()) actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) @@ -980,8 +1097,8 @@ async def send_message_async( # Create asyncio task for background processing asyncio.create_task( - process_message_background( - job_id=run.id, + _process_message_background( + run_id=run.id, server=server, actor=actor, agent_id=agent_id, @@ -1002,7 +1119,7 @@ async def reset_messages( agent_id: str, add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."), server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """Resets the messages for an agent""" actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) @@ -1011,12 +1128,12 @@ async def reset_messages( ) -@router.get("/{agent_id}/groups", response_model=List[Group], operation_id="list_agent_groups") +@router.get("/{agent_id}/groups", response_model=list[Group], operation_id="list_agent_groups") async def list_agent_groups( agent_id: str, - manager_type: Optional[str] = Query(None, description="Manager type to filter groups by"), + manager_type: str | None = Query(None, description="Manager type to filter groups by"), server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """Lists the groups for an agent""" actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) @@ -1030,7 +1147,7 @@ async def summarize_agent_conversation( request_obj: Request, # FastAPI Request max_message_length: int = Query(..., description="Maximum number of messages to retain after summarization."), server: SyncServer = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), + actor_id: str | None = Header(None, alias="user_id"), ): """ Summarize an agent's conversation history to a target message length. diff --git a/letta/server/rest_api/routers/v1/jobs.py b/letta/server/rest_api/routers/v1/jobs.py index 4c5595fa..90c108d9 100644 --- a/letta/server/rest_api/routers/v1/jobs.py +++ b/letta/server/rest_api/routers/v1/jobs.py @@ -15,10 +15,15 @@ router = APIRouter(prefix="/jobs", tags=["jobs"]) async def list_jobs( server: "SyncServer" = Depends(get_letta_server), source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."), + before: Optional[str] = Query(None, description="Cursor for pagination"), + after: Optional[str] = Query(None, description="Cursor for pagination"), + limit: Optional[int] = Query(50, description="Limit for pagination"), + ascending: bool = Query(True, description="Whether to sort jobs oldest to newest (True, default) or newest to oldest (False)"), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ List all jobs. + TODO (cliandy): implementation for pagination """ actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) @@ -26,6 +31,10 @@ async def list_jobs( return await server.job_manager.list_jobs_async( actor=actor, source_id=source_id, + before=before, + after=after, + limit=limit, + ascending=ascending, ) @@ -34,12 +43,24 @@ async def list_active_jobs( server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."), + before: Optional[str] = Query(None, description="Cursor for pagination"), + after: Optional[str] = Query(None, description="Cursor for pagination"), + limit: Optional[int] = Query(50, description="Limit for pagination"), + ascending: bool = Query(True, description="Whether to sort jobs oldest to newest (True, default) or newest to oldest (False)"), ): """ List all active jobs. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - return await server.job_manager.list_jobs_async(actor=actor, statuses=[JobStatus.created, JobStatus.running], source_id=source_id) + return await server.job_manager.list_jobs_async( + actor=actor, + statuses=[JobStatus.created, JobStatus.running], + source_id=source_id, + before=before, + after=after, + limit=limit, + ascending=ascending, + ) @router.get("/{job_id}", response_model=Job, operation_id="retrieve_job") @@ -59,6 +80,33 @@ async def retrieve_job( raise HTTPException(status_code=404, detail="Job not found") +@router.patch("/{job_id}/cancel", response_model=Job, operation_id="cancel_job") +async def cancel_job( + job_id: str, + actor_id: Optional[str] = Header(None, alias="user_id"), + server: "SyncServer" = Depends(get_letta_server), +): + """ + Cancel a job by its job_id. + + This endpoint marks a job as cancelled, which will cause any associated + agent execution to terminate as soon as possible. + """ + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + + try: + # First check if the job exists and is in a cancellable state + existing_job = await server.job_manager.get_job_by_id_async(job_id=job_id, actor=actor) + + if existing_job.status.is_terminal: + return False + + return await server.job_manager.safe_update_job_status_async(job_id=job_id, new_status=JobStatus.cancelled, actor=actor) + + except NoResultFound: + raise HTTPException(status_code=404, detail="Job not found") + + @router.delete("/{job_id}", response_model=Job, operation_id="delete_job") async def delete_job( job_id: str, diff --git a/letta/server/rest_api/streaming_response.py b/letta/server/rest_api/streaming_response.py index 13d57e87..ac22c469 100644 --- a/letta/server/rest_api/streaming_response.py +++ b/letta/server/rest_api/streaming_response.py @@ -2,6 +2,7 @@ # stremaing HTTP trailers, as we cannot set codes after the initial response. # Taken from: https://github.com/fastapi/fastapi/discussions/10138#discussioncomment-10377361 +import asyncio import json from collections.abc import AsyncIterator @@ -9,10 +10,73 @@ from fastapi.responses import StreamingResponse from starlette.types import Send from letta.log import get_logger +from letta.schemas.enums import JobStatus +from letta.schemas.user import User +from letta.services.job_manager import JobManager logger = get_logger(__name__) +# TODO (cliandy) wrap this and handle types +async def cancellation_aware_stream_wrapper( + stream_generator: AsyncIterator[str | bytes], + job_manager: JobManager, + job_id: str, + actor: User, + cancellation_check_interval: float = 0.5, +) -> AsyncIterator[str | bytes]: + """ + Wraps a stream generator to provide real-time job cancellation checking. + + This wrapper periodically checks for job cancellation while streaming and + can interrupt the stream at any point, not just at step boundaries. + + Args: + stream_generator: The original stream generator to wrap + job_manager: Job manager instance for checking job status + job_id: ID of the job to monitor for cancellation + actor: User/actor making the request + cancellation_check_interval: How often to check for cancellation (seconds) + + Yields: + Stream chunks from the original generator until cancelled + + Raises: + asyncio.CancelledError: If the job is cancelled during streaming + """ + last_cancellation_check = asyncio.get_event_loop().time() + + try: + async for chunk in stream_generator: + # Check for cancellation periodically (not on every chunk for performance) + current_time = asyncio.get_event_loop().time() + if current_time - last_cancellation_check >= cancellation_check_interval: + try: + job = await job_manager.get_job_by_id_async(job_id=job_id, actor=actor) + if job.status == JobStatus.cancelled: + logger.info(f"Stream cancelled for job {job_id}, interrupting stream") + # Send cancellation event to client + cancellation_event = {"message_type": "stop_reason", "stop_reason": "cancelled"} + yield f"data: {json.dumps(cancellation_event)}\n\n" + # Raise CancelledError to interrupt the stream + raise asyncio.CancelledError(f"Job {job_id} was cancelled") + except Exception as e: + # Log warning but don't fail the stream if cancellation check fails + logger.warning(f"Failed to check job cancellation for job {job_id}: {e}") + + last_cancellation_check = current_time + + yield chunk + + except asyncio.CancelledError: + # Re-raise CancelledError to ensure proper cleanup + logger.info(f"Stream for job {job_id} was cancelled and cleaned up") + raise + except Exception as e: + logger.error(f"Error in cancellation-aware stream wrapper for job {job_id}: {e}") + raise + + class StreamingResponseWithStatusCode(StreamingResponse): """ Variation of StreamingResponse that can dynamically decide the HTTP status code, @@ -81,6 +145,30 @@ class StreamingResponseWithStatusCode(StreamingResponse): } ) + # This should be handled properly upstream? + except asyncio.CancelledError: + logger.info("Stream was cancelled by client or job cancellation") + # Handle cancellation gracefully + more_body = False + cancellation_resp = {"error": {"message": "Stream cancelled"}} + cancellation_event = f"event: cancelled\ndata: {json.dumps(cancellation_resp)}\n\n".encode(self.charset) + if not self.response_started: + await send( + { + "type": "http.response.start", + "status": 200, # Use 200 for graceful cancellation + "headers": self.raw_headers, + } + ) + await send( + { + "type": "http.response.body", + "body": cancellation_event, + "more_body": more_body, + } + ) + return + except Exception: logger.exception("unhandled_streaming_error") more_body = False diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index 257ef49e..ca5fa9b9 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -1,4 +1,4 @@ -from functools import reduce +from functools import partial, reduce from operator import add from typing import List, Literal, Optional, Union @@ -125,6 +125,46 @@ class JobManager: return job.to_pydantic() + @enforce_types + @trace_method + async def safe_update_job_status_async( + self, job_id: str, new_status: JobStatus, actor: PydanticUser, metadata: Optional[dict] = None + ) -> bool: + """ + Safely update job status with state transition guards. + Created -> Pending -> Running --> + + Returns: + True if update was successful, False if update was skipped due to invalid transition + """ + try: + # Get current job state + current_job = await self.get_job_by_id_async(job_id=job_id, actor=actor) + + current_status = current_job.status + if not any( + ( + new_status.is_terminal and not current_status.is_terminal, + current_status == JobStatus.created and new_status != JobStatus.created, + current_status == JobStatus.pending and new_status == JobStatus.running, + ) + ): + logger.warning(f"Invalid job status transition from {current_job.status} to {new_status} for job {job_id}") + return False + + job_update_builder = partial(JobUpdate, status=new_status) + if metadata: + job_update_builder = partial(job_update_builder, metadata=metadata) + if new_status.is_terminal: + job_update_builder = partial(job_update_builder, completed_at=get_utc_time()) + + await self.update_job_by_id_async(job_id=job_id, job_update=job_update_builder(), actor=actor) + return True + + except Exception as e: + logger.error(f"Failed to safely update job status for job {job_id}: {e}") + return False + @enforce_types @trace_method def get_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob: @@ -656,7 +696,7 @@ class JobManager: job.callback_status_code = resp.status_code except Exception as e: - error_message = f"Failed to dispatch callback for job {job.id} to {job.callback_url}: {str(e)}" + error_message = f"Failed to dispatch callback for job {job.id} to {job.callback_url}: {e!s}" logger.error(error_message) # Record the failed attempt job.callback_sent_at = get_utc_time().replace(tzinfo=None) @@ -686,7 +726,7 @@ class JobManager: job.callback_sent_at = get_utc_time().replace(tzinfo=None) job.callback_status_code = resp.status_code except Exception as e: - error_message = f"Failed to dispatch callback for job {job.id} to {job.callback_url}: {str(e)}" + error_message = f"Failed to dispatch callback for job {job.id} to {job.callback_url}: {e!s}" logger.error(error_message) # Record the failed attempt job.callback_sent_at = get_utc_time().replace(tzinfo=None) diff --git a/letta/utils.py b/letta/utils.py index c054c977..4ddf0a61 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -12,11 +12,12 @@ import re import subprocess import sys import uuid +from collections.abc import Coroutine from contextlib import contextmanager from datetime import datetime, timezone from functools import wraps from logging import Logger -from typing import Any, Coroutine, List, Union, _GenericAlias, get_args, get_origin, get_type_hints +from typing import Any, Coroutine, Union, _GenericAlias, get_args, get_origin, get_type_hints from urllib.parse import urljoin, urlparse import demjson3 as demjson @@ -519,7 +520,7 @@ def enforce_types(func): arg_names = inspect.getfullargspec(func).args # Pair each argument with its corresponding type hint - args_with_hints = dict(zip(arg_names[1:], args[1:])) # Skipping 'self' + args_with_hints = dict(zip(arg_names[1:], args[1:], strict=False)) # Skipping 'self' # Function to check if a value matches a given type hint def matches_type(value, hint): @@ -557,7 +558,7 @@ def enforce_types(func): return wrapper -def annotate_message_json_list_with_tool_calls(messages: List[dict], allow_tool_roles: bool = False): +def annotate_message_json_list_with_tool_calls(messages: list[dict], allow_tool_roles: bool = False): """Add in missing tool_call_id fields to a list of messages using function call style Walk through the list forwards: @@ -946,7 +947,7 @@ def get_human_text(name: str, enforce_limit=True): for file_path in list_human_files(): file = os.path.basename(file_path) if f"{name}.txt" == file or name == file: - human_text = open(file_path, "r", encoding="utf-8").read().strip() + human_text = open(file_path, encoding="utf-8").read().strip() if enforce_limit and len(human_text) > CORE_MEMORY_HUMAN_CHAR_LIMIT: raise ValueError(f"Contents of {name}.txt is over the character limit ({len(human_text)} > {CORE_MEMORY_HUMAN_CHAR_LIMIT})") return human_text @@ -958,7 +959,7 @@ def get_persona_text(name: str, enforce_limit=True): for file_path in list_persona_files(): file = os.path.basename(file_path) if f"{name}.txt" == file or name == file: - persona_text = open(file_path, "r", encoding="utf-8").read().strip() + persona_text = open(file_path, encoding="utf-8").read().strip() if enforce_limit and len(persona_text) > CORE_MEMORY_PERSONA_CHAR_LIMIT: raise ValueError( f"Contents of {name}.txt is over the character limit ({len(persona_text)} > {CORE_MEMORY_PERSONA_CHAR_LIMIT})" @@ -1109,3 +1110,75 @@ def safe_create_task(coro, logger: Logger, label: str = "background task"): logger.exception(f"{label} failed with {type(e).__name__}: {e}") return asyncio.create_task(wrapper()) + + +class CancellationSignal: + """ + A signal that can be checked for cancellation during streaming operations. + + This provides a lightweight way to check if an operation should be cancelled + without having to pass job managers and other dependencies through every method. + """ + + def __init__(self, job_manager=None, job_id=None, actor=None): + + from letta.log import get_logger + from letta.schemas.user import User + from letta.services.job_manager import JobManager + + self.job_manager: JobManager | None = job_manager + self.job_id: str | None = job_id + self.actor: User | None = actor + self._is_cancelled = False + self.logger = get_logger(__name__) + + async def is_cancelled(self) -> bool: + """ + Check if the operation has been cancelled. + + Returns: + True if cancelled, False otherwise + """ + from letta.schemas.enums import JobStatus + + if self._is_cancelled: + return True + + if not self.job_manager or not self.job_id or not self.actor: + return False + + try: + job = await self.job_manager.get_job_by_id_async(job_id=self.job_id, actor=self.actor) + self._is_cancelled = job.status == JobStatus.cancelled + return self._is_cancelled + except Exception as e: + self.logger.warning(f"Failed to check cancellation status for job {self.job_id}: {e}") + return False + + def cancel(self): + """Mark this signal as cancelled locally (for testing or direct cancellation).""" + self._is_cancelled = True + + async def check_and_raise_if_cancelled(self): + """ + Check for cancellation and raise CancelledError if cancelled. + + Raises: + asyncio.CancelledError: If the operation has been cancelled + """ + if await self.is_cancelled(): + self.logger.info(f"Operation cancelled for job {self.job_id}") + raise asyncio.CancelledError(f"Job {self.job_id} was cancelled") + + +class NullCancellationSignal(CancellationSignal): + """A null cancellation signal that is never cancelled.""" + + def __init__(self): + super().__init__() + + async def is_cancelled(self) -> bool: + return False + + async def check_and_raise_if_cancelled(self): + pass diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index f84f7d46..3894981b 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -1251,3 +1251,207 @@ def test_auto_summarize(disable_e2b_api_key: Any, client: Letta, llm_config: LLM prev_length = current_length else: raise AssertionError("Summarization was not triggered after 10 messages") + + +# ============================ +# Job Cancellation Tests +# ============================ + + +def wait_for_run_status(client: Letta, run_id: str, target_status: str, timeout: float = 30.0, interval: float = 0.1) -> Run: + """Wait for a run to reach a specific status""" + start = time.time() + while True: + run = client.runs.retrieve(run_id) + if run.status == target_status: + return run + if time.time() - start > timeout: + raise TimeoutError(f"Run {run_id} did not reach status '{target_status}' within {timeout} seconds (last status: {run.status})") + time.sleep(interval) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_job_creation_for_send_message( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Test that send_message endpoint creates a job and the job completes successfully. + """ + previous_runs = client.runs.list(agent_ids=[agent_state.id]) + client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + + # Send a simple message and verify a job was created + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + ) + + # The response should be successful + assert response.messages is not None + assert len(response.messages) > 0 + + runs = client.runs.list(agent_ids=[agent_state.id]) + new_runs = set(r.id for r in runs) - set(r.id for r in previous_runs) + assert len(new_runs) == 1 + + for run in runs: + if run.id == list(new_runs)[0]: + assert run.status == "completed" + + +# TODO (cliandy): MERGE BACK IN POST +# @pytest.mark.parametrize( +# "llm_config", +# TESTED_LLM_CONFIGS, +# ids=[c.model for c in TESTED_LLM_CONFIGS], +# ) +# def test_async_job_cancellation( +# disable_e2b_api_key: Any, +# client: Letta, +# agent_state: AgentState, +# llm_config: LLMConfig, +# ) -> None: +# """ +# Test that an async job can be cancelled and the cancellation is reflected in the job status. +# """ +# client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) +# +# # client.runs.cancel +# # Start an async job +# run = client.agents.messages.create_async( +# agent_id=agent_state.id, +# messages=USER_MESSAGE_FORCE_REPLY, +# ) +# +# # Verify the job was created +# assert run.id is not None +# assert run.status in ["created", "running"] +# +# # Cancel the job quickly (before it potentially completes) +# cancelled_run = client.jobs.cancel(run.id) +# +# # Verify the job was cancelled +# assert cancelled_run.status == "cancelled" +# +# # Wait a bit and verify it stays cancelled (no invalid state transitions) +# time.sleep(1) +# final_run = client.runs.retrieve(run.id) +# assert final_run.status == "cancelled" +# +# # Verify the job metadata indicates cancellation +# if final_run.metadata: +# assert final_run.metadata.get("cancelled") is True or "stop_reason" in final_run.metadata +# +# +# def test_job_cancellation_endpoint_validation( +# disable_e2b_api_key: Any, +# client: Letta, +# agent_state: AgentState, +# ) -> None: +# """ +# Test job cancellation endpoint validation (trying to cancel completed/failed jobs). +# """ +# # Test cancelling a non-existent job +# with pytest.raises(ApiError) as exc_info: +# client.jobs.cancel("non-existent-job-id") +# assert exc_info.value.status_code == 404 +# +# +# @pytest.mark.parametrize( +# "llm_config", +# TESTED_LLM_CONFIGS, +# ids=[c.model for c in TESTED_LLM_CONFIGS], +# ) +# def test_completed_job_cannot_be_cancelled( +# disable_e2b_api_key: Any, +# client: Letta, +# agent_state: AgentState, +# llm_config: LLMConfig, +# ) -> None: +# """ +# Test that completed jobs cannot be cancelled. +# """ +# client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) +# +# # Start an async job and wait for it to complete +# run = client.agents.messages.create_async( +# agent_id=agent_state.id, +# messages=USER_MESSAGE_FORCE_REPLY, +# ) +# +# # Wait for completion +# completed_run = wait_for_run_completion(client, run.id) +# assert completed_run.status == "completed" +# +# # Try to cancel the completed job - should fail +# with pytest.raises(ApiError) as exc_info: +# client.jobs.cancel(run.id) +# assert exc_info.value.status_code == 400 +# assert "Cannot cancel job with status 'completed'" in str(exc_info.value) +# +# +# @pytest.mark.parametrize( +# "llm_config", +# TESTED_LLM_CONFIGS, +# ids=[c.model for c in TESTED_LLM_CONFIGS], +# ) +# def test_streaming_job_independence_from_client_disconnect( +# disable_e2b_api_key: Any, +# client: Letta, +# agent_state: AgentState, +# llm_config: LLMConfig, +# ) -> None: +# """ +# Test that streaming jobs are independent of client connection state. +# This verifies that jobs continue even if the client "disconnects" (simulated by not consuming the stream). +# """ +# client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) +# +# # Create a streaming request +# import threading +# +# import httpx +# +# # Get the base URL and create a raw HTTP request to simulate partial consumption +# base_url = client._client_wrapper._base_url +# +# def start_stream_and_abandon(): +# """Start a streaming request but abandon it (simulating client disconnect)""" +# try: +# response = httpx.post( +# f"{base_url}/agents/{agent_state.id}/messages/stream", +# json={"messages": [{"role": "user", "text": "Hello, how are you?"}], "stream_tokens": False}, +# headers={"user_id": "test-user"}, +# timeout=30.0, +# ) +# +# # Read just a few chunks then "disconnect" by not reading the rest +# chunk_count = 0 +# for chunk in response.iter_lines(): +# chunk_count += 1 +# if chunk_count > 3: # Read a few chunks then stop +# break +# # Connection is now "abandoned" but the job should continue +# +# except Exception: +# pass # Ignore connection errors +# +# # Start the stream in a separate thread to simulate abandonment +# thread = threading.Thread(target=start_stream_and_abandon) +# thread.start() +# thread.join(timeout=5.0) # Wait up to 5 seconds for the "disconnect" +# +# # The important thing is that this test validates our architecture: +# # 1. Jobs are created before streaming starts (verified by our other tests) +# # 2. Jobs track execution independent of client connection (handled by our wrapper) +# # 3. Only explicit cancellation terminates jobs (tested by other tests) +# +# # This test primarily validates that the implementation doesn't break under simulated disconnection +# assert True # If we get here without errors, the architecture is sound diff --git a/tests/test_managers.py b/tests/test_managers.py index 0503fa94..630f8c8b 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -4551,7 +4551,7 @@ async def test_create_and_upsert_identity(server: SyncServer, default_user, even actor=default_user, ) - identity_create.properties = [(IdentityProperty(key="age", value=29, type=IdentityPropertyType.number))] + identity_create.properties = [IdentityProperty(key="age", value=29, type=IdentityPropertyType.number)] identity = await server.identity_manager.upsert_identity_async( identity=IdentityUpsert(**identity_create.model_dump()), actor=default_user diff --git a/tests/test_provider_trace.py b/tests/test_provider_trace.py index 574c4a1c..1580ce8c 100644 --- a/tests/test_provider_trace.py +++ b/tests/test_provider_trace.py @@ -17,6 +17,7 @@ from letta.schemas.message import MessageCreate from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager +from letta.services.job_manager import JobManager from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager from letta.services.step_manager import StepManager