feat: stop reasons and error messages and sentry fixes

This commit is contained in:
Andy Li
2025-07-18 11:56:20 -07:00
committed by GitHub
parent 904d9ba5a2
commit b7b678db4e
18 changed files with 746 additions and 442 deletions

View File

@@ -0,0 +1,42 @@
"""add stop reasons to steps and message error flag
Revision ID: cce9a6174366
Revises: 2c059cad97cc
Create Date: 2025-07-10 13:56:17.383612
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "cce9a6174366"
down_revision: Union[str, None] = "2c059cad97cc"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("messages", sa.Column("is_err", sa.Boolean(), nullable=True))
# manually added to handle non-table creation enums
stopreasontype = sa.Enum(
"end_turn", "error", "invalid_tool_call", "max_steps", "no_tool_call", "tool_rule", "cancelled", name="stopreasontype"
)
stopreasontype.create(op.get_bind())
op.add_column("steps", sa.Column("stop_reason", stopreasontype, nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("steps", "stop_reason")
op.drop_column("messages", "is_err")
stopreasontype = sa.Enum(name="stopreasontype")
stopreasontype.drop(op.get_bind())
# ### end Alembic commands ###

View File

@@ -43,6 +43,7 @@ from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message, MessageCreate from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
from letta.schemas.provider_trace import ProviderTraceCreate from letta.schemas.provider_trace import ProviderTraceCreate
from letta.schemas.step import StepProgression
from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.usage import LettaUsageStatistics from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User from letta.schemas.user import User
@@ -238,100 +239,164 @@ class LettaAgent(BaseAgent):
agent_step_span = tracer.start_span("agent_step", start_time=step_start) agent_step_span = tracer.start_span("agent_step", start_time=step_start)
agent_step_span.set_attributes({"step_id": step_id}) agent_step_span.set_attributes({"step_id": step_id})
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = ( step_progression = StepProgression.START
await self._build_and_request_from_llm( should_continue = False
current_in_context_messages, try:
new_in_context_messages, request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
agent_state, await self._build_and_request_from_llm(
llm_client, current_in_context_messages,
tool_rules_solver, new_in_context_messages,
agent_step_span, agent_state,
) llm_client,
) tool_rules_solver,
in_context_messages = current_in_context_messages + new_in_context_messages agent_step_span,
log_event("agent.stream_no_tokens.llm_response.received") # [3^]
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
# update usage
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
MetricRegistry().message_output_tokens.record(
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
)
if not response.choices[0].message.tool_calls:
# TODO: make into a real error
raise ValueError("No tool calls found in response, model must make a tool call")
tool_call = response.choices[0].message.tool_calls[0]
if response.choices[0].message.reasoning_content:
reasoning = [
ReasoningContent(
reasoning=response.choices[0].message.reasoning_content,
is_native=True,
signature=response.choices[0].message.reasoning_content_signature,
) )
] )
elif response.choices[0].message.omitted_reasoning_content: in_context_messages = current_in_context_messages + new_in_context_messages
reasoning = [OmittedReasoningContent()]
elif response.choices[0].message.content:
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
else:
self.logger.info("No reasoning content found.")
reasoning = None
persisted_messages, should_continue, stop_reason = await self._handle_ai_response( step_progression = StepProgression.RESPONSE_RECEIVED
tool_call, log_event("agent.stream_no_tokens.llm_response.received") # [3^]
valid_tool_names,
agent_state,
tool_rules_solver,
response.usage,
reasoning_content=reasoning,
step_id=step_id,
initial_messages=initial_messages,
agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1),
)
# TODO (cliandy): handle message contexts with larger refactor and dedupe logic response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
new_message_idx = len(initial_messages) if initial_messages else 0
self.response_messages.extend(persisted_messages[new_message_idx:])
new_in_context_messages.extend(persisted_messages[new_message_idx:])
initial_messages = None
log_event("agent.stream_no_tokens.llm_response.processed") # [4^]
# log step time # update usage
now = get_utc_timestamp_ns() usage.step_count += 1
step_ns = now - step_start usage.completion_tokens += response.usage.completion_tokens
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)}) usage.prompt_tokens += response.usage.prompt_tokens
agent_step_span.end() usage.total_tokens += response.usage.total_tokens
MetricRegistry().message_output_tokens.record(
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
)
# Log LLM Trace if not response.choices[0].message.tool_calls:
await self.telemetry_manager.create_provider_trace_async( stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
actor=self.actor, raise ValueError("No tool calls found in response, model must make a tool call")
provider_trace_create=ProviderTraceCreate( tool_call = response.choices[0].message.tool_calls[0]
request_json=request_data, if response.choices[0].message.reasoning_content:
response_json=response_data, reasoning = [
ReasoningContent(
reasoning=response.choices[0].message.reasoning_content,
is_native=True,
signature=response.choices[0].message.reasoning_content_signature,
)
]
elif response.choices[0].message.omitted_reasoning_content:
reasoning = [OmittedReasoningContent()]
elif response.choices[0].message.content:
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
else:
self.logger.info("No reasoning content found.")
reasoning = None
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
tool_call,
valid_tool_names,
agent_state,
tool_rules_solver,
response.usage,
reasoning_content=reasoning,
step_id=step_id, step_id=step_id,
organization_id=self.actor.organization_id, initial_messages=initial_messages,
), agent_step_span=agent_step_span,
) is_final_step=(i == max_steps - 1),
)
step_progression = StepProgression.STEP_LOGGED
# stream step # TODO (cliandy): handle message contexts with larger refactor and dedupe logic
# TODO: improve TTFT new_message_idx = len(initial_messages) if initial_messages else 0
filter_user_messages = [m for m in persisted_messages if m.role != "user"] self.response_messages.extend(persisted_messages[new_message_idx:])
letta_messages = Message.to_letta_messages_from_list( new_in_context_messages.extend(persisted_messages[new_message_idx:])
filter_user_messages, use_assistant_message=use_assistant_message, reverse=False initial_messages = None
) log_event("agent.stream_no_tokens.llm_response.processed") # [4^]
for message in letta_messages: # log step time
if include_return_message_types is None or message.message_type in include_return_message_types: now = get_utc_timestamp_ns()
yield f"data: {message.model_dump_json()}\n\n" step_ns = now - step_start
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
agent_step_span.end()
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes()) # Log LLM Trace
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
step_progression = StepProgression.LOGGED_TRACE
# stream step
# TODO: improve TTFT
filter_user_messages = [m for m in persisted_messages if m.role != "user"]
letta_messages = Message.to_letta_messages_from_list(
filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
)
for message in letta_messages:
if include_return_message_types is None or message.message_type in include_return_message_types:
yield f"data: {message.model_dump_json()}\n\n"
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
step_progression = StepProgression.FINISHED
except Exception as e:
# Handle any unexpected errors during step processing
self.logger.error(f"Error during step processing: {e}")
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
if not stop_reason:
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule):
self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason)
elif stop_reason.stop_reason not in (StopReasonType.no_tool_call, StopReasonType.invalid_tool_call):
raise ValueError(f"Invalid Stop Reason: {stop_reason}")
# Send error stop reason to client and re-raise
yield f"data: {stop_reason.model_dump_json()}\n\n", 500
raise
# Update step if it needs to be updated
finally:
if settings.track_stop_reason:
self.logger.info("Running final update. Step Progression: %s", step_progression)
try:
if step_progression < StepProgression.STEP_LOGGED:
await self.step_manager.log_step_async(
actor=self.actor,
agent_id=agent_state.id,
provider_name=agent_state.llm_config.model_endpoint_type,
provider_category=agent_state.llm_config.provider_category or "base",
model=agent_state.llm_config.model,
model_endpoint=agent_state.llm_config.model_endpoint,
context_window_limit=agent_state.llm_config.context_window,
usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0),
provider_id=None,
job_id=self.current_run_id if self.current_run_id else None,
step_id=step_id,
project_id=agent_state.project_id,
stop_reason=stop_reason,
)
if step_progression <= StepProgression.RESPONSE_RECEIVED:
# TODO (cliandy): persist response if we get it back
if settings.track_errored_messages:
for message in initial_messages:
message.is_err = True
message.step_id = step_id
await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor)
elif step_progression <= StepProgression.LOGGED_TRACE:
if stop_reason is None:
self.logger.error("Error in step after logging step")
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
elif step_progression == StepProgression.FINISHED and not should_continue:
if stop_reason is None:
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
else:
self.logger.error("Invalid StepProgression value")
except Exception as e:
self.logger.error("Failed to update step: %s", e)
if not should_continue: if not should_continue:
break break
@@ -396,6 +461,16 @@ class LettaAgent(BaseAgent):
stop_reason = None stop_reason = None
usage = LettaUsageStatistics() usage = LettaUsageStatistics()
for i in range(max_steps): for i in range(max_steps):
# If dry run, build request data and return it without making LLM call
if dry_run:
request_data, valid_tool_names = await self._create_llm_request_data_async(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
agent_state=agent_state,
tool_rules_solver=tool_rules_solver,
)
return request_data
# Check for job cancellation at the start of each step # Check for job cancellation at the start of each step
if await self._check_run_cancellation(): if await self._check_run_cancellation():
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value) stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
@@ -407,94 +482,148 @@ class LettaAgent(BaseAgent):
agent_step_span = tracer.start_span("agent_step", start_time=step_start) agent_step_span = tracer.start_span("agent_step", start_time=step_start)
agent_step_span.set_attributes({"step_id": step_id}) agent_step_span.set_attributes({"step_id": step_id})
# If dry run, build request data and return it without making LLM call step_progression = StepProgression.START
if dry_run: should_continue = False
request_data, valid_tool_names = await self._create_llm_request_data_async(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
agent_state=agent_state,
tool_rules_solver=tool_rules_solver,
)
return request_data
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = ( try:
await self._build_and_request_from_llm( request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver, agent_step_span await self._build_and_request_from_llm(
) current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver, agent_step_span
)
in_context_messages = current_in_context_messages + new_in_context_messages
log_event("agent.step.llm_response.received") # [3^]
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
usage.run_ids = [run_id] if run_id else None
MetricRegistry().message_output_tokens.record(
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
)
if not response.choices[0].message.tool_calls:
# TODO: make into a real error
raise ValueError("No tool calls found in response, model must make a tool call")
tool_call = response.choices[0].message.tool_calls[0]
if response.choices[0].message.reasoning_content:
reasoning = [
ReasoningContent(
reasoning=response.choices[0].message.reasoning_content,
is_native=True,
signature=response.choices[0].message.reasoning_content_signature,
) )
] )
elif response.choices[0].message.content: in_context_messages = current_in_context_messages + new_in_context_messages
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
elif response.choices[0].message.omitted_reasoning_content:
reasoning = [OmittedReasoningContent()]
else:
self.logger.info("No reasoning content found.")
reasoning = None
persisted_messages, should_continue, stop_reason = await self._handle_ai_response( step_progression = StepProgression.RESPONSE_RECEIVED
tool_call, log_event("agent.step.llm_response.received") # [3^]
valid_tool_names,
agent_state,
tool_rules_solver,
response.usage,
reasoning_content=reasoning,
step_id=step_id,
initial_messages=initial_messages,
agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1),
run_id=run_id,
)
new_message_idx = len(initial_messages) if initial_messages else 0
self.response_messages.extend(persisted_messages[new_message_idx:])
new_in_context_messages.extend(persisted_messages[new_message_idx:])
initial_messages = None response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
log_event("agent.step.llm_response.processed") # [4^]
# log step time usage.step_count += 1
now = get_utc_timestamp_ns() usage.completion_tokens += response.usage.completion_tokens
step_ns = now - step_start usage.prompt_tokens += response.usage.prompt_tokens
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)}) usage.total_tokens += response.usage.total_tokens
agent_step_span.end() usage.run_ids = [run_id] if run_id else None
MetricRegistry().message_output_tokens.record(
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
)
# Log LLM Trace if not response.choices[0].message.tool_calls:
await self.telemetry_manager.create_provider_trace_async( stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
actor=self.actor, raise ValueError("No tool calls found in response, model must make a tool call")
provider_trace_create=ProviderTraceCreate( tool_call = response.choices[0].message.tool_calls[0]
request_json=request_data, if response.choices[0].message.reasoning_content:
response_json=response_data, reasoning = [
ReasoningContent(
reasoning=response.choices[0].message.reasoning_content,
is_native=True,
signature=response.choices[0].message.reasoning_content_signature,
)
]
elif response.choices[0].message.content:
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
elif response.choices[0].message.omitted_reasoning_content:
reasoning = [OmittedReasoningContent()]
else:
self.logger.info("No reasoning content found.")
reasoning = None
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
tool_call,
valid_tool_names,
agent_state,
tool_rules_solver,
response.usage,
reasoning_content=reasoning,
step_id=step_id, step_id=step_id,
organization_id=self.actor.organization_id, initial_messages=initial_messages,
), agent_step_span=agent_step_span,
) is_final_step=(i == max_steps - 1),
run_id=run_id,
)
step_progression = StepProgression.STEP_LOGGED
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes()) new_message_idx = len(initial_messages) if initial_messages else 0
self.response_messages.extend(persisted_messages[new_message_idx:])
new_in_context_messages.extend(persisted_messages[new_message_idx:])
initial_messages = None
log_event("agent.step.llm_response.processed") # [4^]
# log step time
now = get_utc_timestamp_ns()
step_ns = now - step_start
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
agent_step_span.end()
# Log LLM Trace
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
step_progression = StepProgression.LOGGED_TRACE
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
step_progression = StepProgression.FINISHED
except Exception as e:
# Handle any unexpected errors during step processing
self.logger.error(f"Error during step processing: {e}")
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
if not stop_reason:
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule):
self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason)
elif stop_reason.stop_reason not in (StopReasonType.no_tool_call, StopReasonType.invalid_tool_call):
raise ValueError(f"Invalid Stop Reason: {stop_reason}")
raise
# Update step if it needs to be updated
finally:
if settings.track_stop_reason:
self.logger.info("Running final update. Step Progression: %s", step_progression)
try:
if step_progression < StepProgression.STEP_LOGGED:
await self.step_manager.log_step_async(
actor=self.actor,
agent_id=agent_state.id,
provider_name=agent_state.llm_config.model_endpoint_type,
provider_category=agent_state.llm_config.provider_category or "base",
model=agent_state.llm_config.model,
model_endpoint=agent_state.llm_config.model_endpoint,
context_window_limit=agent_state.llm_config.context_window,
usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0),
provider_id=None,
job_id=self.current_run_id if self.current_run_id else None,
step_id=step_id,
project_id=agent_state.project_id,
stop_reason=stop_reason,
)
if step_progression <= StepProgression.RESPONSE_RECEIVED:
# TODO (cliandy): persist response if we get it back
if settings.track_errored_messages:
for message in initial_messages:
message.is_err = True
message.step_id = step_id
await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor)
elif step_progression <= StepProgression.LOGGED_TRACE:
if stop_reason is None:
self.logger.error("Error in step after logging step")
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
elif step_progression == StepProgression.FINISHED and not should_continue:
if stop_reason is None:
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
else:
self.logger.error("Invalid StepProgression value")
except Exception as e:
self.logger.error("Failed to update step: %s", e)
if not should_continue: if not should_continue:
break break
@@ -576,6 +705,7 @@ class LettaAgent(BaseAgent):
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None}) request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
for i in range(max_steps): for i in range(max_steps):
step_id = generate_step_id()
# Check for job cancellation at the start of each step # Check for job cancellation at the start of each step
if await self._check_run_cancellation(): if await self._check_run_cancellation():
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value) stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
@@ -583,163 +713,230 @@ class LettaAgent(BaseAgent):
yield f"data: {stop_reason.model_dump_json()}\n\n" yield f"data: {stop_reason.model_dump_json()}\n\n"
break break
step_id = generate_step_id()
step_start = get_utc_timestamp_ns() step_start = get_utc_timestamp_ns()
agent_step_span = tracer.start_span("agent_step", start_time=step_start) agent_step_span = tracer.start_span("agent_step", start_time=step_start)
agent_step_span.set_attributes({"step_id": step_id}) agent_step_span.set_attributes({"step_id": step_id})
( step_progression = StepProgression.START
request_data, should_continue = False
stream,
current_in_context_messages,
new_in_context_messages,
valid_tool_names,
provider_request_start_timestamp_ns,
) = await self._build_and_request_from_llm_streaming(
first_chunk,
agent_step_span,
request_start_timestamp_ns,
current_in_context_messages,
new_in_context_messages,
agent_state,
llm_client,
tool_rules_solver,
)
log_event("agent.stream.llm_response.received") # [3^]
# TODO: THIS IS INCREDIBLY UGLY
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
if agent_state.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]:
interface = AnthropicStreamingInterface(
use_assistant_message=use_assistant_message,
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
)
elif agent_state.llm_config.model_endpoint_type == ProviderType.openai:
interface = OpenAIStreamingInterface(
use_assistant_message=use_assistant_message,
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
)
else:
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
async for chunk in interface.process(
stream,
ttft_span=request_span,
provider_request_start_timestamp_ns=provider_request_start_timestamp_ns,
):
# Measure time to first token
if first_chunk and request_span is not None:
now = get_utc_timestamp_ns()
ttft_ns = now - request_start_timestamp_ns
request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ns_to_ms(ttft_ns)})
metric_attributes = get_ctx_attributes()
metric_attributes["model.name"] = agent_state.llm_config.model
MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes)
first_chunk = False
if include_return_message_types is None or chunk.message_type in include_return_message_types:
# filter down returned data
yield f"data: {chunk.model_dump_json()}\n\n"
stream_end_time_ns = get_utc_timestamp_ns()
# update usage
usage.step_count += 1
usage.completion_tokens += interface.output_tokens
usage.prompt_tokens += interface.input_tokens
usage.total_tokens += interface.input_tokens + interface.output_tokens
MetricRegistry().message_output_tokens.record(
interface.output_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
)
# log LLM request time
llm_request_ms = ns_to_ms(stream_end_time_ns - provider_request_start_timestamp_ns)
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ms})
MetricRegistry().llm_execution_time_ms_histogram.record(
llm_request_ms,
dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}),
)
# Process resulting stream content
try: try:
tool_call = interface.get_tool_call_object() (
except ValueError as e: request_data,
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value) stream,
yield f"data: {stop_reason.model_dump_json()}\n\n" current_in_context_messages,
raise e new_in_context_messages,
except Exception as e: valid_tool_names,
stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value) provider_request_start_timestamp_ns,
yield f"data: {stop_reason.model_dump_json()}\n\n" ) = await self._build_and_request_from_llm_streaming(
raise e first_chunk,
reasoning_content = interface.get_reasoning_content() agent_step_span,
persisted_messages, should_continue, stop_reason = await self._handle_ai_response( request_start_timestamp_ns,
tool_call, current_in_context_messages,
valid_tool_names, new_in_context_messages,
agent_state, agent_state,
tool_rules_solver, llm_client,
UsageStatistics( tool_rules_solver,
completion_tokens=interface.output_tokens, )
prompt_tokens=interface.input_tokens,
total_tokens=interface.input_tokens + interface.output_tokens,
),
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=interface.letta_message_id,
step_id=step_id,
initial_messages=initial_messages,
agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1),
)
new_message_idx = len(initial_messages) if initial_messages else 0
self.response_messages.extend(persisted_messages[new_message_idx:])
new_in_context_messages.extend(persisted_messages[new_message_idx:])
initial_messages = None step_progression = StepProgression.STREAM_RECEIVED
log_event("agent.stream.llm_response.received") # [3^]
# log total step time # TODO: THIS IS INCREDIBLY UGLY
now = get_utc_timestamp_ns() # TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
step_ns = now - step_start if agent_state.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]:
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)}) interface = AnthropicStreamingInterface(
agent_step_span.end() use_assistant_message=use_assistant_message,
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
)
elif agent_state.llm_config.model_endpoint_type == ProviderType.openai:
interface = OpenAIStreamingInterface(
use_assistant_message=use_assistant_message,
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
)
else:
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
# TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream async for chunk in interface.process(
# log_event("agent.stream.llm_response.processed") # [4^] stream,
ttft_span=request_span,
provider_request_start_timestamp_ns=provider_request_start_timestamp_ns,
):
# Measure time to first token
if first_chunk and request_span is not None:
now = get_utc_timestamp_ns()
ttft_ns = now - request_start_timestamp_ns
request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ns_to_ms(ttft_ns)})
metric_attributes = get_ctx_attributes()
metric_attributes["model.name"] = agent_state.llm_config.model
MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes)
first_chunk = False
# Log LLM Trace if include_return_message_types is None or chunk.message_type in include_return_message_types:
# TODO (cliandy): we are piecing together the streamed response here. Content here does not match the actual response schema. # filter down returned data
await self.telemetry_manager.create_provider_trace_async( yield f"data: {chunk.model_dump_json()}\n\n"
actor=self.actor,
provider_trace_create=ProviderTraceCreate( stream_end_time_ns = get_utc_timestamp_ns()
request_json=request_data,
response_json={ # update usage
"content": { usage.step_count += 1
"tool_call": tool_call.model_dump_json(), usage.completion_tokens += interface.output_tokens
"reasoning": [content.model_dump_json() for content in reasoning_content], usage.prompt_tokens += interface.input_tokens
}, usage.total_tokens += interface.input_tokens + interface.output_tokens
"id": interface.message_id, MetricRegistry().message_output_tokens.record(
"model": interface.model, interface.output_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
"role": "assistant", )
# "stop_reason": "",
# "stop_sequence": None, # log LLM request time
"type": "message", llm_request_ms = ns_to_ms(stream_end_time_ns - provider_request_start_timestamp_ns)
"usage": {"input_tokens": interface.input_tokens, "output_tokens": interface.output_tokens}, agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ms})
}, MetricRegistry().llm_execution_time_ms_histogram.record(
llm_request_ms,
dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}),
)
# Process resulting stream content
try:
tool_call = interface.get_tool_call_object()
except ValueError as e:
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
raise e
except Exception as e:
stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value)
raise e
reasoning_content = interface.get_reasoning_content()
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
tool_call,
valid_tool_names,
agent_state,
tool_rules_solver,
UsageStatistics(
completion_tokens=interface.output_tokens,
prompt_tokens=interface.input_tokens,
total_tokens=interface.input_tokens + interface.output_tokens,
),
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=interface.letta_message_id,
step_id=step_id, step_id=step_id,
organization_id=self.actor.organization_id, initial_messages=initial_messages,
), agent_step_span=agent_step_span,
) is_final_step=(i == max_steps - 1),
)
step_progression = StepProgression.STEP_LOGGED
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0] new_message_idx = len(initial_messages) if initial_messages else 0
if not (use_assistant_message and tool_return.name == "send_message"): self.response_messages.extend(persisted_messages[new_message_idx:])
# Apply message type filtering if specified new_in_context_messages.extend(persisted_messages[new_message_idx:])
if include_return_message_types is None or tool_return.message_type in include_return_message_types:
yield f"data: {tool_return.model_dump_json()}\n\n"
# TODO (cliandy): consolidate and expand with trace initial_messages = None
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
# log total step time
now = get_utc_timestamp_ns()
step_ns = now - step_start
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
agent_step_span.end()
# TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream
# log_event("agent.stream.llm_response.processed") # [4^]
# Log LLM Trace
# We are piecing together the streamed response here.
# Content here does not match the actual response schema as streams come in chunks.
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json={
"content": {
"tool_call": tool_call.model_dump_json(),
"reasoning": [content.model_dump_json() for content in reasoning_content],
},
"id": interface.message_id,
"model": interface.model,
"role": "assistant",
# "stop_reason": "",
# "stop_sequence": None,
"type": "message",
"usage": {
"input_tokens": interface.input_tokens,
"output_tokens": interface.output_tokens,
},
},
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
step_progression = StepProgression.LOGGED_TRACE
# yields tool response as this is handled from Letta and not the response from the LLM provider
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
if not (use_assistant_message and tool_return.name == "send_message"):
# Apply message type filtering if specified
if include_return_message_types is None or tool_return.message_type in include_return_message_types:
yield f"data: {tool_return.model_dump_json()}\n\n"
# TODO (cliandy): consolidate and expand with trace
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
step_progression = StepProgression.FINISHED
except Exception as e:
# Handle any unexpected errors during step processing
self.logger.error(f"Error during step processing: {e}")
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
if not stop_reason:
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule):
self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason)
elif stop_reason.stop_reason not in (StopReasonType.no_tool_call, StopReasonType.invalid_tool_call):
raise ValueError(f"Invalid Stop Reason: {stop_reason}")
# Send error stop reason to client and re-raise with expected response code
yield f"data: {stop_reason.model_dump_json()}\n\n", 500
raise
# Update step if it needs to be updated
finally:
if settings.track_stop_reason:
self.logger.info("Running final update. Step Progression: %s", step_progression)
try:
if step_progression < StepProgression.STEP_LOGGED:
await self.step_manager.log_step_async(
actor=self.actor,
agent_id=agent_state.id,
provider_name=agent_state.llm_config.model_endpoint_type,
provider_category=agent_state.llm_config.provider_category or "base",
model=agent_state.llm_config.model,
model_endpoint=agent_state.llm_config.model_endpoint,
context_window_limit=agent_state.llm_config.context_window,
usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0),
provider_id=None,
job_id=self.current_run_id if self.current_run_id else None,
step_id=step_id,
project_id=agent_state.project_id,
stop_reason=stop_reason,
)
if step_progression <= StepProgression.STREAM_RECEIVED:
if first_chunk and settings.track_errored_messages:
for message in initial_messages:
message.is_err = True
message.step_id = step_id
await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor)
elif step_progression <= StepProgression.LOGGED_TRACE:
if stop_reason is None:
self.logger.error("Error in step after logging step")
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
elif step_progression == StepProgression.FINISHED and not should_continue:
if stop_reason is None:
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
else:
self.logger.error("Invalid StepProgression value")
except Exception as e:
self.logger.error("Failed to update step: %s", e)
if not should_continue: if not should_continue:
break break
# Extend the in context message ids # Extend the in context message ids
if not agent_state.message_buffer_autoclear: if not agent_state.message_buffer_autoclear:
await self._rebuild_context_window( await self._rebuild_context_window(
@@ -1106,6 +1303,7 @@ class LettaAgent(BaseAgent):
job_id=run_id if run_id else self.current_run_id, job_id=run_id if run_id else self.current_run_id,
step_id=step_id, step_id=step_id,
project_id=agent_state.project_id, project_id=agent_state.project_id,
stop_reason=stop_reason,
) )
tool_call_messages = create_letta_messages_from_llm_response( tool_call_messages = create_letta_messages_from_llm_response(

View File

@@ -1,4 +1,4 @@
from typing import List, Optional, Set, Union from typing import List, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -107,25 +107,20 @@ class ToolRulesSolver(BaseModel):
self.tool_call_history.clear() self.tool_call_history.clear()
def get_allowed_tool_names( def get_allowed_tool_names(
self, available_tools: Set[str], error_on_empty: bool = False, last_function_response: Optional[str] = None self, available_tools: set[str], error_on_empty: bool = True, last_function_response: str | None = None
) -> List[str]: ) -> List[str]:
"""Get a list of tool names allowed based on the last tool called.""" """Get a list of tool names allowed based on the last tool called.
The logic is as follows:
1. if there are no previous tool calls and we have InitToolRules, those are the only options for the first tool call
2. else we take the intersection of the Parent/Child/Conditional/MaxSteps as the options
3. Continue/Terminal/RequiredBeforeExit rules are applied in the agent loop flow, not to restrict tools
"""
# TODO: This piece of code here is quite ugly and deserves a refactor # TODO: This piece of code here is quite ugly and deserves a refactor
# TODO: There's some weird logic encoded here:
# TODO: -> This only takes into consideration Init, and a set of Child/Conditional/MaxSteps tool rules
# TODO: -> Init tool rules outputs are treated additively, Child/Conditional/MaxSteps are intersection based
# TODO: -> Tool rules should probably be refactored to take in a set of tool names? # TODO: -> Tool rules should probably be refactored to take in a set of tool names?
# If no tool has been called yet, return InitToolRules additively if not self.tool_call_history and self.init_tool_rules:
if not self.tool_call_history: return [rule.tool_name for rule in self.init_tool_rules]
if self.init_tool_rules:
# If there are init tool rules, only return those defined in the init tool rules
return [rule.tool_name for rule in self.init_tool_rules]
else:
# Otherwise, return all tools besides those constrained by parent tool rules
available_tools = available_tools - set.union(set(), *(set(rule.children) for rule in self.parent_tool_rules))
return list(available_tools)
else: else:
# Collect valid tools from all child-based rules
valid_tool_sets = [] valid_tool_sets = []
for rule in self.child_based_tool_rules + self.parent_tool_rules: for rule in self.child_based_tool_rules + self.parent_tool_rules:
tools = rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response) tools = rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response)
@@ -151,11 +146,11 @@ class ToolRulesSolver(BaseModel):
"""Check if the tool is defined as a continue tool in the tool rules.""" """Check if the tool is defined as a continue tool in the tool rules."""
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules) return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
def has_required_tools_been_called(self, available_tools: Set[str]) -> bool: def has_required_tools_been_called(self, available_tools: set[str]) -> bool:
"""Check if all required-before-exit tools have been called.""" """Check if all required-before-exit tools have been called."""
return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0 return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0
def get_uncalled_required_tools(self, available_tools: Set[str]) -> List[str]: def get_uncalled_required_tools(self, available_tools: set[str]) -> List[str]:
"""Get the list of required-before-exit tools that have not been called yet.""" """Get the list of required-before-exit tools that have not been called yet."""
if not self.required_before_exit_tool_rules: if not self.required_before_exit_tool_rules:
return [] # No required tools means no uncalled tools return [] # No required tools means no uncalled tools

View File

@@ -49,6 +49,9 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
nullable=True, nullable=True,
doc="The id of the LLMBatchItem that this message is associated with", doc="The id of the LLMBatchItem that this message is associated with",
) )
is_err: Mapped[Optional[bool]] = mapped_column(
nullable=True, doc="Whether this message is part of an error step. Used only for debugging purposes."
)
# Monotonically increasing sequence for efficient/correct listing # Monotonically increasing sequence for efficient/correct listing
sequence_id: Mapped[int] = mapped_column( sequence_id: Mapped[int] = mapped_column(

View File

@@ -5,6 +5,7 @@ from sqlalchemy import JSON, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.letta_stop_reason import StopReasonType
from letta.schemas.step import Step as PydanticStep from letta.schemas.step import Step as PydanticStep
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -45,6 +46,7 @@ class Step(SqlalchemyBase):
prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt") prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt")
total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent") total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent")
completion_tokens_details: Mapped[Optional[Dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.") completion_tokens_details: Mapped[Optional[Dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
stop_reason: Mapped[Optional[StopReasonType]] = mapped_column(None, nullable=True, doc="The stop reason associated with this step.")
tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.") tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.")
tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.") tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.")
trace_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The trace id of the agent step.") trace_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The trace id of the agent step.")

View File

@@ -40,15 +40,18 @@ class LettaMessage(BaseModel):
message_type (MessageType): The type of the message message_type (MessageType): The type of the message
otid (Optional[str]): The offline threading id associated with this message otid (Optional[str]): The offline threading id associated with this message
sender_id (Optional[str]): The id of the sender of the message, can be an identity id or agent id sender_id (Optional[str]): The id of the sender of the message, can be an identity id or agent id
step_id (Optional[str]): The step id associated with the message
is_err (Optional[bool]): Whether the message is an errored message or not. Used for debugging purposes only.
""" """
id: str id: str
date: datetime date: datetime
name: Optional[str] = None name: str | None = None
message_type: MessageType = Field(..., description="The type of the message.") message_type: MessageType = Field(..., description="The type of the message.")
otid: Optional[str] = None otid: str | None = None
sender_id: Optional[str] = None sender_id: str | None = None
step_id: Optional[str] = None step_id: str | None = None
is_err: bool | None = None
@field_serializer("date") @field_serializer("date")
def serialize_datetime(self, dt: datetime, _info): def serialize_datetime(self, dt: datetime, _info):
@@ -60,6 +63,14 @@ class LettaMessage(BaseModel):
dt = dt.replace(tzinfo=timezone.utc) dt = dt.replace(tzinfo=timezone.utc)
return dt.isoformat(timespec="seconds") return dt.isoformat(timespec="seconds")
@field_serializer("is_err", when_used="unless-none")
def serialize_is_err(self, value: bool | None, _info):
"""
Only serialize is_err field when it's True (for debugging purposes).
When is_err is None or False, this field will be excluded from the JSON output.
"""
return value if value is True else None
class SystemMessage(LettaMessage): class SystemMessage(LettaMessage):
""" """

View File

@@ -172,6 +172,9 @@ class Message(BaseMessage):
group_id: Optional[str] = Field(default=None, description="The multi-agent group that the message was sent in") group_id: Optional[str] = Field(default=None, description="The multi-agent group that the message was sent in")
sender_id: Optional[str] = Field(default=None, description="The id of the sender of the message, can be an identity id or agent id") sender_id: Optional[str] = Field(default=None, description="The id of the sender of the message, can be an identity id or agent id")
batch_item_id: Optional[str] = Field(default=None, description="The id of the LLMBatchItem that this message is associated with") batch_item_id: Optional[str] = Field(default=None, description="The id of the LLMBatchItem that this message is associated with")
is_err: Optional[bool] = Field(
default=None, description="Whether this message is part of an error step. Used only for debugging purposes."
)
# This overrides the optional base orm schema, created_at MUST exist on all messages objects # This overrides the optional base orm schema, created_at MUST exist on all messages objects
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.") created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
@@ -191,6 +194,7 @@ class Message(BaseMessage):
if not is_utc_datetime(self.created_at): if not is_utc_datetime(self.created_at):
self.created_at = self.created_at.replace(tzinfo=timezone.utc) self.created_at = self.created_at.replace(tzinfo=timezone.utc)
json_message["created_at"] = self.created_at.isoformat() json_message["created_at"] = self.created_at.isoformat()
json_message.pop("is_err", None) # make sure we don't include this debugging information
return json_message return json_message
@staticmethod @staticmethod
@@ -204,6 +208,7 @@ class Message(BaseMessage):
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL, assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG, assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
reverse: bool = True, reverse: bool = True,
include_err: Optional[bool] = None,
) -> List[LettaMessage]: ) -> List[LettaMessage]:
if use_assistant_message: if use_assistant_message:
message_ids_to_remove = [] message_ids_to_remove = []
@@ -234,6 +239,7 @@ class Message(BaseMessage):
assistant_message_tool_name=assistant_message_tool_name, assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg, assistant_message_tool_kwarg=assistant_message_tool_kwarg,
reverse=reverse, reverse=reverse,
include_err=include_err,
) )
] ]
@@ -243,6 +249,7 @@ class Message(BaseMessage):
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL, assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG, assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
reverse: bool = True, reverse: bool = True,
include_err: Optional[bool] = None,
) -> List[LettaMessage]: ) -> List[LettaMessage]:
"""Convert message object (in DB format) to the style used by the original Letta API""" """Convert message object (in DB format) to the style used by the original Letta API"""
messages = [] messages = []
@@ -682,14 +689,13 @@ class Message(BaseMessage):
# since the only "parts" we have are for supporting various COT # since the only "parts" we have are for supporting various COT
if self.role == "system": if self.role == "system":
assert all([v is not None for v in [self.role]]), vars(self)
openai_message = { openai_message = {
"content": text_content, "content": text_content,
"role": "developer" if use_developer_message else self.role, "role": "developer" if use_developer_message else self.role,
} }
elif self.role == "user": elif self.role == "user":
assert all([v is not None for v in [text_content, self.role]]), vars(self) assert text_content is not None, vars(self)
openai_message = { openai_message = {
"content": text_content, "content": text_content,
"role": self.role, "role": self.role,
@@ -720,7 +726,7 @@ class Message(BaseMessage):
tool_call_dict["id"] = tool_call_dict["id"][:max_tool_id_length] tool_call_dict["id"] = tool_call_dict["id"][:max_tool_id_length]
elif self.role == "tool": elif self.role == "tool":
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self) assert self.tool_call_id is not None, vars(self)
openai_message = { openai_message = {
"content": text_content, "content": text_content,
"role": self.role, "role": self.role,
@@ -776,7 +782,7 @@ class Message(BaseMessage):
if self.role == "system": if self.role == "system":
# NOTE: this is not for system instructions, but instead system "events" # NOTE: this is not for system instructions, but instead system "events"
assert all([v is not None for v in [text_content, self.role]]), vars(self) assert text_content is not None, vars(self)
# Two options here, we would use system.package_system_message, # Two options here, we would use system.package_system_message,
# or use a more Anthropic-specific packaging ie xml tags # or use a more Anthropic-specific packaging ie xml tags
user_system_event = add_xml_tag(string=f"SYSTEM ALERT: {text_content}", xml_tag="event") user_system_event = add_xml_tag(string=f"SYSTEM ALERT: {text_content}", xml_tag="event")
@@ -875,7 +881,7 @@ class Message(BaseMessage):
elif self.role == "tool": elif self.role == "tool":
# NOTE: Anthropic uses role "user" for "tool" responses # NOTE: Anthropic uses role "user" for "tool" responses
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self) assert self.tool_call_id is not None, vars(self)
anthropic_message = { anthropic_message = {
"role": "user", # NOTE: diff "role": "user", # NOTE: diff
"content": [ "content": [
@@ -988,7 +994,7 @@ class Message(BaseMessage):
elif self.role == "tool": elif self.role == "tool":
# NOTE: Significantly different tool calling format, more similar to function calling format # NOTE: Significantly different tool calling format, more similar to function calling format
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self) assert self.tool_call_id is not None, vars(self)
if self.name is None: if self.name is None:
warnings.warn(f"Couldn't find function name on tool call, defaulting to tool ID instead.") warnings.warn(f"Couldn't find function name on tool call, defaulting to tool ID instead.")

View File

@@ -1,8 +1,10 @@
from enum import Enum, auto
from typing import Dict, List, Literal, Optional from typing import Dict, List, Literal, Optional
from pydantic import Field from pydantic import Field
from letta.schemas.letta_base import LettaBase from letta.schemas.letta_base import LettaBase
from letta.schemas.letta_stop_reason import StopReasonType
from letta.schemas.message import Message from letta.schemas.message import Message
@@ -28,6 +30,7 @@ class Step(StepBase):
prompt_tokens: Optional[int] = Field(None, description="The number of tokens in the prompt during this step.") prompt_tokens: Optional[int] = Field(None, description="The number of tokens in the prompt during this step.")
total_tokens: Optional[int] = Field(None, description="The total number of tokens processed by the agent during this step.") total_tokens: Optional[int] = Field(None, description="The total number of tokens processed by the agent during this step.")
completion_tokens_details: Optional[Dict] = Field(None, description="Metadata for the agent.") completion_tokens_details: Optional[Dict] = Field(None, description="Metadata for the agent.")
stop_reason: Optional[StopReasonType] = Field(None, description="The stop reason associated with the step.")
tags: List[str] = Field([], description="Metadata tags.") tags: List[str] = Field([], description="Metadata tags.")
tid: Optional[str] = Field(None, description="The unique identifier of the transaction that processed this step.") tid: Optional[str] = Field(None, description="The unique identifier of the transaction that processed this step.")
trace_id: Optional[str] = Field(None, description="The trace id of the agent step.") trace_id: Optional[str] = Field(None, description="The trace id of the agent step.")
@@ -36,3 +39,12 @@ class Step(StepBase):
None, description="The feedback for this step. Must be either 'positive' or 'negative'." None, description="The feedback for this step. Must be either 'positive' or 'negative'."
) )
project_id: Optional[str] = Field(None, description="The project that the agent that executed this step belongs to (cloud only).") project_id: Optional[str] = Field(None, description="The project that the agent that executed this step belongs to (cloud only).")
class StepProgression(int, Enum):
START = auto()
STREAM_RECEIVED = auto()
RESPONSE_RECEIVED = auto()
STEP_LOGGED = auto()
LOGGED_TRACE = auto()
FINISHED = auto()

View File

@@ -2,8 +2,10 @@ import importlib.util
import json import json
import logging import logging
import os import os
import platform
import sys import sys
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@@ -34,32 +36,25 @@ from letta.server.db import db_registry
from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right? from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right?
from letta.server.rest_api.interface import StreamingServerInterface from letta.server.rest_api.interface import StreamingServerInterface
from letta.server.rest_api.routers.openai.chat_completions.chat_completions import router as openai_chat_completions_router from letta.server.rest_api.routers.openai.chat_completions.chat_completions import router as openai_chat_completions_router
# from letta.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM
from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes
from letta.server.rest_api.routers.v1.organizations import router as organizations_router from letta.server.rest_api.routers.v1.organizations import router as organizations_router
from letta.server.rest_api.routers.v1.users import router as users_router # TODO: decide on admin from letta.server.rest_api.routers.v1.users import router as users_router # TODO: decide on admin
from letta.server.rest_api.static_files import mount_static_files from letta.server.rest_api.static_files import mount_static_files
from letta.server.rest_api.utils import SENTRY_ENABLED
from letta.server.server import SyncServer from letta.server.server import SyncServer
from letta.settings import settings from letta.settings import settings
# TODO(ethan) if SENTRY_ENABLED:
import sentry_sdk
IS_WINDOWS = platform.system() == "Windows"
# NOTE(charles): @ethan I had to add this to get the global as the bottom to work # NOTE(charles): @ethan I had to add this to get the global as the bottom to work
interface: StreamingServerInterface = StreamingServerInterface interface: type = StreamingServerInterface
server = SyncServer(default_interface_factory=lambda: interface()) server = SyncServer(default_interface_factory=lambda: interface())
logger = get_logger(__name__) logger = get_logger(__name__)
import logging
import platform
from fastapi import FastAPI
is_windows = platform.system() == "Windows"
log = logging.getLogger("uvicorn")
def generate_openapi_schema(app: FastAPI): def generate_openapi_schema(app: FastAPI):
# Update the OpenAPI schema # Update the OpenAPI schema
if not app.openapi_schema: if not app.openapi_schema:
@@ -177,9 +172,7 @@ def create_application() -> "FastAPI":
# server = SyncServer(default_interface_factory=lambda: interface()) # server = SyncServer(default_interface_factory=lambda: interface())
print(f"\n[[ Letta server // v{letta_version} ]]") print(f"\n[[ Letta server // v{letta_version} ]]")
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""): if SENTRY_ENABLED:
import sentry_sdk
sentry_sdk.init( sentry_sdk.init(
dsn=os.getenv("SENTRY_DSN"), dsn=os.getenv("SENTRY_DSN"),
traces_sample_rate=1.0, traces_sample_rate=1.0,
@@ -187,6 +180,7 @@ def create_application() -> "FastAPI":
"continuous_profiling_auto_start": True, "continuous_profiling_auto_start": True,
}, },
) )
logger.info("Sentry enabled.")
debug_mode = "--debug" in sys.argv debug_mode = "--debug" in sys.argv
app = FastAPI( app = FastAPI(
@@ -199,31 +193,13 @@ def create_application() -> "FastAPI":
lifespan=lifespan, lifespan=lifespan,
) )
@app.exception_handler(IncompatibleAgentType) # === Exception Handlers ===
async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType): # TODO (cliandy): move to separate file
return JSONResponse(
status_code=400,
content={
"detail": str(exc),
"expected_type": exc.expected_type,
"actual_type": exc.actual_type,
},
)
@app.exception_handler(Exception) @app.exception_handler(Exception)
async def generic_error_handler(request: Request, exc: Exception): async def generic_error_handler(request: Request, exc: Exception):
# Log the actual error for debugging logger.error(f"Unhandled error: {str(exc)}", exc_info=True)
log.error(f"Unhandled error: {str(exc)}", exc_info=True) if SENTRY_ENABLED:
print(f"Unhandled error: {str(exc)}")
import traceback
# Print the stack trace
print(f"Stack trace: {traceback.format_exc()}")
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
import sentry_sdk
sentry_sdk.capture_exception(exc) sentry_sdk.capture_exception(exc)
return JSONResponse( return JSONResponse(
@@ -235,62 +211,70 @@ def create_application() -> "FastAPI":
}, },
) )
@app.exception_handler(NoResultFound) async def error_handler_with_code(request: Request, exc: Exception, code: int, detail: str | None = None):
async def no_result_found_handler(request: Request, exc: NoResultFound): logger.error(f"{type(exc).__name__}", exc_info=exc)
logger.error(f"NoResultFound: {exc}") if SENTRY_ENABLED:
sentry_sdk.capture_exception(exc)
if not detail:
detail = str(exc)
return JSONResponse( return JSONResponse(
status_code=404, status_code=code,
content={"detail": str(exc)}, content={"detail": detail},
) )
@app.exception_handler(ForeignKeyConstraintViolationError) _error_handler_400 = partial(error_handler_with_code, code=400)
async def foreign_key_constraint_handler(request: Request, exc: ForeignKeyConstraintViolationError): _error_handler_404 = partial(error_handler_with_code, code=404)
logger.error(f"ForeignKeyConstraintViolationError: {exc}") _error_handler_404_agent = partial(_error_handler_404, detail="Agent not found")
_error_handler_404_user = partial(_error_handler_404, detail="User not found")
_error_handler_409 = partial(error_handler_with_code, code=409)
app.add_exception_handler(ValueError, _error_handler_400)
app.add_exception_handler(NoResultFound, _error_handler_404)
app.add_exception_handler(LettaAgentNotFoundError, _error_handler_404_agent)
app.add_exception_handler(LettaUserNotFoundError, _error_handler_404_user)
app.add_exception_handler(ForeignKeyConstraintViolationError, _error_handler_409)
app.add_exception_handler(UniqueConstraintViolationError, _error_handler_409)
@app.exception_handler(IncompatibleAgentType)
async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType):
logger.error("Incompatible agent types. Expected: %s, Actual: %s", exc.expected_type, exc.actual_type)
if SENTRY_ENABLED:
sentry_sdk.capture_exception(exc)
return JSONResponse( return JSONResponse(
status_code=409, status_code=400,
content={"detail": str(exc)}, content={
) "detail": str(exc),
"expected_type": exc.expected_type,
@app.exception_handler(UniqueConstraintViolationError) "actual_type": exc.actual_type,
async def unique_key_constraint_handler(request: Request, exc: UniqueConstraintViolationError): },
logger.error(f"UniqueConstraintViolationError: {exc}")
return JSONResponse(
status_code=409,
content={"detail": str(exc)},
) )
@app.exception_handler(DatabaseTimeoutError) @app.exception_handler(DatabaseTimeoutError)
async def database_timeout_error_handler(request: Request, exc: DatabaseTimeoutError): async def database_timeout_error_handler(request: Request, exc: DatabaseTimeoutError):
logger.error(f"Timeout occurred: {exc}. Original exception: {exc.original_exception}") logger.error(f"Timeout occurred: {exc}. Original exception: {exc.original_exception}")
if SENTRY_ENABLED:
sentry_sdk.capture_exception(exc)
return JSONResponse( return JSONResponse(
status_code=503, status_code=503,
content={"detail": "The database is temporarily unavailable. Please try again later."}, content={"detail": "The database is temporarily unavailable. Please try again later."},
) )
@app.exception_handler(ValueError)
async def value_error_handler(request: Request, exc: ValueError):
return JSONResponse(status_code=400, content={"detail": str(exc)})
@app.exception_handler(LettaAgentNotFoundError)
async def agent_not_found_handler(request: Request, exc: LettaAgentNotFoundError):
return JSONResponse(status_code=404, content={"detail": "Agent not found"})
@app.exception_handler(LettaUserNotFoundError)
async def user_not_found_handler(request: Request, exc: LettaUserNotFoundError):
return JSONResponse(status_code=404, content={"detail": "User not found"})
@app.exception_handler(BedrockPermissionError) @app.exception_handler(BedrockPermissionError)
async def bedrock_permission_error_handler(request, exc: BedrockPermissionError): async def bedrock_permission_error_handler(request, exc: BedrockPermissionError):
logger.error(f"Bedrock permission denied.")
if SENTRY_ENABLED:
sentry_sdk.capture_exception(exc)
return JSONResponse( return JSONResponse(
status_code=403, status_code=403,
content={ content={
"error": { "error": {
"type": "bedrock_permission_denied", "type": "bedrock_permission_denied",
"message": "Unable to access the required AI model. Please check your Bedrock permissions or contact support.", "message": "Unable to access the required AI model. Please check your Bedrock permissions or contact support.",
"details": {"model_arn": exc.model_arn, "reason": str(exc)}, "detail": {str(exc)},
} }
}, },
) )
@@ -301,6 +285,9 @@ def create_application() -> "FastAPI":
print(f"▶ Using secure mode with password: {random_password}") print(f"▶ Using secure mode with password: {random_password}")
app.add_middleware(CheckPasswordMiddleware) app.add_middleware(CheckPasswordMiddleware)
# Add reverse proxy middleware to handle X-Forwarded-* headers
# app.add_middleware(ReverseProxyMiddleware, base_path=settings.server_base_path)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=settings.cors_origins, allow_origins=settings.cors_origins,
@@ -442,7 +429,7 @@ def start_server(
) )
else: else:
if is_windows: if IS_WINDOWS:
# Windows doesn't those the fancy unicode characters # Windows doesn't those the fancy unicode characters
print(f"Server running at: http://{host or 'localhost'}:{port or REST_DEFAULT_PORT}") print(f"Server running at: http://{host or 'localhost'}:{port or REST_DEFAULT_PORT}")
print(f"View using ADE at: https://app.letta.com/development-servers/local/dashboard\n") print(f"View using ADE at: https://app.letta.com/development-servers/local/dashboard\n")

View File

@@ -636,6 +636,9 @@ async def list_messages(
use_assistant_message: bool = Query(True, description="Whether to use assistant messages"), use_assistant_message: bool = Query(True, description="Whether to use assistant messages"),
assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."), assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."),
assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."), assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."),
include_err: bool | None = Query(
None, description="Whether to include error messages and error statuses. For debugging purposes only."
),
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
): ):
""" """
@@ -654,6 +657,7 @@ async def list_messages(
use_assistant_message=use_assistant_message, use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name, assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg, assistant_message_tool_kwarg=assistant_message_tool_kwarg,
include_err=include_err,
actor=actor, actor=actor,
) )
@@ -1156,7 +1160,7 @@ async def list_agent_groups(
): ):
"""Lists the groups for an agent""" """Lists the groups for an agent"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
print("in list agents with manager_type", manager_type) logger.info("in list agents with manager_type", manager_type)
return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor) return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor)

