From e06d9cbb8fac238d90ca4d8a6968765dfeb2fb82 Mon Sep 17 00:00:00 2001 From: cthomas Date: Mon, 8 Sep 2025 16:35:24 -0700 Subject: [PATCH] feat: add error handling parity to new agent loop (#4471) --- letta/agents/letta_agent_v2.py | 431 ++++++++++++++++++++++----------- 1 file changed, 283 insertions(+), 148 deletions(-) diff --git a/letta/agents/letta_agent_v2.py b/letta/agents/letta_agent_v2.py index 35617952..39385ab7 100644 --- a/letta/agents/letta_agent_v2.py +++ b/letta/agents/letta_agent_v2.py @@ -1,6 +1,7 @@ import asyncio import json import uuid +from datetime import datetime from typing import AsyncGenerator, Tuple from opentelemetry.trace import Span @@ -26,7 +27,7 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.log import get_logger from letta.otel.tracing import log_event, trace_method, tracer from letta.prompts.prompt_generator import PromptGenerator -from letta.schemas.agent import AgentState +from letta.schemas.agent import AgentState, UpdateAgent from letta.schemas.enums import JobStatus, MessageRole, MessageStreamStatus, StepStatus from letta.schemas.letta_message import LettaMessage, MessageType from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent @@ -51,7 +52,7 @@ from letta.services.step_manager import StepManager from letta.services.summarizer.summarizer import Summarizer from letta.services.telemetry_manager import TelemetryManager from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager -from letta.settings import summarizer_settings +from letta.settings import settings, summarizer_settings from letta.system import package_function_response from letta.types import JsonDict from letta.utils import log_telemetry, united_diff, validate_function_response @@ -155,6 +156,8 @@ class LettaAgentV2(BaseAgentV2): LettaResponse: Complete response with all messages and metadata """ self._initialize_state() + request_span = self._request_checkpoint_start(request_start_timestamp_ns=request_start_timestamp_ns) + in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async( input_messages, self.agent_state, self.message_manager, self.actor ) @@ -190,6 +193,7 @@ class LettaAgentV2(BaseAgentV2): if self.stop_reason is None: self.stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) + self._request_checkpoint_finish(request_span=request_span, request_start_timestamp_ns=request_start_timestamp_ns) return LettaResponse(messages=response_letta_messages, stop_reason=self.stop_reason, usage=self.usage) async def stream( @@ -222,6 +226,10 @@ class LettaAgentV2(BaseAgentV2): Yields: str: JSON-formatted SSE data chunks for each completed step """ + self._initialize_state() + request_span = self._request_checkpoint_start(request_start_timestamp_ns=request_start_timestamp_ns) + first_chunk = True + if stream_tokens: llm_adapter = LettaLLMStreamAdapter( llm_client=self.llm_client, @@ -234,7 +242,6 @@ class LettaAgentV2(BaseAgentV2): ) try: - self._initialize_state() in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async( input_messages, self.agent_state, self.message_manager, self.actor ) @@ -250,7 +257,10 @@ class LettaAgentV2(BaseAgentV2): request_start_timestamp_ns=request_start_timestamp_ns, ) async for chunk in response: + if first_chunk: + request_span = self._request_checkpoint_ttft(request_span, request_start_timestamp_ns) yield f"data: {chunk.model_dump_json()}\n\n" + first_chunk = False if not self.should_continue: break @@ -270,6 +280,7 @@ class LettaAgentV2(BaseAgentV2): yield f"data: {self.stop_reason.model_dump_json()}\n\n" raise + self._request_checkpoint_finish(request_span=request_span, request_start_timestamp_ns=request_start_timestamp_ns) for finish_chunk in self.get_finish_chunks_for_stream(self.usage, self.stop_reason): yield f"data: {finish_chunk}\n\n" @@ -307,141 +318,226 @@ class LettaAgentV2(BaseAgentV2): LettaMessage or dict: Chunks for streaming mode, or request data for dry_run """ step_progression = StepProgression.START - tool_call, reasoning_content, agent_step_span, first_chunk, logged_step, step_start_ns = None, None, None, None, None, None - valid_tools = await self._get_valid_tools(messages) # remove messages input - approval_request, approval_response = await self._maybe_get_approval_messages(messages) - if approval_request and approval_response: - tool_call = approval_request.tool_calls[0] - reasoning_content = approval_request.content - step_id = approval_request.step_id - step_metrics = await self.step_manager.get_step_metrics_async(step_id=step_id, actor=self.actor) - else: - # Check for job cancellation at the start of each step - if run_id and await self._check_run_cancellation(run_id): - self.stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value) - self.logger.info(f"Agent execution cancelled for run {run_id}") - return - - step_id = generate_step_id() - step_progression, step_metrics, agent_step_span = self._step_checkpoint_start(step_id=step_id) - - # Create step early with PENDING status - logged_step = await self.step_manager.log_step_async( - actor=self.actor, - agent_id=self.agent_state.id, - provider_name=self.agent_state.llm_config.model_endpoint_type, - provider_category=self.agent_state.llm_config.provider_category or "base", - model=self.agent_state.llm_config.model, - model_endpoint=self.agent_state.llm_config.model_endpoint, - context_window_limit=self.agent_state.llm_config.context_window, - usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0), - provider_id=None, - job_id=run_id, - step_id=step_id, - project_id=self.agent_state.project_id, - status=StepStatus.PENDING, - ) - - messages = await self._refresh_messages(messages) - force_tool_call = valid_tools[0]["name"] if len(valid_tools) == 1 else None - request_data = self.llm_client.build_request_data( - messages=messages, - llm_config=self.agent_state.llm_config, - tools=valid_tools, - force_tool_call=force_tool_call, - ) - if dry_run: - yield request_data - return - - step_progression, step_metrics = self._step_checkpoint_llm_request_start(step_metrics, agent_step_span) - - try: - invocation = llm_adapter.invoke_llm( - request_data=request_data, - messages=messages, - tools=valid_tools, - use_assistant_message=use_assistant_message, - step_id=step_id, - actor=self.actor, - ) - async for chunk in invocation: - if llm_adapter.supports_token_streaming(): - if include_return_message_types is None or chunk.message_type in include_return_message_types: - first_chunk = True - yield chunk - except ValueError: - self.stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_llm_response.value) - raise - - step_progression, step_metrics = self._step_checkpoint_llm_request_finish( - step_metrics, agent_step_span, llm_adapter.llm_request_finish_timestamp_ns - ) - - self._update_global_usage_stats(llm_adapter.usage) - - # Handle the AI response with the extracted data - if tool_call is None and llm_adapter.tool_call is None: - self.stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value) - raise ValueError("No tool calls found in response, model must make a tool call") - - persisted_messages, self.should_continue, self.stop_reason = await self._handle_ai_response( - tool_call or llm_adapter.tool_call, - [tool["name"] for tool in valid_tools], - self.agent_state, - self.tool_rules_solver, - UsageStatistics( - completion_tokens=self.usage.completion_tokens, - prompt_tokens=self.usage.prompt_tokens, - total_tokens=self.usage.total_tokens, - ), - reasoning_content=reasoning_content or llm_adapter.reasoning_content, - pre_computed_assistant_message_id=llm_adapter.message_id, - step_id=step_id, - initial_messages=input_messages_to_persist, - agent_step_span=agent_step_span, - is_final_step=(remaining_turns == 0), - run_id=run_id, - step_metrics=step_metrics, - is_approval=approval_response.approve if approval_response is not None else False, - is_denial=(approval_response.approve == False) if approval_response is not None else False, - denial_reason=approval_response.denial_reason if approval_response is not None else None, + # TODO(@caren): clean this up + tool_call, reasoning_content, agent_step_span, first_chunk, step_id, logged_step, step_start_ns, step_metrics = ( + None, + None, + None, + None, + None, + None, + None, + None, ) + try: + valid_tools = await self._get_valid_tools(messages) # remove messages input + approval_request, approval_response = await self._maybe_get_approval_messages(messages) + if approval_request and approval_response: + tool_call = approval_request.tool_calls[0] + reasoning_content = approval_request.content + step_id = approval_request.step_id + step_metrics = await self.step_manager.get_step_metrics_async(step_id=step_id, actor=self.actor) + else: + # Check for job cancellation at the start of each step + if run_id and await self._check_run_cancellation(run_id): + self.stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value) + self.logger.info(f"Agent execution cancelled for run {run_id}") + return - # Update step with actual usage now that we have it (if step was created) - if logged_step: - await self.step_manager.update_step_success_async( - self.actor, - step_id, + step_id = generate_step_id() + step_progression, step_metrics, agent_step_span = self._step_checkpoint_start(step_id=step_id) + + # Create step early with PENDING status + logged_step = await self.step_manager.log_step_async( + actor=self.actor, + agent_id=self.agent_state.id, + provider_name=self.agent_state.llm_config.model_endpoint_type, + provider_category=self.agent_state.llm_config.provider_category or "base", + model=self.agent_state.llm_config.model, + model_endpoint=self.agent_state.llm_config.model_endpoint, + context_window_limit=self.agent_state.llm_config.context_window, + usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0), + provider_id=None, + job_id=run_id, + step_id=step_id, + project_id=self.agent_state.project_id, + status=StepStatus.PENDING, + ) + + messages = await self._refresh_messages(messages) + force_tool_call = valid_tools[0]["name"] if len(valid_tools) == 1 else None + request_data = self.llm_client.build_request_data( + messages=messages, + llm_config=self.agent_state.llm_config, + tools=valid_tools, + force_tool_call=force_tool_call, + ) + if dry_run: + yield request_data + return + + step_progression, step_metrics = self._step_checkpoint_llm_request_start(step_metrics, agent_step_span) + + try: + invocation = llm_adapter.invoke_llm( + request_data=request_data, + messages=messages, + tools=valid_tools, + use_assistant_message=use_assistant_message, + step_id=step_id, + actor=self.actor, + ) + async for chunk in invocation: + if llm_adapter.supports_token_streaming(): + if include_return_message_types is None or chunk.message_type in include_return_message_types: + first_chunk = True + yield chunk + except ValueError: + self.stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_llm_response.value) + raise + + step_progression, step_metrics = self._step_checkpoint_llm_request_finish( + step_metrics, agent_step_span, llm_adapter.llm_request_finish_timestamp_ns + ) + + self._update_global_usage_stats(llm_adapter.usage) + + # Handle the AI response with the extracted data + if tool_call is None and llm_adapter.tool_call is None: + self.stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value) + raise ValueError("No tool calls found in response, model must make a tool call") + + persisted_messages, self.should_continue, self.stop_reason = await self._handle_ai_response( + tool_call or llm_adapter.tool_call, + [tool["name"] for tool in valid_tools], + self.agent_state, + self.tool_rules_solver, UsageStatistics( completion_tokens=self.usage.completion_tokens, prompt_tokens=self.usage.prompt_tokens, total_tokens=self.usage.total_tokens, ), - self.stop_reason, + reasoning_content=reasoning_content or llm_adapter.reasoning_content, + pre_computed_assistant_message_id=llm_adapter.message_id, + step_id=step_id, + initial_messages=input_messages_to_persist, + agent_step_span=agent_step_span, + is_final_step=(remaining_turns == 0), + run_id=run_id, + step_metrics=step_metrics, + is_approval=approval_response.approve if approval_response is not None else False, + is_denial=(approval_response.approve == False) if approval_response is not None else False, + denial_reason=approval_response.denial_reason if approval_response is not None else None, ) - step_progression = StepProgression.STEP_LOGGED - new_message_idx = len(input_messages_to_persist) if input_messages_to_persist else 0 - self.response_messages.extend(persisted_messages[new_message_idx:]) + # Update step with actual usage now that we have it (if step was created) + if logged_step: + await self.step_manager.update_step_success_async( + self.actor, + step_id, + UsageStatistics( + completion_tokens=self.usage.completion_tokens, + prompt_tokens=self.usage.prompt_tokens, + total_tokens=self.usage.total_tokens, + ), + self.stop_reason, + ) + step_progression = StepProgression.STEP_LOGGED - if llm_adapter.supports_token_streaming(): - tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0] - if not (use_assistant_message and tool_return.name == "send_message"): - if include_return_message_types is None or tool_return.message_type in include_return_message_types: - yield tool_return - else: - filter_user_messages = [m for m in persisted_messages[new_message_idx:] if m.role != "user"] - letta_messages = Message.to_letta_messages_from_list( - filter_user_messages, - use_assistant_message=use_assistant_message, - reverse=False, - ) - for message in letta_messages: - if include_return_message_types is None or message.message_type in include_return_message_types: - yield message + new_message_idx = len(input_messages_to_persist) if input_messages_to_persist else 0 + self.response_messages.extend(persisted_messages[new_message_idx:]) - step_progression, step_metrics = self._step_checkpoint_finish(step_metrics, agent_step_span, run_id) + if llm_adapter.supports_token_streaming(): + tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0] + if not (use_assistant_message and tool_return.name == "send_message"): + if include_return_message_types is None or tool_return.message_type in include_return_message_types: + yield tool_return + else: + filter_user_messages = [m for m in persisted_messages[new_message_idx:] if m.role != "user"] + letta_messages = Message.to_letta_messages_from_list( + filter_user_messages, + use_assistant_message=use_assistant_message, + reverse=False, + ) + for message in letta_messages: + if include_return_message_types is None or message.message_type in include_return_message_types: + yield message + + step_progression, step_metrics = self._step_checkpoint_finish(step_metrics, agent_step_span, run_id) + except Exception as e: + self.logger.error(f"Error during step processing: {e}") + self.job_update_metadata = {"error": str(e)} + + # This indicates we failed after we decided to stop stepping, which indicates a bug with our flow. + if not self.stop_reason: + self.stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) + elif self.stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule): + self.logger.error("Error occurred during step processing, with valid stop reason: %s", self.stop_reason.stop_reason) + elif self.stop_reason.stop_reason not in ( + StopReasonType.no_tool_call, + StopReasonType.invalid_tool_call, + StopReasonType.invalid_llm_response, + ): + self.logger.error("Error occurred during step processing, with unexpected stop reason: %s", self.stop_reason.stop_reason) + raise e + finally: + self.logger.debug("Running cleanup for agent loop run: %s", run_id) + self.logger.info("Running final update. Step Progression: %s", step_progression) + try: + if step_progression == StepProgression.FINISHED and not self.should_continue: + if self.stop_reason is None: + self.stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) + if logged_step and step_id: + await self.step_manager.update_step_stop_reason(self.actor, step_id, self.stop_reason.stop_reason) + return + if step_progression < StepProgression.STEP_LOGGED: + # Error occurred before step was fully logged + import traceback + + if logged_step: + await self.step_manager.update_step_error_async( + actor=self.actor, + step_id=step_id, # Use original step_id for telemetry + error_type=type(e).__name__ if "e" in locals() else "Unknown", + error_message=str(e) if "e" in locals() else "Unknown error", + error_traceback=traceback.format_exc(), + stop_reason=self.stop_reason, + ) + if step_progression <= StepProgression.STREAM_RECEIVED: + if first_chunk and settings.track_errored_messages and input_messages_to_persist: + for message in input_messages_to_persist: + message.is_err = True + message.step_id = step_id + await self.message_manager.create_many_messages_async( + input_messages_to_persist, + actor=self.actor, + project_id=self.agent_state.project_id, + ) + elif step_progression <= StepProgression.LOGGED_TRACE: + if self.stop_reason is None: + self.logger.error("Error in step after logging step") + self.stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) + if logged_step: + await self.step_manager.update_step_stop_reason(self.actor, step_id, self.stop_reason.stop_reason) + else: + self.logger.error("Invalid StepProgression value") + + # Do tracking for failure cases. Can consolidate with success conditions later. + if settings.track_stop_reason: + await self._log_request(request_start_timestamp_ns, None, self.job_update_metadata, is_error=True, run_id=run_id) + + # Record partial step metrics on failure (capture whatever timing data we have) + if logged_step and step_metrics and step_progression < StepProgression.FINISHED: + # Calculate total step time up to the failure point + step_metrics.step_ns = get_utc_timestamp_ns() - step_metrics.step_start_ns + + await self._record_step_metrics( + step_id=step_id, + step_metrics=step_metrics, + run_id=run_id, + ) + except Exception as e: + self.logger.error(f"Error during post-completion step tracking: {e}") def _initialize_state(self): self.should_continue = True @@ -467,27 +563,6 @@ class LettaAgentV2(BaseAgentV2): self.logger.warning(f"Failed to check job cancellation status for job {run_id}: {e}") return False - async def _create_step_trackers(self, step_id: str, step_start_ns: int, run_id: str | None = None) -> Tuple[Span, Step, StepMetrics]: - span = tracer.start_span("agent_step", start_time=step_start_ns) - span.set_attributes({"step_id": step_id, "agent_id": self.agent_state.id}) - step = await self.step_manager.log_step_async( - actor=self.actor, - agent_id=self.agent_state.id, - provider_name=self.agent_state.llm_config.model_endpoint_type, - provider_category=self.agent_state.llm_config.provider_category or "base", - model=self.agent_state.llm_config.model, - model_endpoint=self.agent_state.llm_config.model_endpoint, - context_window_limit=self.agent_state.llm_config.context_window, - usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0), - provider_id=None, - job_id=run_id, - step_id=step_id, - project_id=self.agent_state.project_id, - status=StepStatus.PENDING, - ) - metrics = StepMetrics(id=step_id) - return span, step, metrics - async def _refresh_messages(self, in_context_messages: list[Message]): num_messages = await self.message_manager.size_async( agent_id=self.agent_state.id, @@ -624,6 +699,29 @@ class LettaAgentV2(BaseAgentV2): raise ValueError(f"Invalid JSON format in message: {text_content}") return None + def _request_checkpoint_start(self, request_start_timestamp_ns: int | None) -> Span | None: + if request_start_timestamp_ns is not None: + request_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns) + request_span.set_attributes( + {f"llm_config.{k}": v for k, v in self.agent_state.llm_config.model_dump().items() if v is not None} + ) + return request_span + return None + + def _request_checkpoint_ttft(self, request_span: Span | None, request_start_timestamp_ns: int | None) -> Span | None: + if request_span: + ttft_ns = get_utc_timestamp_ns() - request_start_timestamp_ns + request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ns_to_ms(ttft_ns)}) + return request_span + return None + + def _request_checkpoint_finish(self, request_span: Span | None, request_start_timestamp_ns: int | None) -> None: + if request_span is not None: + duration_ns = get_utc_timestamp_ns() - request_start_timestamp_ns + request_span.add_event(name="letta_request_ms", attributes={"duration_ms": ns_to_ms(duration_ns)}) + request_span.end() + return None + def _step_checkpoint_start(self, step_id: str) -> Tuple[StepProgression, StepMetrics, Span]: step_start_ns = get_utc_timestamp_ns() step_metrics = StepMetrics(id=step_id, step_start_ns=step_start_ns) @@ -1009,6 +1107,43 @@ class LettaAgentV2(BaseAgentV2): ) return task + async def _log_request( + self, + request_start_timestamp_ns: int, + request_span: "Span | None", + job_update_metadata: dict | None, + is_error: bool, + run_id: str | None = None, + ): + if request_start_timestamp_ns: + now_ns, now = get_utc_timestamp_ns(), get_utc_time() + duration_ns = now_ns - request_start_timestamp_ns + if request_span: + request_span.add_event(name="letta_request_ms", attributes={"duration_ms": ns_to_ms(duration_ns)}) + await self._update_agent_last_run_metrics(now, ns_to_ms(duration_ns)) + if settings.track_agent_run and run_id: + await self.job_manager.record_response_duration(run_id, duration_ns, self.actor) + await self.job_manager.safe_update_job_status_async( + job_id=run_id, + new_status=JobStatus.failed if is_error else JobStatus.completed, + actor=self.actor, + metadata=job_update_metadata, + ) + if request_span: + request_span.end() + + async def _update_agent_last_run_metrics(self, completion_time: datetime, duration_ms: float) -> None: + if not settings.track_last_agent_run: + return + try: + await self.agent_manager.update_agent_async( + agent_id=self.agent_id, + agent_update=UpdateAgent(last_run_completion=completion_time, last_run_duration_ms=duration_ms), + actor=self.actor, + ) + except Exception as e: + self.logger.error(f"Failed to update agent's last run metrics: {e}") + def get_finish_chunks_for_stream( self, usage: LettaUsageStatistics,