feat: stop reasons and error messages and sentry fixes

This commit is contained in:
Andy Li
2025-07-18 11:56:20 -07:00
committed by GitHub
parent 904d9ba5a2
commit b7b678db4e
18 changed files with 746 additions and 442 deletions

View File

@@ -0,0 +1,42 @@
"""add stop reasons to steps and message error flag
Revision ID: cce9a6174366
Revises: 2c059cad97cc
Create Date: 2025-07-10 13:56:17.383612
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "cce9a6174366"
down_revision: Union[str, None] = "2c059cad97cc"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("messages", sa.Column("is_err", sa.Boolean(), nullable=True))
# manually added to handle non-table creation enums
stopreasontype = sa.Enum(
"end_turn", "error", "invalid_tool_call", "max_steps", "no_tool_call", "tool_rule", "cancelled", name="stopreasontype"
)
stopreasontype.create(op.get_bind())
op.add_column("steps", sa.Column("stop_reason", stopreasontype, nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("steps", "stop_reason")
op.drop_column("messages", "is_err")
stopreasontype = sa.Enum(name="stopreasontype")
stopreasontype.drop(op.get_bind())
# ### end Alembic commands ###

View File

@@ -43,6 +43,7 @@ from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
from letta.schemas.provider_trace import ProviderTraceCreate
from letta.schemas.step import StepProgression
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
@@ -238,100 +239,164 @@ class LettaAgent(BaseAgent):
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
agent_step_span.set_attributes({"step_id": step_id})
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
await self._build_and_request_from_llm(
current_in_context_messages,
new_in_context_messages,
agent_state,
llm_client,
tool_rules_solver,
agent_step_span,
)
)
in_context_messages = current_in_context_messages + new_in_context_messages
log_event("agent.stream_no_tokens.llm_response.received") # [3^]
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
# update usage
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
MetricRegistry().message_output_tokens.record(
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
)
if not response.choices[0].message.tool_calls:
# TODO: make into a real error
raise ValueError("No tool calls found in response, model must make a tool call")
tool_call = response.choices[0].message.tool_calls[0]
if response.choices[0].message.reasoning_content:
reasoning = [
ReasoningContent(
reasoning=response.choices[0].message.reasoning_content,
is_native=True,
signature=response.choices[0].message.reasoning_content_signature,
step_progression = StepProgression.START
should_continue = False
try:
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
await self._build_and_request_from_llm(
current_in_context_messages,
new_in_context_messages,
agent_state,
llm_client,
tool_rules_solver,
agent_step_span,
)
]
elif response.choices[0].message.omitted_reasoning_content:
reasoning = [OmittedReasoningContent()]
elif response.choices[0].message.content:
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
else:
self.logger.info("No reasoning content found.")
reasoning = None
)
in_context_messages = current_in_context_messages + new_in_context_messages
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
tool_call,
valid_tool_names,
agent_state,
tool_rules_solver,
response.usage,
reasoning_content=reasoning,
step_id=step_id,
initial_messages=initial_messages,
agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1),
)
step_progression = StepProgression.RESPONSE_RECEIVED
log_event("agent.stream_no_tokens.llm_response.received") # [3^]
# TODO (cliandy): handle message contexts with larger refactor and dedupe logic
new_message_idx = len(initial_messages) if initial_messages else 0
self.response_messages.extend(persisted_messages[new_message_idx:])
new_in_context_messages.extend(persisted_messages[new_message_idx:])
initial_messages = None
log_event("agent.stream_no_tokens.llm_response.processed") # [4^]
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
# log step time
now = get_utc_timestamp_ns()
step_ns = now - step_start
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
agent_step_span.end()
# update usage
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
MetricRegistry().message_output_tokens.record(
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
)
# Log LLM Trace
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
if not response.choices[0].message.tool_calls:
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
raise ValueError("No tool calls found in response, model must make a tool call")
tool_call = response.choices[0].message.tool_calls[0]
if response.choices[0].message.reasoning_content:
reasoning = [
ReasoningContent(
reasoning=response.choices[0].message.reasoning_content,
is_native=True,
signature=response.choices[0].message.reasoning_content_signature,
)
]
elif response.choices[0].message.omitted_reasoning_content:
reasoning = [OmittedReasoningContent()]
elif response.choices[0].message.content:
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
else:
self.logger.info("No reasoning content found.")
reasoning = None
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
tool_call,
valid_tool_names,
agent_state,
tool_rules_solver,
response.usage,
reasoning_content=reasoning,
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
initial_messages=initial_messages,
agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1),
)
step_progression = StepProgression.STEP_LOGGED
# stream step
# TODO: improve TTFT
filter_user_messages = [m for m in persisted_messages if m.role != "user"]
letta_messages = Message.to_letta_messages_from_list(
filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
)
# TODO (cliandy): handle message contexts with larger refactor and dedupe logic
new_message_idx = len(initial_messages) if initial_messages else 0
self.response_messages.extend(persisted_messages[new_message_idx:])
new_in_context_messages.extend(persisted_messages[new_message_idx:])
initial_messages = None
log_event("agent.stream_no_tokens.llm_response.processed") # [4^]
for message in letta_messages:
if include_return_message_types is None or message.message_type in include_return_message_types:
yield f"data: {message.model_dump_json()}\n\n"
# log step time
now = get_utc_timestamp_ns()
step_ns = now - step_start
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
agent_step_span.end()
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
# Log LLM Trace
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
step_progression = StepProgression.LOGGED_TRACE
# stream step
# TODO: improve TTFT
filter_user_messages = [m for m in persisted_messages if m.role != "user"]
letta_messages = Message.to_letta_messages_from_list(
filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
)
for message in letta_messages:
if include_return_message_types is None or message.message_type in include_return_message_types:
yield f"data: {message.model_dump_json()}\n\n"
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
step_progression = StepProgression.FINISHED
except Exception as e:
# Handle any unexpected errors during step processing
self.logger.error(f"Error during step processing: {e}")
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
if not stop_reason:
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule):
self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason)
elif stop_reason.stop_reason not in (StopReasonType.no_tool_call, StopReasonType.invalid_tool_call):
raise ValueError(f"Invalid Stop Reason: {stop_reason}")
# Send error stop reason to client and re-raise
yield f"data: {stop_reason.model_dump_json()}\n\n", 500
raise
# Update step if it needs to be updated
finally:
if settings.track_stop_reason:
self.logger.info("Running final update. Step Progression: %s", step_progression)
try:
if step_progression < StepProgression.STEP_LOGGED:
await self.step_manager.log_step_async(
actor=self.actor,
agent_id=agent_state.id,
provider_name=agent_state.llm_config.model_endpoint_type,
provider_category=agent_state.llm_config.provider_category or "base",
model=agent_state.llm_config.model,
model_endpoint=agent_state.llm_config.model_endpoint,
context_window_limit=agent_state.llm_config.context_window,
usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0),
provider_id=None,
job_id=self.current_run_id if self.current_run_id else None,
step_id=step_id,
project_id=agent_state.project_id,
stop_reason=stop_reason,
)
if step_progression <= StepProgression.RESPONSE_RECEIVED:
# TODO (cliandy): persist response if we get it back
if settings.track_errored_messages:
for message in initial_messages:
message.is_err = True
message.step_id = step_id
await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor)
elif step_progression <= StepProgression.LOGGED_TRACE:
if stop_reason is None:
self.logger.error("Error in step after logging step")
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
elif step_progression == StepProgression.FINISHED and not should_continue:
if stop_reason is None:
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
else:
self.logger.error("Invalid StepProgression value")
except Exception as e:
self.logger.error("Failed to update step: %s", e)
if not should_continue:
break
@@ -396,6 +461,16 @@ class LettaAgent(BaseAgent):
stop_reason = None
usage = LettaUsageStatistics()
for i in range(max_steps):
# If dry run, build request data and return it without making LLM call
if dry_run:
request_data, valid_tool_names = await self._create_llm_request_data_async(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
agent_state=agent_state,
tool_rules_solver=tool_rules_solver,
)
return request_data
# Check for job cancellation at the start of each step
if await self._check_run_cancellation():
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
@@ -407,94 +482,148 @@ class LettaAgent(BaseAgent):
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
agent_step_span.set_attributes({"step_id": step_id})
# If dry run, build request data and return it without making LLM call
if dry_run:
request_data, valid_tool_names = await self._create_llm_request_data_async(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
agent_state=agent_state,
tool_rules_solver=tool_rules_solver,
)
return request_data
step_progression = StepProgression.START
should_continue = False
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
await self._build_and_request_from_llm(
current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver, agent_step_span
)
)
in_context_messages = current_in_context_messages + new_in_context_messages
log_event("agent.step.llm_response.received") # [3^]
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
usage.run_ids = [run_id] if run_id else None
MetricRegistry().message_output_tokens.record(
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
)
if not response.choices[0].message.tool_calls:
# TODO: make into a real error
raise ValueError("No tool calls found in response, model must make a tool call")
tool_call = response.choices[0].message.tool_calls[0]
if response.choices[0].message.reasoning_content:
reasoning = [
ReasoningContent(
reasoning=response.choices[0].message.reasoning_content,
is_native=True,
signature=response.choices[0].message.reasoning_content_signature,
try:
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
await self._build_and_request_from_llm(
current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver, agent_step_span
)
]
elif response.choices[0].message.content:
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
elif response.choices[0].message.omitted_reasoning_content:
reasoning = [OmittedReasoningContent()]
else:
self.logger.info("No reasoning content found.")
reasoning = None
)
in_context_messages = current_in_context_messages + new_in_context_messages
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
tool_call,
valid_tool_names,
agent_state,
tool_rules_solver,
response.usage,
reasoning_content=reasoning,
step_id=step_id,
initial_messages=initial_messages,
agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1),
run_id=run_id,
)
new_message_idx = len(initial_messages) if initial_messages else 0
self.response_messages.extend(persisted_messages[new_message_idx:])
new_in_context_messages.extend(persisted_messages[new_message_idx:])
step_progression = StepProgression.RESPONSE_RECEIVED
log_event("agent.step.llm_response.received") # [3^]
initial_messages = None
log_event("agent.step.llm_response.processed") # [4^]
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
# log step time
now = get_utc_timestamp_ns()
step_ns = now - step_start
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
agent_step_span.end()
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
usage.run_ids = [run_id] if run_id else None
MetricRegistry().message_output_tokens.record(
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
)
# Log LLM Trace
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
if not response.choices[0].message.tool_calls:
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
raise ValueError("No tool calls found in response, model must make a tool call")
tool_call = response.choices[0].message.tool_calls[0]
if response.choices[0].message.reasoning_content:
reasoning = [
ReasoningContent(
reasoning=response.choices[0].message.reasoning_content,
is_native=True,
signature=response.choices[0].message.reasoning_content_signature,
)
]
elif response.choices[0].message.content:
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
elif response.choices[0].message.omitted_reasoning_content:
reasoning = [OmittedReasoningContent()]
else:
self.logger.info("No reasoning content found.")
reasoning = None
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
tool_call,
valid_tool_names,
agent_state,
tool_rules_solver,
response.usage,
reasoning_content=reasoning,
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
initial_messages=initial_messages,
agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1),
run_id=run_id,
)
step_progression = StepProgression.STEP_LOGGED
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
new_message_idx = len(initial_messages) if initial_messages else 0
self.response_messages.extend(persisted_messages[new_message_idx:])
new_in_context_messages.extend(persisted_messages[new_message_idx:])
initial_messages = None
log_event("agent.step.llm_response.processed") # [4^]
# log step time
now = get_utc_timestamp_ns()
step_ns = now - step_start
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
agent_step_span.end()
# Log LLM Trace
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
step_progression = StepProgression.LOGGED_TRACE
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
step_progression = StepProgression.FINISHED
except Exception as e:
# Handle any unexpected errors during step processing
self.logger.error(f"Error during step processing: {e}")
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
if not stop_reason:
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule):
self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason)
elif stop_reason.stop_reason not in (StopReasonType.no_tool_call, StopReasonType.invalid_tool_call):
raise ValueError(f"Invalid Stop Reason: {stop_reason}")
raise
# Update step if it needs to be updated
finally:
if settings.track_stop_reason:
self.logger.info("Running final update. Step Progression: %s", step_progression)
try:
if step_progression < StepProgression.STEP_LOGGED:
await self.step_manager.log_step_async(
actor=self.actor,
agent_id=agent_state.id,
provider_name=agent_state.llm_config.model_endpoint_type,
provider_category=agent_state.llm_config.provider_category or "base",
model=agent_state.llm_config.model,
model_endpoint=agent_state.llm_config.model_endpoint,
context_window_limit=agent_state.llm_config.context_window,
usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0),
provider_id=None,
job_id=self.current_run_id if self.current_run_id else None,
step_id=step_id,
project_id=agent_state.project_id,
stop_reason=stop_reason,
)
if step_progression <= StepProgression.RESPONSE_RECEIVED:
# TODO (cliandy): persist response if we get it back
if settings.track_errored_messages:
for message in initial_messages:
message.is_err = True
message.step_id = step_id
await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor)
elif step_progression <= StepProgression.LOGGED_TRACE:
if stop_reason is None:
self.logger.error("Error in step after logging step")
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
elif step_progression == StepProgression.FINISHED and not should_continue:
if stop_reason is None:
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
else:
self.logger.error("Invalid StepProgression value")
except Exception as e:
self.logger.error("Failed to update step: %s", e)
if not should_continue:
break
@@ -576,6 +705,7 @@ class LettaAgent(BaseAgent):
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
for i in range(max_steps):
step_id = generate_step_id()
# Check for job cancellation at the start of each step
if await self._check_run_cancellation():
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
@@ -583,163 +713,230 @@ class LettaAgent(BaseAgent):
yield f"data: {stop_reason.model_dump_json()}\n\n"
break
step_id = generate_step_id()
step_start = get_utc_timestamp_ns()
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
agent_step_span.set_attributes({"step_id": step_id})
(
request_data,
stream,
current_in_context_messages,
new_in_context_messages,
valid_tool_names,
provider_request_start_timestamp_ns,
) = await self._build_and_request_from_llm_streaming(
first_chunk,
agent_step_span,
request_start_timestamp_ns,
current_in_context_messages,
new_in_context_messages,
agent_state,
llm_client,
tool_rules_solver,
)
log_event("agent.stream.llm_response.received") # [3^]
# TODO: THIS IS INCREDIBLY UGLY
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
if agent_state.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]:
interface = AnthropicStreamingInterface(
use_assistant_message=use_assistant_message,
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
)
elif agent_state.llm_config.model_endpoint_type == ProviderType.openai:
interface = OpenAIStreamingInterface(
use_assistant_message=use_assistant_message,
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
)
else:
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
async for chunk in interface.process(
stream,
ttft_span=request_span,
provider_request_start_timestamp_ns=provider_request_start_timestamp_ns,
):
# Measure time to first token
if first_chunk and request_span is not None:
now = get_utc_timestamp_ns()
ttft_ns = now - request_start_timestamp_ns
request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ns_to_ms(ttft_ns)})
metric_attributes = get_ctx_attributes()
metric_attributes["model.name"] = agent_state.llm_config.model
MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes)
first_chunk = False
if include_return_message_types is None or chunk.message_type in include_return_message_types:
# filter down returned data
yield f"data: {chunk.model_dump_json()}\n\n"
stream_end_time_ns = get_utc_timestamp_ns()
# update usage
usage.step_count += 1
usage.completion_tokens += interface.output_tokens
usage.prompt_tokens += interface.input_tokens
usage.total_tokens += interface.input_tokens + interface.output_tokens
MetricRegistry().message_output_tokens.record(
interface.output_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
)
# log LLM request time
llm_request_ms = ns_to_ms(stream_end_time_ns - provider_request_start_timestamp_ns)
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ms})
MetricRegistry().llm_execution_time_ms_histogram.record(
llm_request_ms,
dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}),
)
# Process resulting stream content
step_progression = StepProgression.START
should_continue = False
try:
tool_call = interface.get_tool_call_object()
except ValueError as e:
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
yield f"data: {stop_reason.model_dump_json()}\n\n"
raise e
except Exception as e:
stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value)
yield f"data: {stop_reason.model_dump_json()}\n\n"
raise e
reasoning_content = interface.get_reasoning_content()
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
tool_call,
valid_tool_names,
agent_state,
tool_rules_solver,
UsageStatistics(
completion_tokens=interface.output_tokens,
prompt_tokens=interface.input_tokens,
total_tokens=interface.input_tokens + interface.output_tokens,
),
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=interface.letta_message_id,
step_id=step_id,
initial_messages=initial_messages,
agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1),
)
new_message_idx = len(initial_messages) if initial_messages else 0
self.response_messages.extend(persisted_messages[new_message_idx:])
new_in_context_messages.extend(persisted_messages[new_message_idx:])
(
request_data,
stream,
current_in_context_messages,
new_in_context_messages,
valid_tool_names,
provider_request_start_timestamp_ns,
) = await self._build_and_request_from_llm_streaming(
first_chunk,
agent_step_span,
request_start_timestamp_ns,
current_in_context_messages,
new_in_context_messages,
agent_state,
llm_client,
tool_rules_solver,
)
initial_messages = None
step_progression = StepProgression.STREAM_RECEIVED
log_event("agent.stream.llm_response.received") # [3^]
# log total step time
now = get_utc_timestamp_ns()
step_ns = now - step_start
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
agent_step_span.end()
# TODO: THIS IS INCREDIBLY UGLY
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
if agent_state.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]:
interface = AnthropicStreamingInterface(
use_assistant_message=use_assistant_message,
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
)
elif agent_state.llm_config.model_endpoint_type == ProviderType.openai:
interface = OpenAIStreamingInterface(
use_assistant_message=use_assistant_message,
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
)
else:
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
# TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream
# log_event("agent.stream.llm_response.processed") # [4^]
async for chunk in interface.process(
stream,
ttft_span=request_span,
provider_request_start_timestamp_ns=provider_request_start_timestamp_ns,
):
# Measure time to first token
if first_chunk and request_span is not None:
now = get_utc_timestamp_ns()
ttft_ns = now - request_start_timestamp_ns
request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ns_to_ms(ttft_ns)})
metric_attributes = get_ctx_attributes()
metric_attributes["model.name"] = agent_state.llm_config.model
MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes)
first_chunk = False
# Log LLM Trace
# TODO (cliandy): we are piecing together the streamed response here. Content here does not match the actual response schema.
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json={
"content": {
"tool_call": tool_call.model_dump_json(),
"reasoning": [content.model_dump_json() for content in reasoning_content],
},
"id": interface.message_id,
"model": interface.model,
"role": "assistant",
# "stop_reason": "",
# "stop_sequence": None,
"type": "message",
"usage": {"input_tokens": interface.input_tokens, "output_tokens": interface.output_tokens},
},
if include_return_message_types is None or chunk.message_type in include_return_message_types:
# filter down returned data
yield f"data: {chunk.model_dump_json()}\n\n"
stream_end_time_ns = get_utc_timestamp_ns()
# update usage
usage.step_count += 1
usage.completion_tokens += interface.output_tokens
usage.prompt_tokens += interface.input_tokens
usage.total_tokens += interface.input_tokens + interface.output_tokens
MetricRegistry().message_output_tokens.record(
interface.output_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
)
# log LLM request time
llm_request_ms = ns_to_ms(stream_end_time_ns - provider_request_start_timestamp_ns)
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ms})
MetricRegistry().llm_execution_time_ms_histogram.record(
llm_request_ms,
dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}),
)
# Process resulting stream content
try:
tool_call = interface.get_tool_call_object()
except ValueError as e:
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
raise e
except Exception as e:
stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value)
raise e
reasoning_content = interface.get_reasoning_content()
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
tool_call,
valid_tool_names,
agent_state,
tool_rules_solver,
UsageStatistics(
completion_tokens=interface.output_tokens,
prompt_tokens=interface.input_tokens,
total_tokens=interface.input_tokens + interface.output_tokens,
),
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=interface.letta_message_id,
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
initial_messages=initial_messages,
agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1),
)
step_progression = StepProgression.STEP_LOGGED
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
if not (use_assistant_message and tool_return.name == "send_message"):
# Apply message type filtering if specified
if include_return_message_types is None or tool_return.message_type in include_return_message_types:
yield f"data: {tool_return.model_dump_json()}\n\n"
new_message_idx = len(initial_messages) if initial_messages else 0
self.response_messages.extend(persisted_messages[new_message_idx:])
new_in_context_messages.extend(persisted_messages[new_message_idx:])
# TODO (cliandy): consolidate and expand with trace
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
initial_messages = None
# log total step time
now = get_utc_timestamp_ns()
step_ns = now - step_start
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
agent_step_span.end()
# TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream
# log_event("agent.stream.llm_response.processed") # [4^]
# Log LLM Trace
# We are piecing together the streamed response here.
# Content here does not match the actual response schema as streams come in chunks.
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json={
"content": {
"tool_call": tool_call.model_dump_json(),
"reasoning": [content.model_dump_json() for content in reasoning_content],
},
"id": interface.message_id,
"model": interface.model,
"role": "assistant",
# "stop_reason": "",
# "stop_sequence": None,
"type": "message",
"usage": {
"input_tokens": interface.input_tokens,
"output_tokens": interface.output_tokens,
},
},
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
step_progression = StepProgression.LOGGED_TRACE
# yields tool response as this is handled from Letta and not the response from the LLM provider
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
if not (use_assistant_message and tool_return.name == "send_message"):
# Apply message type filtering if specified
if include_return_message_types is None or tool_return.message_type in include_return_message_types:
yield f"data: {tool_return.model_dump_json()}\n\n"
# TODO (cliandy): consolidate and expand with trace
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
step_progression = StepProgression.FINISHED
except Exception as e:
# Handle any unexpected errors during step processing
self.logger.error(f"Error during step processing: {e}")
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
if not stop_reason:
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule):
self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason)
elif stop_reason.stop_reason not in (StopReasonType.no_tool_call, StopReasonType.invalid_tool_call):
raise ValueError(f"Invalid Stop Reason: {stop_reason}")
# Send error stop reason to client and re-raise with expected response code
yield f"data: {stop_reason.model_dump_json()}\n\n", 500
raise
# Update step if it needs to be updated
finally:
if settings.track_stop_reason:
self.logger.info("Running final update. Step Progression: %s", step_progression)
try:
if step_progression < StepProgression.STEP_LOGGED:
await self.step_manager.log_step_async(
actor=self.actor,
agent_id=agent_state.id,
provider_name=agent_state.llm_config.model_endpoint_type,
provider_category=agent_state.llm_config.provider_category or "base",
model=agent_state.llm_config.model,
model_endpoint=agent_state.llm_config.model_endpoint,
context_window_limit=agent_state.llm_config.context_window,
usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0),
provider_id=None,
job_id=self.current_run_id if self.current_run_id else None,
step_id=step_id,
project_id=agent_state.project_id,
stop_reason=stop_reason,
)
if step_progression <= StepProgression.STREAM_RECEIVED:
if first_chunk and settings.track_errored_messages:
for message in initial_messages:
message.is_err = True
message.step_id = step_id
await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor)
elif step_progression <= StepProgression.LOGGED_TRACE:
if stop_reason is None:
self.logger.error("Error in step after logging step")
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
elif step_progression == StepProgression.FINISHED and not should_continue:
if stop_reason is None:
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
else:
self.logger.error("Invalid StepProgression value")
except Exception as e:
self.logger.error("Failed to update step: %s", e)
if not should_continue:
break
# Extend the in context message ids
if not agent_state.message_buffer_autoclear:
await self._rebuild_context_window(
@@ -1106,6 +1303,7 @@ class LettaAgent(BaseAgent):
job_id=run_id if run_id else self.current_run_id,
step_id=step_id,
project_id=agent_state.project_id,
stop_reason=stop_reason,
)
tool_call_messages = create_letta_messages_from_llm_response(

View File

@@ -1,4 +1,4 @@
from typing import List, Optional, Set, Union
from typing import List, Optional, Union
from pydantic import BaseModel, Field
@@ -107,25 +107,20 @@ class ToolRulesSolver(BaseModel):
self.tool_call_history.clear()
def get_allowed_tool_names(
self, available_tools: Set[str], error_on_empty: bool = False, last_function_response: Optional[str] = None
self, available_tools: set[str], error_on_empty: bool = True, last_function_response: str | None = None
) -> List[str]:
"""Get a list of tool names allowed based on the last tool called."""
"""Get a list of tool names allowed based on the last tool called.
The logic is as follows:
1. if there are no previous tool calls and we have InitToolRules, those are the only options for the first tool call
2. else we take the intersection of the Parent/Child/Conditional/MaxSteps as the options
3. Continue/Terminal/RequiredBeforeExit rules are applied in the agent loop flow, not to restrict tools
"""
# TODO: This piece of code here is quite ugly and deserves a refactor
# TODO: There's some weird logic encoded here:
# TODO: -> This only takes into consideration Init, and a set of Child/Conditional/MaxSteps tool rules
# TODO: -> Init tool rules outputs are treated additively, Child/Conditional/MaxSteps are intersection based
# TODO: -> Tool rules should probably be refactored to take in a set of tool names?
# If no tool has been called yet, return InitToolRules additively
if not self.tool_call_history:
if self.init_tool_rules:
# If there are init tool rules, only return those defined in the init tool rules
return [rule.tool_name for rule in self.init_tool_rules]
else:
# Otherwise, return all tools besides those constrained by parent tool rules
available_tools = available_tools - set.union(set(), *(set(rule.children) for rule in self.parent_tool_rules))
return list(available_tools)
if not self.tool_call_history and self.init_tool_rules:
return [rule.tool_name for rule in self.init_tool_rules]
else:
# Collect valid tools from all child-based rules
valid_tool_sets = []
for rule in self.child_based_tool_rules + self.parent_tool_rules:
tools = rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response)
@@ -151,11 +146,11 @@ class ToolRulesSolver(BaseModel):
"""Check if the tool is defined as a continue tool in the tool rules."""
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
def has_required_tools_been_called(self, available_tools: Set[str]) -> bool:
def has_required_tools_been_called(self, available_tools: set[str]) -> bool:
"""Check if all required-before-exit tools have been called."""
return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0
def get_uncalled_required_tools(self, available_tools: Set[str]) -> List[str]:
def get_uncalled_required_tools(self, available_tools: set[str]) -> List[str]:
"""Get the list of required-before-exit tools that have not been called yet."""
if not self.required_before_exit_tool_rules:
return [] # No required tools means no uncalled tools

View File

@@ -49,6 +49,9 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
nullable=True,
doc="The id of the LLMBatchItem that this message is associated with",
)
is_err: Mapped[Optional[bool]] = mapped_column(
nullable=True, doc="Whether this message is part of an error step. Used only for debugging purposes."
)
# Monotonically increasing sequence for efficient/correct listing
sequence_id: Mapped[int] = mapped_column(

View File

@@ -5,6 +5,7 @@ from sqlalchemy import JSON, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.letta_stop_reason import StopReasonType
from letta.schemas.step import Step as PydanticStep
if TYPE_CHECKING:
@@ -45,6 +46,7 @@ class Step(SqlalchemyBase):
prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt")
total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent")
completion_tokens_details: Mapped[Optional[Dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
stop_reason: Mapped[Optional[StopReasonType]] = mapped_column(None, nullable=True, doc="The stop reason associated with this step.")
tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.")
tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.")
trace_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The trace id of the agent step.")

View File

@@ -40,15 +40,18 @@ class LettaMessage(BaseModel):
message_type (MessageType): The type of the message
otid (Optional[str]): The offline threading id associated with this message
sender_id (Optional[str]): The id of the sender of the message, can be an identity id or agent id
step_id (Optional[str]): The step id associated with the message
is_err (Optional[bool]): Whether the message is an errored message or not. Used for debugging purposes only.
"""
id: str
date: datetime
name: Optional[str] = None
name: str | None = None
message_type: MessageType = Field(..., description="The type of the message.")
otid: Optional[str] = None
sender_id: Optional[str] = None
step_id: Optional[str] = None
otid: str | None = None
sender_id: str | None = None
step_id: str | None = None
is_err: bool | None = None
@field_serializer("date")
def serialize_datetime(self, dt: datetime, _info):
@@ -60,6 +63,14 @@ class LettaMessage(BaseModel):
dt = dt.replace(tzinfo=timezone.utc)
return dt.isoformat(timespec="seconds")
@field_serializer("is_err", when_used="unless-none")
def serialize_is_err(self, value: bool | None, _info):
"""
Only serialize is_err field when it's True (for debugging purposes).
When is_err is None or False, this field will be excluded from the JSON output.
"""
return value if value is True else None
class SystemMessage(LettaMessage):
"""

View File

@@ -172,6 +172,9 @@ class Message(BaseMessage):
group_id: Optional[str] = Field(default=None, description="The multi-agent group that the message was sent in")
sender_id: Optional[str] = Field(default=None, description="The id of the sender of the message, can be an identity id or agent id")
batch_item_id: Optional[str] = Field(default=None, description="The id of the LLMBatchItem that this message is associated with")
is_err: Optional[bool] = Field(
default=None, description="Whether this message is part of an error step. Used only for debugging purposes."
)
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
@@ -191,6 +194,7 @@ class Message(BaseMessage):
if not is_utc_datetime(self.created_at):
self.created_at = self.created_at.replace(tzinfo=timezone.utc)
json_message["created_at"] = self.created_at.isoformat()
json_message.pop("is_err", None) # make sure we don't include this debugging information
return json_message
@staticmethod
@@ -204,6 +208,7 @@ class Message(BaseMessage):
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
reverse: bool = True,
include_err: Optional[bool] = None,
) -> List[LettaMessage]:
if use_assistant_message:
message_ids_to_remove = []
@@ -234,6 +239,7 @@ class Message(BaseMessage):
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
reverse=reverse,
include_err=include_err,
)
]
@@ -243,6 +249,7 @@ class Message(BaseMessage):
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
reverse: bool = True,
include_err: Optional[bool] = None,
) -> List[LettaMessage]:
"""Convert message object (in DB format) to the style used by the original Letta API"""
messages = []
@@ -682,14 +689,13 @@ class Message(BaseMessage):
# since the only "parts" we have are for supporting various COT
if self.role == "system":
assert all([v is not None for v in [self.role]]), vars(self)
openai_message = {
"content": text_content,
"role": "developer" if use_developer_message else self.role,
}
elif self.role == "user":
assert all([v is not None for v in [text_content, self.role]]), vars(self)
assert text_content is not None, vars(self)
openai_message = {
"content": text_content,
"role": self.role,
@@ -720,7 +726,7 @@ class Message(BaseMessage):
tool_call_dict["id"] = tool_call_dict["id"][:max_tool_id_length]
elif self.role == "tool":
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self)
assert self.tool_call_id is not None, vars(self)
openai_message = {
"content": text_content,
"role": self.role,
@@ -776,7 +782,7 @@ class Message(BaseMessage):
if self.role == "system":
# NOTE: this is not for system instructions, but instead system "events"
assert all([v is not None for v in [text_content, self.role]]), vars(self)
assert text_content is not None, vars(self)
# Two options here, we would use system.package_system_message,
# or use a more Anthropic-specific packaging ie xml tags
user_system_event = add_xml_tag(string=f"SYSTEM ALERT: {text_content}", xml_tag="event")
@@ -875,7 +881,7 @@ class Message(BaseMessage):
elif self.role == "tool":
# NOTE: Anthropic uses role "user" for "tool" responses
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self)
assert self.tool_call_id is not None, vars(self)
anthropic_message = {
"role": "user", # NOTE: diff
"content": [
@@ -988,7 +994,7 @@ class Message(BaseMessage):
elif self.role == "tool":
# NOTE: Significantly different tool calling format, more similar to function calling format
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self)
assert self.tool_call_id is not None, vars(self)
if self.name is None:
warnings.warn(f"Couldn't find function name on tool call, defaulting to tool ID instead.")

View File

@@ -1,8 +1,10 @@
from enum import Enum, auto
from typing import Dict, List, Literal, Optional
from pydantic import Field
from letta.schemas.letta_base import LettaBase
from letta.schemas.letta_stop_reason import StopReasonType
from letta.schemas.message import Message
@@ -28,6 +30,7 @@ class Step(StepBase):
prompt_tokens: Optional[int] = Field(None, description="The number of tokens in the prompt during this step.")
total_tokens: Optional[int] = Field(None, description="The total number of tokens processed by the agent during this step.")
completion_tokens_details: Optional[Dict] = Field(None, description="Metadata for the agent.")
stop_reason: Optional[StopReasonType] = Field(None, description="The stop reason associated with the step.")
tags: List[str] = Field([], description="Metadata tags.")
tid: Optional[str] = Field(None, description="The unique identifier of the transaction that processed this step.")
trace_id: Optional[str] = Field(None, description="The trace id of the agent step.")
@@ -36,3 +39,12 @@ class Step(StepBase):
None, description="The feedback for this step. Must be either 'positive' or 'negative'."
)
project_id: Optional[str] = Field(None, description="The project that the agent that executed this step belongs to (cloud only).")
class StepProgression(int, Enum):
START = auto()
STREAM_RECEIVED = auto()
RESPONSE_RECEIVED = auto()
STEP_LOGGED = auto()
LOGGED_TRACE = auto()
FINISHED = auto()

View File

@@ -2,8 +2,10 @@ import importlib.util
import json
import logging
import os
import platform
import sys
from contextlib import asynccontextmanager
from functools import partial
from pathlib import Path
from typing import Optional
@@ -34,32 +36,25 @@ from letta.server.db import db_registry
from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right?
from letta.server.rest_api.interface import StreamingServerInterface
from letta.server.rest_api.routers.openai.chat_completions.chat_completions import router as openai_chat_completions_router
# from letta.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM
from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes
from letta.server.rest_api.routers.v1.organizations import router as organizations_router
from letta.server.rest_api.routers.v1.users import router as users_router # TODO: decide on admin
from letta.server.rest_api.static_files import mount_static_files
from letta.server.rest_api.utils import SENTRY_ENABLED
from letta.server.server import SyncServer
from letta.settings import settings
# TODO(ethan)
if SENTRY_ENABLED:
import sentry_sdk
IS_WINDOWS = platform.system() == "Windows"
# NOTE(charles): @ethan I had to add this to get the global as the bottom to work
interface: StreamingServerInterface = StreamingServerInterface
interface: type = StreamingServerInterface
server = SyncServer(default_interface_factory=lambda: interface())
logger = get_logger(__name__)
import logging
import platform
from fastapi import FastAPI
is_windows = platform.system() == "Windows"
log = logging.getLogger("uvicorn")
def generate_openapi_schema(app: FastAPI):
# Update the OpenAPI schema
if not app.openapi_schema:
@@ -177,9 +172,7 @@ def create_application() -> "FastAPI":
# server = SyncServer(default_interface_factory=lambda: interface())
print(f"\n[[ Letta server // v{letta_version} ]]")
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
import sentry_sdk
if SENTRY_ENABLED:
sentry_sdk.init(
dsn=os.getenv("SENTRY_DSN"),
traces_sample_rate=1.0,
@@ -187,6 +180,7 @@ def create_application() -> "FastAPI":
"continuous_profiling_auto_start": True,
},
)
logger.info("Sentry enabled.")
debug_mode = "--debug" in sys.argv
app = FastAPI(
@@ -199,31 +193,13 @@ def create_application() -> "FastAPI":
lifespan=lifespan,
)
@app.exception_handler(IncompatibleAgentType)
async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType):
return JSONResponse(
status_code=400,
content={
"detail": str(exc),
"expected_type": exc.expected_type,
"actual_type": exc.actual_type,
},
)
# === Exception Handlers ===
# TODO (cliandy): move to separate file
@app.exception_handler(Exception)
async def generic_error_handler(request: Request, exc: Exception):
# Log the actual error for debugging
log.error(f"Unhandled error: {str(exc)}", exc_info=True)
print(f"Unhandled error: {str(exc)}")
import traceback
# Print the stack trace
print(f"Stack trace: {traceback.format_exc()}")
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
import sentry_sdk
logger.error(f"Unhandled error: {str(exc)}", exc_info=True)
if SENTRY_ENABLED:
sentry_sdk.capture_exception(exc)
return JSONResponse(
@@ -235,62 +211,70 @@ def create_application() -> "FastAPI":
},
)
@app.exception_handler(NoResultFound)
async def no_result_found_handler(request: Request, exc: NoResultFound):
logger.error(f"NoResultFound: {exc}")
async def error_handler_with_code(request: Request, exc: Exception, code: int, detail: str | None = None):
logger.error(f"{type(exc).__name__}", exc_info=exc)
if SENTRY_ENABLED:
sentry_sdk.capture_exception(exc)
if not detail:
detail = str(exc)
return JSONResponse(
status_code=404,
content={"detail": str(exc)},
status_code=code,
content={"detail": detail},
)
@app.exception_handler(ForeignKeyConstraintViolationError)
async def foreign_key_constraint_handler(request: Request, exc: ForeignKeyConstraintViolationError):
logger.error(f"ForeignKeyConstraintViolationError: {exc}")
_error_handler_400 = partial(error_handler_with_code, code=400)
_error_handler_404 = partial(error_handler_with_code, code=404)
_error_handler_404_agent = partial(_error_handler_404, detail="Agent not found")
_error_handler_404_user = partial(_error_handler_404, detail="User not found")
_error_handler_409 = partial(error_handler_with_code, code=409)
app.add_exception_handler(ValueError, _error_handler_400)
app.add_exception_handler(NoResultFound, _error_handler_404)
app.add_exception_handler(LettaAgentNotFoundError, _error_handler_404_agent)
app.add_exception_handler(LettaUserNotFoundError, _error_handler_404_user)
app.add_exception_handler(ForeignKeyConstraintViolationError, _error_handler_409)
app.add_exception_handler(UniqueConstraintViolationError, _error_handler_409)
@app.exception_handler(IncompatibleAgentType)
async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType):
logger.error("Incompatible agent types. Expected: %s, Actual: %s", exc.expected_type, exc.actual_type)
if SENTRY_ENABLED:
sentry_sdk.capture_exception(exc)
return JSONResponse(
status_code=409,
content={"detail": str(exc)},
)
@app.exception_handler(UniqueConstraintViolationError)
async def unique_key_constraint_handler(request: Request, exc: UniqueConstraintViolationError):
logger.error(f"UniqueConstraintViolationError: {exc}")
return JSONResponse(
status_code=409,
content={"detail": str(exc)},
status_code=400,
content={
"detail": str(exc),
"expected_type": exc.expected_type,
"actual_type": exc.actual_type,
},
)
@app.exception_handler(DatabaseTimeoutError)
async def database_timeout_error_handler(request: Request, exc: DatabaseTimeoutError):
logger.error(f"Timeout occurred: {exc}. Original exception: {exc.original_exception}")
if SENTRY_ENABLED:
sentry_sdk.capture_exception(exc)
return JSONResponse(
status_code=503,
content={"detail": "The database is temporarily unavailable. Please try again later."},
)
@app.exception_handler(ValueError)
async def value_error_handler(request: Request, exc: ValueError):
return JSONResponse(status_code=400, content={"detail": str(exc)})
@app.exception_handler(LettaAgentNotFoundError)
async def agent_not_found_handler(request: Request, exc: LettaAgentNotFoundError):
return JSONResponse(status_code=404, content={"detail": "Agent not found"})
@app.exception_handler(LettaUserNotFoundError)
async def user_not_found_handler(request: Request, exc: LettaUserNotFoundError):
return JSONResponse(status_code=404, content={"detail": "User not found"})
@app.exception_handler(BedrockPermissionError)
async def bedrock_permission_error_handler(request, exc: BedrockPermissionError):
logger.error(f"Bedrock permission denied.")
if SENTRY_ENABLED:
sentry_sdk.capture_exception(exc)
return JSONResponse(
status_code=403,
content={
"error": {
"type": "bedrock_permission_denied",
"message": "Unable to access the required AI model. Please check your Bedrock permissions or contact support.",
"details": {"model_arn": exc.model_arn, "reason": str(exc)},
"detail": {str(exc)},
}
},
)
@@ -301,6 +285,9 @@ def create_application() -> "FastAPI":
print(f"▶ Using secure mode with password: {random_password}")
app.add_middleware(CheckPasswordMiddleware)
# Add reverse proxy middleware to handle X-Forwarded-* headers
# app.add_middleware(ReverseProxyMiddleware, base_path=settings.server_base_path)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
@@ -442,7 +429,7 @@ def start_server(
)
else:
if is_windows:
if IS_WINDOWS:
# Windows doesn't those the fancy unicode characters
print(f"Server running at: http://{host or 'localhost'}:{port or REST_DEFAULT_PORT}")
print(f"View using ADE at: https://app.letta.com/development-servers/local/dashboard\n")

View File

@@ -636,6 +636,9 @@ async def list_messages(
use_assistant_message: bool = Query(True, description="Whether to use assistant messages"),
assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."),
assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."),
include_err: bool | None = Query(
None, description="Whether to include error messages and error statuses. For debugging purposes only."
),
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
@@ -654,6 +657,7 @@ async def list_messages(
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
include_err=include_err,
actor=actor,
)
@@ -1156,7 +1160,7 @@ async def list_agent_groups(
):
"""Lists the groups for an agent"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
print("in list agents with manager_type", manager_type)
logger.info("in list agents with manager_type", manager_type)
return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor)