View File

@@ -12,6 +12,7 @@ from starlette.types import Send
from letta.log import get_logger from letta.log import get_logger
from letta.schemas.enums import JobStatus from letta.schemas.enums import JobStatus
from letta.schemas.user import User from letta.schemas.user import User
from letta.server.rest_api.utils import capture_sentry_exception
from letta.services.job_manager import JobManager from letta.services.job_manager import JobManager
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -92,6 +93,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
more_body = True more_body = True
try: try:
first_chunk = await self.body_iterator.__anext__() first_chunk = await self.body_iterator.__anext__()
logger.debug("stream_response first chunk:", first_chunk)
if isinstance(first_chunk, tuple): if isinstance(first_chunk, tuple):
first_chunk_content, self.status_code = first_chunk first_chunk_content, self.status_code = first_chunk
else: else:
@@ -130,7 +132,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
"more_body": more_body, "more_body": more_body,
} }
) )
return raise Exception(f"An exception occurred mid-stream with status code {status_code}", detail={"content": content})
else: else:
content = chunk content = chunk
@@ -146,8 +148,8 @@ class StreamingResponseWithStatusCode(StreamingResponse):
) )
# This should be handled properly upstream? # This should be handled properly upstream?
except asyncio.CancelledError: except asyncio.CancelledError as exc:
logger.info("Stream was cancelled by client or job cancellation") logger.warning("Stream was cancelled by client or job cancellation")
# Handle cancellation gracefully # Handle cancellation gracefully
more_body = False more_body = False
cancellation_resp = {"error": {"message": "Stream cancelled"}} cancellation_resp = {"error": {"message": "Stream cancelled"}}
@@ -160,6 +162,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
"headers": self.raw_headers, "headers": self.raw_headers,
} }
) )
raise
await send( await send(
{ {
"type": "http.response.body", "type": "http.response.body",
@@ -167,13 +170,15 @@ class StreamingResponseWithStatusCode(StreamingResponse):
"more_body": more_body, "more_body": more_body,
} }
) )
capture_sentry_exception(exc)
return return
except Exception: except Exception as exc:
logger.exception("unhandled_streaming_error") logger.exception("Unhandled Streaming Error")
more_body = False more_body = False
error_resp = {"error": {"message": "Internal Server Error"}} error_resp = {"error": {"message": "Internal Server Error"}}
error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset) error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset)
logger.debug("response_started:", self.response_started)
if not self.response_started: if not self.response_started:
await send( await send(
{ {
@@ -182,6 +187,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
"headers": self.raw_headers, "headers": self.raw_headers,
} }
) )
raise
await send( await send(
{ {
"type": "http.response.body", "type": "http.response.body",
@@ -189,5 +195,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
"more_body": more_body, "more_body": more_body,
} }
) )
capture_sentry_exception(exc)
return
if more_body: if more_body:
await send({"type": "http.response.body", "body": b"", "more_body": False}) await send({"type": "http.response.body", "body": b"", "more_body": False})

