From 86b073d7261a012ceae8b480c72f6106d3a287bb Mon Sep 17 00:00:00 2001 From: cthomas Date: Sun, 31 Aug 2025 11:26:21 -0700 Subject: [PATCH] feat: support approve tool call request (#4333) --- letta/agents/helpers.py | 12 +- letta/agents/letta_agent.py | 1568 +++++++++-------- letta/schemas/message.py | 4 +- .../schemas/openai/chat_completion_request.py | 2 +- letta/server/rest_api/utils.py | 54 +- tests/integration_test_human_in_the_loop.py | 21 + 6 files changed, 891 insertions(+), 770 deletions(-) diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index b4491bd4..9bfa9ad7 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -161,6 +161,10 @@ async def _prepare_in_context_messages_no_persist_async( f"Invalid approval request ID. Expected '{current_in_context_messages[-1].id}' " f"but received '{input_messages[0].approval_request_id}'." ) + if input_messages[0].approve: + new_in_context_messages = [] + else: + raise NotImplementedError("Deny flow not yet supported") else: # User is trying to send a regular message if current_in_context_messages[-1].role == "approval": @@ -169,10 +173,10 @@ async def _prepare_in_context_messages_no_persist_async( "Please approve or deny the pending request before continuing." ) - # Create a new user message from the input but dont store it yet - new_in_context_messages = create_input_messages( - input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, actor=actor - ) + # Create a new user message from the input but dont store it yet + new_in_context_messages = create_input_messages( + input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, actor=actor + ) return current_in_context_messages, new_in_context_messages diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 90fbf67e..66302b5a 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -233,145 +233,31 @@ class LettaAgent(BaseAgent): request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None}) for i in range(max_steps): - # Check for job cancellation at the start of each step - if await self._check_run_cancellation(): - stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value) - logger.info(f"Agent execution cancelled for run {self.current_run_id}") - yield f"data: {stop_reason.model_dump_json()}\n\n" - break - - step_id = generate_step_id() - step_start = get_utc_timestamp_ns() - agent_step_span = tracer.start_span("agent_step", start_time=step_start) - agent_step_span.set_attributes({"step_id": step_id}) - - step_progression = StepProgression.START - should_continue = False - step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking - - # Create step early with PENDING status - logged_step = await self.step_manager.log_step_async( - actor=self.actor, - agent_id=agent_state.id, - provider_name=agent_state.llm_config.model_endpoint_type, - provider_category=agent_state.llm_config.provider_category or "base", - model=agent_state.llm_config.model, - model_endpoint=agent_state.llm_config.model_endpoint, - context_window_limit=agent_state.llm_config.context_window, - usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0), - provider_id=None, - job_id=self.current_run_id if self.current_run_id else None, - step_id=step_id, - project_id=agent_state.project_id, - status=StepStatus.PENDING, - ) - # Only use step_id in messages if step was actually created - effective_step_id = step_id if logged_step else None - - try: - ( - request_data, - response_data, - current_in_context_messages, - new_in_context_messages, - valid_tool_names, - ) = await self._build_and_request_from_llm( - current_in_context_messages, - new_in_context_messages, - agent_state, - llm_client, - tool_rules_solver, - agent_step_span, - step_metrics, - ) - in_context_messages = current_in_context_messages + new_in_context_messages - - step_progression = StepProgression.RESPONSE_RECEIVED - log_event("agent.stream_no_tokens.llm_response.received") # [3^] - - try: - response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config) - except ValueError as e: - stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_llm_response.value) - raise e - - # update usage - usage.step_count += 1 - usage.completion_tokens += response.usage.completion_tokens - usage.prompt_tokens += response.usage.prompt_tokens - usage.total_tokens += response.usage.total_tokens - MetricRegistry().message_output_tokens.record( - response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}) - ) - - if not response.choices[0].message.tool_calls: - stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value) - raise ValueError("No tool calls found in response, model must make a tool call") - tool_call = response.choices[0].message.tool_calls[0] - if response.choices[0].message.reasoning_content: - reasoning = [ - ReasoningContent( - reasoning=response.choices[0].message.reasoning_content, - is_native=True, - signature=response.choices[0].message.reasoning_content_signature, - ) - ] - elif response.choices[0].message.omitted_reasoning_content: - reasoning = [OmittedReasoningContent()] - elif response.choices[0].message.content: - reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons - else: - self.logger.info("No reasoning content found.") - reasoning = None - + if not new_in_context_messages and current_in_context_messages[-1].role == "approval": + approval_request_message = current_in_context_messages[-1] + step_metrics = await self.step_manager.get_step_metrics_async(step_id=approval_request_message.step_id, actor=self.actor) persisted_messages, should_continue, stop_reason = await self._handle_ai_response( - tool_call, - valid_tool_names, + approval_request_message.tool_calls[0], + [], # TODO: update this agent_state, tool_rules_solver, - response.usage, - reasoning_content=reasoning, - step_id=effective_step_id, - initial_messages=initial_messages, - agent_step_span=agent_step_span, + usage, + reasoning_content=approval_request_message.content, + step_id=approval_request_message.step_id, + initial_messages=[], is_final_step=(i == max_steps - 1), step_metrics=step_metrics, + run_id=self.current_run_id, + is_approval=True, ) - step_progression = StepProgression.STEP_LOGGED - - # Update step with actual usage now that we have it (if step was created) - if logged_step: - await self.step_manager.update_step_success_async(self.actor, step_id, response.usage, stop_reason) - - # TODO (cliandy): handle message contexts with larger refactor and dedupe logic - new_message_idx = len(initial_messages) if initial_messages else 0 - self.response_messages.extend(persisted_messages[new_message_idx:]) - new_in_context_messages.extend(persisted_messages[new_message_idx:]) + new_message_idx = 0 + self.response_messages.extend(persisted_messages) + new_in_context_messages.extend(persisted_messages) initial_messages = None - log_event("agent.stream_no_tokens.llm_response.processed") # [4^] - - # log step time - now = get_utc_timestamp_ns() - step_ns = now - step_start - agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)}) - agent_step_span.end() - - # Log LLM Trace - if settings.track_provider_trace: - await self.telemetry_manager.create_provider_trace_async( - actor=self.actor, - provider_trace_create=ProviderTraceCreate( - request_json=request_data, - response_json=response_data, - step_id=step_id, # Use original step_id for telemetry - organization_id=self.actor.organization_id, - ), - ) - step_progression = StepProgression.LOGGED_TRACE # stream step # TODO: improve TTFT - filter_user_messages = [m for m in persisted_messages if m.role != "user"] + filter_user_messages = [m for m in persisted_messages if m.role != "user" and m.role != "approval"] letta_messages = Message.to_letta_messages_from_list( filter_user_messages, use_assistant_message=use_assistant_message, reverse=False ) @@ -379,108 +265,259 @@ class LettaAgent(BaseAgent): for message in letta_messages: if include_return_message_types is None or message.message_type in include_return_message_types: yield f"data: {message.model_dump_json()}\n\n" + else: + # Check for job cancellation at the start of each step + if await self._check_run_cancellation(): + stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value) + logger.info(f"Agent execution cancelled for run {self.current_run_id}") + yield f"data: {stop_reason.model_dump_json()}\n\n" + break - MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes()) - step_progression = StepProgression.FINISHED + step_id = generate_step_id() + step_start = get_utc_timestamp_ns() + agent_step_span = tracer.start_span("agent_step", start_time=step_start) + agent_step_span.set_attributes({"step_id": step_id}) - # Record step metrics for successful completion - if logged_step and step_metrics: - # Set the step_ns that was already calculated - step_metrics.step_ns = step_ns - await self._record_step_metrics( - step_id=step_id, - agent_state=agent_state, - step_metrics=step_metrics, + step_progression = StepProgression.START + should_continue = False + step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking + + # Create step early with PENDING status + logged_step = await self.step_manager.log_step_async( + actor=self.actor, + agent_id=agent_state.id, + provider_name=agent_state.llm_config.model_endpoint_type, + provider_category=agent_state.llm_config.provider_category or "base", + model=agent_state.llm_config.model, + model_endpoint=agent_state.llm_config.model_endpoint, + context_window_limit=agent_state.llm_config.context_window, + usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0), + provider_id=None, + job_id=self.current_run_id if self.current_run_id else None, + step_id=step_id, + project_id=agent_state.project_id, + status=StepStatus.PENDING, + ) + # Only use step_id in messages if step was actually created + effective_step_id = step_id if logged_step else None + + try: + ( + request_data, + response_data, + current_in_context_messages, + new_in_context_messages, + valid_tool_names, + ) = await self._build_and_request_from_llm( + current_in_context_messages, + new_in_context_messages, + agent_state, + llm_client, + tool_rules_solver, + agent_step_span, + step_metrics, + ) + in_context_messages = current_in_context_messages + new_in_context_messages + + step_progression = StepProgression.RESPONSE_RECEIVED + log_event("agent.stream_no_tokens.llm_response.received") # [3^] + + try: + response = llm_client.convert_response_to_chat_completion( + response_data, in_context_messages, agent_state.llm_config + ) + except ValueError as e: + stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_llm_response.value) + raise e + + # update usage + usage.step_count += 1 + usage.completion_tokens += response.usage.completion_tokens + usage.prompt_tokens += response.usage.prompt_tokens + usage.total_tokens += response.usage.total_tokens + MetricRegistry().message_output_tokens.record( + response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}) ) - except Exception as e: - # Handle any unexpected errors during step processing - self.logger.error(f"Error during step processing: {e}") - job_update_metadata = {"error": str(e)} - - # This indicates we failed after we decided to stop stepping, which indicates a bug with our flow. - if not stop_reason: - stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) - elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule): - self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason) - elif stop_reason.stop_reason not in ( - StopReasonType.no_tool_call, - StopReasonType.invalid_tool_call, - StopReasonType.invalid_llm_response, - ): - self.logger.error("Error occurred during step processing, with unexpected stop reason: %s", stop_reason.stop_reason) - - # Send error stop reason to client and re-raise - yield f"data: {stop_reason.model_dump_json()}\n\n", 500 - raise - - # Update step if it needs to be updated - finally: - if step_progression == StepProgression.FINISHED and should_continue: - continue - - self.logger.debug("Running cleanup for agent loop run: %s", self.current_run_id) - self.logger.info("Running final update. Step Progression: %s", step_progression) - try: - if step_progression == StepProgression.FINISHED and not should_continue: - # Successfully completed - update with final usage and stop reason - if stop_reason is None: - stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) - # Note: step already updated with success status after _handle_ai_response - if logged_step: - await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) - break - - # Handle error cases - if step_progression < StepProgression.STEP_LOGGED: - # Error occurred before step was fully logged - import traceback - - if logged_step: - await self.step_manager.update_step_error_async( - actor=self.actor, - step_id=step_id, # Use original step_id for telemetry - error_type=type(e).__name__ if "e" in locals() else "Unknown", - error_message=str(e) if "e" in locals() else "Unknown error", - error_traceback=traceback.format_exc(), - stop_reason=stop_reason, + if not response.choices[0].message.tool_calls: + stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value) + raise ValueError("No tool calls found in response, model must make a tool call") + tool_call = response.choices[0].message.tool_calls[0] + if response.choices[0].message.reasoning_content: + reasoning = [ + ReasoningContent( + reasoning=response.choices[0].message.reasoning_content, + is_native=True, + signature=response.choices[0].message.reasoning_content_signature, ) - - if step_progression <= StepProgression.RESPONSE_RECEIVED: - # TODO (cliandy): persist response if we get it back - if settings.track_errored_messages and initial_messages: - for message in initial_messages: - message.is_err = True - message.step_id = effective_step_id - await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor) - elif step_progression <= StepProgression.LOGGED_TRACE: - if stop_reason is None: - self.logger.error("Error in step after logging step") - stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) - if logged_step: - await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) + ] + elif response.choices[0].message.omitted_reasoning_content: + reasoning = [OmittedReasoningContent()] + elif response.choices[0].message.content: + reasoning = [ + TextContent(text=response.choices[0].message.content) + ] # reasoning placed into content for legacy reasons else: - self.logger.error("Invalid StepProgression value") + self.logger.info("No reasoning content found.") + reasoning = None - if settings.track_stop_reason: - await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True) + persisted_messages, should_continue, stop_reason = await self._handle_ai_response( + tool_call, + valid_tool_names, + agent_state, + tool_rules_solver, + response.usage, + reasoning_content=reasoning, + step_id=effective_step_id, + initial_messages=initial_messages, + agent_step_span=agent_step_span, + is_final_step=(i == max_steps - 1), + step_metrics=step_metrics, + ) + step_progression = StepProgression.STEP_LOGGED - # Record partial step metrics on failure (capture whatever timing data we have) - if logged_step and step_metrics and step_progression < StepProgression.FINISHED: - # Calculate total step time up to the failure point - step_metrics.step_ns = get_utc_timestamp_ns() - step_start + # Update step with actual usage now that we have it (if step was created) + if logged_step: + await self.step_manager.update_step_success_async(self.actor, step_id, response.usage, stop_reason) + + # TODO (cliandy): handle message contexts with larger refactor and dedupe logic + new_message_idx = len(initial_messages) if initial_messages else 0 + self.response_messages.extend(persisted_messages[new_message_idx:]) + new_in_context_messages.extend(persisted_messages[new_message_idx:]) + initial_messages = None + log_event("agent.stream_no_tokens.llm_response.processed") # [4^] + + # log step time + now = get_utc_timestamp_ns() + step_ns = now - step_start + agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)}) + agent_step_span.end() + + # Log LLM Trace + if settings.track_provider_trace: + await self.telemetry_manager.create_provider_trace_async( + actor=self.actor, + provider_trace_create=ProviderTraceCreate( + request_json=request_data, + response_json=response_data, + step_id=step_id, # Use original step_id for telemetry + organization_id=self.actor.organization_id, + ), + ) + step_progression = StepProgression.LOGGED_TRACE + + # stream step + # TODO: improve TTFT + filter_user_messages = [m for m in persisted_messages if m.role != "user"] + letta_messages = Message.to_letta_messages_from_list( + filter_user_messages, use_assistant_message=use_assistant_message, reverse=False + ) + + for message in letta_messages: + if include_return_message_types is None or message.message_type in include_return_message_types: + yield f"data: {message.model_dump_json()}\n\n" + + MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes()) + step_progression = StepProgression.FINISHED + + # Record step metrics for successful completion + if logged_step and step_metrics: + # Set the step_ns that was already calculated + step_metrics.step_ns = step_ns await self._record_step_metrics( step_id=step_id, agent_state=agent_state, step_metrics=step_metrics, - job_id=locals().get("run_id", self.current_run_id), ) except Exception as e: - self.logger.error("Failed to update step: %s", e) + # Handle any unexpected errors during step processing + self.logger.error(f"Error during step processing: {e}") + job_update_metadata = {"error": str(e)} - if not should_continue: - break + # This indicates we failed after we decided to stop stepping, which indicates a bug with our flow. + if not stop_reason: + stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) + elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule): + self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason) + elif stop_reason.stop_reason not in ( + StopReasonType.no_tool_call, + StopReasonType.invalid_tool_call, + StopReasonType.invalid_llm_response, + ): + self.logger.error("Error occurred during step processing, with unexpected stop reason: %s", stop_reason.stop_reason) + + # Send error stop reason to client and re-raise + yield f"data: {stop_reason.model_dump_json()}\n\n", 500 + raise + + # Update step if it needs to be updated + finally: + if step_progression == StepProgression.FINISHED and should_continue: + continue + + self.logger.debug("Running cleanup for agent loop run: %s", self.current_run_id) + self.logger.info("Running final update. Step Progression: %s", step_progression) + try: + if step_progression == StepProgression.FINISHED and not should_continue: + # Successfully completed - update with final usage and stop reason + if stop_reason is None: + stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) + # Note: step already updated with success status after _handle_ai_response + if logged_step: + await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) + break + + # Handle error cases + if step_progression < StepProgression.STEP_LOGGED: + # Error occurred before step was fully logged + import traceback + + if logged_step: + await self.step_manager.update_step_error_async( + actor=self.actor, + step_id=step_id, # Use original step_id for telemetry + error_type=type(e).__name__ if "e" in locals() else "Unknown", + error_message=str(e) if "e" in locals() else "Unknown error", + error_traceback=traceback.format_exc(), + stop_reason=stop_reason, + ) + + if step_progression <= StepProgression.RESPONSE_RECEIVED: + # TODO (cliandy): persist response if we get it back + if settings.track_errored_messages and initial_messages: + for message in initial_messages: + message.is_err = True + message.step_id = effective_step_id + await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor) + elif step_progression <= StepProgression.LOGGED_TRACE: + if stop_reason is None: + self.logger.error("Error in step after logging step") + stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) + if logged_step: + await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) + else: + self.logger.error("Invalid StepProgression value") + + if settings.track_stop_reason: + await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True) + + # Record partial step metrics on failure (capture whatever timing data we have) + if logged_step and step_metrics and step_progression < StepProgression.FINISHED: + # Calculate total step time up to the failure point + step_metrics.step_ns = get_utc_timestamp_ns() - step_start + await self._record_step_metrics( + step_id=step_id, + agent_state=agent_state, + step_metrics=step_metrics, + job_id=locals().get("run_id", self.current_run_id), + ) + + except Exception as e: + self.logger.error("Failed to update step: %s", e) + + if not should_continue: + break # Extend the in context message ids if not agent_state.message_buffer_autoclear: @@ -533,247 +570,273 @@ class LettaAgent(BaseAgent): job_update_metadata = None usage = LettaUsageStatistics() for i in range(max_steps): - # If dry run, build request data and return it without making LLM call - if dry_run: - request_data, valid_tool_names = await self._create_llm_request_data_async( - llm_client=llm_client, - in_context_messages=current_in_context_messages + new_in_context_messages, - agent_state=agent_state, - tool_rules_solver=tool_rules_solver, - ) - return request_data - - # Check for job cancellation at the start of each step - if await self._check_run_cancellation(): - stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value) - logger.info(f"Agent execution cancelled for run {self.current_run_id}") - break - - step_id = generate_step_id() - step_start = get_utc_timestamp_ns() - agent_step_span = tracer.start_span("agent_step", start_time=step_start) - agent_step_span.set_attributes({"step_id": step_id}) - - step_progression = StepProgression.START - should_continue = False - step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking - - # Create step early with PENDING status - logged_step = await self.step_manager.log_step_async( - actor=self.actor, - agent_id=agent_state.id, - provider_name=agent_state.llm_config.model_endpoint_type, - provider_category=agent_state.llm_config.provider_category or "base", - model=agent_state.llm_config.model, - model_endpoint=agent_state.llm_config.model_endpoint, - context_window_limit=agent_state.llm_config.context_window, - usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0), - provider_id=None, - job_id=run_id if run_id else self.current_run_id, - step_id=step_id, - project_id=agent_state.project_id, - status=StepStatus.PENDING, - ) - # Only use step_id in messages if step was actually created - effective_step_id = step_id if logged_step else None - - try: - ( - request_data, - response_data, - current_in_context_messages, - new_in_context_messages, - valid_tool_names, - ) = await self._build_and_request_from_llm( - current_in_context_messages, - new_in_context_messages, - agent_state, - llm_client, - tool_rules_solver, - agent_step_span, - step_metrics, - ) - in_context_messages = current_in_context_messages + new_in_context_messages - - step_progression = StepProgression.RESPONSE_RECEIVED - log_event("agent.step.llm_response.received") # [3^] - - try: - response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config) - except ValueError as e: - stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_llm_response.value) - raise e - - usage.step_count += 1 - usage.completion_tokens += response.usage.completion_tokens - usage.prompt_tokens += response.usage.prompt_tokens - usage.total_tokens += response.usage.total_tokens - usage.run_ids = [run_id] if run_id else None - MetricRegistry().message_output_tokens.record( - response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}) - ) - - if not response.choices[0].message.tool_calls: - stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value) - raise ValueError("No tool calls found in response, model must make a tool call") - tool_call = response.choices[0].message.tool_calls[0] - if response.choices[0].message.reasoning_content: - reasoning = [ - ReasoningContent( - reasoning=response.choices[0].message.reasoning_content, - is_native=True, - signature=response.choices[0].message.reasoning_content_signature, - ) - ] - elif response.choices[0].message.content: - reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons - elif response.choices[0].message.omitted_reasoning_content: - reasoning = [OmittedReasoningContent()] - else: - self.logger.info("No reasoning content found.") - reasoning = None - + if not new_in_context_messages and current_in_context_messages[-1].role == "approval": + approval_request_message = current_in_context_messages[-1] + step_metrics = await self.step_manager.get_step_metrics_async(step_id=approval_request_message.step_id, actor=self.actor) persisted_messages, should_continue, stop_reason = await self._handle_ai_response( - tool_call, - valid_tool_names, + approval_request_message.tool_calls[0], + [], # TODO: update this agent_state, tool_rules_solver, - response.usage, - reasoning_content=reasoning, - step_id=effective_step_id, - initial_messages=initial_messages, - agent_step_span=agent_step_span, + usage, + reasoning_content=approval_request_message.content, + step_id=approval_request_message.step_id, + initial_messages=[], is_final_step=(i == max_steps - 1), - run_id=run_id, step_metrics=step_metrics, + run_id=run_id or self.current_run_id, + is_approval=True, ) - step_progression = StepProgression.STEP_LOGGED - - # Update step with actual usage now that we have it (if step was created) - if logged_step: - await self.step_manager.update_step_success_async(self.actor, step_id, response.usage, stop_reason) - - new_message_idx = len(initial_messages) if initial_messages else 0 - self.response_messages.extend(persisted_messages[new_message_idx:]) - new_in_context_messages.extend(persisted_messages[new_message_idx:]) - + new_message_idx = 0 + self.response_messages.extend(persisted_messages) + new_in_context_messages.extend(persisted_messages) initial_messages = None - log_event("agent.step.llm_response.processed") # [4^] - - # log step time - now = get_utc_timestamp_ns() - step_ns = now - step_start - agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)}) - agent_step_span.end() - - # Log LLM Trace - if settings.track_provider_trace: - await self.telemetry_manager.create_provider_trace_async( - actor=self.actor, - provider_trace_create=ProviderTraceCreate( - request_json=request_data, - response_json=response_data, - step_id=step_id, # Use original step_id for telemetry - organization_id=self.actor.organization_id, - ), - ) - step_progression = StepProgression.LOGGED_TRACE - - MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes()) - step_progression = StepProgression.FINISHED - - # Record step metrics for successful completion - if logged_step and step_metrics: - # Set the step_ns that was already calculated - step_metrics.step_ns = step_ns - await self._record_step_metrics( - step_id=step_id, + else: + # If dry run, build request data and return it without making LLM call + if dry_run: + request_data, valid_tool_names = await self._create_llm_request_data_async( + llm_client=llm_client, + in_context_messages=current_in_context_messages + new_in_context_messages, agent_state=agent_state, - step_metrics=step_metrics, - job_id=run_id if run_id else self.current_run_id, + tool_rules_solver=tool_rules_solver, + ) + return request_data + + # Check for job cancellation at the start of each step + if await self._check_run_cancellation(): + stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value) + logger.info(f"Agent execution cancelled for run {self.current_run_id}") + break + + step_id = generate_step_id() + step_start = get_utc_timestamp_ns() + agent_step_span = tracer.start_span("agent_step", start_time=step_start) + agent_step_span.set_attributes({"step_id": step_id}) + + step_progression = StepProgression.START + should_continue = False + step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking + + # Create step early with PENDING status + logged_step = await self.step_manager.log_step_async( + actor=self.actor, + agent_id=agent_state.id, + provider_name=agent_state.llm_config.model_endpoint_type, + provider_category=agent_state.llm_config.provider_category or "base", + model=agent_state.llm_config.model, + model_endpoint=agent_state.llm_config.model_endpoint, + context_window_limit=agent_state.llm_config.context_window, + usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0), + provider_id=None, + job_id=run_id if run_id else self.current_run_id, + step_id=step_id, + project_id=agent_state.project_id, + status=StepStatus.PENDING, + ) + # Only use step_id in messages if step was actually created + effective_step_id = step_id if logged_step else None + + try: + ( + request_data, + response_data, + current_in_context_messages, + new_in_context_messages, + valid_tool_names, + ) = await self._build_and_request_from_llm( + current_in_context_messages, + new_in_context_messages, + agent_state, + llm_client, + tool_rules_solver, + agent_step_span, + step_metrics, + ) + in_context_messages = current_in_context_messages + new_in_context_messages + + step_progression = StepProgression.RESPONSE_RECEIVED + log_event("agent.step.llm_response.received") # [3^] + + try: + response = llm_client.convert_response_to_chat_completion( + response_data, in_context_messages, agent_state.llm_config + ) + except ValueError as e: + stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_llm_response.value) + raise e + + usage.step_count += 1 + usage.completion_tokens += response.usage.completion_tokens + usage.prompt_tokens += response.usage.prompt_tokens + usage.total_tokens += response.usage.total_tokens + usage.run_ids = [run_id] if run_id else None + MetricRegistry().message_output_tokens.record( + response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}) ) - except Exception as e: - # Handle any unexpected errors during step processing - self.logger.error(f"Error during step processing: {e}") - job_update_metadata = {"error": str(e)} - - # This indicates we failed after we decided to stop stepping, which indicates a bug with our flow. - if not stop_reason: - stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) - elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule): - self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason) - elif stop_reason.stop_reason not in ( - StopReasonType.no_tool_call, - StopReasonType.invalid_tool_call, - StopReasonType.invalid_llm_response, - ): - self.logger.error("Error occurred during step processing, with unexpected stop reason: %s", stop_reason.stop_reason) - raise - - # Update step if it needs to be updated - finally: - if step_progression == StepProgression.FINISHED and should_continue: - continue - - self.logger.debug("Running cleanup for agent loop run: %s", self.current_run_id) - self.logger.info("Running final update. Step Progression: %s", step_progression) - try: - if step_progression == StepProgression.FINISHED and not should_continue: - # Successfully completed - update with final usage and stop reason - if stop_reason is None: - stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) - if logged_step: - await self.step_manager.update_step_success_async(self.actor, step_id, usage, stop_reason) - break - - # Handle error cases - if step_progression < StepProgression.STEP_LOGGED: - # Error occurred before step was fully logged - import traceback - - if logged_step: - await self.step_manager.update_step_error_async( - actor=self.actor, - step_id=step_id, # Use original step_id for telemetry - error_type=type(e).__name__ if "e" in locals() else "Unknown", - error_message=str(e) if "e" in locals() else "Unknown error", - error_traceback=traceback.format_exc(), - stop_reason=stop_reason, + if not response.choices[0].message.tool_calls: + stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value) + raise ValueError("No tool calls found in response, model must make a tool call") + tool_call = response.choices[0].message.tool_calls[0] + if response.choices[0].message.reasoning_content: + reasoning = [ + ReasoningContent( + reasoning=response.choices[0].message.reasoning_content, + is_native=True, + signature=response.choices[0].message.reasoning_content_signature, ) - - if step_progression <= StepProgression.RESPONSE_RECEIVED: - # TODO (cliandy): persist response if we get it back - if settings.track_errored_messages and initial_messages: - for message in initial_messages: - message.is_err = True - message.step_id = effective_step_id - await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor) - elif step_progression <= StepProgression.LOGGED_TRACE: - if stop_reason is None: - self.logger.error("Error in step after logging step") - stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) - if logged_step: - await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) + ] + elif response.choices[0].message.content: + reasoning = [ + TextContent(text=response.choices[0].message.content) + ] # reasoning placed into content for legacy reasons + elif response.choices[0].message.omitted_reasoning_content: + reasoning = [OmittedReasoningContent()] else: - self.logger.error("Invalid StepProgression value") + self.logger.info("No reasoning content found.") + reasoning = None - if settings.track_stop_reason: - await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True) + persisted_messages, should_continue, stop_reason = await self._handle_ai_response( + tool_call, + valid_tool_names, + agent_state, + tool_rules_solver, + response.usage, + reasoning_content=reasoning, + step_id=effective_step_id, + initial_messages=initial_messages, + agent_step_span=agent_step_span, + is_final_step=(i == max_steps - 1), + run_id=run_id, + step_metrics=step_metrics, + ) + step_progression = StepProgression.STEP_LOGGED - # Record partial step metrics on failure (capture whatever timing data we have) - if logged_step and step_metrics and step_progression < StepProgression.FINISHED: - # Calculate total step time up to the failure point - step_metrics.step_ns = get_utc_timestamp_ns() - step_start + # Update step with actual usage now that we have it (if step was created) + if logged_step: + await self.step_manager.update_step_success_async(self.actor, step_id, response.usage, stop_reason) + + new_message_idx = len(initial_messages) if initial_messages else 0 + self.response_messages.extend(persisted_messages[new_message_idx:]) + new_in_context_messages.extend(persisted_messages[new_message_idx:]) + + initial_messages = None + log_event("agent.step.llm_response.processed") # [4^] + + # log step time + now = get_utc_timestamp_ns() + step_ns = now - step_start + agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)}) + agent_step_span.end() + + # Log LLM Trace + if settings.track_provider_trace: + await self.telemetry_manager.create_provider_trace_async( + actor=self.actor, + provider_trace_create=ProviderTraceCreate( + request_json=request_data, + response_json=response_data, + step_id=step_id, # Use original step_id for telemetry + organization_id=self.actor.organization_id, + ), + ) + step_progression = StepProgression.LOGGED_TRACE + + MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes()) + step_progression = StepProgression.FINISHED + + # Record step metrics for successful completion + if logged_step and step_metrics: + # Set the step_ns that was already calculated + step_metrics.step_ns = step_ns await self._record_step_metrics( step_id=step_id, agent_state=agent_state, step_metrics=step_metrics, - job_id=locals().get("run_id", self.current_run_id), + job_id=run_id if run_id else self.current_run_id, ) except Exception as e: - self.logger.error("Failed to update step: %s", e) + # Handle any unexpected errors during step processing + self.logger.error(f"Error during step processing: {e}") + job_update_metadata = {"error": str(e)} + + # This indicates we failed after we decided to stop stepping, which indicates a bug with our flow. + if not stop_reason: + stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) + elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule): + self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason) + elif stop_reason.stop_reason not in ( + StopReasonType.no_tool_call, + StopReasonType.invalid_tool_call, + StopReasonType.invalid_llm_response, + ): + self.logger.error("Error occurred during step processing, with unexpected stop reason: %s", stop_reason.stop_reason) + raise + + # Update step if it needs to be updated + finally: + if step_progression == StepProgression.FINISHED and should_continue: + continue + + self.logger.debug("Running cleanup for agent loop run: %s", self.current_run_id) + self.logger.info("Running final update. Step Progression: %s", step_progression) + try: + if step_progression == StepProgression.FINISHED and not should_continue: + # Successfully completed - update with final usage and stop reason + if stop_reason is None: + stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) + if logged_step: + await self.step_manager.update_step_success_async(self.actor, step_id, usage, stop_reason) + break + + # Handle error cases + if step_progression < StepProgression.STEP_LOGGED: + # Error occurred before step was fully logged + import traceback + + if logged_step: + await self.step_manager.update_step_error_async( + actor=self.actor, + step_id=step_id, # Use original step_id for telemetry + error_type=type(e).__name__ if "e" in locals() else "Unknown", + error_message=str(e) if "e" in locals() else "Unknown error", + error_traceback=traceback.format_exc(), + stop_reason=stop_reason, + ) + + if step_progression <= StepProgression.RESPONSE_RECEIVED: + # TODO (cliandy): persist response if we get it back + if settings.track_errored_messages and initial_messages: + for message in initial_messages: + message.is_err = True + message.step_id = effective_step_id + await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor) + elif step_progression <= StepProgression.LOGGED_TRACE: + if stop_reason is None: + self.logger.error("Error in step after logging step") + stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) + if logged_step: + await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) + else: + self.logger.error("Invalid StepProgression value") + + if settings.track_stop_reason: + await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True) + + # Record partial step metrics on failure (capture whatever timing data we have) + if logged_step and step_metrics and step_progression < StepProgression.FINISHED: + # Calculate total step time up to the failure point + step_metrics.step_ns = get_utc_timestamp_ns() - step_start + await self._record_step_metrics( + step_id=step_id, + agent_state=agent_state, + step_metrics=step_metrics, + job_id=locals().get("run_id", self.current_run_id), + ) + + except Exception as e: + self.logger.error("Failed to update step: %s", e) if not should_continue: break @@ -846,327 +909,265 @@ class LettaAgent(BaseAgent): request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None}) for i in range(max_steps): - step_id = generate_step_id() - # Check for job cancellation at the start of each step - if await self._check_run_cancellation(): - stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value) - logger.info(f"Agent execution cancelled for run {self.current_run_id}") - yield f"data: {stop_reason.model_dump_json()}\n\n" - break - - step_start = get_utc_timestamp_ns() - agent_step_span = tracer.start_span("agent_step", start_time=step_start) - agent_step_span.set_attributes({"step_id": step_id}) - - step_progression = StepProgression.START - should_continue = False - step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking - - # Create step early with PENDING status - logged_step = await self.step_manager.log_step_async( - actor=self.actor, - agent_id=agent_state.id, - provider_name=agent_state.llm_config.model_endpoint_type, - provider_category=agent_state.llm_config.provider_category or "base", - model=agent_state.llm_config.model, - model_endpoint=agent_state.llm_config.model_endpoint, - context_window_limit=agent_state.llm_config.context_window, - usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0), - provider_id=None, - job_id=self.current_run_id if self.current_run_id else None, - step_id=step_id, - project_id=agent_state.project_id, - status=StepStatus.PENDING, - ) - # Only use step_id in messages if step was actually created - effective_step_id = step_id if logged_step else None - - try: - ( - request_data, - stream, - current_in_context_messages, - new_in_context_messages, - valid_tool_names, - provider_request_start_timestamp_ns, - ) = await self._build_and_request_from_llm_streaming( - first_chunk, - agent_step_span, - request_start_timestamp_ns, - current_in_context_messages, - new_in_context_messages, - agent_state, - llm_client, - tool_rules_solver, - ) - - step_progression = StepProgression.STREAM_RECEIVED - log_event("agent.stream.llm_response.received") # [3^] - - # TODO: THIS IS INCREDIBLY UGLY - # TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED - if agent_state.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]: - interface = AnthropicStreamingInterface( - use_assistant_message=use_assistant_message, - put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs, - ) - elif agent_state.llm_config.model_endpoint_type == ProviderType.openai: - interface = OpenAIStreamingInterface( - use_assistant_message=use_assistant_message, - is_openai_proxy=agent_state.llm_config.provider_name == "lmstudio_openai", - messages=current_in_context_messages + new_in_context_messages, - tools=request_data.get("tools", []), - ) - else: - raise ValueError(f"Streaming not supported for {agent_state.llm_config}") - - async for chunk in interface.process( - stream, - ttft_span=request_span, - ): - # Measure TTFT (trace, metric, and db). This should be consolidated. - if first_chunk and request_span is not None: - now = get_utc_timestamp_ns() - ttft_ns = now - request_start_timestamp_ns - - request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ns_to_ms(ttft_ns)}) - metric_attributes = get_ctx_attributes() - metric_attributes["model.name"] = agent_state.llm_config.model - MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes) - - if self.current_run_id and self.job_manager: - await self.job_manager.record_ttft(self.current_run_id, ttft_ns, self.actor) - - first_chunk = False - - if include_return_message_types is None or chunk.message_type in include_return_message_types: - # filter down returned data - yield f"data: {chunk.model_dump_json()}\n\n" - - stream_end_time_ns = get_utc_timestamp_ns() - - # Some providers that rely on the OpenAI client currently e.g. LMStudio don't get usage metrics back on the last streaming chunk, fall back to manual values - if isinstance(interface, OpenAIStreamingInterface) and not interface.input_tokens and not interface.output_tokens: - logger.warning( - f"No token usage metrics received from OpenAI streaming interface for {agent_state.llm_config.model}, falling back to estimated values. Input tokens: {interface.fallback_input_tokens}, Output tokens: {interface.fallback_output_tokens}" - ) - interface.input_tokens = interface.fallback_input_tokens - interface.output_tokens = interface.fallback_output_tokens - - usage.step_count += 1 - usage.completion_tokens += interface.output_tokens - usage.prompt_tokens += interface.input_tokens - usage.total_tokens += interface.input_tokens + interface.output_tokens - MetricRegistry().message_output_tokens.record( - usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}) - ) - - # log LLM request time - llm_request_ns = stream_end_time_ns - provider_request_start_timestamp_ns - step_metrics.llm_request_ns = llm_request_ns - - llm_request_ms = ns_to_ms(llm_request_ns) - agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ms}) - MetricRegistry().llm_execution_time_ms_histogram.record( - llm_request_ms, - dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}), - ) - - # Process resulting stream content - try: - tool_call = interface.get_tool_call_object() - except ValueError as e: - stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value) - raise e - except Exception as e: - stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value) - raise e - reasoning_content = interface.get_reasoning_content() + if not new_in_context_messages and current_in_context_messages[-1].role == "approval": + approval_request_message = current_in_context_messages[-1] + step_metrics = await self.step_manager.get_step_metrics_async(step_id=approval_request_message.step_id, actor=self.actor) persisted_messages, should_continue, stop_reason = await self._handle_ai_response( - tool_call, - valid_tool_names, + approval_request_message.tool_calls[0], + [], # TODO: update this agent_state, tool_rules_solver, - UsageStatistics( - completion_tokens=usage.completion_tokens, - prompt_tokens=usage.prompt_tokens, - total_tokens=usage.total_tokens, - ), - reasoning_content=reasoning_content, - pre_computed_assistant_message_id=interface.letta_message_id, - step_id=effective_step_id, - initial_messages=initial_messages, - agent_step_span=agent_step_span, + usage, + reasoning_content=approval_request_message.content, + step_id=approval_request_message.step_id, + initial_messages=[], is_final_step=(i == max_steps - 1), step_metrics=step_metrics, + run_id=self.current_run_id, + is_approval=True, ) - step_progression = StepProgression.STEP_LOGGED - - # Update step with actual usage now that we have it (if step was created) - if logged_step: - await self.step_manager.update_step_success_async( - self.actor, - step_id, - UsageStatistics( - completion_tokens=usage.completion_tokens, - prompt_tokens=usage.prompt_tokens, - total_tokens=usage.total_tokens, - ), - stop_reason, - ) - - new_message_idx = len(initial_messages) if initial_messages else 0 - self.response_messages.extend(persisted_messages[new_message_idx:]) - new_in_context_messages.extend(persisted_messages[new_message_idx:]) - + new_message_idx = 0 + self.response_messages.extend(persisted_messages) + new_in_context_messages.extend(persisted_messages) initial_messages = None - # log total step time - now = get_utc_timestamp_ns() - step_ns = now - step_start - agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)}) - agent_step_span.end() - - # TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream - # log_event("agent.stream.llm_response.processed") # [4^] - - # Log LLM Trace - # We are piecing together the streamed response here. - # Content here does not match the actual response schema as streams come in chunks. - if settings.track_provider_trace: - await self.telemetry_manager.create_provider_trace_async( - actor=self.actor, - provider_trace_create=ProviderTraceCreate( - request_json=request_data, - response_json={ - "content": { - "tool_call": tool_call.model_dump_json(), - "reasoning": [content.model_dump_json() for content in reasoning_content], - }, - "id": interface.message_id, - "model": interface.model, - "role": "assistant", - # "stop_reason": "", - # "stop_sequence": None, - "type": "message", - "usage": { - "input_tokens": usage.prompt_tokens, - "output_tokens": usage.completion_tokens, - }, - }, - step_id=step_id, # Use original step_id for telemetry - organization_id=self.actor.organization_id, - ), - ) - step_progression = StepProgression.LOGGED_TRACE - # yields tool response as this is handled from Letta and not the response from the LLM provider tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0] if not (use_assistant_message and tool_return.name == "send_message"): # Apply message type filtering if specified if include_return_message_types is None or tool_return.message_type in include_return_message_types: yield f"data: {tool_return.model_dump_json()}\n\n" + else: + step_id = generate_step_id() + # Check for job cancellation at the start of each step + if await self._check_run_cancellation(): + stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value) + logger.info(f"Agent execution cancelled for run {self.current_run_id}") + yield f"data: {stop_reason.model_dump_json()}\n\n" + break - # TODO (cliandy): consolidate and expand with trace - MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes()) - step_progression = StepProgression.FINISHED + step_start = get_utc_timestamp_ns() + agent_step_span = tracer.start_span("agent_step", start_time=step_start) + agent_step_span.set_attributes({"step_id": step_id}) - # Record step metrics for successful completion - if logged_step and step_metrics: - try: - # Set the step_ns that was already calculated - step_metrics.step_ns = step_ns + step_progression = StepProgression.START + should_continue = False + step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking - # Get context attributes for project and template IDs - ctx_attrs = get_ctx_attributes() + # Create step early with PENDING status + logged_step = await self.step_manager.log_step_async( + actor=self.actor, + agent_id=agent_state.id, + provider_name=agent_state.llm_config.model_endpoint_type, + provider_category=agent_state.llm_config.provider_category or "base", + model=agent_state.llm_config.model, + model_endpoint=agent_state.llm_config.model_endpoint, + context_window_limit=agent_state.llm_config.context_window, + usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0), + provider_id=None, + job_id=self.current_run_id if self.current_run_id else None, + step_id=step_id, + project_id=agent_state.project_id, + status=StepStatus.PENDING, + ) + # Only use step_id in messages if step was actually created + effective_step_id = step_id if logged_step else None - await self._record_step_metrics( - step_id=step_id, - agent_state=agent_state, - step_metrics=step_metrics, - ctx_attrs=ctx_attrs, - job_id=self.current_run_id, - ) - except Exception as metrics_error: - self.logger.warning(f"Failed to record step metrics: {metrics_error}") - - except Exception as e: - # Handle any unexpected errors during step processing - self.logger.error(f"Error during step processing: {e}") - job_update_metadata = {"error": str(e)} - - # This indicates we failed after we decided to stop stepping, which indicates a bug with our flow. - if not stop_reason: - stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) - elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule): - self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason) - elif stop_reason.stop_reason not in ( - StopReasonType.no_tool_call, - StopReasonType.invalid_tool_call, - StopReasonType.invalid_llm_response, - ): - self.logger.error("Error occurred during step processing, with unexpected stop reason: %s", stop_reason.stop_reason) - - # Send error stop reason to client and re-raise with expected response code - yield f"data: {stop_reason.model_dump_json()}\n\n", 500 - raise - - # Update step if it needs to be updated - finally: - if step_progression == StepProgression.FINISHED and should_continue: - continue - - self.logger.debug("Running cleanup for agent loop run: %s", self.current_run_id) - self.logger.info("Running final update. Step Progression: %s", step_progression) try: - if step_progression == StepProgression.FINISHED and not should_continue: - # Successfully completed - update with final usage and stop reason - if stop_reason is None: - stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) - # Note: step already updated with success status after _handle_ai_response - if logged_step: - await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) - break + ( + request_data, + stream, + current_in_context_messages, + new_in_context_messages, + valid_tool_names, + provider_request_start_timestamp_ns, + ) = await self._build_and_request_from_llm_streaming( + first_chunk, + agent_step_span, + request_start_timestamp_ns, + current_in_context_messages, + new_in_context_messages, + agent_state, + llm_client, + tool_rules_solver, + ) - # Handle error cases - if step_progression < StepProgression.STEP_LOGGED: - # Error occurred before step was fully logged - import traceback + step_progression = StepProgression.STREAM_RECEIVED + log_event("agent.stream.llm_response.received") # [3^] - if logged_step: - await self.step_manager.update_step_error_async( - actor=self.actor, - step_id=step_id, # Use original step_id for telemetry - error_type=type(e).__name__ if "e" in locals() else "Unknown", - error_message=str(e) if "e" in locals() else "Unknown error", - error_traceback=traceback.format_exc(), - stop_reason=stop_reason, - ) - - if step_progression <= StepProgression.STREAM_RECEIVED: - if first_chunk and settings.track_errored_messages and initial_messages: - for message in initial_messages: - message.is_err = True - message.step_id = effective_step_id - await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor) - elif step_progression <= StepProgression.LOGGED_TRACE: - if stop_reason is None: - self.logger.error("Error in step after logging step") - stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) - if logged_step: - await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) + # TODO: THIS IS INCREDIBLY UGLY + # TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED + if agent_state.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]: + interface = AnthropicStreamingInterface( + use_assistant_message=use_assistant_message, + put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs, + ) + elif agent_state.llm_config.model_endpoint_type == ProviderType.openai: + interface = OpenAIStreamingInterface( + use_assistant_message=use_assistant_message, + is_openai_proxy=agent_state.llm_config.provider_name == "lmstudio_openai", + messages=current_in_context_messages + new_in_context_messages, + tools=request_data.get("tools", []), + ) else: - self.logger.error("Invalid StepProgression value") + raise ValueError(f"Streaming not supported for {agent_state.llm_config}") - # Do tracking for failure cases. Can consolidate with success conditions later. - if settings.track_stop_reason: - await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True) + async for chunk in interface.process( + stream, + ttft_span=request_span, + ): + # Measure TTFT (trace, metric, and db). This should be consolidated. + if first_chunk and request_span is not None: + now = get_utc_timestamp_ns() + ttft_ns = now - request_start_timestamp_ns - # Record partial step metrics on failure (capture whatever timing data we have) - if logged_step and step_metrics and step_progression < StepProgression.FINISHED: + request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ns_to_ms(ttft_ns)}) + metric_attributes = get_ctx_attributes() + metric_attributes["model.name"] = agent_state.llm_config.model + MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes) + + if self.current_run_id and self.job_manager: + await self.job_manager.record_ttft(self.current_run_id, ttft_ns, self.actor) + + first_chunk = False + + if include_return_message_types is None or chunk.message_type in include_return_message_types: + # filter down returned data + yield f"data: {chunk.model_dump_json()}\n\n" + + stream_end_time_ns = get_utc_timestamp_ns() + + # Some providers that rely on the OpenAI client currently e.g. LMStudio don't get usage metrics back on the last streaming chunk, fall back to manual values + if isinstance(interface, OpenAIStreamingInterface) and not interface.input_tokens and not interface.output_tokens: + logger.warning( + f"No token usage metrics received from OpenAI streaming interface for {agent_state.llm_config.model}, falling back to estimated values. Input tokens: {interface.fallback_input_tokens}, Output tokens: {interface.fallback_output_tokens}" + ) + interface.input_tokens = interface.fallback_input_tokens + interface.output_tokens = interface.fallback_output_tokens + + usage.step_count += 1 + usage.completion_tokens += interface.output_tokens + usage.prompt_tokens += interface.input_tokens + usage.total_tokens += interface.input_tokens + interface.output_tokens + MetricRegistry().message_output_tokens.record( + usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}) + ) + + # log LLM request time + llm_request_ns = stream_end_time_ns - provider_request_start_timestamp_ns + step_metrics.llm_request_ns = llm_request_ns + + llm_request_ms = ns_to_ms(llm_request_ns) + agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ms}) + MetricRegistry().llm_execution_time_ms_histogram.record( + llm_request_ms, + dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}), + ) + + # Process resulting stream content + try: + tool_call = interface.get_tool_call_object() + except ValueError as e: + stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value) + raise e + except Exception as e: + stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value) + raise e + reasoning_content = interface.get_reasoning_content() + persisted_messages, should_continue, stop_reason = await self._handle_ai_response( + tool_call, + valid_tool_names, + agent_state, + tool_rules_solver, + UsageStatistics( + completion_tokens=usage.completion_tokens, + prompt_tokens=usage.prompt_tokens, + total_tokens=usage.total_tokens, + ), + reasoning_content=reasoning_content, + pre_computed_assistant_message_id=interface.letta_message_id, + step_id=effective_step_id, + initial_messages=initial_messages, + agent_step_span=agent_step_span, + is_final_step=(i == max_steps - 1), + step_metrics=step_metrics, + ) + step_progression = StepProgression.STEP_LOGGED + + # Update step with actual usage now that we have it (if step was created) + if logged_step: + await self.step_manager.update_step_success_async( + self.actor, + step_id, + UsageStatistics( + completion_tokens=usage.completion_tokens, + prompt_tokens=usage.prompt_tokens, + total_tokens=usage.total_tokens, + ), + stop_reason, + ) + + new_message_idx = len(initial_messages) if initial_messages else 0 + self.response_messages.extend(persisted_messages[new_message_idx:]) + new_in_context_messages.extend(persisted_messages[new_message_idx:]) + + initial_messages = None + + # log total step time + now = get_utc_timestamp_ns() + step_ns = now - step_start + agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)}) + agent_step_span.end() + + # TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream + # log_event("agent.stream.llm_response.processed") # [4^] + + # Log LLM Trace + # We are piecing together the streamed response here. + # Content here does not match the actual response schema as streams come in chunks. + if settings.track_provider_trace: + await self.telemetry_manager.create_provider_trace_async( + actor=self.actor, + provider_trace_create=ProviderTraceCreate( + request_json=request_data, + response_json={ + "content": { + "tool_call": tool_call.model_dump_json(), + "reasoning": [content.model_dump_json() for content in reasoning_content], + }, + "id": interface.message_id, + "model": interface.model, + "role": "assistant", + # "stop_reason": "", + # "stop_sequence": None, + "type": "message", + "usage": { + "input_tokens": usage.prompt_tokens, + "output_tokens": usage.completion_tokens, + }, + }, + step_id=step_id, # Use original step_id for telemetry + organization_id=self.actor.organization_id, + ), + ) + step_progression = StepProgression.LOGGED_TRACE + + # yields tool response as this is handled from Letta and not the response from the LLM provider + tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0] + if not (use_assistant_message and tool_return.name == "send_message"): + # Apply message type filtering if specified + if include_return_message_types is None or tool_return.message_type in include_return_message_types: + yield f"data: {tool_return.model_dump_json()}\n\n" + + # TODO (cliandy): consolidate and expand with trace + MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes()) + step_progression = StepProgression.FINISHED + + # Record step metrics for successful completion + if logged_step and step_metrics: try: - # Calculate total step time up to the failure point - step_metrics.step_ns = get_utc_timestamp_ns() - step_start + # Set the step_ns that was already calculated + step_metrics.step_ns = step_ns # Get context attributes for project and template IDs ctx_attrs = get_ctx_attributes() @@ -1176,16 +1177,107 @@ class LettaAgent(BaseAgent): agent_state=agent_state, step_metrics=step_metrics, ctx_attrs=ctx_attrs, - job_id=locals().get("run_id", self.current_run_id), + job_id=self.current_run_id, ) except Exception as metrics_error: self.logger.warning(f"Failed to record step metrics: {metrics_error}") except Exception as e: - self.logger.error("Failed to update step: %s", e) + # Handle any unexpected errors during step processing + self.logger.error(f"Error during step processing: {e}") + job_update_metadata = {"error": str(e)} - if not should_continue: - break + # This indicates we failed after we decided to stop stepping, which indicates a bug with our flow. + if not stop_reason: + stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) + elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule): + self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason) + elif stop_reason.stop_reason not in ( + StopReasonType.no_tool_call, + StopReasonType.invalid_tool_call, + StopReasonType.invalid_llm_response, + ): + self.logger.error("Error occurred during step processing, with unexpected stop reason: %s", stop_reason.stop_reason) + + # Send error stop reason to client and re-raise with expected response code + yield f"data: {stop_reason.model_dump_json()}\n\n", 500 + raise + + # Update step if it needs to be updated + finally: + if step_progression == StepProgression.FINISHED and should_continue: + continue + + self.logger.debug("Running cleanup for agent loop run: %s", self.current_run_id) + self.logger.info("Running final update. Step Progression: %s", step_progression) + try: + if step_progression == StepProgression.FINISHED and not should_continue: + # Successfully completed - update with final usage and stop reason + if stop_reason is None: + stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) + # Note: step already updated with success status after _handle_ai_response + if logged_step: + await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) + break + + # Handle error cases + if step_progression < StepProgression.STEP_LOGGED: + # Error occurred before step was fully logged + import traceback + + if logged_step: + await self.step_manager.update_step_error_async( + actor=self.actor, + step_id=step_id, # Use original step_id for telemetry + error_type=type(e).__name__ if "e" in locals() else "Unknown", + error_message=str(e) if "e" in locals() else "Unknown error", + error_traceback=traceback.format_exc(), + stop_reason=stop_reason, + ) + + if step_progression <= StepProgression.STREAM_RECEIVED: + if first_chunk and settings.track_errored_messages and initial_messages: + for message in initial_messages: + message.is_err = True + message.step_id = effective_step_id + await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor) + elif step_progression <= StepProgression.LOGGED_TRACE: + if stop_reason is None: + self.logger.error("Error in step after logging step") + stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) + if logged_step: + await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) + else: + self.logger.error("Invalid StepProgression value") + + # Do tracking for failure cases. Can consolidate with success conditions later. + if settings.track_stop_reason: + await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True) + + # Record partial step metrics on failure (capture whatever timing data we have) + if logged_step and step_metrics and step_progression < StepProgression.FINISHED: + try: + # Calculate total step time up to the failure point + step_metrics.step_ns = get_utc_timestamp_ns() - step_start + + # Get context attributes for project and template IDs + ctx_attrs = get_ctx_attributes() + + await self._record_step_metrics( + step_id=step_id, + agent_state=agent_state, + step_metrics=step_metrics, + ctx_attrs=ctx_attrs, + job_id=locals().get("run_id", self.current_run_id), + ) + except Exception as metrics_error: + self.logger.warning(f"Failed to record step metrics: {metrics_error}") + + except Exception as e: + self.logger.error("Failed to update step: %s", e) + + if not should_continue: + break # Extend the in context message ids if not agent_state.message_buffer_autoclear: await self._rebuild_context_window( @@ -1522,6 +1614,7 @@ class LettaAgent(BaseAgent): is_final_step: bool | None = None, run_id: str | None = None, step_metrics: StepMetrics = None, + is_approval: bool | None = None, ) -> tuple[list[Message], bool, LettaStopReason | None]: """ Handle the final AI response once streaming completes, execute / validate the @@ -1543,7 +1636,7 @@ class LettaAgent(BaseAgent): request_heartbeat=request_heartbeat, ) - if tool_rules_solver.is_requires_approval_tool(tool_call_name): + if not is_approval and tool_rules_solver.is_requires_approval_tool(tool_call_name): approval_message = create_approval_request_message_from_llm_response( agent_id=agent_state.id, model=agent_state.llm_config.model, @@ -1561,7 +1654,7 @@ class LettaAgent(BaseAgent): stop_reason = LettaStopReason(stop_reason=StopReasonType.requires_approval.value) else: # 2. Execute the tool (or synthesize an error result if disallowed) - tool_rule_violated = tool_call_name not in valid_tool_names + tool_rule_violated = tool_call_name not in valid_tool_names and not is_approval if tool_rule_violated: tool_execution_result = _build_rule_violation_result(tool_call_name, valid_tool_names, tool_rules_solver) else: @@ -1630,6 +1723,7 @@ class LettaAgent(BaseAgent): reasoning_content=reasoning_content, pre_computed_assistant_message_id=pre_computed_assistant_message_id, step_id=step_id, + is_approval=is_approval, ) messages_to_persist = (initial_messages or []) + tool_call_messages diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 192fa3c6..ed0af739 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -769,11 +769,11 @@ class Message(BaseMessage): "role": self.role, } - elif self.role == "assistant": + elif self.role == "assistant" or self.role == "approval": assert self.tool_calls is not None or text_content is not None openai_message = { "content": None if (put_inner_thoughts_in_kwargs and self.tool_calls is not None) else text_content, - "role": self.role, + "role": "assistant", } if self.tool_calls is not None: diff --git a/letta/schemas/openai/chat_completion_request.py b/letta/schemas/openai/chat_completion_request.py index 8e7a69b0..26d3a4ca 100644 --- a/letta/schemas/openai/chat_completion_request.py +++ b/letta/schemas/openai/chat_completion_request.py @@ -50,7 +50,7 @@ def cast_message_to_subtype(m_dict: dict) -> ChatMessage: return SystemMessage(**m_dict) elif role == "user": return UserMessage(**m_dict) - elif role == "assistant": + elif role == "assistant" or role == "approval": return AssistantMessage(**m_dict) elif role == "tool": return ToolMessage(**m_dict) diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index b20c9db3..eb63d02c 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -233,34 +233,36 @@ def create_letta_messages_from_llm_response( pre_computed_assistant_message_id: Optional[str] = None, llm_batch_item_id: Optional[str] = None, step_id: str | None = None, + is_approval: bool | None = None, ) -> List[Message]: messages = [] - # Construct the tool call with the assistant's message - # Force set request_heartbeat in tool_args to calculated continue_stepping - function_arguments[REQUEST_HEARTBEAT_PARAM] = continue_stepping - tool_call = OpenAIToolCall( - id=tool_call_id, - function=OpenAIFunction( - name=function_name, - arguments=json.dumps(function_arguments), - ), - type="function", - ) - # TODO: Use ToolCallContent instead of tool_calls - # TODO: This helps preserve ordering - assistant_message = Message( - role=MessageRole.assistant, - content=reasoning_content if reasoning_content else [], - agent_id=agent_id, - model=model, - tool_calls=[tool_call], - tool_call_id=tool_call_id, - created_at=get_utc_time(), - batch_item_id=llm_batch_item_id, - ) - if pre_computed_assistant_message_id: - assistant_message.id = pre_computed_assistant_message_id - messages.append(assistant_message) + if not is_approval: + # Construct the tool call with the assistant's message + # Force set request_heartbeat in tool_args to calculated continue_stepping + function_arguments[REQUEST_HEARTBEAT_PARAM] = continue_stepping + tool_call = OpenAIToolCall( + id=tool_call_id, + function=OpenAIFunction( + name=function_name, + arguments=json.dumps(function_arguments), + ), + type="function", + ) + # TODO: Use ToolCallContent instead of tool_calls + # TODO: This helps preserve ordering + assistant_message = Message( + role=MessageRole.assistant, + content=reasoning_content if reasoning_content else [], + agent_id=agent_id, + model=model, + tool_calls=[tool_call], + tool_call_id=tool_call_id, + created_at=get_utc_time(), + batch_item_id=llm_batch_item_id, + ) + if pre_computed_assistant_message_id: + assistant_message.id = pre_computed_assistant_message_id + messages.append(assistant_message) # TODO: Use ToolReturnContent instead of TextContent # TODO: This helps preserve ordering diff --git a/tests/integration_test_human_in_the_loop.py b/tests/integration_test_human_in_the_loop.py index bdbe710f..9c32c196 100644 --- a/tests/integration_test_human_in_the_loop.py +++ b/tests/integration_test_human_in_the_loop.py @@ -160,6 +160,8 @@ def test_send_message_with_approval_tool( assert len(response.messages) == 2 assert response.messages[0].message_type == "reasoning_message" assert response.messages[1].message_type == "approval_request_message" + approval_request_id = response.messages[0].id + tool_call_id = response.messages[1].tool_call.tool_call_id # Attempt to send user message - should fail with pytest.raises(ApiError, match="Please approve or deny the pending request before continuing"): @@ -174,3 +176,22 @@ def test_send_message_with_approval_tool( agent_id=agent.id, messages=[ApprovalCreate(approve=True, approval_request_id="fake_id")], ) + + response = client.agents.messages.create( + agent_id=agent.id, + messages=[ + ApprovalCreate( + approve=True, + approval_request_id=approval_request_id, + ), + ], + ) + + # Basic assertion that we got a response with tool call return + assert response.messages is not None + assert len(response.messages) == 3 + assert response.messages[0].message_type == "tool_return_message" + assert response.messages[0].tool_call_id == tool_call_id + assert response.messages[0].status == "success" + assert response.messages[1].message_type == "reasoning_message" + assert response.messages[2].message_type == "assistant_message"