View File

@@ -12,6 +12,7 @@ from starlette.types import Send
from letta.log import get_logger
from letta.schemas.enums import JobStatus
from letta.schemas.user import User
from letta.server.rest_api.utils import capture_sentry_exception
from letta.services.job_manager import JobManager
logger = get_logger(__name__)
@@ -92,6 +93,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
more_body = True
try:
first_chunk = await self.body_iterator.__anext__()
logger.debug("stream_response first chunk:", first_chunk)
if isinstance(first_chunk, tuple):
first_chunk_content, self.status_code = first_chunk
else:
@@ -130,7 +132,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
"more_body": more_body,
}
)
return
raise Exception(f"An exception occurred mid-stream with status code {status_code}", detail={"content": content})
else:
content = chunk
@@ -146,8 +148,8 @@ class StreamingResponseWithStatusCode(StreamingResponse):
)
# This should be handled properly upstream?
except asyncio.CancelledError:
logger.info("Stream was cancelled by client or job cancellation")
except asyncio.CancelledError as exc:
logger.warning("Stream was cancelled by client or job cancellation")
# Handle cancellation gracefully
more_body = False
cancellation_resp = {"error": {"message": "Stream cancelled"}}
@@ -160,6 +162,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
"headers": self.raw_headers,
}
)
raise
await send(
{
"type": "http.response.body",
@@ -167,13 +170,15 @@ class StreamingResponseWithStatusCode(StreamingResponse):
"more_body": more_body,
}
)
capture_sentry_exception(exc)
return
except Exception:
logger.exception("unhandled_streaming_error")
except Exception as exc:
logger.exception("Unhandled Streaming Error")
more_body = False
error_resp = {"error": {"message": "Internal Server Error"}}
error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset)
logger.debug("response_started:", self.response_started)
if not self.response_started:
await send(
{
@@ -182,6 +187,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
"headers": self.raw_headers,
}
)
raise
await send(
{
"type": "http.response.body",
@@ -189,5 +195,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
"more_body": more_body,
}
)
capture_sentry_exception(exc)
return
if more_body:
await send({"type": "http.response.body", "body": b"", "more_body": False})