View File

@@ -2,7 +2,6 @@ import asyncio
import json import json
import os import os
import uuid import uuid
import warnings
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Iterable, List, Optional, Union, cast from typing import TYPE_CHECKING, AsyncGenerator, Dict, Iterable, List, Optional, Union, cast
@@ -34,12 +33,15 @@ from letta.schemas.message import Message, MessageCreate, ToolReturn
from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.usage import LettaUsageStatistics from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User from letta.schemas.user import User
from letta.server.rest_api.interface import StreamingServerInterface
from letta.system import get_heartbeat, package_function_response from letta.system import get_heartbeat, package_function_response
if TYPE_CHECKING: if TYPE_CHECKING:
from letta.server.server import SyncServer from letta.server.server import SyncServer
SENTRY_ENABLED = bool(os.getenv("SENTRY_DSN"))
if SENTRY_ENABLED:
import sentry_sdk
SSE_PREFIX = "data: " SSE_PREFIX = "data: "
SSE_SUFFIX = "\n\n" SSE_SUFFIX = "\n\n"
@@ -157,21 +159,9 @@ def get_user_id(user_id: Optional[str] = Header(None, alias="user_id")) -> Optio
return user_id return user_id
def get_current_interface() -> StreamingServerInterface: def capture_sentry_exception(e: BaseException):
return StreamingServerInterface """This will capture the exception in sentry, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response"""
if SENTRY_ENABLED:
def log_error_to_sentry(e):
import traceback
traceback.print_exc()
warnings.warn(f"SSE stream generator failed: {e}")
# Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response
# Print the stack trace
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
import sentry_sdk
sentry_sdk.capture_exception(e) sentry_sdk.capture_exception(e)

