feat: stop reasons and error messages and sentry fixes
This commit is contained in:
@@ -0,0 +1,42 @@
|
||||
"""add stop reasons to steps and message error flag
|
||||
|
||||
Revision ID: cce9a6174366
|
||||
Revises: 2c059cad97cc
|
||||
Create Date: 2025-07-10 13:56:17.383612
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "cce9a6174366"
|
||||
down_revision: Union[str, None] = "2c059cad97cc"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("messages", sa.Column("is_err", sa.Boolean(), nullable=True))
|
||||
|
||||
# manually added to handle non-table creation enums
|
||||
stopreasontype = sa.Enum(
|
||||
"end_turn", "error", "invalid_tool_call", "max_steps", "no_tool_call", "tool_rule", "cancelled", name="stopreasontype"
|
||||
)
|
||||
stopreasontype.create(op.get_bind())
|
||||
op.add_column("steps", sa.Column("stop_reason", stopreasontype, nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("steps", "stop_reason")
|
||||
op.drop_column("messages", "is_err")
|
||||
|
||||
stopreasontype = sa.Enum(name="stopreasontype")
|
||||
stopreasontype.drop(op.get_bind())
|
||||
# ### end Alembic commands ###
|
||||
@@ -43,6 +43,7 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
||||
from letta.schemas.provider_trace import ProviderTraceCreate
|
||||
from letta.schemas.step import StepProgression
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
@@ -238,100 +239,164 @@ class LettaAgent(BaseAgent):
|
||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||||
agent_step_span.set_attributes({"step_id": step_id})
|
||||
|
||||
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
|
||||
await self._build_and_request_from_llm(
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
agent_state,
|
||||
llm_client,
|
||||
tool_rules_solver,
|
||||
agent_step_span,
|
||||
)
|
||||
)
|
||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||
|
||||
log_event("agent.stream_no_tokens.llm_response.received") # [3^]
|
||||
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
|
||||
|
||||
# update usage
|
||||
usage.step_count += 1
|
||||
usage.completion_tokens += response.usage.completion_tokens
|
||||
usage.prompt_tokens += response.usage.prompt_tokens
|
||||
usage.total_tokens += response.usage.total_tokens
|
||||
MetricRegistry().message_output_tokens.record(
|
||||
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
|
||||
)
|
||||
|
||||
if not response.choices[0].message.tool_calls:
|
||||
# TODO: make into a real error
|
||||
raise ValueError("No tool calls found in response, model must make a tool call")
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
if response.choices[0].message.reasoning_content:
|
||||
reasoning = [
|
||||
ReasoningContent(
|
||||
reasoning=response.choices[0].message.reasoning_content,
|
||||
is_native=True,
|
||||
signature=response.choices[0].message.reasoning_content_signature,
|
||||
step_progression = StepProgression.START
|
||||
should_continue = False
|
||||
try:
|
||||
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
|
||||
await self._build_and_request_from_llm(
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
agent_state,
|
||||
llm_client,
|
||||
tool_rules_solver,
|
||||
agent_step_span,
|
||||
)
|
||||
]
|
||||
elif response.choices[0].message.omitted_reasoning_content:
|
||||
reasoning = [OmittedReasoningContent()]
|
||||
elif response.choices[0].message.content:
|
||||
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
|
||||
else:
|
||||
self.logger.info("No reasoning content found.")
|
||||
reasoning = None
|
||||
)
|
||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||
|
||||
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
||||
tool_call,
|
||||
valid_tool_names,
|
||||
agent_state,
|
||||
tool_rules_solver,
|
||||
response.usage,
|
||||
reasoning_content=reasoning,
|
||||
step_id=step_id,
|
||||
initial_messages=initial_messages,
|
||||
agent_step_span=agent_step_span,
|
||||
is_final_step=(i == max_steps - 1),
|
||||
)
|
||||
step_progression = StepProgression.RESPONSE_RECEIVED
|
||||
log_event("agent.stream_no_tokens.llm_response.received") # [3^]
|
||||
|
||||
# TODO (cliandy): handle message contexts with larger refactor and dedupe logic
|
||||
new_message_idx = len(initial_messages) if initial_messages else 0
|
||||
self.response_messages.extend(persisted_messages[new_message_idx:])
|
||||
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
||||
initial_messages = None
|
||||
log_event("agent.stream_no_tokens.llm_response.processed") # [4^]
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
|
||||
|
||||
# log step time
|
||||
now = get_utc_timestamp_ns()
|
||||
step_ns = now - step_start
|
||||
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
|
||||
agent_step_span.end()
|
||||
# update usage
|
||||
usage.step_count += 1
|
||||
usage.completion_tokens += response.usage.completion_tokens
|
||||
usage.prompt_tokens += response.usage.prompt_tokens
|
||||
usage.total_tokens += response.usage.total_tokens
|
||||
MetricRegistry().message_output_tokens.record(
|
||||
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
|
||||
)
|
||||
|
||||
# Log LLM Trace
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
if not response.choices[0].message.tool_calls:
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
|
||||
raise ValueError("No tool calls found in response, model must make a tool call")
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
if response.choices[0].message.reasoning_content:
|
||||
reasoning = [
|
||||
ReasoningContent(
|
||||
reasoning=response.choices[0].message.reasoning_content,
|
||||
is_native=True,
|
||||
signature=response.choices[0].message.reasoning_content_signature,
|
||||
)
|
||||
]
|
||||
elif response.choices[0].message.omitted_reasoning_content:
|
||||
reasoning = [OmittedReasoningContent()]
|
||||
elif response.choices[0].message.content:
|
||||
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
|
||||
else:
|
||||
self.logger.info("No reasoning content found.")
|
||||
reasoning = None
|
||||
|
||||
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
||||
tool_call,
|
||||
valid_tool_names,
|
||||
agent_state,
|
||||
tool_rules_solver,
|
||||
response.usage,
|
||||
reasoning_content=reasoning,
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
initial_messages=initial_messages,
|
||||
agent_step_span=agent_step_span,
|
||||
is_final_step=(i == max_steps - 1),
|
||||
)
|
||||
step_progression = StepProgression.STEP_LOGGED
|
||||
|
||||
# stream step
|
||||
# TODO: improve TTFT
|
||||
filter_user_messages = [m for m in persisted_messages if m.role != "user"]
|
||||
letta_messages = Message.to_letta_messages_from_list(
|
||||
filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
|
||||
)
|
||||
# TODO (cliandy): handle message contexts with larger refactor and dedupe logic
|
||||
new_message_idx = len(initial_messages) if initial_messages else 0
|
||||
self.response_messages.extend(persisted_messages[new_message_idx:])
|
||||
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
||||
initial_messages = None
|
||||
log_event("agent.stream_no_tokens.llm_response.processed") # [4^]
|
||||
|
||||
for message in letta_messages:
|
||||
if include_return_message_types is None or message.message_type in include_return_message_types:
|
||||
yield f"data: {message.model_dump_json()}\n\n"
|
||||
# log step time
|
||||
now = get_utc_timestamp_ns()
|
||||
step_ns = now - step_start
|
||||
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
|
||||
agent_step_span.end()
|
||||
|
||||
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
|
||||
# Log LLM Trace
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
|
||||
# stream step
|
||||
# TODO: improve TTFT
|
||||
filter_user_messages = [m for m in persisted_messages if m.role != "user"]
|
||||
letta_messages = Message.to_letta_messages_from_list(
|
||||
filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
|
||||
)
|
||||
|
||||
for message in letta_messages:
|
||||
if include_return_message_types is None or message.message_type in include_return_message_types:
|
||||
yield f"data: {message.model_dump_json()}\n\n"
|
||||
|
||||
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
|
||||
step_progression = StepProgression.FINISHED
|
||||
except Exception as e:
|
||||
# Handle any unexpected errors during step processing
|
||||
self.logger.error(f"Error during step processing: {e}")
|
||||
|
||||
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
|
||||
if not stop_reason:
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||
elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule):
|
||||
self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason)
|
||||
elif stop_reason.stop_reason not in (StopReasonType.no_tool_call, StopReasonType.invalid_tool_call):
|
||||
raise ValueError(f"Invalid Stop Reason: {stop_reason}")
|
||||
|
||||
# Send error stop reason to client and re-raise
|
||||
yield f"data: {stop_reason.model_dump_json()}\n\n", 500
|
||||
raise
|
||||
|
||||
# Update step if it needs to be updated
|
||||
finally:
|
||||
if settings.track_stop_reason:
|
||||
self.logger.info("Running final update. Step Progression: %s", step_progression)
|
||||
try:
|
||||
if step_progression < StepProgression.STEP_LOGGED:
|
||||
await self.step_manager.log_step_async(
|
||||
actor=self.actor,
|
||||
agent_id=agent_state.id,
|
||||
provider_name=agent_state.llm_config.model_endpoint_type,
|
||||
provider_category=agent_state.llm_config.provider_category or "base",
|
||||
model=agent_state.llm_config.model,
|
||||
model_endpoint=agent_state.llm_config.model_endpoint,
|
||||
context_window_limit=agent_state.llm_config.context_window,
|
||||
usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0),
|
||||
provider_id=None,
|
||||
job_id=self.current_run_id if self.current_run_id else None,
|
||||
step_id=step_id,
|
||||
project_id=agent_state.project_id,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
if step_progression <= StepProgression.RESPONSE_RECEIVED:
|
||||
# TODO (cliandy): persist response if we get it back
|
||||
if settings.track_errored_messages:
|
||||
for message in initial_messages:
|
||||
message.is_err = True
|
||||
message.step_id = step_id
|
||||
await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor)
|
||||
elif step_progression <= StepProgression.LOGGED_TRACE:
|
||||
if stop_reason is None:
|
||||
self.logger.error("Error in step after logging step")
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
|
||||
elif step_progression == StepProgression.FINISHED and not should_continue:
|
||||
if stop_reason is None:
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
||||
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
|
||||
else:
|
||||
self.logger.error("Invalid StepProgression value")
|
||||
except Exception as e:
|
||||
self.logger.error("Failed to update step: %s", e)
|
||||
|
||||
if not should_continue:
|
||||
break
|
||||
@@ -396,6 +461,16 @@ class LettaAgent(BaseAgent):
|
||||
stop_reason = None
|
||||
usage = LettaUsageStatistics()
|
||||
for i in range(max_steps):
|
||||
# If dry run, build request data and return it without making LLM call
|
||||
if dry_run:
|
||||
request_data, valid_tool_names = await self._create_llm_request_data_async(
|
||||
llm_client=llm_client,
|
||||
in_context_messages=current_in_context_messages + new_in_context_messages,
|
||||
agent_state=agent_state,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
)
|
||||
return request_data
|
||||
|
||||
# Check for job cancellation at the start of each step
|
||||
if await self._check_run_cancellation():
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
|
||||
@@ -407,94 +482,148 @@ class LettaAgent(BaseAgent):
|
||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||||
agent_step_span.set_attributes({"step_id": step_id})
|
||||
|
||||
# If dry run, build request data and return it without making LLM call
|
||||
if dry_run:
|
||||
request_data, valid_tool_names = await self._create_llm_request_data_async(
|
||||
llm_client=llm_client,
|
||||
in_context_messages=current_in_context_messages + new_in_context_messages,
|
||||
agent_state=agent_state,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
)
|
||||
return request_data
|
||||
step_progression = StepProgression.START
|
||||
should_continue = False
|
||||
|
||||
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
|
||||
await self._build_and_request_from_llm(
|
||||
current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver, agent_step_span
|
||||
)
|
||||
)
|
||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||
|
||||
log_event("agent.step.llm_response.received") # [3^]
|
||||
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
|
||||
|
||||
usage.step_count += 1
|
||||
usage.completion_tokens += response.usage.completion_tokens
|
||||
usage.prompt_tokens += response.usage.prompt_tokens
|
||||
usage.total_tokens += response.usage.total_tokens
|
||||
usage.run_ids = [run_id] if run_id else None
|
||||
MetricRegistry().message_output_tokens.record(
|
||||
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
|
||||
)
|
||||
|
||||
if not response.choices[0].message.tool_calls:
|
||||
# TODO: make into a real error
|
||||
raise ValueError("No tool calls found in response, model must make a tool call")
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
if response.choices[0].message.reasoning_content:
|
||||
reasoning = [
|
||||
ReasoningContent(
|
||||
reasoning=response.choices[0].message.reasoning_content,
|
||||
is_native=True,
|
||||
signature=response.choices[0].message.reasoning_content_signature,
|
||||
try:
|
||||
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
|
||||
await self._build_and_request_from_llm(
|
||||
current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver, agent_step_span
|
||||
)
|
||||
]
|
||||
elif response.choices[0].message.content:
|
||||
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
|
||||
elif response.choices[0].message.omitted_reasoning_content:
|
||||
reasoning = [OmittedReasoningContent()]
|
||||
else:
|
||||
self.logger.info("No reasoning content found.")
|
||||
reasoning = None
|
||||
)
|
||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||
|
||||
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
||||
tool_call,
|
||||
valid_tool_names,
|
||||
agent_state,
|
||||
tool_rules_solver,
|
||||
response.usage,
|
||||
reasoning_content=reasoning,
|
||||
step_id=step_id,
|
||||
initial_messages=initial_messages,
|
||||
agent_step_span=agent_step_span,
|
||||
is_final_step=(i == max_steps - 1),
|
||||
run_id=run_id,
|
||||
)
|
||||
new_message_idx = len(initial_messages) if initial_messages else 0
|
||||
self.response_messages.extend(persisted_messages[new_message_idx:])
|
||||
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
||||
step_progression = StepProgression.RESPONSE_RECEIVED
|
||||
log_event("agent.step.llm_response.received") # [3^]
|
||||
|
||||
initial_messages = None
|
||||
log_event("agent.step.llm_response.processed") # [4^]
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
|
||||
|
||||
# log step time
|
||||
now = get_utc_timestamp_ns()
|
||||
step_ns = now - step_start
|
||||
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
|
||||
agent_step_span.end()
|
||||
usage.step_count += 1
|
||||
usage.completion_tokens += response.usage.completion_tokens
|
||||
usage.prompt_tokens += response.usage.prompt_tokens
|
||||
usage.total_tokens += response.usage.total_tokens
|
||||
usage.run_ids = [run_id] if run_id else None
|
||||
MetricRegistry().message_output_tokens.record(
|
||||
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
|
||||
)
|
||||
|
||||
# Log LLM Trace
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
if not response.choices[0].message.tool_calls:
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
|
||||
raise ValueError("No tool calls found in response, model must make a tool call")
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
if response.choices[0].message.reasoning_content:
|
||||
reasoning = [
|
||||
ReasoningContent(
|
||||
reasoning=response.choices[0].message.reasoning_content,
|
||||
is_native=True,
|
||||
signature=response.choices[0].message.reasoning_content_signature,
|
||||
)
|
||||
]
|
||||
elif response.choices[0].message.content:
|
||||
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
|
||||
elif response.choices[0].message.omitted_reasoning_content:
|
||||
reasoning = [OmittedReasoningContent()]
|
||||
else:
|
||||
self.logger.info("No reasoning content found.")
|
||||
reasoning = None
|
||||
|
||||
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
||||
tool_call,
|
||||
valid_tool_names,
|
||||
agent_state,
|
||||
tool_rules_solver,
|
||||
response.usage,
|
||||
reasoning_content=reasoning,
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
initial_messages=initial_messages,
|
||||
agent_step_span=agent_step_span,
|
||||
is_final_step=(i == max_steps - 1),
|
||||
run_id=run_id,
|
||||
)
|
||||
step_progression = StepProgression.STEP_LOGGED
|
||||
|
||||
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
|
||||
new_message_idx = len(initial_messages) if initial_messages else 0
|
||||
self.response_messages.extend(persisted_messages[new_message_idx:])
|
||||
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
||||
|
||||
initial_messages = None
|
||||
log_event("agent.step.llm_response.processed") # [4^]
|
||||
|
||||
# log step time
|
||||
now = get_utc_timestamp_ns()
|
||||
step_ns = now - step_start
|
||||
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
|
||||
agent_step_span.end()
|
||||
|
||||
# Log LLM Trace
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
|
||||
step_progression = StepProgression.FINISHED
|
||||
|
||||
except Exception as e:
|
||||
# Handle any unexpected errors during step processing
|
||||
self.logger.error(f"Error during step processing: {e}")
|
||||
|
||||
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
|
||||
if not stop_reason:
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||
elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule):
|
||||
self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason)
|
||||
elif stop_reason.stop_reason not in (StopReasonType.no_tool_call, StopReasonType.invalid_tool_call):
|
||||
raise ValueError(f"Invalid Stop Reason: {stop_reason}")
|
||||
raise
|
||||
|
||||
# Update step if it needs to be updated
|
||||
finally:
|
||||
if settings.track_stop_reason:
|
||||
self.logger.info("Running final update. Step Progression: %s", step_progression)
|
||||
try:
|
||||
if step_progression < StepProgression.STEP_LOGGED:
|
||||
await self.step_manager.log_step_async(
|
||||
actor=self.actor,
|
||||
agent_id=agent_state.id,
|
||||
provider_name=agent_state.llm_config.model_endpoint_type,
|
||||
provider_category=agent_state.llm_config.provider_category or "base",
|
||||
model=agent_state.llm_config.model,
|
||||
model_endpoint=agent_state.llm_config.model_endpoint,
|
||||
context_window_limit=agent_state.llm_config.context_window,
|
||||
usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0),
|
||||
provider_id=None,
|
||||
job_id=self.current_run_id if self.current_run_id else None,
|
||||
step_id=step_id,
|
||||
project_id=agent_state.project_id,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
if step_progression <= StepProgression.RESPONSE_RECEIVED:
|
||||
# TODO (cliandy): persist response if we get it back
|
||||
if settings.track_errored_messages:
|
||||
for message in initial_messages:
|
||||
message.is_err = True
|
||||
message.step_id = step_id
|
||||
await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor)
|
||||
elif step_progression <= StepProgression.LOGGED_TRACE:
|
||||
if stop_reason is None:
|
||||
self.logger.error("Error in step after logging step")
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
|
||||
elif step_progression == StepProgression.FINISHED and not should_continue:
|
||||
if stop_reason is None:
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
||||
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
|
||||
else:
|
||||
self.logger.error("Invalid StepProgression value")
|
||||
except Exception as e:
|
||||
self.logger.error("Failed to update step: %s", e)
|
||||
|
||||
if not should_continue:
|
||||
break
|
||||
@@ -576,6 +705,7 @@ class LettaAgent(BaseAgent):
|
||||
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
||||
|
||||
for i in range(max_steps):
|
||||
step_id = generate_step_id()
|
||||
# Check for job cancellation at the start of each step
|
||||
if await self._check_run_cancellation():
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
|
||||
@@ -583,163 +713,230 @@ class LettaAgent(BaseAgent):
|
||||
yield f"data: {stop_reason.model_dump_json()}\n\n"
|
||||
break
|
||||
|
||||
step_id = generate_step_id()
|
||||
step_start = get_utc_timestamp_ns()
|
||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||||
agent_step_span.set_attributes({"step_id": step_id})
|
||||
|
||||
(
|
||||
request_data,
|
||||
stream,
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
valid_tool_names,
|
||||
provider_request_start_timestamp_ns,
|
||||
) = await self._build_and_request_from_llm_streaming(
|
||||
first_chunk,
|
||||
agent_step_span,
|
||||
request_start_timestamp_ns,
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
agent_state,
|
||||
llm_client,
|
||||
tool_rules_solver,
|
||||
)
|
||||
log_event("agent.stream.llm_response.received") # [3^]
|
||||
|
||||
# TODO: THIS IS INCREDIBLY UGLY
|
||||
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
|
||||
if agent_state.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]:
|
||||
interface = AnthropicStreamingInterface(
|
||||
use_assistant_message=use_assistant_message,
|
||||
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
|
||||
)
|
||||
elif agent_state.llm_config.model_endpoint_type == ProviderType.openai:
|
||||
interface = OpenAIStreamingInterface(
|
||||
use_assistant_message=use_assistant_message,
|
||||
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
|
||||
|
||||
async for chunk in interface.process(
|
||||
stream,
|
||||
ttft_span=request_span,
|
||||
provider_request_start_timestamp_ns=provider_request_start_timestamp_ns,
|
||||
):
|
||||
# Measure time to first token
|
||||
if first_chunk and request_span is not None:
|
||||
now = get_utc_timestamp_ns()
|
||||
ttft_ns = now - request_start_timestamp_ns
|
||||
request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ns_to_ms(ttft_ns)})
|
||||
metric_attributes = get_ctx_attributes()
|
||||
metric_attributes["model.name"] = agent_state.llm_config.model
|
||||
MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes)
|
||||
first_chunk = False
|
||||
|
||||
if include_return_message_types is None or chunk.message_type in include_return_message_types:
|
||||
# filter down returned data
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
stream_end_time_ns = get_utc_timestamp_ns()
|
||||
|
||||
# update usage
|
||||
usage.step_count += 1
|
||||
usage.completion_tokens += interface.output_tokens
|
||||
usage.prompt_tokens += interface.input_tokens
|
||||
usage.total_tokens += interface.input_tokens + interface.output_tokens
|
||||
MetricRegistry().message_output_tokens.record(
|
||||
interface.output_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
|
||||
)
|
||||
|
||||
# log LLM request time
|
||||
llm_request_ms = ns_to_ms(stream_end_time_ns - provider_request_start_timestamp_ns)
|
||||
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ms})
|
||||
MetricRegistry().llm_execution_time_ms_histogram.record(
|
||||
llm_request_ms,
|
||||
dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}),
|
||||
)
|
||||
|
||||
# Process resulting stream content
|
||||
step_progression = StepProgression.START
|
||||
should_continue = False
|
||||
try:
|
||||
tool_call = interface.get_tool_call_object()
|
||||
except ValueError as e:
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
|
||||
yield f"data: {stop_reason.model_dump_json()}\n\n"
|
||||
raise e
|
||||
except Exception as e:
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value)
|
||||
yield f"data: {stop_reason.model_dump_json()}\n\n"
|
||||
raise e
|
||||
reasoning_content = interface.get_reasoning_content()
|
||||
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
||||
tool_call,
|
||||
valid_tool_names,
|
||||
agent_state,
|
||||
tool_rules_solver,
|
||||
UsageStatistics(
|
||||
completion_tokens=interface.output_tokens,
|
||||
prompt_tokens=interface.input_tokens,
|
||||
total_tokens=interface.input_tokens + interface.output_tokens,
|
||||
),
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=interface.letta_message_id,
|
||||
step_id=step_id,
|
||||
initial_messages=initial_messages,
|
||||
agent_step_span=agent_step_span,
|
||||
is_final_step=(i == max_steps - 1),
|
||||
)
|
||||
new_message_idx = len(initial_messages) if initial_messages else 0
|
||||
self.response_messages.extend(persisted_messages[new_message_idx:])
|
||||
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
||||
(
|
||||
request_data,
|
||||
stream,
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
valid_tool_names,
|
||||
provider_request_start_timestamp_ns,
|
||||
) = await self._build_and_request_from_llm_streaming(
|
||||
first_chunk,
|
||||
agent_step_span,
|
||||
request_start_timestamp_ns,
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
agent_state,
|
||||
llm_client,
|
||||
tool_rules_solver,
|
||||
)
|
||||
|
||||
initial_messages = None
|
||||
step_progression = StepProgression.STREAM_RECEIVED
|
||||
log_event("agent.stream.llm_response.received") # [3^]
|
||||
|
||||
# log total step time
|
||||
now = get_utc_timestamp_ns()
|
||||
step_ns = now - step_start
|
||||
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
|
||||
agent_step_span.end()
|
||||
# TODO: THIS IS INCREDIBLY UGLY
|
||||
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
|
||||
if agent_state.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]:
|
||||
interface = AnthropicStreamingInterface(
|
||||
use_assistant_message=use_assistant_message,
|
||||
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
|
||||
)
|
||||
elif agent_state.llm_config.model_endpoint_type == ProviderType.openai:
|
||||
interface = OpenAIStreamingInterface(
|
||||
use_assistant_message=use_assistant_message,
|
||||
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
|
||||
|
||||
# TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream
|
||||
# log_event("agent.stream.llm_response.processed") # [4^]
|
||||
async for chunk in interface.process(
|
||||
stream,
|
||||
ttft_span=request_span,
|
||||
provider_request_start_timestamp_ns=provider_request_start_timestamp_ns,
|
||||
):
|
||||
# Measure time to first token
|
||||
if first_chunk and request_span is not None:
|
||||
now = get_utc_timestamp_ns()
|
||||
ttft_ns = now - request_start_timestamp_ns
|
||||
request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ns_to_ms(ttft_ns)})
|
||||
metric_attributes = get_ctx_attributes()
|
||||
metric_attributes["model.name"] = agent_state.llm_config.model
|
||||
MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes)
|
||||
first_chunk = False
|
||||
|
||||
# Log LLM Trace
|
||||
# TODO (cliandy): we are piecing together the streamed response here. Content here does not match the actual response schema.
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json={
|
||||
"content": {
|
||||
"tool_call": tool_call.model_dump_json(),
|
||||
"reasoning": [content.model_dump_json() for content in reasoning_content],
|
||||
},
|
||||
"id": interface.message_id,
|
||||
"model": interface.model,
|
||||
"role": "assistant",
|
||||
# "stop_reason": "",
|
||||
# "stop_sequence": None,
|
||||
"type": "message",
|
||||
"usage": {"input_tokens": interface.input_tokens, "output_tokens": interface.output_tokens},
|
||||
},
|
||||
if include_return_message_types is None or chunk.message_type in include_return_message_types:
|
||||
# filter down returned data
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
stream_end_time_ns = get_utc_timestamp_ns()
|
||||
|
||||
# update usage
|
||||
usage.step_count += 1
|
||||
usage.completion_tokens += interface.output_tokens
|
||||
usage.prompt_tokens += interface.input_tokens
|
||||
usage.total_tokens += interface.input_tokens + interface.output_tokens
|
||||
MetricRegistry().message_output_tokens.record(
|
||||
interface.output_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
|
||||
)
|
||||
|
||||
# log LLM request time
|
||||
llm_request_ms = ns_to_ms(stream_end_time_ns - provider_request_start_timestamp_ns)
|
||||
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ms})
|
||||
MetricRegistry().llm_execution_time_ms_histogram.record(
|
||||
llm_request_ms,
|
||||
dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}),
|
||||
)
|
||||
|
||||
# Process resulting stream content
|
||||
try:
|
||||
tool_call = interface.get_tool_call_object()
|
||||
except ValueError as e:
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
|
||||
raise e
|
||||
except Exception as e:
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value)
|
||||
raise e
|
||||
reasoning_content = interface.get_reasoning_content()
|
||||
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
||||
tool_call,
|
||||
valid_tool_names,
|
||||
agent_state,
|
||||
tool_rules_solver,
|
||||
UsageStatistics(
|
||||
completion_tokens=interface.output_tokens,
|
||||
prompt_tokens=interface.input_tokens,
|
||||
total_tokens=interface.input_tokens + interface.output_tokens,
|
||||
),
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=interface.letta_message_id,
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
initial_messages=initial_messages,
|
||||
agent_step_span=agent_step_span,
|
||||
is_final_step=(i == max_steps - 1),
|
||||
)
|
||||
step_progression = StepProgression.STEP_LOGGED
|
||||
|
||||
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
|
||||
if not (use_assistant_message and tool_return.name == "send_message"):
|
||||
# Apply message type filtering if specified
|
||||
if include_return_message_types is None or tool_return.message_type in include_return_message_types:
|
||||
yield f"data: {tool_return.model_dump_json()}\n\n"
|
||||
new_message_idx = len(initial_messages) if initial_messages else 0
|
||||
self.response_messages.extend(persisted_messages[new_message_idx:])
|
||||
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
||||
|
||||
# TODO (cliandy): consolidate and expand with trace
|
||||
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
|
||||
initial_messages = None
|
||||
|
||||
# log total step time
|
||||
now = get_utc_timestamp_ns()
|
||||
step_ns = now - step_start
|
||||
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
|
||||
agent_step_span.end()
|
||||
|
||||
# TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream
|
||||
# log_event("agent.stream.llm_response.processed") # [4^]
|
||||
|
||||
# Log LLM Trace
|
||||
# We are piecing together the streamed response here.
|
||||
# Content here does not match the actual response schema as streams come in chunks.
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json={
|
||||
"content": {
|
||||
"tool_call": tool_call.model_dump_json(),
|
||||
"reasoning": [content.model_dump_json() for content in reasoning_content],
|
||||
},
|
||||
"id": interface.message_id,
|
||||
"model": interface.model,
|
||||
"role": "assistant",
|
||||
# "stop_reason": "",
|
||||
# "stop_sequence": None,
|
||||
"type": "message",
|
||||
"usage": {
|
||||
"input_tokens": interface.input_tokens,
|
||||
"output_tokens": interface.output_tokens,
|
||||
},
|
||||
},
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
|
||||
# yields tool response as this is handled from Letta and not the response from the LLM provider
|
||||
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
|
||||
if not (use_assistant_message and tool_return.name == "send_message"):
|
||||
# Apply message type filtering if specified
|
||||
if include_return_message_types is None or tool_return.message_type in include_return_message_types:
|
||||
yield f"data: {tool_return.model_dump_json()}\n\n"
|
||||
|
||||
# TODO (cliandy): consolidate and expand with trace
|
||||
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
|
||||
step_progression = StepProgression.FINISHED
|
||||
|
||||
except Exception as e:
|
||||
# Handle any unexpected errors during step processing
|
||||
self.logger.error(f"Error during step processing: {e}")
|
||||
|
||||
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
|
||||
if not stop_reason:
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||
elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule):
|
||||
self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason)
|
||||
elif stop_reason.stop_reason not in (StopReasonType.no_tool_call, StopReasonType.invalid_tool_call):
|
||||
raise ValueError(f"Invalid Stop Reason: {stop_reason}")
|
||||
|
||||
# Send error stop reason to client and re-raise with expected response code
|
||||
yield f"data: {stop_reason.model_dump_json()}\n\n", 500
|
||||
raise
|
||||
|
||||
# Update step if it needs to be updated
|
||||
finally:
|
||||
if settings.track_stop_reason:
|
||||
self.logger.info("Running final update. Step Progression: %s", step_progression)
|
||||
try:
|
||||
if step_progression < StepProgression.STEP_LOGGED:
|
||||
await self.step_manager.log_step_async(
|
||||
actor=self.actor,
|
||||
agent_id=agent_state.id,
|
||||
provider_name=agent_state.llm_config.model_endpoint_type,
|
||||
provider_category=agent_state.llm_config.provider_category or "base",
|
||||
model=agent_state.llm_config.model,
|
||||
model_endpoint=agent_state.llm_config.model_endpoint,
|
||||
context_window_limit=agent_state.llm_config.context_window,
|
||||
usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0),
|
||||
provider_id=None,
|
||||
job_id=self.current_run_id if self.current_run_id else None,
|
||||
step_id=step_id,
|
||||
project_id=agent_state.project_id,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
if step_progression <= StepProgression.STREAM_RECEIVED:
|
||||
if first_chunk and settings.track_errored_messages:
|
||||
for message in initial_messages:
|
||||
message.is_err = True
|
||||
message.step_id = step_id
|
||||
await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor)
|
||||
elif step_progression <= StepProgression.LOGGED_TRACE:
|
||||
if stop_reason is None:
|
||||
self.logger.error("Error in step after logging step")
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
|
||||
elif step_progression == StepProgression.FINISHED and not should_continue:
|
||||
if stop_reason is None:
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
||||
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
|
||||
else:
|
||||
self.logger.error("Invalid StepProgression value")
|
||||
except Exception as e:
|
||||
self.logger.error("Failed to update step: %s", e)
|
||||
|
||||
if not should_continue:
|
||||
break
|
||||
|
||||
# Extend the in context message ids
|
||||
if not agent_state.message_buffer_autoclear:
|
||||
await self._rebuild_context_window(
|
||||
@@ -1106,6 +1303,7 @@ class LettaAgent(BaseAgent):
|
||||
job_id=run_id if run_id else self.current_run_id,
|
||||
step_id=step_id,
|
||||
project_id=agent_state.project_id,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
||||
tool_call_messages = create_letta_messages_from_llm_response(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Set, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -107,25 +107,20 @@ class ToolRulesSolver(BaseModel):
|
||||
self.tool_call_history.clear()
|
||||
|
||||
def get_allowed_tool_names(
|
||||
self, available_tools: Set[str], error_on_empty: bool = False, last_function_response: Optional[str] = None
|
||||
self, available_tools: set[str], error_on_empty: bool = True, last_function_response: str | None = None
|
||||
) -> List[str]:
|
||||
"""Get a list of tool names allowed based on the last tool called."""
|
||||
"""Get a list of tool names allowed based on the last tool called.
|
||||
|
||||
The logic is as follows:
|
||||
1. if there are no previous tool calls and we have InitToolRules, those are the only options for the first tool call
|
||||
2. else we take the intersection of the Parent/Child/Conditional/MaxSteps as the options
|
||||
3. Continue/Terminal/RequiredBeforeExit rules are applied in the agent loop flow, not to restrict tools
|
||||
"""
|
||||
# TODO: This piece of code here is quite ugly and deserves a refactor
|
||||
# TODO: There's some weird logic encoded here:
|
||||
# TODO: -> This only takes into consideration Init, and a set of Child/Conditional/MaxSteps tool rules
|
||||
# TODO: -> Init tool rules outputs are treated additively, Child/Conditional/MaxSteps are intersection based
|
||||
# TODO: -> Tool rules should probably be refactored to take in a set of tool names?
|
||||
# If no tool has been called yet, return InitToolRules additively
|
||||
if not self.tool_call_history:
|
||||
if self.init_tool_rules:
|
||||
# If there are init tool rules, only return those defined in the init tool rules
|
||||
return [rule.tool_name for rule in self.init_tool_rules]
|
||||
else:
|
||||
# Otherwise, return all tools besides those constrained by parent tool rules
|
||||
available_tools = available_tools - set.union(set(), *(set(rule.children) for rule in self.parent_tool_rules))
|
||||
return list(available_tools)
|
||||
if not self.tool_call_history and self.init_tool_rules:
|
||||
return [rule.tool_name for rule in self.init_tool_rules]
|
||||
else:
|
||||
# Collect valid tools from all child-based rules
|
||||
valid_tool_sets = []
|
||||
for rule in self.child_based_tool_rules + self.parent_tool_rules:
|
||||
tools = rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response)
|
||||
@@ -151,11 +146,11 @@ class ToolRulesSolver(BaseModel):
|
||||
"""Check if the tool is defined as a continue tool in the tool rules."""
|
||||
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
|
||||
|
||||
def has_required_tools_been_called(self, available_tools: Set[str]) -> bool:
|
||||
def has_required_tools_been_called(self, available_tools: set[str]) -> bool:
|
||||
"""Check if all required-before-exit tools have been called."""
|
||||
return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0
|
||||
|
||||
def get_uncalled_required_tools(self, available_tools: Set[str]) -> List[str]:
|
||||
def get_uncalled_required_tools(self, available_tools: set[str]) -> List[str]:
|
||||
"""Get the list of required-before-exit tools that have not been called yet."""
|
||||
if not self.required_before_exit_tool_rules:
|
||||
return [] # No required tools means no uncalled tools
|
||||
|
||||
@@ -49,6 +49,9 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
nullable=True,
|
||||
doc="The id of the LLMBatchItem that this message is associated with",
|
||||
)
|
||||
is_err: Mapped[Optional[bool]] = mapped_column(
|
||||
nullable=True, doc="Whether this message is part of an error step. Used only for debugging purposes."
|
||||
)
|
||||
|
||||
# Monotonically increasing sequence for efficient/correct listing
|
||||
sequence_id: Mapped[int] = mapped_column(
|
||||
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy import JSON, ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
from letta.schemas.step import Step as PydanticStep
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -45,6 +46,7 @@ class Step(SqlalchemyBase):
|
||||
prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt")
|
||||
total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent")
|
||||
completion_tokens_details: Mapped[Optional[Dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
|
||||
stop_reason: Mapped[Optional[StopReasonType]] = mapped_column(None, nullable=True, doc="The stop reason associated with this step.")
|
||||
tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.")
|
||||
tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.")
|
||||
trace_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The trace id of the agent step.")
|
||||
|
||||
@@ -40,15 +40,18 @@ class LettaMessage(BaseModel):
|
||||
message_type (MessageType): The type of the message
|
||||
otid (Optional[str]): The offline threading id associated with this message
|
||||
sender_id (Optional[str]): The id of the sender of the message, can be an identity id or agent id
|
||||
step_id (Optional[str]): The step id associated with the message
|
||||
is_err (Optional[bool]): Whether the message is an errored message or not. Used for debugging purposes only.
|
||||
"""
|
||||
|
||||
id: str
|
||||
date: datetime
|
||||
name: Optional[str] = None
|
||||
name: str | None = None
|
||||
message_type: MessageType = Field(..., description="The type of the message.")
|
||||
otid: Optional[str] = None
|
||||
sender_id: Optional[str] = None
|
||||
step_id: Optional[str] = None
|
||||
otid: str | None = None
|
||||
sender_id: str | None = None
|
||||
step_id: str | None = None
|
||||
is_err: bool | None = None
|
||||
|
||||
@field_serializer("date")
|
||||
def serialize_datetime(self, dt: datetime, _info):
|
||||
@@ -60,6 +63,14 @@ class LettaMessage(BaseModel):
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt.isoformat(timespec="seconds")
|
||||
|
||||
@field_serializer("is_err", when_used="unless-none")
|
||||
def serialize_is_err(self, value: bool | None, _info):
|
||||
"""
|
||||
Only serialize is_err field when it's True (for debugging purposes).
|
||||
When is_err is None or False, this field will be excluded from the JSON output.
|
||||
"""
|
||||
return value if value is True else None
|
||||
|
||||
|
||||
class SystemMessage(LettaMessage):
|
||||
"""
|
||||
|
||||
@@ -172,6 +172,9 @@ class Message(BaseMessage):
|
||||
group_id: Optional[str] = Field(default=None, description="The multi-agent group that the message was sent in")
|
||||
sender_id: Optional[str] = Field(default=None, description="The id of the sender of the message, can be an identity id or agent id")
|
||||
batch_item_id: Optional[str] = Field(default=None, description="The id of the LLMBatchItem that this message is associated with")
|
||||
is_err: Optional[bool] = Field(
|
||||
default=None, description="Whether this message is part of an error step. Used only for debugging purposes."
|
||||
)
|
||||
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||
|
||||
@@ -191,6 +194,7 @@ class Message(BaseMessage):
|
||||
if not is_utc_datetime(self.created_at):
|
||||
self.created_at = self.created_at.replace(tzinfo=timezone.utc)
|
||||
json_message["created_at"] = self.created_at.isoformat()
|
||||
json_message.pop("is_err", None) # make sure we don't include this debugging information
|
||||
return json_message
|
||||
|
||||
@staticmethod
|
||||
@@ -204,6 +208,7 @@ class Message(BaseMessage):
|
||||
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
reverse: bool = True,
|
||||
include_err: Optional[bool] = None,
|
||||
) -> List[LettaMessage]:
|
||||
if use_assistant_message:
|
||||
message_ids_to_remove = []
|
||||
@@ -234,6 +239,7 @@ class Message(BaseMessage):
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
reverse=reverse,
|
||||
include_err=include_err,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -243,6 +249,7 @@ class Message(BaseMessage):
|
||||
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
reverse: bool = True,
|
||||
include_err: Optional[bool] = None,
|
||||
) -> List[LettaMessage]:
|
||||
"""Convert message object (in DB format) to the style used by the original Letta API"""
|
||||
messages = []
|
||||
@@ -682,14 +689,13 @@ class Message(BaseMessage):
|
||||
# since the only "parts" we have are for supporting various COT
|
||||
|
||||
if self.role == "system":
|
||||
assert all([v is not None for v in [self.role]]), vars(self)
|
||||
openai_message = {
|
||||
"content": text_content,
|
||||
"role": "developer" if use_developer_message else self.role,
|
||||
}
|
||||
|
||||
elif self.role == "user":
|
||||
assert all([v is not None for v in [text_content, self.role]]), vars(self)
|
||||
assert text_content is not None, vars(self)
|
||||
openai_message = {
|
||||
"content": text_content,
|
||||
"role": self.role,
|
||||
@@ -720,7 +726,7 @@ class Message(BaseMessage):
|
||||
tool_call_dict["id"] = tool_call_dict["id"][:max_tool_id_length]
|
||||
|
||||
elif self.role == "tool":
|
||||
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self)
|
||||
assert self.tool_call_id is not None, vars(self)
|
||||
openai_message = {
|
||||
"content": text_content,
|
||||
"role": self.role,
|
||||
@@ -776,7 +782,7 @@ class Message(BaseMessage):
|
||||
if self.role == "system":
|
||||
# NOTE: this is not for system instructions, but instead system "events"
|
||||
|
||||
assert all([v is not None for v in [text_content, self.role]]), vars(self)
|
||||
assert text_content is not None, vars(self)
|
||||
# Two options here, we would use system.package_system_message,
|
||||
# or use a more Anthropic-specific packaging ie xml tags
|
||||
user_system_event = add_xml_tag(string=f"SYSTEM ALERT: {text_content}", xml_tag="event")
|
||||
@@ -875,7 +881,7 @@ class Message(BaseMessage):
|
||||
|
||||
elif self.role == "tool":
|
||||
# NOTE: Anthropic uses role "user" for "tool" responses
|
||||
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self)
|
||||
assert self.tool_call_id is not None, vars(self)
|
||||
anthropic_message = {
|
||||
"role": "user", # NOTE: diff
|
||||
"content": [
|
||||
@@ -988,7 +994,7 @@ class Message(BaseMessage):
|
||||
|
||||
elif self.role == "tool":
|
||||
# NOTE: Significantly different tool calling format, more similar to function calling format
|
||||
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self)
|
||||
assert self.tool_call_id is not None, vars(self)
|
||||
|
||||
if self.name is None:
|
||||
warnings.warn(f"Couldn't find function name on tool call, defaulting to tool ID instead.")
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from enum import Enum, auto
|
||||
from typing import Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
from letta.schemas.message import Message
|
||||
|
||||
|
||||
@@ -28,6 +30,7 @@ class Step(StepBase):
|
||||
prompt_tokens: Optional[int] = Field(None, description="The number of tokens in the prompt during this step.")
|
||||
total_tokens: Optional[int] = Field(None, description="The total number of tokens processed by the agent during this step.")
|
||||
completion_tokens_details: Optional[Dict] = Field(None, description="Metadata for the agent.")
|
||||
stop_reason: Optional[StopReasonType] = Field(None, description="The stop reason associated with the step.")
|
||||
tags: List[str] = Field([], description="Metadata tags.")
|
||||
tid: Optional[str] = Field(None, description="The unique identifier of the transaction that processed this step.")
|
||||
trace_id: Optional[str] = Field(None, description="The trace id of the agent step.")
|
||||
@@ -36,3 +39,12 @@ class Step(StepBase):
|
||||
None, description="The feedback for this step. Must be either 'positive' or 'negative'."
|
||||
)
|
||||
project_id: Optional[str] = Field(None, description="The project that the agent that executed this step belongs to (cloud only).")
|
||||
|
||||
|
||||
class StepProgression(int, Enum):
|
||||
START = auto()
|
||||
STREAM_RECEIVED = auto()
|
||||
RESPONSE_RECEIVED = auto()
|
||||
STEP_LOGGED = auto()
|
||||
LOGGED_TRACE = auto()
|
||||
FINISHED = auto()
|
||||
|
||||
@@ -2,8 +2,10 @@ import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -34,32 +36,25 @@ from letta.server.db import db_registry
|
||||
from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right?
|
||||
from letta.server.rest_api.interface import StreamingServerInterface
|
||||
from letta.server.rest_api.routers.openai.chat_completions.chat_completions import router as openai_chat_completions_router
|
||||
|
||||
# from letta.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM
|
||||
from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes
|
||||
from letta.server.rest_api.routers.v1.organizations import router as organizations_router
|
||||
from letta.server.rest_api.routers.v1.users import router as users_router # TODO: decide on admin
|
||||
from letta.server.rest_api.static_files import mount_static_files
|
||||
from letta.server.rest_api.utils import SENTRY_ENABLED
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import settings
|
||||
|
||||
# TODO(ethan)
|
||||
if SENTRY_ENABLED:
|
||||
import sentry_sdk
|
||||
|
||||
IS_WINDOWS = platform.system() == "Windows"
|
||||
|
||||
# NOTE(charles): @ethan I had to add this to get the global as the bottom to work
|
||||
interface: StreamingServerInterface = StreamingServerInterface
|
||||
interface: type = StreamingServerInterface
|
||||
server = SyncServer(default_interface_factory=lambda: interface())
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
import logging
|
||||
import platform
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
is_windows = platform.system() == "Windows"
|
||||
|
||||
log = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def generate_openapi_schema(app: FastAPI):
|
||||
# Update the OpenAPI schema
|
||||
if not app.openapi_schema:
|
||||
@@ -177,9 +172,7 @@ def create_application() -> "FastAPI":
|
||||
# server = SyncServer(default_interface_factory=lambda: interface())
|
||||
print(f"\n[[ Letta server // v{letta_version} ]]")
|
||||
|
||||
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
|
||||
import sentry_sdk
|
||||
|
||||
if SENTRY_ENABLED:
|
||||
sentry_sdk.init(
|
||||
dsn=os.getenv("SENTRY_DSN"),
|
||||
traces_sample_rate=1.0,
|
||||
@@ -187,6 +180,7 @@ def create_application() -> "FastAPI":
|
||||
"continuous_profiling_auto_start": True,
|
||||
},
|
||||
)
|
||||
logger.info("Sentry enabled.")
|
||||
|
||||
debug_mode = "--debug" in sys.argv
|
||||
app = FastAPI(
|
||||
@@ -199,31 +193,13 @@ def create_application() -> "FastAPI":
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
@app.exception_handler(IncompatibleAgentType)
|
||||
async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType):
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"detail": str(exc),
|
||||
"expected_type": exc.expected_type,
|
||||
"actual_type": exc.actual_type,
|
||||
},
|
||||
)
|
||||
# === Exception Handlers ===
|
||||
# TODO (cliandy): move to separate file
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def generic_error_handler(request: Request, exc: Exception):
|
||||
# Log the actual error for debugging
|
||||
log.error(f"Unhandled error: {str(exc)}", exc_info=True)
|
||||
print(f"Unhandled error: {str(exc)}")
|
||||
|
||||
import traceback
|
||||
|
||||
# Print the stack trace
|
||||
print(f"Stack trace: {traceback.format_exc()}")
|
||||
|
||||
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
|
||||
import sentry_sdk
|
||||
|
||||
logger.error(f"Unhandled error: {str(exc)}", exc_info=True)
|
||||
if SENTRY_ENABLED:
|
||||
sentry_sdk.capture_exception(exc)
|
||||
|
||||
return JSONResponse(
|
||||
@@ -235,62 +211,70 @@ def create_application() -> "FastAPI":
|
||||
},
|
||||
)
|
||||
|
||||
@app.exception_handler(NoResultFound)
|
||||
async def no_result_found_handler(request: Request, exc: NoResultFound):
|
||||
logger.error(f"NoResultFound: {exc}")
|
||||
async def error_handler_with_code(request: Request, exc: Exception, code: int, detail: str | None = None):
|
||||
logger.error(f"{type(exc).__name__}", exc_info=exc)
|
||||
if SENTRY_ENABLED:
|
||||
sentry_sdk.capture_exception(exc)
|
||||
|
||||
if not detail:
|
||||
detail = str(exc)
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={"detail": str(exc)},
|
||||
status_code=code,
|
||||
content={"detail": detail},
|
||||
)
|
||||
|
||||
@app.exception_handler(ForeignKeyConstraintViolationError)
|
||||
async def foreign_key_constraint_handler(request: Request, exc: ForeignKeyConstraintViolationError):
|
||||
logger.error(f"ForeignKeyConstraintViolationError: {exc}")
|
||||
_error_handler_400 = partial(error_handler_with_code, code=400)
|
||||
_error_handler_404 = partial(error_handler_with_code, code=404)
|
||||
_error_handler_404_agent = partial(_error_handler_404, detail="Agent not found")
|
||||
_error_handler_404_user = partial(_error_handler_404, detail="User not found")
|
||||
_error_handler_409 = partial(error_handler_with_code, code=409)
|
||||
|
||||
app.add_exception_handler(ValueError, _error_handler_400)
|
||||
app.add_exception_handler(NoResultFound, _error_handler_404)
|
||||
app.add_exception_handler(LettaAgentNotFoundError, _error_handler_404_agent)
|
||||
app.add_exception_handler(LettaUserNotFoundError, _error_handler_404_user)
|
||||
app.add_exception_handler(ForeignKeyConstraintViolationError, _error_handler_409)
|
||||
app.add_exception_handler(UniqueConstraintViolationError, _error_handler_409)
|
||||
|
||||
@app.exception_handler(IncompatibleAgentType)
|
||||
async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType):
|
||||
logger.error("Incompatible agent types. Expected: %s, Actual: %s", exc.expected_type, exc.actual_type)
|
||||
if SENTRY_ENABLED:
|
||||
sentry_sdk.capture_exception(exc)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content={"detail": str(exc)},
|
||||
)
|
||||
|
||||
@app.exception_handler(UniqueConstraintViolationError)
|
||||
async def unique_key_constraint_handler(request: Request, exc: UniqueConstraintViolationError):
|
||||
logger.error(f"UniqueConstraintViolationError: {exc}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content={"detail": str(exc)},
|
||||
status_code=400,
|
||||
content={
|
||||
"detail": str(exc),
|
||||
"expected_type": exc.expected_type,
|
||||
"actual_type": exc.actual_type,
|
||||
},
|
||||
)
|
||||
|
||||
@app.exception_handler(DatabaseTimeoutError)
|
||||
async def database_timeout_error_handler(request: Request, exc: DatabaseTimeoutError):
|
||||
logger.error(f"Timeout occurred: {exc}. Original exception: {exc.original_exception}")
|
||||
if SENTRY_ENABLED:
|
||||
sentry_sdk.capture_exception(exc)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"detail": "The database is temporarily unavailable. Please try again later."},
|
||||
)
|
||||
|
||||
@app.exception_handler(ValueError)
|
||||
async def value_error_handler(request: Request, exc: ValueError):
|
||||
return JSONResponse(status_code=400, content={"detail": str(exc)})
|
||||
|
||||
@app.exception_handler(LettaAgentNotFoundError)
|
||||
async def agent_not_found_handler(request: Request, exc: LettaAgentNotFoundError):
|
||||
return JSONResponse(status_code=404, content={"detail": "Agent not found"})
|
||||
|
||||
@app.exception_handler(LettaUserNotFoundError)
|
||||
async def user_not_found_handler(request: Request, exc: LettaUserNotFoundError):
|
||||
return JSONResponse(status_code=404, content={"detail": "User not found"})
|
||||
|
||||
@app.exception_handler(BedrockPermissionError)
|
||||
async def bedrock_permission_error_handler(request, exc: BedrockPermissionError):
|
||||
logger.error(f"Bedrock permission denied.")
|
||||
if SENTRY_ENABLED:
|
||||
sentry_sdk.capture_exception(exc)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={
|
||||
"error": {
|
||||
"type": "bedrock_permission_denied",
|
||||
"message": "Unable to access the required AI model. Please check your Bedrock permissions or contact support.",
|
||||
"details": {"model_arn": exc.model_arn, "reason": str(exc)},
|
||||
"detail": {str(exc)},
|
||||
}
|
||||
},
|
||||
)
|
||||
@@ -301,6 +285,9 @@ def create_application() -> "FastAPI":
|
||||
print(f"▶ Using secure mode with password: {random_password}")
|
||||
app.add_middleware(CheckPasswordMiddleware)
|
||||
|
||||
# Add reverse proxy middleware to handle X-Forwarded-* headers
|
||||
# app.add_middleware(ReverseProxyMiddleware, base_path=settings.server_base_path)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
@@ -442,7 +429,7 @@ def start_server(
|
||||
)
|
||||
|
||||
else:
|
||||
if is_windows:
|
||||
if IS_WINDOWS:
|
||||
# Windows doesn't those the fancy unicode characters
|
||||
print(f"Server running at: http://{host or 'localhost'}:{port or REST_DEFAULT_PORT}")
|
||||
print(f"View using ADE at: https://app.letta.com/development-servers/local/dashboard\n")
|
||||
|
||||
@@ -636,6 +636,9 @@ async def list_messages(
|
||||
use_assistant_message: bool = Query(True, description="Whether to use assistant messages"),
|
||||
assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."),
|
||||
assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."),
|
||||
include_err: bool | None = Query(
|
||||
None, description="Whether to include error messages and error statuses. For debugging purposes only."
|
||||
),
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
@@ -654,6 +657,7 @@ async def list_messages(
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
include_err=include_err,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
@@ -1156,7 +1160,7 @@ async def list_agent_groups(
|
||||
):
|
||||
"""Lists the groups for an agent"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
print("in list agents with manager_type", manager_type)
|
||||
logger.info("in list agents with manager_type", manager_type)
|
||||
return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor)
|
||||
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from starlette.types import Send
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import capture_sentry_exception
|
||||
from letta.services.job_manager import JobManager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -92,6 +93,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
|
||||
more_body = True
|
||||
try:
|
||||
first_chunk = await self.body_iterator.__anext__()
|
||||
logger.debug("stream_response first chunk:", first_chunk)
|
||||
if isinstance(first_chunk, tuple):
|
||||
first_chunk_content, self.status_code = first_chunk
|
||||
else:
|
||||
@@ -130,7 +132,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
|
||||
"more_body": more_body,
|
||||
}
|
||||
)
|
||||
return
|
||||
raise Exception(f"An exception occurred mid-stream with status code {status_code}", detail={"content": content})
|
||||
else:
|
||||
content = chunk
|
||||
|
||||
@@ -146,8 +148,8 @@ class StreamingResponseWithStatusCode(StreamingResponse):
|
||||
)
|
||||
|
||||
# This should be handled properly upstream?
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Stream was cancelled by client or job cancellation")
|
||||
except asyncio.CancelledError as exc:
|
||||
logger.warning("Stream was cancelled by client or job cancellation")
|
||||
# Handle cancellation gracefully
|
||||
more_body = False
|
||||
cancellation_resp = {"error": {"message": "Stream cancelled"}}
|
||||
@@ -160,6 +162,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
raise
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
@@ -167,13 +170,15 @@ class StreamingResponseWithStatusCode(StreamingResponse):
|
||||
"more_body": more_body,
|
||||
}
|
||||
)
|
||||
capture_sentry_exception(exc)
|
||||
return
|
||||
|
||||
except Exception:
|
||||
logger.exception("unhandled_streaming_error")
|
||||
except Exception as exc:
|
||||
logger.exception("Unhandled Streaming Error")
|
||||
more_body = False
|
||||
error_resp = {"error": {"message": "Internal Server Error"}}
|
||||
error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset)
|
||||
logger.debug("response_started:", self.response_started)
|
||||
if not self.response_started:
|
||||
await send(
|
||||
{
|
||||
@@ -182,6 +187,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
raise
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
@@ -189,5 +195,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
|
||||
"more_body": more_body,
|
||||
}
|
||||
)
|
||||
capture_sentry_exception(exc)
|
||||
return
|
||||
if more_body:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
|
||||
@@ -2,7 +2,6 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Iterable, List, Optional, Union, cast
|
||||
|
||||
@@ -34,12 +33,15 @@ from letta.schemas.message import Message, MessageCreate, ToolReturn
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.interface import StreamingServerInterface
|
||||
from letta.system import get_heartbeat, package_function_response
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
SENTRY_ENABLED = bool(os.getenv("SENTRY_DSN"))
|
||||
|
||||
if SENTRY_ENABLED:
|
||||
import sentry_sdk
|
||||
|
||||
SSE_PREFIX = "data: "
|
||||
SSE_SUFFIX = "\n\n"
|
||||
@@ -157,21 +159,9 @@ def get_user_id(user_id: Optional[str] = Header(None, alias="user_id")) -> Optio
|
||||
return user_id
|
||||
|
||||
|
||||
def get_current_interface() -> StreamingServerInterface:
|
||||
return StreamingServerInterface
|
||||
|
||||
|
||||
def log_error_to_sentry(e):
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
warnings.warn(f"SSE stream generator failed: {e}")
|
||||
|
||||
# Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response
|
||||
# Print the stack trace
|
||||
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
|
||||
import sentry_sdk
|
||||
|
||||
def capture_sentry_exception(e: BaseException):
|
||||
"""This will capture the exception in sentry, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response"""
|
||||
if SENTRY_ENABLED:
|
||||
sentry_sdk.capture_exception(e)
|
||||
|
||||
|
||||
|
||||
@@ -105,6 +105,7 @@ from letta.services.tool_executor.tool_execution_manager import ToolExecutionMan
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import model_settings, settings, tool_settings
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface
|
||||
from letta.utils import get_friendly_error_msg, get_persona_text, make_key
|
||||
|
||||
config = LettaConfig.load()
|
||||
@@ -176,7 +177,7 @@ class SyncServer(Server):
|
||||
self,
|
||||
chaining: bool = True,
|
||||
max_chaining_steps: Optional[int] = 100,
|
||||
default_interface_factory: Callable[[], AgentInterface] = lambda: CLIInterface(),
|
||||
default_interface_factory: Callable[[], AgentChunkStreamingInterface] = lambda: CLIInterface(),
|
||||
init_with_default_org_and_user: bool = True,
|
||||
# default_interface: AgentInterface = CLIInterface(),
|
||||
# default_persistence_manager_cls: PersistenceManager = LocalStateManager,
|
||||
@@ -1244,6 +1245,7 @@ class SyncServer(Server):
|
||||
use_assistant_message: bool = True,
|
||||
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
include_err: Optional[bool] = None,
|
||||
) -> Union[List[Message], List[LettaMessage]]:
|
||||
records = await self.message_manager.list_messages_for_agent_async(
|
||||
agent_id=agent_id,
|
||||
@@ -1253,6 +1255,7 @@ class SyncServer(Server):
|
||||
limit=limit,
|
||||
ascending=not reverse,
|
||||
group_id=group_id,
|
||||
include_err=include_err,
|
||||
)
|
||||
|
||||
if not return_message_object:
|
||||
@@ -1262,6 +1265,7 @@ class SyncServer(Server):
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
reverse=reverse,
|
||||
include_err=include_err,
|
||||
)
|
||||
|
||||
if reverse:
|
||||
|
||||
@@ -520,6 +520,7 @@ class MessageManager:
|
||||
limit: Optional[int] = 50,
|
||||
ascending: bool = True,
|
||||
group_id: Optional[str] = None,
|
||||
include_err: Optional[bool] = None,
|
||||
) -> List[PydanticMessage]:
|
||||
"""
|
||||
Most performant query to list messages for an agent by directly querying the Message table.
|
||||
@@ -539,6 +540,7 @@ class MessageManager:
|
||||
limit: Maximum number of messages to return.
|
||||
ascending: If True, sort by sequence_id ascending; if False, sort descending.
|
||||
group_id: Optional group ID to filter messages by group_id.
|
||||
include_err: Optional boolean to include errors and error statuses. Used for debugging only.
|
||||
|
||||
Returns:
|
||||
List[PydanticMessage]: A list of messages (converted via .to_pydantic()).
|
||||
@@ -558,6 +560,9 @@ class MessageManager:
|
||||
if group_id:
|
||||
query = query.where(MessageModel.group_id == group_id)
|
||||
|
||||
if not include_err:
|
||||
query = query.where((MessageModel.is_err == False) | (MessageModel.is_err.is_(None)))
|
||||
|
||||
# If query_text is provided, filter messages using database-specific JSON search.
|
||||
if query_text:
|
||||
if settings.letta_pg_uri_no_default:
|
||||
|
||||
@@ -12,6 +12,7 @@ from letta.orm.job import Job as JobModel
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.orm.step import Step as StepModel
|
||||
from letta.otel.tracing import get_trace_id, trace_method
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.step import Step as PydanticStep
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
@@ -131,6 +132,7 @@ class StepManager:
|
||||
job_id: Optional[str] = None,
|
||||
step_id: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
stop_reason: Optional[LettaStopReason] = None,
|
||||
) -> PydanticStep:
|
||||
step_data = {
|
||||
"origin": None,
|
||||
@@ -153,6 +155,8 @@ class StepManager:
|
||||
}
|
||||
if step_id:
|
||||
step_data["id"] = step_id
|
||||
if stop_reason:
|
||||
step_data["stop_reason"] = stop_reason.stop_reason
|
||||
async with db_registry.async_session() as session:
|
||||
if job_id:
|
||||
await self._verify_job_access_async(session, job_id, actor, access=["write"])
|
||||
@@ -207,6 +211,33 @@ class StepManager:
|
||||
await session.commit()
|
||||
return step.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def update_step_stop_reason(self, actor: PydanticUser, step_id: str, stop_reason: StopReasonType) -> PydanticStep:
|
||||
"""Update the stop reason for a step.
|
||||
|
||||
Args:
|
||||
actor: The user making the request
|
||||
step_id: The ID of the step to update
|
||||
stop_reason: The stop reason to set
|
||||
|
||||
Returns:
|
||||
The updated step
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the step does not exist
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
step = await session.get(StepModel, step_id)
|
||||
if not step:
|
||||
raise NoResultFound(f"Step with id {step_id} does not exist")
|
||||
if step.organization_id != actor.organization_id:
|
||||
raise Exception("Unauthorized")
|
||||
|
||||
step.stop_reason = stop_reason
|
||||
await session.commit()
|
||||
return step
|
||||
|
||||
def _verify_job_access(
|
||||
self,
|
||||
session: Session,
|
||||
@@ -309,5 +340,6 @@ class NoopStepManager(StepManager):
|
||||
job_id: Optional[str] = None,
|
||||
step_id: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
stop_reason: Optional[LettaStopReason] = None,
|
||||
) -> PydanticStep:
|
||||
return
|
||||
|
||||
@@ -220,13 +220,15 @@ class Settings(BaseSettings):
|
||||
multi_agent_concurrent_sends: int = 50
|
||||
|
||||
# telemetry logging
|
||||
otel_exporter_otlp_endpoint: Optional[str] = None # otel default: "http://localhost:4317"
|
||||
otel_preferred_temporality: Optional[int] = Field(
|
||||
otel_exporter_otlp_endpoint: str | None = None # otel default: "http://localhost:4317"
|
||||
otel_preferred_temporality: int | None = Field(
|
||||
default=1, ge=0, le=2, description="Exported metric temporality. {0: UNSPECIFIED, 1: DELTA, 2: CUMULATIVE}"
|
||||
)
|
||||
disable_tracing: bool = Field(default=False, description="Disable OTEL Tracing")
|
||||
llm_api_logging: bool = Field(default=True, description="Enable LLM API logging at each step")
|
||||
track_last_agent_run: bool = Field(default=False, description="Update last agent run metrics")
|
||||
track_errored_messages: bool = Field(default=True, description="Enable tracking for errored messages")
|
||||
track_stop_reason: bool = Field(default=True, description="Enable tracking stop reason on steps.")
|
||||
|
||||
# uvicorn settings
|
||||
uvicorn_workers: int = 1
|
||||
|
||||
@@ -752,14 +752,15 @@ def test_step_stream_agent_loop_error(
|
||||
"""
|
||||
last_message = client.agents.messages.list(agent_id=agent_state_no_tools.id, limit=1)
|
||||
agent_state_no_tools = client.agents.modify(agent_id=agent_state_no_tools.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state_no_tools.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state_no_tools.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
list(response)
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
print("error info:", exc_info)
|
||||
assert type(exc_info.value) in (ApiError, ValueError)
|
||||
print(exc_info.value)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state_no_tools.id, after=last_message[0].id)
|
||||
assert len(messages_from_db) == 0
|
||||
|
||||
|
||||
@@ -138,7 +138,9 @@ def test_max_count_per_step_tool_rule():
|
||||
assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "After first use, should still allow 'start_tool'"
|
||||
|
||||
solver.register_tool_call(START_TOOL)
|
||||
assert solver.get_allowed_tool_names({START_TOOL}) == [], "After reaching max count, 'start_tool' should no longer be allowed"
|
||||
assert (
|
||||
solver.get_allowed_tool_names({START_TOOL}, error_on_empty=False) == []
|
||||
), "After reaching max count, 'start_tool' should no longer be allowed"
|
||||
|
||||
|
||||
def test_max_count_per_step_tool_rule_allows_usage_up_to_limit():
|
||||
@@ -155,7 +157,7 @@ def test_max_count_per_step_tool_rule_allows_usage_up_to_limit():
|
||||
assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Should still allow 'start_tool' after 2 uses"
|
||||
|
||||
solver.register_tool_call(START_TOOL)
|
||||
assert solver.get_allowed_tool_names({START_TOOL}) == [], "Should no longer allow 'start_tool' after 3 uses"
|
||||
assert solver.get_allowed_tool_names({START_TOOL}, error_on_empty=False) == [], "Should no longer allow 'start_tool' after 3 uses"
|
||||
|
||||
|
||||
def test_max_count_per_step_tool_rule_does_not_affect_other_tools():
|
||||
@@ -180,7 +182,7 @@ def test_max_count_per_step_tool_rule_resets_on_clear():
|
||||
solver.register_tool_call(START_TOOL)
|
||||
solver.register_tool_call(START_TOOL)
|
||||
|
||||
assert solver.get_allowed_tool_names({START_TOOL}) == [], "Should not allow 'start_tool' after reaching limit"
|
||||
assert solver.get_allowed_tool_names({START_TOOL}, error_on_empty=False) == [], "Should not allow 'start_tool' after reaching limit"
|
||||
|
||||
solver.clear_tool_history()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user