View File

@@ -2,7 +2,6 @@ import asyncio
import json
import os
import uuid
import warnings
from enum import Enum
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Iterable, List, Optional, Union, cast
@@ -34,12 +33,15 @@ from letta.schemas.message import Message, MessageCreate, ToolReturn
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.interface import StreamingServerInterface
from letta.system import get_heartbeat, package_function_response
if TYPE_CHECKING:
from letta.server.server import SyncServer
SENTRY_ENABLED = bool(os.getenv("SENTRY_DSN"))
if SENTRY_ENABLED:
import sentry_sdk
SSE_PREFIX = "data: "
SSE_SUFFIX = "\n\n"
@@ -157,21 +159,9 @@ def get_user_id(user_id: Optional[str] = Header(None, alias="user_id")) -> Optio
return user_id
def get_current_interface() -> StreamingServerInterface:
return StreamingServerInterface
def log_error_to_sentry(e):
import traceback
traceback.print_exc()
warnings.warn(f"SSE stream generator failed: {e}")
# Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response
# Print the stack trace
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
import sentry_sdk
def capture_sentry_exception(e: BaseException):
"""This will capture the exception in sentry, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response"""
if SENTRY_ENABLED:
sentry_sdk.capture_exception(e)

View File

@@ -105,6 +105,7 @@ from letta.services.tool_executor.tool_execution_manager import ToolExecutionMan
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.settings import model_settings, settings, tool_settings
from letta.streaming_interface import AgentChunkStreamingInterface
from letta.utils import get_friendly_error_msg, get_persona_text, make_key
config = LettaConfig.load()
@@ -176,7 +177,7 @@ class SyncServer(Server):
self,
chaining: bool = True,
max_chaining_steps: Optional[int] = 100,
default_interface_factory: Callable[[], AgentInterface] = lambda: CLIInterface(),
default_interface_factory: Callable[[], AgentChunkStreamingInterface] = lambda: CLIInterface(),
init_with_default_org_and_user: bool = True,
# default_interface: AgentInterface = CLIInterface(),
# default_persistence_manager_cls: PersistenceManager = LocalStateManager,
@@ -1244,6 +1245,7 @@ class SyncServer(Server):
use_assistant_message: bool = True,
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
include_err: Optional[bool] = None,
) -> Union[List[Message], List[LettaMessage]]:
records = await self.message_manager.list_messages_for_agent_async(
agent_id=agent_id,
@@ -1253,6 +1255,7 @@ class SyncServer(Server):
limit=limit,
ascending=not reverse,
group_id=group_id,
include_err=include_err,
)
if not return_message_object:
@@ -1262,6 +1265,7 @@ class SyncServer(Server):
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
reverse=reverse,
include_err=include_err,
)
if reverse:

View File

@@ -520,6 +520,7 @@ class MessageManager:
limit: Optional[int] = 50,
ascending: bool = True,
group_id: Optional[str] = None,
include_err: Optional[bool] = None,
) -> List[PydanticMessage]:
"""
Most performant query to list messages for an agent by directly querying the Message table.
@@ -539,6 +540,7 @@ class MessageManager:
limit: Maximum number of messages to return.
ascending: If True, sort by sequence_id ascending; if False, sort descending.
group_id: Optional group ID to filter messages by group_id.
include_err: Optional boolean to include errors and error statuses. Used for debugging only.
Returns:
List[PydanticMessage]: A list of messages (converted via .to_pydantic()).
@@ -558,6 +560,9 @@ class MessageManager:
if group_id:
query = query.where(MessageModel.group_id == group_id)
if not include_err:
query = query.where((MessageModel.is_err == False) | (MessageModel.is_err.is_(None)))
# If query_text is provided, filter messages using database-specific JSON search.
if query_text:
if settings.letta_pg_uri_no_default:

View File

@@ -12,6 +12,7 @@ from letta.orm.job import Job as JobModel
from letta.orm.sqlalchemy_base import AccessType
from letta.orm.step import Step as StepModel
from letta.otel.tracing import get_trace_id, trace_method
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.step import Step as PydanticStep
from letta.schemas.user import User as PydanticUser
@@ -131,6 +132,7 @@ class StepManager:
job_id: Optional[str] = None,
step_id: Optional[str] = None,
project_id: Optional[str] = None,
stop_reason: Optional[LettaStopReason] = None,
) -> PydanticStep:
step_data = {
"origin": None,
@@ -153,6 +155,8 @@ class StepManager:
}
if step_id:
step_data["id"] = step_id
if stop_reason:
step_data["stop_reason"] = stop_reason.stop_reason
async with db_registry.async_session() as session:
if job_id:
await self._verify_job_access_async(session, job_id, actor, access=["write"])
@@ -207,6 +211,33 @@ class StepManager:
await session.commit()
return step.to_pydantic()
@enforce_types
@trace_method
async def update_step_stop_reason(self, actor: PydanticUser, step_id: str, stop_reason: StopReasonType) -> PydanticStep:
"""Update the stop reason for a step.
Args:
actor: The user making the request
step_id: The ID of the step to update
stop_reason: The stop reason to set
Returns:
The updated step
Raises:
NoResultFound: If the step does not exist
"""
async with db_registry.async_session() as session:
step = await session.get(StepModel, step_id)
if not step:
raise NoResultFound(f"Step with id {step_id} does not exist")
if step.organization_id != actor.organization_id:
raise Exception("Unauthorized")
step.stop_reason = stop_reason
await session.commit()
return step
def _verify_job_access(
self,
session: Session,
@@ -309,5 +340,6 @@ class NoopStepManager(StepManager):
job_id: Optional[str] = None,
step_id: Optional[str] = None,
project_id: Optional[str] = None,
stop_reason: Optional[LettaStopReason] = None,
) -> PydanticStep:
return

View File

@@ -220,13 +220,15 @@ class Settings(BaseSettings):
multi_agent_concurrent_sends: int = 50
# telemetry logging
otel_exporter_otlp_endpoint: Optional[str] = None # otel default: "http://localhost:4317"
otel_preferred_temporality: Optional[int] = Field(
otel_exporter_otlp_endpoint: str | None = None # otel default: "http://localhost:4317"
otel_preferred_temporality: int | None = Field(
default=1, ge=0, le=2, description="Exported metric temporality. {0: UNSPECIFIED, 1: DELTA, 2: CUMULATIVE}"
)
disable_tracing: bool = Field(default=False, description="Disable OTEL Tracing")
llm_api_logging: bool = Field(default=True, description="Enable LLM API logging at each step")
track_last_agent_run: bool = Field(default=False, description="Update last agent run metrics")
track_errored_messages: bool = Field(default=True, description="Enable tracking for errored messages")
track_stop_reason: bool = Field(default=True, description="Enable tracking stop reason on steps.")
# uvicorn settings
uvicorn_workers: int = 1

View File

@@ -752,14 +752,15 @@ def test_step_stream_agent_loop_error(
"""
last_message = client.agents.messages.list(agent_id=agent_state_no_tools.id, limit=1)
agent_state_no_tools = client.agents.modify(agent_id=agent_state_no_tools.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state_no_tools.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
with pytest.raises(Exception) as exc_info:
response = client.agents.messages.create_stream(
agent_id=agent_state_no_tools.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
list(response)
for chunk in response:
print(chunk)
print("error info:", exc_info)
assert type(exc_info.value) in (ApiError, ValueError)
print(exc_info.value)
messages_from_db = client.agents.messages.list(agent_id=agent_state_no_tools.id, after=last_message[0].id)
assert len(messages_from_db) == 0

View File

@@ -138,7 +138,9 @@ def test_max_count_per_step_tool_rule():
assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "After first use, should still allow 'start_tool'"
solver.register_tool_call(START_TOOL)
assert solver.get_allowed_tool_names({START_TOOL}) == [], "After reaching max count, 'start_tool' should no longer be allowed"
assert (
solver.get_allowed_tool_names({START_TOOL}, error_on_empty=False) == []
), "After reaching max count, 'start_tool' should no longer be allowed"
def test_max_count_per_step_tool_rule_allows_usage_up_to_limit():
@@ -155,7 +157,7 @@ def test_max_count_per_step_tool_rule_allows_usage_up_to_limit():
assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Should still allow 'start_tool' after 2 uses"
solver.register_tool_call(START_TOOL)
assert solver.get_allowed_tool_names({START_TOOL}) == [], "Should no longer allow 'start_tool' after 3 uses"
assert solver.get_allowed_tool_names({START_TOOL}, error_on_empty=False) == [], "Should no longer allow 'start_tool' after 3 uses"
def test_max_count_per_step_tool_rule_does_not_affect_other_tools():
@@ -180,7 +182,7 @@ def test_max_count_per_step_tool_rule_resets_on_clear():
solver.register_tool_call(START_TOOL)
solver.register_tool_call(START_TOOL)
assert solver.get_allowed_tool_names({START_TOOL}) == [], "Should not allow 'start_tool' after reaching limit"
assert solver.get_allowed_tool_names({START_TOOL}, error_on_empty=False) == [], "Should not allow 'start_tool' after reaching limit"
solver.clear_tool_history()