diff --git a/alembic/versions/cce9a6174366_add_stop_reasons_to_steps_and_message_.py b/alembic/versions/cce9a6174366_add_stop_reasons_to_steps_and_message_.py new file mode 100644 index 00000000..14ac1cdd --- /dev/null +++ b/alembic/versions/cce9a6174366_add_stop_reasons_to_steps_and_message_.py @@ -0,0 +1,42 @@ +"""add stop reasons to steps and message error flag + +Revision ID: cce9a6174366 +Revises: 2c059cad97cc +Create Date: 2025-07-10 13:56:17.383612 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "cce9a6174366" +down_revision: Union[str, None] = "2c059cad97cc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("messages", sa.Column("is_err", sa.Boolean(), nullable=True)) + + # manually added to handle non-table creation enums + stopreasontype = sa.Enum( + "end_turn", "error", "invalid_tool_call", "max_steps", "no_tool_call", "tool_rule", "cancelled", name="stopreasontype" + ) + stopreasontype.create(op.get_bind()) + op.add_column("steps", sa.Column("stop_reason", stopreasontype, nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("steps", "stop_reason") + op.drop_column("messages", "is_err") + + stopreasontype = sa.Enum(name="stopreasontype") + stopreasontype.drop(op.get_bind()) + # ### end Alembic commands ### diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 0bd21f37..41543d44 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -43,6 +43,7 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message, MessageCreate from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics from letta.schemas.provider_trace import ProviderTraceCreate +from letta.schemas.step import StepProgression from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User @@ -238,100 +239,164 @@ class LettaAgent(BaseAgent): agent_step_span = tracer.start_span("agent_step", start_time=step_start) agent_step_span.set_attributes({"step_id": step_id}) - request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = ( - await self._build_and_request_from_llm( - current_in_context_messages, - new_in_context_messages, - agent_state, - llm_client, - tool_rules_solver, - agent_step_span, - ) - ) - in_context_messages = current_in_context_messages + new_in_context_messages - - log_event("agent.stream_no_tokens.llm_response.received") # [3^] - - response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config) - - # update usage - usage.step_count += 1 - usage.completion_tokens += response.usage.completion_tokens - usage.prompt_tokens += response.usage.prompt_tokens - usage.total_tokens += response.usage.total_tokens - MetricRegistry().message_output_tokens.record( - response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}) - ) - - if not response.choices[0].message.tool_calls: - # TODO: make into a real error - raise ValueError("No tool calls found in response, model must make a tool call") - tool_call = response.choices[0].message.tool_calls[0] - if response.choices[0].message.reasoning_content: - reasoning = [ - ReasoningContent( - reasoning=response.choices[0].message.reasoning_content, - is_native=True, - signature=response.choices[0].message.reasoning_content_signature, + step_progression = StepProgression.START + should_continue = False + try: + request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = ( + await self._build_and_request_from_llm( + current_in_context_messages, + new_in_context_messages, + agent_state, + llm_client, + tool_rules_solver, + agent_step_span, ) - ] - elif response.choices[0].message.omitted_reasoning_content: - reasoning = [OmittedReasoningContent()] - elif response.choices[0].message.content: - reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons - else: - self.logger.info("No reasoning content found.") - reasoning = None + ) + in_context_messages = current_in_context_messages + new_in_context_messages - persisted_messages, should_continue, stop_reason = await self._handle_ai_response( - tool_call, - valid_tool_names, - agent_state, - tool_rules_solver, - response.usage, - reasoning_content=reasoning, - step_id=step_id, - initial_messages=initial_messages, - agent_step_span=agent_step_span, - is_final_step=(i == max_steps - 1), - ) + step_progression = StepProgression.RESPONSE_RECEIVED + log_event("agent.stream_no_tokens.llm_response.received") # [3^] - # TODO (cliandy): handle message contexts with larger refactor and dedupe logic - new_message_idx = len(initial_messages) if initial_messages else 0 - self.response_messages.extend(persisted_messages[new_message_idx:]) - new_in_context_messages.extend(persisted_messages[new_message_idx:]) - initial_messages = None - log_event("agent.stream_no_tokens.llm_response.processed") # [4^] + response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config) - # log step time - now = get_utc_timestamp_ns() - step_ns = now - step_start - agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)}) - agent_step_span.end() + # update usage + usage.step_count += 1 + usage.completion_tokens += response.usage.completion_tokens + usage.prompt_tokens += response.usage.prompt_tokens + usage.total_tokens += response.usage.total_tokens + MetricRegistry().message_output_tokens.record( + response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}) + ) - # Log LLM Trace - await self.telemetry_manager.create_provider_trace_async( - actor=self.actor, - provider_trace_create=ProviderTraceCreate( - request_json=request_data, - response_json=response_data, + if not response.choices[0].message.tool_calls: + stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value) + raise ValueError("No tool calls found in response, model must make a tool call") + tool_call = response.choices[0].message.tool_calls[0] + if response.choices[0].message.reasoning_content: + reasoning = [ + ReasoningContent( + reasoning=response.choices[0].message.reasoning_content, + is_native=True, + signature=response.choices[0].message.reasoning_content_signature, + ) + ] + elif response.choices[0].message.omitted_reasoning_content: + reasoning = [OmittedReasoningContent()] + elif response.choices[0].message.content: + reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons + else: + self.logger.info("No reasoning content found.") + reasoning = None + + persisted_messages, should_continue, stop_reason = await self._handle_ai_response( + tool_call, + valid_tool_names, + agent_state, + tool_rules_solver, + response.usage, + reasoning_content=reasoning, step_id=step_id, - organization_id=self.actor.organization_id, - ), - ) + initial_messages=initial_messages, + agent_step_span=agent_step_span, + is_final_step=(i == max_steps - 1), + ) + step_progression = StepProgression.STEP_LOGGED - # stream step - # TODO: improve TTFT - filter_user_messages = [m for m in persisted_messages if m.role != "user"] - letta_messages = Message.to_letta_messages_from_list( - filter_user_messages, use_assistant_message=use_assistant_message, reverse=False - ) + # TODO (cliandy): handle message contexts with larger refactor and dedupe logic + new_message_idx = len(initial_messages) if initial_messages else 0 + self.response_messages.extend(persisted_messages[new_message_idx:]) + new_in_context_messages.extend(persisted_messages[new_message_idx:]) + initial_messages = None + log_event("agent.stream_no_tokens.llm_response.processed") # [4^] - for message in letta_messages: - if include_return_message_types is None or message.message_type in include_return_message_types: - yield f"data: {message.model_dump_json()}\n\n" + # log step time + now = get_utc_timestamp_ns() + step_ns = now - step_start + agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)}) + agent_step_span.end() - MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes()) + # Log LLM Trace + await self.telemetry_manager.create_provider_trace_async( + actor=self.actor, + provider_trace_create=ProviderTraceCreate( + request_json=request_data, + response_json=response_data, + step_id=step_id, + organization_id=self.actor.organization_id, + ), + ) + step_progression = StepProgression.LOGGED_TRACE + + # stream step + # TODO: improve TTFT + filter_user_messages = [m for m in persisted_messages if m.role != "user"] + letta_messages = Message.to_letta_messages_from_list( + filter_user_messages, use_assistant_message=use_assistant_message, reverse=False + ) + + for message in letta_messages: + if include_return_message_types is None or message.message_type in include_return_message_types: + yield f"data: {message.model_dump_json()}\n\n" + + MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes()) + step_progression = StepProgression.FINISHED + except Exception as e: + # Handle any unexpected errors during step processing + self.logger.error(f"Error during step processing: {e}") + + # This indicates we failed after we decided to stop stepping, which indicates a bug with our flow. + if not stop_reason: + stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) + elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule): + self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason) + elif stop_reason.stop_reason not in (StopReasonType.no_tool_call, StopReasonType.invalid_tool_call): + raise ValueError(f"Invalid Stop Reason: {stop_reason}") + + # Send error stop reason to client and re-raise + yield f"data: {stop_reason.model_dump_json()}\n\n", 500 + raise + + # Update step if it needs to be updated + finally: + if settings.track_stop_reason: + self.logger.info("Running final update. Step Progression: %s", step_progression) + try: + if step_progression < StepProgression.STEP_LOGGED: + await self.step_manager.log_step_async( + actor=self.actor, + agent_id=agent_state.id, + provider_name=agent_state.llm_config.model_endpoint_type, + provider_category=agent_state.llm_config.provider_category or "base", + model=agent_state.llm_config.model, + model_endpoint=agent_state.llm_config.model_endpoint, + context_window_limit=agent_state.llm_config.context_window, + usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0), + provider_id=None, + job_id=self.current_run_id if self.current_run_id else None, + step_id=step_id, + project_id=agent_state.project_id, + stop_reason=stop_reason, + ) + if step_progression <= StepProgression.RESPONSE_RECEIVED: + # TODO (cliandy): persist response if we get it back + if settings.track_errored_messages: + for message in initial_messages: + message.is_err = True + message.step_id = step_id + await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor) + elif step_progression <= StepProgression.LOGGED_TRACE: + if stop_reason is None: + self.logger.error("Error in step after logging step") + stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) + await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) + elif step_progression == StepProgression.FINISHED and not should_continue: + if stop_reason is None: + stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) + await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) + else: + self.logger.error("Invalid StepProgression value") + except Exception as e: + self.logger.error("Failed to update step: %s", e) if not should_continue: break @@ -396,6 +461,16 @@ class LettaAgent(BaseAgent): stop_reason = None usage = LettaUsageStatistics() for i in range(max_steps): + # If dry run, build request data and return it without making LLM call + if dry_run: + request_data, valid_tool_names = await self._create_llm_request_data_async( + llm_client=llm_client, + in_context_messages=current_in_context_messages + new_in_context_messages, + agent_state=agent_state, + tool_rules_solver=tool_rules_solver, + ) + return request_data + # Check for job cancellation at the start of each step if await self._check_run_cancellation(): stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value) @@ -407,94 +482,148 @@ class LettaAgent(BaseAgent): agent_step_span = tracer.start_span("agent_step", start_time=step_start) agent_step_span.set_attributes({"step_id": step_id}) - # If dry run, build request data and return it without making LLM call - if dry_run: - request_data, valid_tool_names = await self._create_llm_request_data_async( - llm_client=llm_client, - in_context_messages=current_in_context_messages + new_in_context_messages, - agent_state=agent_state, - tool_rules_solver=tool_rules_solver, - ) - return request_data + step_progression = StepProgression.START + should_continue = False - request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = ( - await self._build_and_request_from_llm( - current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver, agent_step_span - ) - ) - in_context_messages = current_in_context_messages + new_in_context_messages - - log_event("agent.step.llm_response.received") # [3^] - - response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config) - - usage.step_count += 1 - usage.completion_tokens += response.usage.completion_tokens - usage.prompt_tokens += response.usage.prompt_tokens - usage.total_tokens += response.usage.total_tokens - usage.run_ids = [run_id] if run_id else None - MetricRegistry().message_output_tokens.record( - response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}) - ) - - if not response.choices[0].message.tool_calls: - # TODO: make into a real error - raise ValueError("No tool calls found in response, model must make a tool call") - tool_call = response.choices[0].message.tool_calls[0] - if response.choices[0].message.reasoning_content: - reasoning = [ - ReasoningContent( - reasoning=response.choices[0].message.reasoning_content, - is_native=True, - signature=response.choices[0].message.reasoning_content_signature, + try: + request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = ( + await self._build_and_request_from_llm( + current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver, agent_step_span ) - ] - elif response.choices[0].message.content: - reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons - elif response.choices[0].message.omitted_reasoning_content: - reasoning = [OmittedReasoningContent()] - else: - self.logger.info("No reasoning content found.") - reasoning = None + ) + in_context_messages = current_in_context_messages + new_in_context_messages - persisted_messages, should_continue, stop_reason = await self._handle_ai_response( - tool_call, - valid_tool_names, - agent_state, - tool_rules_solver, - response.usage, - reasoning_content=reasoning, - step_id=step_id, - initial_messages=initial_messages, - agent_step_span=agent_step_span, - is_final_step=(i == max_steps - 1), - run_id=run_id, - ) - new_message_idx = len(initial_messages) if initial_messages else 0 - self.response_messages.extend(persisted_messages[new_message_idx:]) - new_in_context_messages.extend(persisted_messages[new_message_idx:]) + step_progression = StepProgression.RESPONSE_RECEIVED + log_event("agent.step.llm_response.received") # [3^] - initial_messages = None - log_event("agent.step.llm_response.processed") # [4^] + response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config) - # log step time - now = get_utc_timestamp_ns() - step_ns = now - step_start - agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)}) - agent_step_span.end() + usage.step_count += 1 + usage.completion_tokens += response.usage.completion_tokens + usage.prompt_tokens += response.usage.prompt_tokens + usage.total_tokens += response.usage.total_tokens + usage.run_ids = [run_id] if run_id else None + MetricRegistry().message_output_tokens.record( + response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}) + ) - # Log LLM Trace - await self.telemetry_manager.create_provider_trace_async( - actor=self.actor, - provider_trace_create=ProviderTraceCreate( - request_json=request_data, - response_json=response_data, + if not response.choices[0].message.tool_calls: + stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value) + raise ValueError("No tool calls found in response, model must make a tool call") + tool_call = response.choices[0].message.tool_calls[0] + if response.choices[0].message.reasoning_content: + reasoning = [ + ReasoningContent( + reasoning=response.choices[0].message.reasoning_content, + is_native=True, + signature=response.choices[0].message.reasoning_content_signature, + ) + ] + elif response.choices[0].message.content: + reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons + elif response.choices[0].message.omitted_reasoning_content: + reasoning = [OmittedReasoningContent()] + else: + self.logger.info("No reasoning content found.") + reasoning = None + + persisted_messages, should_continue, stop_reason = await self._handle_ai_response( + tool_call, + valid_tool_names, + agent_state, + tool_rules_solver, + response.usage, + reasoning_content=reasoning, step_id=step_id, - organization_id=self.actor.organization_id, - ), - ) + initial_messages=initial_messages, + agent_step_span=agent_step_span, + is_final_step=(i == max_steps - 1), + run_id=run_id, + ) + step_progression = StepProgression.STEP_LOGGED - MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes()) + new_message_idx = len(initial_messages) if initial_messages else 0 + self.response_messages.extend(persisted_messages[new_message_idx:]) + new_in_context_messages.extend(persisted_messages[new_message_idx:]) + + initial_messages = None + log_event("agent.step.llm_response.processed") # [4^] + + # log step time + now = get_utc_timestamp_ns() + step_ns = now - step_start + agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)}) + agent_step_span.end() + + # Log LLM Trace + await self.telemetry_manager.create_provider_trace_async( + actor=self.actor, + provider_trace_create=ProviderTraceCreate( + request_json=request_data, + response_json=response_data, + step_id=step_id, + organization_id=self.actor.organization_id, + ), + ) + + step_progression = StepProgression.LOGGED_TRACE + MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes()) + step_progression = StepProgression.FINISHED + + except Exception as e: + # Handle any unexpected errors during step processing + self.logger.error(f"Error during step processing: {e}") + + # This indicates we failed after we decided to stop stepping, which indicates a bug with our flow. + if not stop_reason: + stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) + elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule): + self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason) + elif stop_reason.stop_reason not in (StopReasonType.no_tool_call, StopReasonType.invalid_tool_call): + raise ValueError(f"Invalid Stop Reason: {stop_reason}") + raise + + # Update step if it needs to be updated + finally: + if settings.track_stop_reason: + self.logger.info("Running final update. Step Progression: %s", step_progression) + try: + if step_progression < StepProgression.STEP_LOGGED: + await self.step_manager.log_step_async( + actor=self.actor, + agent_id=agent_state.id, + provider_name=agent_state.llm_config.model_endpoint_type, + provider_category=agent_state.llm_config.provider_category or "base", + model=agent_state.llm_config.model, + model_endpoint=agent_state.llm_config.model_endpoint, + context_window_limit=agent_state.llm_config.context_window, + usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0), + provider_id=None, + job_id=self.current_run_id if self.current_run_id else None, + step_id=step_id, + project_id=agent_state.project_id, + stop_reason=stop_reason, + ) + if step_progression <= StepProgression.RESPONSE_RECEIVED: + # TODO (cliandy): persist response if we get it back + if settings.track_errored_messages: + for message in initial_messages: + message.is_err = True + message.step_id = step_id + await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor) + elif step_progression <= StepProgression.LOGGED_TRACE: + if stop_reason is None: + self.logger.error("Error in step after logging step") + stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) + await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) + elif step_progression == StepProgression.FINISHED and not should_continue: + if stop_reason is None: + stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) + await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) + else: + self.logger.error("Invalid StepProgression value") + except Exception as e: + self.logger.error("Failed to update step: %s", e) if not should_continue: break @@ -576,6 +705,7 @@ class LettaAgent(BaseAgent): request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None}) for i in range(max_steps): + step_id = generate_step_id() # Check for job cancellation at the start of each step if await self._check_run_cancellation(): stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value) @@ -583,163 +713,230 @@ class LettaAgent(BaseAgent): yield f"data: {stop_reason.model_dump_json()}\n\n" break - step_id = generate_step_id() step_start = get_utc_timestamp_ns() agent_step_span = tracer.start_span("agent_step", start_time=step_start) agent_step_span.set_attributes({"step_id": step_id}) - ( - request_data, - stream, - current_in_context_messages, - new_in_context_messages, - valid_tool_names, - provider_request_start_timestamp_ns, - ) = await self._build_and_request_from_llm_streaming( - first_chunk, - agent_step_span, - request_start_timestamp_ns, - current_in_context_messages, - new_in_context_messages, - agent_state, - llm_client, - tool_rules_solver, - ) - log_event("agent.stream.llm_response.received") # [3^] - - # TODO: THIS IS INCREDIBLY UGLY - # TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED - if agent_state.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]: - interface = AnthropicStreamingInterface( - use_assistant_message=use_assistant_message, - put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs, - ) - elif agent_state.llm_config.model_endpoint_type == ProviderType.openai: - interface = OpenAIStreamingInterface( - use_assistant_message=use_assistant_message, - put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs, - ) - else: - raise ValueError(f"Streaming not supported for {agent_state.llm_config}") - - async for chunk in interface.process( - stream, - ttft_span=request_span, - provider_request_start_timestamp_ns=provider_request_start_timestamp_ns, - ): - # Measure time to first token - if first_chunk and request_span is not None: - now = get_utc_timestamp_ns() - ttft_ns = now - request_start_timestamp_ns - request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ns_to_ms(ttft_ns)}) - metric_attributes = get_ctx_attributes() - metric_attributes["model.name"] = agent_state.llm_config.model - MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes) - first_chunk = False - - if include_return_message_types is None or chunk.message_type in include_return_message_types: - # filter down returned data - yield f"data: {chunk.model_dump_json()}\n\n" - - stream_end_time_ns = get_utc_timestamp_ns() - - # update usage - usage.step_count += 1 - usage.completion_tokens += interface.output_tokens - usage.prompt_tokens += interface.input_tokens - usage.total_tokens += interface.input_tokens + interface.output_tokens - MetricRegistry().message_output_tokens.record( - interface.output_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}) - ) - - # log LLM request time - llm_request_ms = ns_to_ms(stream_end_time_ns - provider_request_start_timestamp_ns) - agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ms}) - MetricRegistry().llm_execution_time_ms_histogram.record( - llm_request_ms, - dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}), - ) - - # Process resulting stream content + step_progression = StepProgression.START + should_continue = False try: - tool_call = interface.get_tool_call_object() - except ValueError as e: - stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value) - yield f"data: {stop_reason.model_dump_json()}\n\n" - raise e - except Exception as e: - stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value) - yield f"data: {stop_reason.model_dump_json()}\n\n" - raise e - reasoning_content = interface.get_reasoning_content() - persisted_messages, should_continue, stop_reason = await self._handle_ai_response( - tool_call, - valid_tool_names, - agent_state, - tool_rules_solver, - UsageStatistics( - completion_tokens=interface.output_tokens, - prompt_tokens=interface.input_tokens, - total_tokens=interface.input_tokens + interface.output_tokens, - ), - reasoning_content=reasoning_content, - pre_computed_assistant_message_id=interface.letta_message_id, - step_id=step_id, - initial_messages=initial_messages, - agent_step_span=agent_step_span, - is_final_step=(i == max_steps - 1), - ) - new_message_idx = len(initial_messages) if initial_messages else 0 - self.response_messages.extend(persisted_messages[new_message_idx:]) - new_in_context_messages.extend(persisted_messages[new_message_idx:]) + ( + request_data, + stream, + current_in_context_messages, + new_in_context_messages, + valid_tool_names, + provider_request_start_timestamp_ns, + ) = await self._build_and_request_from_llm_streaming( + first_chunk, + agent_step_span, + request_start_timestamp_ns, + current_in_context_messages, + new_in_context_messages, + agent_state, + llm_client, + tool_rules_solver, + ) - initial_messages = None + step_progression = StepProgression.STREAM_RECEIVED + log_event("agent.stream.llm_response.received") # [3^] - # log total step time - now = get_utc_timestamp_ns() - step_ns = now - step_start - agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)}) - agent_step_span.end() + # TODO: THIS IS INCREDIBLY UGLY + # TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED + if agent_state.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]: + interface = AnthropicStreamingInterface( + use_assistant_message=use_assistant_message, + put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs, + ) + elif agent_state.llm_config.model_endpoint_type == ProviderType.openai: + interface = OpenAIStreamingInterface( + use_assistant_message=use_assistant_message, + put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs, + ) + else: + raise ValueError(f"Streaming not supported for {agent_state.llm_config}") - # TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream - # log_event("agent.stream.llm_response.processed") # [4^] + async for chunk in interface.process( + stream, + ttft_span=request_span, + provider_request_start_timestamp_ns=provider_request_start_timestamp_ns, + ): + # Measure time to first token + if first_chunk and request_span is not None: + now = get_utc_timestamp_ns() + ttft_ns = now - request_start_timestamp_ns + request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ns_to_ms(ttft_ns)}) + metric_attributes = get_ctx_attributes() + metric_attributes["model.name"] = agent_state.llm_config.model + MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes) + first_chunk = False - # Log LLM Trace - # TODO (cliandy): we are piecing together the streamed response here. Content here does not match the actual response schema. - await self.telemetry_manager.create_provider_trace_async( - actor=self.actor, - provider_trace_create=ProviderTraceCreate( - request_json=request_data, - response_json={ - "content": { - "tool_call": tool_call.model_dump_json(), - "reasoning": [content.model_dump_json() for content in reasoning_content], - }, - "id": interface.message_id, - "model": interface.model, - "role": "assistant", - # "stop_reason": "", - # "stop_sequence": None, - "type": "message", - "usage": {"input_tokens": interface.input_tokens, "output_tokens": interface.output_tokens}, - }, + if include_return_message_types is None or chunk.message_type in include_return_message_types: + # filter down returned data + yield f"data: {chunk.model_dump_json()}\n\n" + + stream_end_time_ns = get_utc_timestamp_ns() + + # update usage + usage.step_count += 1 + usage.completion_tokens += interface.output_tokens + usage.prompt_tokens += interface.input_tokens + usage.total_tokens += interface.input_tokens + interface.output_tokens + MetricRegistry().message_output_tokens.record( + interface.output_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}) + ) + + # log LLM request time + llm_request_ms = ns_to_ms(stream_end_time_ns - provider_request_start_timestamp_ns) + agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ms}) + MetricRegistry().llm_execution_time_ms_histogram.record( + llm_request_ms, + dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}), + ) + + # Process resulting stream content + try: + tool_call = interface.get_tool_call_object() + except ValueError as e: + stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value) + raise e + except Exception as e: + stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value) + raise e + reasoning_content = interface.get_reasoning_content() + persisted_messages, should_continue, stop_reason = await self._handle_ai_response( + tool_call, + valid_tool_names, + agent_state, + tool_rules_solver, + UsageStatistics( + completion_tokens=interface.output_tokens, + prompt_tokens=interface.input_tokens, + total_tokens=interface.input_tokens + interface.output_tokens, + ), + reasoning_content=reasoning_content, + pre_computed_assistant_message_id=interface.letta_message_id, step_id=step_id, - organization_id=self.actor.organization_id, - ), - ) + initial_messages=initial_messages, + agent_step_span=agent_step_span, + is_final_step=(i == max_steps - 1), + ) + step_progression = StepProgression.STEP_LOGGED - tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0] - if not (use_assistant_message and tool_return.name == "send_message"): - # Apply message type filtering if specified - if include_return_message_types is None or tool_return.message_type in include_return_message_types: - yield f"data: {tool_return.model_dump_json()}\n\n" + new_message_idx = len(initial_messages) if initial_messages else 0 + self.response_messages.extend(persisted_messages[new_message_idx:]) + new_in_context_messages.extend(persisted_messages[new_message_idx:]) - # TODO (cliandy): consolidate and expand with trace - MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes()) + initial_messages = None + + # log total step time + now = get_utc_timestamp_ns() + step_ns = now - step_start + agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)}) + agent_step_span.end() + + # TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream + # log_event("agent.stream.llm_response.processed") # [4^] + + # Log LLM Trace + # We are piecing together the streamed response here. + # Content here does not match the actual response schema as streams come in chunks. + await self.telemetry_manager.create_provider_trace_async( + actor=self.actor, + provider_trace_create=ProviderTraceCreate( + request_json=request_data, + response_json={ + "content": { + "tool_call": tool_call.model_dump_json(), + "reasoning": [content.model_dump_json() for content in reasoning_content], + }, + "id": interface.message_id, + "model": interface.model, + "role": "assistant", + # "stop_reason": "", + # "stop_sequence": None, + "type": "message", + "usage": { + "input_tokens": interface.input_tokens, + "output_tokens": interface.output_tokens, + }, + }, + step_id=step_id, + organization_id=self.actor.organization_id, + ), + ) + step_progression = StepProgression.LOGGED_TRACE + + # yields tool response as this is handled from Letta and not the response from the LLM provider + tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0] + if not (use_assistant_message and tool_return.name == "send_message"): + # Apply message type filtering if specified + if include_return_message_types is None or tool_return.message_type in include_return_message_types: + yield f"data: {tool_return.model_dump_json()}\n\n" + + # TODO (cliandy): consolidate and expand with trace + MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes()) + step_progression = StepProgression.FINISHED + + except Exception as e: + # Handle any unexpected errors during step processing + self.logger.error(f"Error during step processing: {e}") + + # This indicates we failed after we decided to stop stepping, which indicates a bug with our flow. + if not stop_reason: + stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) + elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule): + self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason) + elif stop_reason.stop_reason not in (StopReasonType.no_tool_call, StopReasonType.invalid_tool_call): + raise ValueError(f"Invalid Stop Reason: {stop_reason}") + + # Send error stop reason to client and re-raise with expected response code + yield f"data: {stop_reason.model_dump_json()}\n\n", 500 + raise + + # Update step if it needs to be updated + finally: + if settings.track_stop_reason: + self.logger.info("Running final update. Step Progression: %s", step_progression) + try: + if step_progression < StepProgression.STEP_LOGGED: + await self.step_manager.log_step_async( + actor=self.actor, + agent_id=agent_state.id, + provider_name=agent_state.llm_config.model_endpoint_type, + provider_category=agent_state.llm_config.provider_category or "base", + model=agent_state.llm_config.model, + model_endpoint=agent_state.llm_config.model_endpoint, + context_window_limit=agent_state.llm_config.context_window, + usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0), + provider_id=None, + job_id=self.current_run_id if self.current_run_id else None, + step_id=step_id, + project_id=agent_state.project_id, + stop_reason=stop_reason, + ) + if step_progression <= StepProgression.STREAM_RECEIVED: + if first_chunk and settings.track_errored_messages: + for message in initial_messages: + message.is_err = True + message.step_id = step_id + await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor) + elif step_progression <= StepProgression.LOGGED_TRACE: + if stop_reason is None: + self.logger.error("Error in step after logging step") + stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) + await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) + elif step_progression == StepProgression.FINISHED and not should_continue: + if stop_reason is None: + stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) + await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason) + else: + self.logger.error("Invalid StepProgression value") + except Exception as e: + self.logger.error("Failed to update step: %s", e) if not should_continue: break - # Extend the in context message ids if not agent_state.message_buffer_autoclear: await self._rebuild_context_window( @@ -1106,6 +1303,7 @@ class LettaAgent(BaseAgent): job_id=run_id if run_id else self.current_run_id, step_id=step_id, project_id=agent_state.project_id, + stop_reason=stop_reason, ) tool_call_messages = create_letta_messages_from_llm_response( diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index e9a2dd71..acf7c2dd 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Set, Union +from typing import List, Optional, Union from pydantic import BaseModel, Field @@ -107,25 +107,20 @@ class ToolRulesSolver(BaseModel): self.tool_call_history.clear() def get_allowed_tool_names( - self, available_tools: Set[str], error_on_empty: bool = False, last_function_response: Optional[str] = None + self, available_tools: set[str], error_on_empty: bool = True, last_function_response: str | None = None ) -> List[str]: - """Get a list of tool names allowed based on the last tool called.""" + """Get a list of tool names allowed based on the last tool called. + + The logic is as follows: + 1. if there are no previous tool calls and we have InitToolRules, those are the only options for the first tool call + 2. else we take the intersection of the Parent/Child/Conditional/MaxSteps as the options + 3. Continue/Terminal/RequiredBeforeExit rules are applied in the agent loop flow, not to restrict tools + """ # TODO: This piece of code here is quite ugly and deserves a refactor - # TODO: There's some weird logic encoded here: - # TODO: -> This only takes into consideration Init, and a set of Child/Conditional/MaxSteps tool rules - # TODO: -> Init tool rules outputs are treated additively, Child/Conditional/MaxSteps are intersection based # TODO: -> Tool rules should probably be refactored to take in a set of tool names? - # If no tool has been called yet, return InitToolRules additively - if not self.tool_call_history: - if self.init_tool_rules: - # If there are init tool rules, only return those defined in the init tool rules - return [rule.tool_name for rule in self.init_tool_rules] - else: - # Otherwise, return all tools besides those constrained by parent tool rules - available_tools = available_tools - set.union(set(), *(set(rule.children) for rule in self.parent_tool_rules)) - return list(available_tools) + if not self.tool_call_history and self.init_tool_rules: + return [rule.tool_name for rule in self.init_tool_rules] else: - # Collect valid tools from all child-based rules valid_tool_sets = [] for rule in self.child_based_tool_rules + self.parent_tool_rules: tools = rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response) @@ -151,11 +146,11 @@ class ToolRulesSolver(BaseModel): """Check if the tool is defined as a continue tool in the tool rules.""" return any(rule.tool_name == tool_name for rule in self.continue_tool_rules) - def has_required_tools_been_called(self, available_tools: Set[str]) -> bool: + def has_required_tools_been_called(self, available_tools: set[str]) -> bool: """Check if all required-before-exit tools have been called.""" return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0 - def get_uncalled_required_tools(self, available_tools: Set[str]) -> List[str]: + def get_uncalled_required_tools(self, available_tools: set[str]) -> List[str]: """Get the list of required-before-exit tools that have not been called yet.""" if not self.required_before_exit_tool_rules: return [] # No required tools means no uncalled tools diff --git a/letta/orm/message.py b/letta/orm/message.py index ad4ed743..ba3acd82 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -49,6 +49,9 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): nullable=True, doc="The id of the LLMBatchItem that this message is associated with", ) + is_err: Mapped[Optional[bool]] = mapped_column( + nullable=True, doc="Whether this message is part of an error step. Used only for debugging purposes." + ) # Monotonically increasing sequence for efficient/correct listing sequence_id: Mapped[int] = mapped_column( diff --git a/letta/orm/step.py b/letta/orm/step.py index 85d2afae..05da631c 100644 --- a/letta/orm/step.py +++ b/letta/orm/step.py @@ -5,6 +5,7 @@ from sqlalchemy import JSON, ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.letta_stop_reason import StopReasonType from letta.schemas.step import Step as PydanticStep if TYPE_CHECKING: @@ -45,6 +46,7 @@ class Step(SqlalchemyBase): prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt") total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent") completion_tokens_details: Mapped[Optional[Dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.") + stop_reason: Mapped[Optional[StopReasonType]] = mapped_column(None, nullable=True, doc="The stop reason associated with this step.") tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.") tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.") trace_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The trace id of the agent step.") diff --git a/letta/schemas/letta_message.py b/letta/schemas/letta_message.py index cfae6b38..d1e9b2a7 100644 --- a/letta/schemas/letta_message.py +++ b/letta/schemas/letta_message.py @@ -40,15 +40,18 @@ class LettaMessage(BaseModel): message_type (MessageType): The type of the message otid (Optional[str]): The offline threading id associated with this message sender_id (Optional[str]): The id of the sender of the message, can be an identity id or agent id + step_id (Optional[str]): The step id associated with the message + is_err (Optional[bool]): Whether the message is an errored message or not. Used for debugging purposes only. """ id: str date: datetime - name: Optional[str] = None + name: str | None = None message_type: MessageType = Field(..., description="The type of the message.") - otid: Optional[str] = None - sender_id: Optional[str] = None - step_id: Optional[str] = None + otid: str | None = None + sender_id: str | None = None + step_id: str | None = None + is_err: bool | None = None @field_serializer("date") def serialize_datetime(self, dt: datetime, _info): @@ -60,6 +63,14 @@ class LettaMessage(BaseModel): dt = dt.replace(tzinfo=timezone.utc) return dt.isoformat(timespec="seconds") + @field_serializer("is_err", when_used="unless-none") + def serialize_is_err(self, value: bool | None, _info): + """ + Only serialize is_err field when it's True (for debugging purposes). + When is_err is None or False, this field will be excluded from the JSON output. + """ + return value if value is True else None + class SystemMessage(LettaMessage): """ diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 59390b41..417540c8 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -172,6 +172,9 @@ class Message(BaseMessage): group_id: Optional[str] = Field(default=None, description="The multi-agent group that the message was sent in") sender_id: Optional[str] = Field(default=None, description="The id of the sender of the message, can be an identity id or agent id") batch_item_id: Optional[str] = Field(default=None, description="The id of the LLMBatchItem that this message is associated with") + is_err: Optional[bool] = Field( + default=None, description="Whether this message is part of an error step. Used only for debugging purposes." + ) # This overrides the optional base orm schema, created_at MUST exist on all messages objects created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.") @@ -191,6 +194,7 @@ class Message(BaseMessage): if not is_utc_datetime(self.created_at): self.created_at = self.created_at.replace(tzinfo=timezone.utc) json_message["created_at"] = self.created_at.isoformat() + json_message.pop("is_err", None) # make sure we don't include this debugging information return json_message @staticmethod @@ -204,6 +208,7 @@ class Message(BaseMessage): assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL, assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG, reverse: bool = True, + include_err: Optional[bool] = None, ) -> List[LettaMessage]: if use_assistant_message: message_ids_to_remove = [] @@ -234,6 +239,7 @@ class Message(BaseMessage): assistant_message_tool_name=assistant_message_tool_name, assistant_message_tool_kwarg=assistant_message_tool_kwarg, reverse=reverse, + include_err=include_err, ) ] @@ -243,6 +249,7 @@ class Message(BaseMessage): assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL, assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG, reverse: bool = True, + include_err: Optional[bool] = None, ) -> List[LettaMessage]: """Convert message object (in DB format) to the style used by the original Letta API""" messages = [] @@ -682,14 +689,13 @@ class Message(BaseMessage): # since the only "parts" we have are for supporting various COT if self.role == "system": - assert all([v is not None for v in [self.role]]), vars(self) openai_message = { "content": text_content, "role": "developer" if use_developer_message else self.role, } elif self.role == "user": - assert all([v is not None for v in [text_content, self.role]]), vars(self) + assert text_content is not None, vars(self) openai_message = { "content": text_content, "role": self.role, @@ -720,7 +726,7 @@ class Message(BaseMessage): tool_call_dict["id"] = tool_call_dict["id"][:max_tool_id_length] elif self.role == "tool": - assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self) + assert self.tool_call_id is not None, vars(self) openai_message = { "content": text_content, "role": self.role, @@ -776,7 +782,7 @@ class Message(BaseMessage): if self.role == "system": # NOTE: this is not for system instructions, but instead system "events" - assert all([v is not None for v in [text_content, self.role]]), vars(self) + assert text_content is not None, vars(self) # Two options here, we would use system.package_system_message, # or use a more Anthropic-specific packaging ie xml tags user_system_event = add_xml_tag(string=f"SYSTEM ALERT: {text_content}", xml_tag="event") @@ -875,7 +881,7 @@ class Message(BaseMessage): elif self.role == "tool": # NOTE: Anthropic uses role "user" for "tool" responses - assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self) + assert self.tool_call_id is not None, vars(self) anthropic_message = { "role": "user", # NOTE: diff "content": [ @@ -988,7 +994,7 @@ class Message(BaseMessage): elif self.role == "tool": # NOTE: Significantly different tool calling format, more similar to function calling format - assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self) + assert self.tool_call_id is not None, vars(self) if self.name is None: warnings.warn(f"Couldn't find function name on tool call, defaulting to tool ID instead.") diff --git a/letta/schemas/step.py b/letta/schemas/step.py index 37153a56..bc5cd204 100644 --- a/letta/schemas/step.py +++ b/letta/schemas/step.py @@ -1,8 +1,10 @@ +from enum import Enum, auto from typing import Dict, List, Literal, Optional from pydantic import Field from letta.schemas.letta_base import LettaBase +from letta.schemas.letta_stop_reason import StopReasonType from letta.schemas.message import Message @@ -28,6 +30,7 @@ class Step(StepBase): prompt_tokens: Optional[int] = Field(None, description="The number of tokens in the prompt during this step.") total_tokens: Optional[int] = Field(None, description="The total number of tokens processed by the agent during this step.") completion_tokens_details: Optional[Dict] = Field(None, description="Metadata for the agent.") + stop_reason: Optional[StopReasonType] = Field(None, description="The stop reason associated with the step.") tags: List[str] = Field([], description="Metadata tags.") tid: Optional[str] = Field(None, description="The unique identifier of the transaction that processed this step.") trace_id: Optional[str] = Field(None, description="The trace id of the agent step.") @@ -36,3 +39,12 @@ class Step(StepBase): None, description="The feedback for this step. Must be either 'positive' or 'negative'." ) project_id: Optional[str] = Field(None, description="The project that the agent that executed this step belongs to (cloud only).") + + +class StepProgression(int, Enum): + START = auto() + STREAM_RECEIVED = auto() + RESPONSE_RECEIVED = auto() + STEP_LOGGED = auto() + LOGGED_TRACE = auto() + FINISHED = auto() diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index 68b21934..759a0c46 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -2,8 +2,10 @@ import importlib.util import json import logging import os +import platform import sys from contextlib import asynccontextmanager +from functools import partial from pathlib import Path from typing import Optional @@ -34,32 +36,25 @@ from letta.server.db import db_registry from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right? from letta.server.rest_api.interface import StreamingServerInterface from letta.server.rest_api.routers.openai.chat_completions.chat_completions import router as openai_chat_completions_router - -# from letta.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes from letta.server.rest_api.routers.v1.organizations import router as organizations_router from letta.server.rest_api.routers.v1.users import router as users_router # TODO: decide on admin from letta.server.rest_api.static_files import mount_static_files +from letta.server.rest_api.utils import SENTRY_ENABLED from letta.server.server import SyncServer from letta.settings import settings -# TODO(ethan) +if SENTRY_ENABLED: + import sentry_sdk + +IS_WINDOWS = platform.system() == "Windows" + # NOTE(charles): @ethan I had to add this to get the global as the bottom to work -interface: StreamingServerInterface = StreamingServerInterface +interface: type = StreamingServerInterface server = SyncServer(default_interface_factory=lambda: interface()) logger = get_logger(__name__) -import logging -import platform - -from fastapi import FastAPI - -is_windows = platform.system() == "Windows" - -log = logging.getLogger("uvicorn") - - def generate_openapi_schema(app: FastAPI): # Update the OpenAPI schema if not app.openapi_schema: @@ -177,9 +172,7 @@ def create_application() -> "FastAPI": # server = SyncServer(default_interface_factory=lambda: interface()) print(f"\n[[ Letta server // v{letta_version} ]]") - if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""): - import sentry_sdk - + if SENTRY_ENABLED: sentry_sdk.init( dsn=os.getenv("SENTRY_DSN"), traces_sample_rate=1.0, @@ -187,6 +180,7 @@ def create_application() -> "FastAPI": "continuous_profiling_auto_start": True, }, ) + logger.info("Sentry enabled.") debug_mode = "--debug" in sys.argv app = FastAPI( @@ -199,31 +193,13 @@ def create_application() -> "FastAPI": lifespan=lifespan, ) - @app.exception_handler(IncompatibleAgentType) - async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType): - return JSONResponse( - status_code=400, - content={ - "detail": str(exc), - "expected_type": exc.expected_type, - "actual_type": exc.actual_type, - }, - ) + # === Exception Handlers === + # TODO (cliandy): move to separate file @app.exception_handler(Exception) async def generic_error_handler(request: Request, exc: Exception): - # Log the actual error for debugging - log.error(f"Unhandled error: {str(exc)}", exc_info=True) - print(f"Unhandled error: {str(exc)}") - - import traceback - - # Print the stack trace - print(f"Stack trace: {traceback.format_exc()}") - - if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""): - import sentry_sdk - + logger.error(f"Unhandled error: {str(exc)}", exc_info=True) + if SENTRY_ENABLED: sentry_sdk.capture_exception(exc) return JSONResponse( @@ -235,62 +211,70 @@ def create_application() -> "FastAPI": }, ) - @app.exception_handler(NoResultFound) - async def no_result_found_handler(request: Request, exc: NoResultFound): - logger.error(f"NoResultFound: {exc}") + async def error_handler_with_code(request: Request, exc: Exception, code: int, detail: str | None = None): + logger.error(f"{type(exc).__name__}", exc_info=exc) + if SENTRY_ENABLED: + sentry_sdk.capture_exception(exc) + if not detail: + detail = str(exc) return JSONResponse( - status_code=404, - content={"detail": str(exc)}, + status_code=code, + content={"detail": detail}, ) - @app.exception_handler(ForeignKeyConstraintViolationError) - async def foreign_key_constraint_handler(request: Request, exc: ForeignKeyConstraintViolationError): - logger.error(f"ForeignKeyConstraintViolationError: {exc}") + _error_handler_400 = partial(error_handler_with_code, code=400) + _error_handler_404 = partial(error_handler_with_code, code=404) + _error_handler_404_agent = partial(_error_handler_404, detail="Agent not found") + _error_handler_404_user = partial(_error_handler_404, detail="User not found") + _error_handler_409 = partial(error_handler_with_code, code=409) + + app.add_exception_handler(ValueError, _error_handler_400) + app.add_exception_handler(NoResultFound, _error_handler_404) + app.add_exception_handler(LettaAgentNotFoundError, _error_handler_404_agent) + app.add_exception_handler(LettaUserNotFoundError, _error_handler_404_user) + app.add_exception_handler(ForeignKeyConstraintViolationError, _error_handler_409) + app.add_exception_handler(UniqueConstraintViolationError, _error_handler_409) + + @app.exception_handler(IncompatibleAgentType) + async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType): + logger.error("Incompatible agent types. Expected: %s, Actual: %s", exc.expected_type, exc.actual_type) + if SENTRY_ENABLED: + sentry_sdk.capture_exception(exc) return JSONResponse( - status_code=409, - content={"detail": str(exc)}, - ) - - @app.exception_handler(UniqueConstraintViolationError) - async def unique_key_constraint_handler(request: Request, exc: UniqueConstraintViolationError): - logger.error(f"UniqueConstraintViolationError: {exc}") - - return JSONResponse( - status_code=409, - content={"detail": str(exc)}, + status_code=400, + content={ + "detail": str(exc), + "expected_type": exc.expected_type, + "actual_type": exc.actual_type, + }, ) @app.exception_handler(DatabaseTimeoutError) async def database_timeout_error_handler(request: Request, exc: DatabaseTimeoutError): logger.error(f"Timeout occurred: {exc}. Original exception: {exc.original_exception}") + if SENTRY_ENABLED: + sentry_sdk.capture_exception(exc) + return JSONResponse( status_code=503, content={"detail": "The database is temporarily unavailable. Please try again later."}, ) - @app.exception_handler(ValueError) - async def value_error_handler(request: Request, exc: ValueError): - return JSONResponse(status_code=400, content={"detail": str(exc)}) - - @app.exception_handler(LettaAgentNotFoundError) - async def agent_not_found_handler(request: Request, exc: LettaAgentNotFoundError): - return JSONResponse(status_code=404, content={"detail": "Agent not found"}) - - @app.exception_handler(LettaUserNotFoundError) - async def user_not_found_handler(request: Request, exc: LettaUserNotFoundError): - return JSONResponse(status_code=404, content={"detail": "User not found"}) - @app.exception_handler(BedrockPermissionError) async def bedrock_permission_error_handler(request, exc: BedrockPermissionError): + logger.error(f"Bedrock permission denied.") + if SENTRY_ENABLED: + sentry_sdk.capture_exception(exc) + return JSONResponse( status_code=403, content={ "error": { "type": "bedrock_permission_denied", "message": "Unable to access the required AI model. Please check your Bedrock permissions or contact support.", - "details": {"model_arn": exc.model_arn, "reason": str(exc)}, + "detail": {str(exc)}, } }, ) @@ -301,6 +285,9 @@ def create_application() -> "FastAPI": print(f"▶ Using secure mode with password: {random_password}") app.add_middleware(CheckPasswordMiddleware) + # Add reverse proxy middleware to handle X-Forwarded-* headers + # app.add_middleware(ReverseProxyMiddleware, base_path=settings.server_base_path) + app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins, @@ -442,7 +429,7 @@ def start_server( ) else: - if is_windows: + if IS_WINDOWS: # Windows doesn't those the fancy unicode characters print(f"Server running at: http://{host or 'localhost'}:{port or REST_DEFAULT_PORT}") print(f"View using ADE at: https://app.letta.com/development-servers/local/dashboard\n") diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 78355f66..07fbcee5 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -636,6 +636,9 @@ async def list_messages( use_assistant_message: bool = Query(True, description="Whether to use assistant messages"), assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."), assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."), + include_err: bool | None = Query( + None, description="Whether to include error messages and error statuses. For debugging purposes only." + ), actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ @@ -654,6 +657,7 @@ async def list_messages( use_assistant_message=use_assistant_message, assistant_message_tool_name=assistant_message_tool_name, assistant_message_tool_kwarg=assistant_message_tool_kwarg, + include_err=include_err, actor=actor, ) @@ -1156,7 +1160,7 @@ async def list_agent_groups( ): """Lists the groups for an agent""" actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - print("in list agents with manager_type", manager_type) + logger.info("in list agents with manager_type", manager_type) return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor) diff --git a/letta/server/rest_api/streaming_response.py b/letta/server/rest_api/streaming_response.py index ac22c469..4daf65a7 100644 --- a/letta/server/rest_api/streaming_response.py +++ b/letta/server/rest_api/streaming_response.py @@ -12,6 +12,7 @@ from starlette.types import Send from letta.log import get_logger from letta.schemas.enums import JobStatus from letta.schemas.user import User +from letta.server.rest_api.utils import capture_sentry_exception from letta.services.job_manager import JobManager logger = get_logger(__name__) @@ -92,6 +93,7 @@ class StreamingResponseWithStatusCode(StreamingResponse): more_body = True try: first_chunk = await self.body_iterator.__anext__() + logger.debug("stream_response first chunk:", first_chunk) if isinstance(first_chunk, tuple): first_chunk_content, self.status_code = first_chunk else: @@ -130,7 +132,7 @@ class StreamingResponseWithStatusCode(StreamingResponse): "more_body": more_body, } ) - return + raise Exception(f"An exception occurred mid-stream with status code {status_code}", detail={"content": content}) else: content = chunk @@ -146,8 +148,8 @@ class StreamingResponseWithStatusCode(StreamingResponse): ) # This should be handled properly upstream? - except asyncio.CancelledError: - logger.info("Stream was cancelled by client or job cancellation") + except asyncio.CancelledError as exc: + logger.warning("Stream was cancelled by client or job cancellation") # Handle cancellation gracefully more_body = False cancellation_resp = {"error": {"message": "Stream cancelled"}} @@ -160,6 +162,7 @@ class StreamingResponseWithStatusCode(StreamingResponse): "headers": self.raw_headers, } ) + raise await send( { "type": "http.response.body", @@ -167,13 +170,15 @@ class StreamingResponseWithStatusCode(StreamingResponse): "more_body": more_body, } ) + capture_sentry_exception(exc) return - except Exception: - logger.exception("unhandled_streaming_error") + except Exception as exc: + logger.exception("Unhandled Streaming Error") more_body = False error_resp = {"error": {"message": "Internal Server Error"}} error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset) + logger.debug("response_started:", self.response_started) if not self.response_started: await send( { @@ -182,6 +187,7 @@ class StreamingResponseWithStatusCode(StreamingResponse): "headers": self.raw_headers, } ) + raise await send( { "type": "http.response.body", @@ -189,5 +195,7 @@ class StreamingResponseWithStatusCode(StreamingResponse): "more_body": more_body, } ) + capture_sentry_exception(exc) + return if more_body: await send({"type": "http.response.body", "body": b"", "more_body": False}) diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 91f75012..1d47647e 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -2,7 +2,6 @@ import asyncio import json import os import uuid -import warnings from enum import Enum from typing import TYPE_CHECKING, AsyncGenerator, Dict, Iterable, List, Optional, Union, cast @@ -34,12 +33,15 @@ from letta.schemas.message import Message, MessageCreate, ToolReturn from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User -from letta.server.rest_api.interface import StreamingServerInterface from letta.system import get_heartbeat, package_function_response if TYPE_CHECKING: from letta.server.server import SyncServer +SENTRY_ENABLED = bool(os.getenv("SENTRY_DSN")) + +if SENTRY_ENABLED: + import sentry_sdk SSE_PREFIX = "data: " SSE_SUFFIX = "\n\n" @@ -157,21 +159,9 @@ def get_user_id(user_id: Optional[str] = Header(None, alias="user_id")) -> Optio return user_id -def get_current_interface() -> StreamingServerInterface: - return StreamingServerInterface - - -def log_error_to_sentry(e): - import traceback - - traceback.print_exc() - warnings.warn(f"SSE stream generator failed: {e}") - - # Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response - # Print the stack trace - if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""): - import sentry_sdk - +def capture_sentry_exception(e: BaseException): + """This will capture the exception in sentry, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response""" + if SENTRY_ENABLED: sentry_sdk.capture_exception(e) diff --git a/letta/server/server.py b/letta/server/server.py index af996ffc..1afeccaf 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -105,6 +105,7 @@ from letta.services.tool_executor.tool_execution_manager import ToolExecutionMan from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.settings import model_settings, settings, tool_settings +from letta.streaming_interface import AgentChunkStreamingInterface from letta.utils import get_friendly_error_msg, get_persona_text, make_key config = LettaConfig.load() @@ -176,7 +177,7 @@ class SyncServer(Server): self, chaining: bool = True, max_chaining_steps: Optional[int] = 100, - default_interface_factory: Callable[[], AgentInterface] = lambda: CLIInterface(), + default_interface_factory: Callable[[], AgentChunkStreamingInterface] = lambda: CLIInterface(), init_with_default_org_and_user: bool = True, # default_interface: AgentInterface = CLIInterface(), # default_persistence_manager_cls: PersistenceManager = LocalStateManager, @@ -1244,6 +1245,7 @@ class SyncServer(Server): use_assistant_message: bool = True, assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL, assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG, + include_err: Optional[bool] = None, ) -> Union[List[Message], List[LettaMessage]]: records = await self.message_manager.list_messages_for_agent_async( agent_id=agent_id, @@ -1253,6 +1255,7 @@ class SyncServer(Server): limit=limit, ascending=not reverse, group_id=group_id, + include_err=include_err, ) if not return_message_object: @@ -1262,6 +1265,7 @@ class SyncServer(Server): assistant_message_tool_name=assistant_message_tool_name, assistant_message_tool_kwarg=assistant_message_tool_kwarg, reverse=reverse, + include_err=include_err, ) if reverse: diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index af4ce2c6..6770157f 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -520,6 +520,7 @@ class MessageManager: limit: Optional[int] = 50, ascending: bool = True, group_id: Optional[str] = None, + include_err: Optional[bool] = None, ) -> List[PydanticMessage]: """ Most performant query to list messages for an agent by directly querying the Message table. @@ -539,6 +540,7 @@ class MessageManager: limit: Maximum number of messages to return. ascending: If True, sort by sequence_id ascending; if False, sort descending. group_id: Optional group ID to filter messages by group_id. + include_err: Optional boolean to include errors and error statuses. Used for debugging only. Returns: List[PydanticMessage]: A list of messages (converted via .to_pydantic()). @@ -558,6 +560,9 @@ class MessageManager: if group_id: query = query.where(MessageModel.group_id == group_id) + if not include_err: + query = query.where((MessageModel.is_err == False) | (MessageModel.is_err.is_(None))) + # If query_text is provided, filter messages using database-specific JSON search. if query_text: if settings.letta_pg_uri_no_default: diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index c16af2fe..c1034ebb 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -12,6 +12,7 @@ from letta.orm.job import Job as JobModel from letta.orm.sqlalchemy_base import AccessType from letta.orm.step import Step as StepModel from letta.otel.tracing import get_trace_id, trace_method +from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.step import Step as PydanticStep from letta.schemas.user import User as PydanticUser @@ -131,6 +132,7 @@ class StepManager: job_id: Optional[str] = None, step_id: Optional[str] = None, project_id: Optional[str] = None, + stop_reason: Optional[LettaStopReason] = None, ) -> PydanticStep: step_data = { "origin": None, @@ -153,6 +155,8 @@ class StepManager: } if step_id: step_data["id"] = step_id + if stop_reason: + step_data["stop_reason"] = stop_reason.stop_reason async with db_registry.async_session() as session: if job_id: await self._verify_job_access_async(session, job_id, actor, access=["write"]) @@ -207,6 +211,33 @@ class StepManager: await session.commit() return step.to_pydantic() + @enforce_types + @trace_method + async def update_step_stop_reason(self, actor: PydanticUser, step_id: str, stop_reason: StopReasonType) -> PydanticStep: + """Update the stop reason for a step. + + Args: + actor: The user making the request + step_id: The ID of the step to update + stop_reason: The stop reason to set + + Returns: + The updated step + + Raises: + NoResultFound: If the step does not exist + """ + async with db_registry.async_session() as session: + step = await session.get(StepModel, step_id) + if not step: + raise NoResultFound(f"Step with id {step_id} does not exist") + if step.organization_id != actor.organization_id: + raise Exception("Unauthorized") + + step.stop_reason = stop_reason + await session.commit() + return step + def _verify_job_access( self, session: Session, @@ -309,5 +340,6 @@ class NoopStepManager(StepManager): job_id: Optional[str] = None, step_id: Optional[str] = None, project_id: Optional[str] = None, + stop_reason: Optional[LettaStopReason] = None, ) -> PydanticStep: return diff --git a/letta/settings.py b/letta/settings.py index f5b749f5..74dce759 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -220,13 +220,15 @@ class Settings(BaseSettings): multi_agent_concurrent_sends: int = 50 # telemetry logging - otel_exporter_otlp_endpoint: Optional[str] = None # otel default: "http://localhost:4317" - otel_preferred_temporality: Optional[int] = Field( + otel_exporter_otlp_endpoint: str | None = None # otel default: "http://localhost:4317" + otel_preferred_temporality: int | None = Field( default=1, ge=0, le=2, description="Exported metric temporality. {0: UNSPECIFIED, 1: DELTA, 2: CUMULATIVE}" ) disable_tracing: bool = Field(default=False, description="Disable OTEL Tracing") llm_api_logging: bool = Field(default=True, description="Enable LLM API logging at each step") track_last_agent_run: bool = Field(default=False, description="Update last agent run metrics") + track_errored_messages: bool = Field(default=True, description="Enable tracking for errored messages") + track_stop_reason: bool = Field(default=True, description="Enable tracking stop reason on steps.") # uvicorn settings uvicorn_workers: int = 1 diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index 9f6a250e..4493c6b5 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -752,14 +752,15 @@ def test_step_stream_agent_loop_error( """ last_message = client.agents.messages.list(agent_id=agent_state_no_tools.id, limit=1) agent_state_no_tools = client.agents.modify(agent_id=agent_state_no_tools.id, llm_config=llm_config) + response = client.agents.messages.create_stream( + agent_id=agent_state_no_tools.id, + messages=USER_MESSAGE_FORCE_REPLY, + ) with pytest.raises(Exception) as exc_info: - response = client.agents.messages.create_stream( - agent_id=agent_state_no_tools.id, - messages=USER_MESSAGE_FORCE_REPLY, - ) - list(response) + for chunk in response: + print(chunk) + print("error info:", exc_info) assert type(exc_info.value) in (ApiError, ValueError) - print(exc_info.value) messages_from_db = client.agents.messages.list(agent_id=agent_state_no_tools.id, after=last_message[0].id) assert len(messages_from_db) == 0 diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py index d81b2011..a228e250 100644 --- a/tests/test_tool_rule_solver.py +++ b/tests/test_tool_rule_solver.py @@ -138,7 +138,9 @@ def test_max_count_per_step_tool_rule(): assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "After first use, should still allow 'start_tool'" solver.register_tool_call(START_TOOL) - assert solver.get_allowed_tool_names({START_TOOL}) == [], "After reaching max count, 'start_tool' should no longer be allowed" + assert ( + solver.get_allowed_tool_names({START_TOOL}, error_on_empty=False) == [] + ), "After reaching max count, 'start_tool' should no longer be allowed" def test_max_count_per_step_tool_rule_allows_usage_up_to_limit(): @@ -155,7 +157,7 @@ def test_max_count_per_step_tool_rule_allows_usage_up_to_limit(): assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Should still allow 'start_tool' after 2 uses" solver.register_tool_call(START_TOOL) - assert solver.get_allowed_tool_names({START_TOOL}) == [], "Should no longer allow 'start_tool' after 3 uses" + assert solver.get_allowed_tool_names({START_TOOL}, error_on_empty=False) == [], "Should no longer allow 'start_tool' after 3 uses" def test_max_count_per_step_tool_rule_does_not_affect_other_tools(): @@ -180,7 +182,7 @@ def test_max_count_per_step_tool_rule_resets_on_clear(): solver.register_tool_call(START_TOOL) solver.register_tool_call(START_TOOL) - assert solver.get_allowed_tool_names({START_TOOL}) == [], "Should not allow 'start_tool' after reaching limit" + assert solver.get_allowed_tool_names({START_TOOL}, error_on_empty=False) == [], "Should not allow 'start_tool' after reaching limit" solver.clear_tool_history()