View File

@@ -105,6 +105,7 @@ from letta.services.tool_executor.tool_execution_manager import ToolExecutionMan
from letta.services.tool_manager import ToolManager from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager from letta.services.user_manager import UserManager
from letta.settings import model_settings, settings, tool_settings from letta.settings import model_settings, settings, tool_settings
from letta.streaming_interface import AgentChunkStreamingInterface
from letta.utils import get_friendly_error_msg, get_persona_text, make_key from letta.utils import get_friendly_error_msg, get_persona_text, make_key
config = LettaConfig.load() config = LettaConfig.load()
@@ -176,7 +177,7 @@ class SyncServer(Server):
self, self,
chaining: bool = True, chaining: bool = True,
max_chaining_steps: Optional[int] = 100, max_chaining_steps: Optional[int] = 100,
default_interface_factory: Callable[[], AgentInterface] = lambda: CLIInterface(), default_interface_factory: Callable[[], AgentChunkStreamingInterface] = lambda: CLIInterface(),
init_with_default_org_and_user: bool = True, init_with_default_org_and_user: bool = True,
# default_interface: AgentInterface = CLIInterface(), # default_interface: AgentInterface = CLIInterface(),
# default_persistence_manager_cls: PersistenceManager = LocalStateManager, # default_persistence_manager_cls: PersistenceManager = LocalStateManager,
@@ -1244,6 +1245,7 @@ class SyncServer(Server):
use_assistant_message: bool = True, use_assistant_message: bool = True,
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL, assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG, assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
include_err: Optional[bool] = None,
) -> Union[List[Message], List[LettaMessage]]: ) -> Union[List[Message], List[LettaMessage]]:
records = await self.message_manager.list_messages_for_agent_async( records = await self.message_manager.list_messages_for_agent_async(
agent_id=agent_id, agent_id=agent_id,
@@ -1253,6 +1255,7 @@ class SyncServer(Server):
limit=limit, limit=limit,
ascending=not reverse, ascending=not reverse,
group_id=group_id, group_id=group_id,
include_err=include_err,
) )
if not return_message_object: if not return_message_object:
@@ -1262,6 +1265,7 @@ class SyncServer(Server):
assistant_message_tool_name=assistant_message_tool_name, assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg, assistant_message_tool_kwarg=assistant_message_tool_kwarg,
reverse=reverse, reverse=reverse,
include_err=include_err,
) )
if reverse: if reverse:

View File

@@ -520,6 +520,7 @@ class MessageManager:
limit: Optional[int] = 50, limit: Optional[int] = 50,
ascending: bool = True, ascending: bool = True,
group_id: Optional[str] = None, group_id: Optional[str] = None,
include_err: Optional[bool] = None,
) -> List[PydanticMessage]: ) -> List[PydanticMessage]:
""" """
Most performant query to list messages for an agent by directly querying the Message table. Most performant query to list messages for an agent by directly querying the Message table.
@@ -539,6 +540,7 @@ class MessageManager:
limit: Maximum number of messages to return. limit: Maximum number of messages to return.
ascending: If True, sort by sequence_id ascending; if False, sort descending. ascending: If True, sort by sequence_id ascending; if False, sort descending.
group_id: Optional group ID to filter messages by group_id. group_id: Optional group ID to filter messages by group_id.
include_err: Optional boolean to include errors and error statuses. Used for debugging only.
Returns: Returns:
List[PydanticMessage]: A list of messages (converted via .to_pydantic()). List[PydanticMessage]: A list of messages (converted via .to_pydantic()).
@@ -558,6 +560,9 @@ class MessageManager:
if group_id: if group_id:
query = query.where(MessageModel.group_id == group_id) query = query.where(MessageModel.group_id == group_id)
if not include_err:
query = query.where((MessageModel.is_err == False) | (MessageModel.is_err.is_(None)))
# If query_text is provided, filter messages using database-specific JSON search. # If query_text is provided, filter messages using database-specific JSON search.
if query_text: if query_text:
if settings.letta_pg_uri_no_default: if settings.letta_pg_uri_no_default:

View File

@@ -12,6 +12,7 @@ from letta.orm.job import Job as JobModel
from letta.orm.sqlalchemy_base import AccessType from letta.orm.sqlalchemy_base import AccessType
from letta.orm.step import Step as StepModel from letta.orm.step import Step as StepModel
from letta.otel.tracing import get_trace_id, trace_method from letta.otel.tracing import get_trace_id, trace_method
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.step import Step as PydanticStep from letta.schemas.step import Step as PydanticStep
from letta.schemas.user import User as PydanticUser from letta.schemas.user import User as PydanticUser
@@ -131,6 +132,7 @@ class StepManager:
job_id: Optional[str] = None, job_id: Optional[str] = None,
step_id: Optional[str] = None, step_id: Optional[str] = None,
project_id: Optional[str] = None, project_id: Optional[str] = None,
stop_reason: Optional[LettaStopReason] = None,
) -> PydanticStep: ) -> PydanticStep:
step_data = { step_data = {
"origin": None, "origin": None,
@@ -153,6 +155,8 @@ class StepManager:
} }
if step_id: if step_id:
step_data["id"] = step_id step_data["id"] = step_id
if stop_reason:
step_data["stop_reason"] = stop_reason.stop_reason
async with db_registry.async_session() as session: async with db_registry.async_session() as session:
if job_id: if job_id:
await self._verify_job_access_async(session, job_id, actor, access=["write"]) await self._verify_job_access_async(session, job_id, actor, access=["write"])
@@ -207,6 +211,33 @@ class StepManager:
await session.commit() await session.commit()
return step.to_pydantic() return step.to_pydantic()
@enforce_types
@trace_method
async def update_step_stop_reason(self, actor: PydanticUser, step_id: str, stop_reason: StopReasonType) -> PydanticStep:
"""Update the stop reason for a step.
Args:
actor: The user making the request
step_id: The ID of the step to update
stop_reason: The stop reason to set
Returns:
The updated step
Raises:
NoResultFound: If the step does not exist
"""
async with db_registry.async_session() as session:
step = await session.get(StepModel, step_id)
if not step:
raise NoResultFound(f"Step with id {step_id} does not exist")
if step.organization_id != actor.organization_id:
raise Exception("Unauthorized")
step.stop_reason = stop_reason
await session.commit()
return step
def _verify_job_access( def _verify_job_access(
self, self,
session: Session, session: Session,
@@ -309,5 +340,6 @@ class NoopStepManager(StepManager):
job_id: Optional[str] = None, job_id: Optional[str] = None,
step_id: Optional[str] = None, step_id: Optional[str] = None,
project_id: Optional[str] = None, project_id: Optional[str] = None,
stop_reason: Optional[LettaStopReason] = None,
) -> PydanticStep: ) -> PydanticStep:
return return

View File

@@ -220,13 +220,15 @@ class Settings(BaseSettings):
multi_agent_concurrent_sends: int = 50 multi_agent_concurrent_sends: int = 50
# telemetry logging # telemetry logging
otel_exporter_otlp_endpoint: Optional[str] = None # otel default: "http://localhost:4317" otel_exporter_otlp_endpoint: str | None = None # otel default: "http://localhost:4317"
otel_preferred_temporality: Optional[int] = Field( otel_preferred_temporality: int | None = Field(
default=1, ge=0, le=2, description="Exported metric temporality. {0: UNSPECIFIED, 1: DELTA, 2: CUMULATIVE}" default=1, ge=0, le=2, description="Exported metric temporality. {0: UNSPECIFIED, 1: DELTA, 2: CUMULATIVE}"
) )
disable_tracing: bool = Field(default=False, description="Disable OTEL Tracing") disable_tracing: bool = Field(default=False, description="Disable OTEL Tracing")
llm_api_logging: bool = Field(default=True, description="Enable LLM API logging at each step") llm_api_logging: bool = Field(default=True, description="Enable LLM API logging at each step")
track_last_agent_run: bool = Field(default=False, description="Update last agent run metrics") track_last_agent_run: bool = Field(default=False, description="Update last agent run metrics")
track_errored_messages: bool = Field(default=True, description="Enable tracking for errored messages")
track_stop_reason: bool = Field(default=True, description="Enable tracking stop reason on steps.")
# uvicorn settings # uvicorn settings
uvicorn_workers: int = 1 uvicorn_workers: int = 1

View File

@@ -752,14 +752,15 @@ def test_step_stream_agent_loop_error(
""" """
last_message = client.agents.messages.list(agent_id=agent_state_no_tools.id, limit=1) last_message = client.agents.messages.list(agent_id=agent_state_no_tools.id, limit=1)
agent_state_no_tools = client.agents.modify(agent_id=agent_state_no_tools.id, llm_config=llm_config) agent_state_no_tools = client.agents.modify(agent_id=agent_state_no_tools.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state_no_tools.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
response = client.agents.messages.create_stream( for chunk in response:
agent_id=agent_state_no_tools.id, print(chunk)
messages=USER_MESSAGE_FORCE_REPLY, print("error info:", exc_info)
)
list(response)
assert type(exc_info.value) in (ApiError, ValueError) assert type(exc_info.value) in (ApiError, ValueError)
print(exc_info.value)
messages_from_db = client.agents.messages.list(agent_id=agent_state_no_tools.id, after=last_message[0].id) messages_from_db = client.agents.messages.list(agent_id=agent_state_no_tools.id, after=last_message[0].id)
assert len(messages_from_db) == 0 assert len(messages_from_db) == 0

View File

@@ -138,7 +138,9 @@ def test_max_count_per_step_tool_rule():
assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "After first use, should still allow 'start_tool'" assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "After first use, should still allow 'start_tool'"
solver.register_tool_call(START_TOOL) solver.register_tool_call(START_TOOL)
assert solver.get_allowed_tool_names({START_TOOL}) == [], "After reaching max count, 'start_tool' should no longer be allowed" assert (
solver.get_allowed_tool_names({START_TOOL}, error_on_empty=False) == []
), "After reaching max count, 'start_tool' should no longer be allowed"
def test_max_count_per_step_tool_rule_allows_usage_up_to_limit(): def test_max_count_per_step_tool_rule_allows_usage_up_to_limit():
@@ -155,7 +157,7 @@ def test_max_count_per_step_tool_rule_allows_usage_up_to_limit():
assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Should still allow 'start_tool' after 2 uses" assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Should still allow 'start_tool' after 2 uses"
solver.register_tool_call(START_TOOL) solver.register_tool_call(START_TOOL)
assert solver.get_allowed_tool_names({START_TOOL}) == [], "Should no longer allow 'start_tool' after 3 uses" assert solver.get_allowed_tool_names({START_TOOL}, error_on_empty=False) == [], "Should no longer allow 'start_tool' after 3 uses"
def test_max_count_per_step_tool_rule_does_not_affect_other_tools(): def test_max_count_per_step_tool_rule_does_not_affect_other_tools():
@@ -180,7 +182,7 @@ def test_max_count_per_step_tool_rule_resets_on_clear():
solver.register_tool_call(START_TOOL) solver.register_tool_call(START_TOOL)
solver.register_tool_call(START_TOOL) solver.register_tool_call(START_TOOL)
assert solver.get_allowed_tool_names({START_TOOL}) == [], "Should not allow 'start_tool' after reaching limit" assert solver.get_allowed_tool_names({START_TOOL}, error_on_empty=False) == [], "Should not allow 'start_tool' after reaching limit"
solver.clear_tool_history() solver.clear_tool_history()