feat: stop reasons and error messages and sentry fixes
This commit is contained in:
@@ -0,0 +1,42 @@
|
|||||||
|
"""add stop reasons to steps and message error flag
|
||||||
|
|
||||||
|
Revision ID: cce9a6174366
|
||||||
|
Revises: 2c059cad97cc
|
||||||
|
Create Date: 2025-07-10 13:56:17.383612
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "cce9a6174366"
|
||||||
|
down_revision: Union[str, None] = "2c059cad97cc"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column("messages", sa.Column("is_err", sa.Boolean(), nullable=True))
|
||||||
|
|
||||||
|
# manually added to handle non-table creation enums
|
||||||
|
stopreasontype = sa.Enum(
|
||||||
|
"end_turn", "error", "invalid_tool_call", "max_steps", "no_tool_call", "tool_rule", "cancelled", name="stopreasontype"
|
||||||
|
)
|
||||||
|
stopreasontype.create(op.get_bind())
|
||||||
|
op.add_column("steps", sa.Column("stop_reason", stopreasontype, nullable=True))
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column("steps", "stop_reason")
|
||||||
|
op.drop_column("messages", "is_err")
|
||||||
|
|
||||||
|
stopreasontype = sa.Enum(name="stopreasontype")
|
||||||
|
stopreasontype.drop(op.get_bind())
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -43,6 +43,7 @@ from letta.schemas.llm_config import LLMConfig
|
|||||||
from letta.schemas.message import Message, MessageCreate
|
from letta.schemas.message import Message, MessageCreate
|
||||||
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
||||||
from letta.schemas.provider_trace import ProviderTraceCreate
|
from letta.schemas.provider_trace import ProviderTraceCreate
|
||||||
|
from letta.schemas.step import StepProgression
|
||||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
from letta.schemas.usage import LettaUsageStatistics
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
@@ -238,100 +239,164 @@ class LettaAgent(BaseAgent):
|
|||||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||||||
agent_step_span.set_attributes({"step_id": step_id})
|
agent_step_span.set_attributes({"step_id": step_id})
|
||||||
|
|
||||||
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
|
step_progression = StepProgression.START
|
||||||
await self._build_and_request_from_llm(
|
should_continue = False
|
||||||
current_in_context_messages,
|
try:
|
||||||
new_in_context_messages,
|
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
|
||||||
agent_state,
|
await self._build_and_request_from_llm(
|
||||||
llm_client,
|
current_in_context_messages,
|
||||||
tool_rules_solver,
|
new_in_context_messages,
|
||||||
agent_step_span,
|
agent_state,
|
||||||
)
|
llm_client,
|
||||||
)
|
tool_rules_solver,
|
||||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
agent_step_span,
|
||||||
|
|
||||||
log_event("agent.stream_no_tokens.llm_response.received") # [3^]
|
|
||||||
|
|
||||||
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
|
|
||||||
|
|
||||||
# update usage
|
|
||||||
usage.step_count += 1
|
|
||||||
usage.completion_tokens += response.usage.completion_tokens
|
|
||||||
usage.prompt_tokens += response.usage.prompt_tokens
|
|
||||||
usage.total_tokens += response.usage.total_tokens
|
|
||||||
MetricRegistry().message_output_tokens.record(
|
|
||||||
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
|
|
||||||
)
|
|
||||||
|
|
||||||
if not response.choices[0].message.tool_calls:
|
|
||||||
# TODO: make into a real error
|
|
||||||
raise ValueError("No tool calls found in response, model must make a tool call")
|
|
||||||
tool_call = response.choices[0].message.tool_calls[0]
|
|
||||||
if response.choices[0].message.reasoning_content:
|
|
||||||
reasoning = [
|
|
||||||
ReasoningContent(
|
|
||||||
reasoning=response.choices[0].message.reasoning_content,
|
|
||||||
is_native=True,
|
|
||||||
signature=response.choices[0].message.reasoning_content_signature,
|
|
||||||
)
|
)
|
||||||
]
|
)
|
||||||
elif response.choices[0].message.omitted_reasoning_content:
|
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||||
reasoning = [OmittedReasoningContent()]
|
|
||||||
elif response.choices[0].message.content:
|
|
||||||
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
|
|
||||||
else:
|
|
||||||
self.logger.info("No reasoning content found.")
|
|
||||||
reasoning = None
|
|
||||||
|
|
||||||
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
step_progression = StepProgression.RESPONSE_RECEIVED
|
||||||
tool_call,
|
log_event("agent.stream_no_tokens.llm_response.received") # [3^]
|
||||||
valid_tool_names,
|
|
||||||
agent_state,
|
|
||||||
tool_rules_solver,
|
|
||||||
response.usage,
|
|
||||||
reasoning_content=reasoning,
|
|
||||||
step_id=step_id,
|
|
||||||
initial_messages=initial_messages,
|
|
||||||
agent_step_span=agent_step_span,
|
|
||||||
is_final_step=(i == max_steps - 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO (cliandy): handle message contexts with larger refactor and dedupe logic
|
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
|
||||||
new_message_idx = len(initial_messages) if initial_messages else 0
|
|
||||||
self.response_messages.extend(persisted_messages[new_message_idx:])
|
|
||||||
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
|
||||||
initial_messages = None
|
|
||||||
log_event("agent.stream_no_tokens.llm_response.processed") # [4^]
|
|
||||||
|
|
||||||
# log step time
|
# update usage
|
||||||
now = get_utc_timestamp_ns()
|
usage.step_count += 1
|
||||||
step_ns = now - step_start
|
usage.completion_tokens += response.usage.completion_tokens
|
||||||
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
|
usage.prompt_tokens += response.usage.prompt_tokens
|
||||||
agent_step_span.end()
|
usage.total_tokens += response.usage.total_tokens
|
||||||
|
MetricRegistry().message_output_tokens.record(
|
||||||
|
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
|
||||||
|
)
|
||||||
|
|
||||||
# Log LLM Trace
|
if not response.choices[0].message.tool_calls:
|
||||||
await self.telemetry_manager.create_provider_trace_async(
|
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
|
||||||
actor=self.actor,
|
raise ValueError("No tool calls found in response, model must make a tool call")
|
||||||
provider_trace_create=ProviderTraceCreate(
|
tool_call = response.choices[0].message.tool_calls[0]
|
||||||
request_json=request_data,
|
if response.choices[0].message.reasoning_content:
|
||||||
response_json=response_data,
|
reasoning = [
|
||||||
|
ReasoningContent(
|
||||||
|
reasoning=response.choices[0].message.reasoning_content,
|
||||||
|
is_native=True,
|
||||||
|
signature=response.choices[0].message.reasoning_content_signature,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
elif response.choices[0].message.omitted_reasoning_content:
|
||||||
|
reasoning = [OmittedReasoningContent()]
|
||||||
|
elif response.choices[0].message.content:
|
||||||
|
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
|
||||||
|
else:
|
||||||
|
self.logger.info("No reasoning content found.")
|
||||||
|
reasoning = None
|
||||||
|
|
||||||
|
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
||||||
|
tool_call,
|
||||||
|
valid_tool_names,
|
||||||
|
agent_state,
|
||||||
|
tool_rules_solver,
|
||||||
|
response.usage,
|
||||||
|
reasoning_content=reasoning,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
organization_id=self.actor.organization_id,
|
initial_messages=initial_messages,
|
||||||
),
|
agent_step_span=agent_step_span,
|
||||||
)
|
is_final_step=(i == max_steps - 1),
|
||||||
|
)
|
||||||
|
step_progression = StepProgression.STEP_LOGGED
|
||||||
|
|
||||||
# stream step
|
# TODO (cliandy): handle message contexts with larger refactor and dedupe logic
|
||||||
# TODO: improve TTFT
|
new_message_idx = len(initial_messages) if initial_messages else 0
|
||||||
filter_user_messages = [m for m in persisted_messages if m.role != "user"]
|
self.response_messages.extend(persisted_messages[new_message_idx:])
|
||||||
letta_messages = Message.to_letta_messages_from_list(
|
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
||||||
filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
|
initial_messages = None
|
||||||
)
|
log_event("agent.stream_no_tokens.llm_response.processed") # [4^]
|
||||||
|
|
||||||
for message in letta_messages:
|
# log step time
|
||||||
if include_return_message_types is None or message.message_type in include_return_message_types:
|
now = get_utc_timestamp_ns()
|
||||||
yield f"data: {message.model_dump_json()}\n\n"
|
step_ns = now - step_start
|
||||||
|
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
|
||||||
|
agent_step_span.end()
|
||||||
|
|
||||||
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
|
# Log LLM Trace
|
||||||
|
await self.telemetry_manager.create_provider_trace_async(
|
||||||
|
actor=self.actor,
|
||||||
|
provider_trace_create=ProviderTraceCreate(
|
||||||
|
request_json=request_data,
|
||||||
|
response_json=response_data,
|
||||||
|
step_id=step_id,
|
||||||
|
organization_id=self.actor.organization_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
step_progression = StepProgression.LOGGED_TRACE
|
||||||
|
|
||||||
|
# stream step
|
||||||
|
# TODO: improve TTFT
|
||||||
|
filter_user_messages = [m for m in persisted_messages if m.role != "user"]
|
||||||
|
letta_messages = Message.to_letta_messages_from_list(
|
||||||
|
filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
|
||||||
|
)
|
||||||
|
|
||||||
|
for message in letta_messages:
|
||||||
|
if include_return_message_types is None or message.message_type in include_return_message_types:
|
||||||
|
yield f"data: {message.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
|
||||||
|
step_progression = StepProgression.FINISHED
|
||||||
|
except Exception as e:
|
||||||
|
# Handle any unexpected errors during step processing
|
||||||
|
self.logger.error(f"Error during step processing: {e}")
|
||||||
|
|
||||||
|
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
|
||||||
|
if not stop_reason:
|
||||||
|
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||||
|
elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule):
|
||||||
|
self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason)
|
||||||
|
elif stop_reason.stop_reason not in (StopReasonType.no_tool_call, StopReasonType.invalid_tool_call):
|
||||||
|
raise ValueError(f"Invalid Stop Reason: {stop_reason}")
|
||||||
|
|
||||||
|
# Send error stop reason to client and re-raise
|
||||||
|
yield f"data: {stop_reason.model_dump_json()}\n\n", 500
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Update step if it needs to be updated
|
||||||
|
finally:
|
||||||
|
if settings.track_stop_reason:
|
||||||
|
self.logger.info("Running final update. Step Progression: %s", step_progression)
|
||||||
|
try:
|
||||||
|
if step_progression < StepProgression.STEP_LOGGED:
|
||||||
|
await self.step_manager.log_step_async(
|
||||||
|
actor=self.actor,
|
||||||
|
agent_id=agent_state.id,
|
||||||
|
provider_name=agent_state.llm_config.model_endpoint_type,
|
||||||
|
provider_category=agent_state.llm_config.provider_category or "base",
|
||||||
|
model=agent_state.llm_config.model,
|
||||||
|
model_endpoint=agent_state.llm_config.model_endpoint,
|
||||||
|
context_window_limit=agent_state.llm_config.context_window,
|
||||||
|
usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0),
|
||||||
|
provider_id=None,
|
||||||
|
job_id=self.current_run_id if self.current_run_id else None,
|
||||||
|
step_id=step_id,
|
||||||
|
project_id=agent_state.project_id,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
if step_progression <= StepProgression.RESPONSE_RECEIVED:
|
||||||
|
# TODO (cliandy): persist response if we get it back
|
||||||
|
if settings.track_errored_messages:
|
||||||
|
for message in initial_messages:
|
||||||
|
message.is_err = True
|
||||||
|
message.step_id = step_id
|
||||||
|
await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor)
|
||||||
|
elif step_progression <= StepProgression.LOGGED_TRACE:
|
||||||
|
if stop_reason is None:
|
||||||
|
self.logger.error("Error in step after logging step")
|
||||||
|
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||||
|
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
|
||||||
|
elif step_progression == StepProgression.FINISHED and not should_continue:
|
||||||
|
if stop_reason is None:
|
||||||
|
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
||||||
|
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
|
||||||
|
else:
|
||||||
|
self.logger.error("Invalid StepProgression value")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error("Failed to update step: %s", e)
|
||||||
|
|
||||||
if not should_continue:
|
if not should_continue:
|
||||||
break
|
break
|
||||||
@@ -396,6 +461,16 @@ class LettaAgent(BaseAgent):
|
|||||||
stop_reason = None
|
stop_reason = None
|
||||||
usage = LettaUsageStatistics()
|
usage = LettaUsageStatistics()
|
||||||
for i in range(max_steps):
|
for i in range(max_steps):
|
||||||
|
# If dry run, build request data and return it without making LLM call
|
||||||
|
if dry_run:
|
||||||
|
request_data, valid_tool_names = await self._create_llm_request_data_async(
|
||||||
|
llm_client=llm_client,
|
||||||
|
in_context_messages=current_in_context_messages + new_in_context_messages,
|
||||||
|
agent_state=agent_state,
|
||||||
|
tool_rules_solver=tool_rules_solver,
|
||||||
|
)
|
||||||
|
return request_data
|
||||||
|
|
||||||
# Check for job cancellation at the start of each step
|
# Check for job cancellation at the start of each step
|
||||||
if await self._check_run_cancellation():
|
if await self._check_run_cancellation():
|
||||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
|
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
|
||||||
@@ -407,94 +482,148 @@ class LettaAgent(BaseAgent):
|
|||||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||||||
agent_step_span.set_attributes({"step_id": step_id})
|
agent_step_span.set_attributes({"step_id": step_id})
|
||||||
|
|
||||||
# If dry run, build request data and return it without making LLM call
|
step_progression = StepProgression.START
|
||||||
if dry_run:
|
should_continue = False
|
||||||
request_data, valid_tool_names = await self._create_llm_request_data_async(
|
|
||||||
llm_client=llm_client,
|
|
||||||
in_context_messages=current_in_context_messages + new_in_context_messages,
|
|
||||||
agent_state=agent_state,
|
|
||||||
tool_rules_solver=tool_rules_solver,
|
|
||||||
)
|
|
||||||
return request_data
|
|
||||||
|
|
||||||
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
|
try:
|
||||||
await self._build_and_request_from_llm(
|
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
|
||||||
current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver, agent_step_span
|
await self._build_and_request_from_llm(
|
||||||
)
|
current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver, agent_step_span
|
||||||
)
|
|
||||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
|
||||||
|
|
||||||
log_event("agent.step.llm_response.received") # [3^]
|
|
||||||
|
|
||||||
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
|
|
||||||
|
|
||||||
usage.step_count += 1
|
|
||||||
usage.completion_tokens += response.usage.completion_tokens
|
|
||||||
usage.prompt_tokens += response.usage.prompt_tokens
|
|
||||||
usage.total_tokens += response.usage.total_tokens
|
|
||||||
usage.run_ids = [run_id] if run_id else None
|
|
||||||
MetricRegistry().message_output_tokens.record(
|
|
||||||
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
|
|
||||||
)
|
|
||||||
|
|
||||||
if not response.choices[0].message.tool_calls:
|
|
||||||
# TODO: make into a real error
|
|
||||||
raise ValueError("No tool calls found in response, model must make a tool call")
|
|
||||||
tool_call = response.choices[0].message.tool_calls[0]
|
|
||||||
if response.choices[0].message.reasoning_content:
|
|
||||||
reasoning = [
|
|
||||||
ReasoningContent(
|
|
||||||
reasoning=response.choices[0].message.reasoning_content,
|
|
||||||
is_native=True,
|
|
||||||
signature=response.choices[0].message.reasoning_content_signature,
|
|
||||||
)
|
)
|
||||||
]
|
)
|
||||||
elif response.choices[0].message.content:
|
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||||
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
|
|
||||||
elif response.choices[0].message.omitted_reasoning_content:
|
|
||||||
reasoning = [OmittedReasoningContent()]
|
|
||||||
else:
|
|
||||||
self.logger.info("No reasoning content found.")
|
|
||||||
reasoning = None
|
|
||||||
|
|
||||||
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
step_progression = StepProgression.RESPONSE_RECEIVED
|
||||||
tool_call,
|
log_event("agent.step.llm_response.received") # [3^]
|
||||||
valid_tool_names,
|
|
||||||
agent_state,
|
|
||||||
tool_rules_solver,
|
|
||||||
response.usage,
|
|
||||||
reasoning_content=reasoning,
|
|
||||||
step_id=step_id,
|
|
||||||
initial_messages=initial_messages,
|
|
||||||
agent_step_span=agent_step_span,
|
|
||||||
is_final_step=(i == max_steps - 1),
|
|
||||||
run_id=run_id,
|
|
||||||
)
|
|
||||||
new_message_idx = len(initial_messages) if initial_messages else 0
|
|
||||||
self.response_messages.extend(persisted_messages[new_message_idx:])
|
|
||||||
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
|
||||||
|
|
||||||
initial_messages = None
|
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
|
||||||
log_event("agent.step.llm_response.processed") # [4^]
|
|
||||||
|
|
||||||
# log step time
|
usage.step_count += 1
|
||||||
now = get_utc_timestamp_ns()
|
usage.completion_tokens += response.usage.completion_tokens
|
||||||
step_ns = now - step_start
|
usage.prompt_tokens += response.usage.prompt_tokens
|
||||||
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
|
usage.total_tokens += response.usage.total_tokens
|
||||||
agent_step_span.end()
|
usage.run_ids = [run_id] if run_id else None
|
||||||
|
MetricRegistry().message_output_tokens.record(
|
||||||
|
response.usage.completion_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
|
||||||
|
)
|
||||||
|
|
||||||
# Log LLM Trace
|
if not response.choices[0].message.tool_calls:
|
||||||
await self.telemetry_manager.create_provider_trace_async(
|
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
|
||||||
actor=self.actor,
|
raise ValueError("No tool calls found in response, model must make a tool call")
|
||||||
provider_trace_create=ProviderTraceCreate(
|
tool_call = response.choices[0].message.tool_calls[0]
|
||||||
request_json=request_data,
|
if response.choices[0].message.reasoning_content:
|
||||||
response_json=response_data,
|
reasoning = [
|
||||||
|
ReasoningContent(
|
||||||
|
reasoning=response.choices[0].message.reasoning_content,
|
||||||
|
is_native=True,
|
||||||
|
signature=response.choices[0].message.reasoning_content_signature,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
elif response.choices[0].message.content:
|
||||||
|
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
|
||||||
|
elif response.choices[0].message.omitted_reasoning_content:
|
||||||
|
reasoning = [OmittedReasoningContent()]
|
||||||
|
else:
|
||||||
|
self.logger.info("No reasoning content found.")
|
||||||
|
reasoning = None
|
||||||
|
|
||||||
|
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
||||||
|
tool_call,
|
||||||
|
valid_tool_names,
|
||||||
|
agent_state,
|
||||||
|
tool_rules_solver,
|
||||||
|
response.usage,
|
||||||
|
reasoning_content=reasoning,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
organization_id=self.actor.organization_id,
|
initial_messages=initial_messages,
|
||||||
),
|
agent_step_span=agent_step_span,
|
||||||
)
|
is_final_step=(i == max_steps - 1),
|
||||||
|
run_id=run_id,
|
||||||
|
)
|
||||||
|
step_progression = StepProgression.STEP_LOGGED
|
||||||
|
|
||||||
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
|
new_message_idx = len(initial_messages) if initial_messages else 0
|
||||||
|
self.response_messages.extend(persisted_messages[new_message_idx:])
|
||||||
|
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
||||||
|
|
||||||
|
initial_messages = None
|
||||||
|
log_event("agent.step.llm_response.processed") # [4^]
|
||||||
|
|
||||||
|
# log step time
|
||||||
|
now = get_utc_timestamp_ns()
|
||||||
|
step_ns = now - step_start
|
||||||
|
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
|
||||||
|
agent_step_span.end()
|
||||||
|
|
||||||
|
# Log LLM Trace
|
||||||
|
await self.telemetry_manager.create_provider_trace_async(
|
||||||
|
actor=self.actor,
|
||||||
|
provider_trace_create=ProviderTraceCreate(
|
||||||
|
request_json=request_data,
|
||||||
|
response_json=response_data,
|
||||||
|
step_id=step_id,
|
||||||
|
organization_id=self.actor.organization_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
step_progression = StepProgression.LOGGED_TRACE
|
||||||
|
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
|
||||||
|
step_progression = StepProgression.FINISHED
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Handle any unexpected errors during step processing
|
||||||
|
self.logger.error(f"Error during step processing: {e}")
|
||||||
|
|
||||||
|
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
|
||||||
|
if not stop_reason:
|
||||||
|
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||||
|
elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule):
|
||||||
|
self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason)
|
||||||
|
elif stop_reason.stop_reason not in (StopReasonType.no_tool_call, StopReasonType.invalid_tool_call):
|
||||||
|
raise ValueError(f"Invalid Stop Reason: {stop_reason}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Update step if it needs to be updated
|
||||||
|
finally:
|
||||||
|
if settings.track_stop_reason:
|
||||||
|
self.logger.info("Running final update. Step Progression: %s", step_progression)
|
||||||
|
try:
|
||||||
|
if step_progression < StepProgression.STEP_LOGGED:
|
||||||
|
await self.step_manager.log_step_async(
|
||||||
|
actor=self.actor,
|
||||||
|
agent_id=agent_state.id,
|
||||||
|
provider_name=agent_state.llm_config.model_endpoint_type,
|
||||||
|
provider_category=agent_state.llm_config.provider_category or "base",
|
||||||
|
model=agent_state.llm_config.model,
|
||||||
|
model_endpoint=agent_state.llm_config.model_endpoint,
|
||||||
|
context_window_limit=agent_state.llm_config.context_window,
|
||||||
|
usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0),
|
||||||
|
provider_id=None,
|
||||||
|
job_id=self.current_run_id if self.current_run_id else None,
|
||||||
|
step_id=step_id,
|
||||||
|
project_id=agent_state.project_id,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
if step_progression <= StepProgression.RESPONSE_RECEIVED:
|
||||||
|
# TODO (cliandy): persist response if we get it back
|
||||||
|
if settings.track_errored_messages:
|
||||||
|
for message in initial_messages:
|
||||||
|
message.is_err = True
|
||||||
|
message.step_id = step_id
|
||||||
|
await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor)
|
||||||
|
elif step_progression <= StepProgression.LOGGED_TRACE:
|
||||||
|
if stop_reason is None:
|
||||||
|
self.logger.error("Error in step after logging step")
|
||||||
|
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||||
|
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
|
||||||
|
elif step_progression == StepProgression.FINISHED and not should_continue:
|
||||||
|
if stop_reason is None:
|
||||||
|
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
||||||
|
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
|
||||||
|
else:
|
||||||
|
self.logger.error("Invalid StepProgression value")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error("Failed to update step: %s", e)
|
||||||
|
|
||||||
if not should_continue:
|
if not should_continue:
|
||||||
break
|
break
|
||||||
@@ -576,6 +705,7 @@ class LettaAgent(BaseAgent):
|
|||||||
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
||||||
|
|
||||||
for i in range(max_steps):
|
for i in range(max_steps):
|
||||||
|
step_id = generate_step_id()
|
||||||
# Check for job cancellation at the start of each step
|
# Check for job cancellation at the start of each step
|
||||||
if await self._check_run_cancellation():
|
if await self._check_run_cancellation():
|
||||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
|
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
|
||||||
@@ -583,163 +713,230 @@ class LettaAgent(BaseAgent):
|
|||||||
yield f"data: {stop_reason.model_dump_json()}\n\n"
|
yield f"data: {stop_reason.model_dump_json()}\n\n"
|
||||||
break
|
break
|
||||||
|
|
||||||
step_id = generate_step_id()
|
|
||||||
step_start = get_utc_timestamp_ns()
|
step_start = get_utc_timestamp_ns()
|
||||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||||||
agent_step_span.set_attributes({"step_id": step_id})
|
agent_step_span.set_attributes({"step_id": step_id})
|
||||||
|
|
||||||
(
|
step_progression = StepProgression.START
|
||||||
request_data,
|
should_continue = False
|
||||||
stream,
|
|
||||||
current_in_context_messages,
|
|
||||||
new_in_context_messages,
|
|
||||||
valid_tool_names,
|
|
||||||
provider_request_start_timestamp_ns,
|
|
||||||
) = await self._build_and_request_from_llm_streaming(
|
|
||||||
first_chunk,
|
|
||||||
agent_step_span,
|
|
||||||
request_start_timestamp_ns,
|
|
||||||
current_in_context_messages,
|
|
||||||
new_in_context_messages,
|
|
||||||
agent_state,
|
|
||||||
llm_client,
|
|
||||||
tool_rules_solver,
|
|
||||||
)
|
|
||||||
log_event("agent.stream.llm_response.received") # [3^]
|
|
||||||
|
|
||||||
# TODO: THIS IS INCREDIBLY UGLY
|
|
||||||
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
|
|
||||||
if agent_state.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]:
|
|
||||||
interface = AnthropicStreamingInterface(
|
|
||||||
use_assistant_message=use_assistant_message,
|
|
||||||
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
|
|
||||||
)
|
|
||||||
elif agent_state.llm_config.model_endpoint_type == ProviderType.openai:
|
|
||||||
interface = OpenAIStreamingInterface(
|
|
||||||
use_assistant_message=use_assistant_message,
|
|
||||||
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
|
|
||||||
|
|
||||||
async for chunk in interface.process(
|
|
||||||
stream,
|
|
||||||
ttft_span=request_span,
|
|
||||||
provider_request_start_timestamp_ns=provider_request_start_timestamp_ns,
|
|
||||||
):
|
|
||||||
# Measure time to first token
|
|
||||||
if first_chunk and request_span is not None:
|
|
||||||
now = get_utc_timestamp_ns()
|
|
||||||
ttft_ns = now - request_start_timestamp_ns
|
|
||||||
request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ns_to_ms(ttft_ns)})
|
|
||||||
metric_attributes = get_ctx_attributes()
|
|
||||||
metric_attributes["model.name"] = agent_state.llm_config.model
|
|
||||||
MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes)
|
|
||||||
first_chunk = False
|
|
||||||
|
|
||||||
if include_return_message_types is None or chunk.message_type in include_return_message_types:
|
|
||||||
# filter down returned data
|
|
||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
|
||||||
|
|
||||||
stream_end_time_ns = get_utc_timestamp_ns()
|
|
||||||
|
|
||||||
# update usage
|
|
||||||
usage.step_count += 1
|
|
||||||
usage.completion_tokens += interface.output_tokens
|
|
||||||
usage.prompt_tokens += interface.input_tokens
|
|
||||||
usage.total_tokens += interface.input_tokens + interface.output_tokens
|
|
||||||
MetricRegistry().message_output_tokens.record(
|
|
||||||
interface.output_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
|
|
||||||
)
|
|
||||||
|
|
||||||
# log LLM request time
|
|
||||||
llm_request_ms = ns_to_ms(stream_end_time_ns - provider_request_start_timestamp_ns)
|
|
||||||
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ms})
|
|
||||||
MetricRegistry().llm_execution_time_ms_histogram.record(
|
|
||||||
llm_request_ms,
|
|
||||||
dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process resulting stream content
|
|
||||||
try:
|
try:
|
||||||
tool_call = interface.get_tool_call_object()
|
(
|
||||||
except ValueError as e:
|
request_data,
|
||||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
|
stream,
|
||||||
yield f"data: {stop_reason.model_dump_json()}\n\n"
|
current_in_context_messages,
|
||||||
raise e
|
new_in_context_messages,
|
||||||
except Exception as e:
|
valid_tool_names,
|
||||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value)
|
provider_request_start_timestamp_ns,
|
||||||
yield f"data: {stop_reason.model_dump_json()}\n\n"
|
) = await self._build_and_request_from_llm_streaming(
|
||||||
raise e
|
first_chunk,
|
||||||
reasoning_content = interface.get_reasoning_content()
|
agent_step_span,
|
||||||
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
request_start_timestamp_ns,
|
||||||
tool_call,
|
current_in_context_messages,
|
||||||
valid_tool_names,
|
new_in_context_messages,
|
||||||
agent_state,
|
agent_state,
|
||||||
tool_rules_solver,
|
llm_client,
|
||||||
UsageStatistics(
|
tool_rules_solver,
|
||||||
completion_tokens=interface.output_tokens,
|
)
|
||||||
prompt_tokens=interface.input_tokens,
|
|
||||||
total_tokens=interface.input_tokens + interface.output_tokens,
|
|
||||||
),
|
|
||||||
reasoning_content=reasoning_content,
|
|
||||||
pre_computed_assistant_message_id=interface.letta_message_id,
|
|
||||||
step_id=step_id,
|
|
||||||
initial_messages=initial_messages,
|
|
||||||
agent_step_span=agent_step_span,
|
|
||||||
is_final_step=(i == max_steps - 1),
|
|
||||||
)
|
|
||||||
new_message_idx = len(initial_messages) if initial_messages else 0
|
|
||||||
self.response_messages.extend(persisted_messages[new_message_idx:])
|
|
||||||
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
|
||||||
|
|
||||||
initial_messages = None
|
step_progression = StepProgression.STREAM_RECEIVED
|
||||||
|
log_event("agent.stream.llm_response.received") # [3^]
|
||||||
|
|
||||||
# log total step time
|
# TODO: THIS IS INCREDIBLY UGLY
|
||||||
now = get_utc_timestamp_ns()
|
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
|
||||||
step_ns = now - step_start
|
if agent_state.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]:
|
||||||
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
|
interface = AnthropicStreamingInterface(
|
||||||
agent_step_span.end()
|
use_assistant_message=use_assistant_message,
|
||||||
|
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
|
||||||
|
)
|
||||||
|
elif agent_state.llm_config.model_endpoint_type == ProviderType.openai:
|
||||||
|
interface = OpenAIStreamingInterface(
|
||||||
|
use_assistant_message=use_assistant_message,
|
||||||
|
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
|
||||||
|
|
||||||
# TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream
|
async for chunk in interface.process(
|
||||||
# log_event("agent.stream.llm_response.processed") # [4^]
|
stream,
|
||||||
|
ttft_span=request_span,
|
||||||
|
provider_request_start_timestamp_ns=provider_request_start_timestamp_ns,
|
||||||
|
):
|
||||||
|
# Measure time to first token
|
||||||
|
if first_chunk and request_span is not None:
|
||||||
|
now = get_utc_timestamp_ns()
|
||||||
|
ttft_ns = now - request_start_timestamp_ns
|
||||||
|
request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ns_to_ms(ttft_ns)})
|
||||||
|
metric_attributes = get_ctx_attributes()
|
||||||
|
metric_attributes["model.name"] = agent_state.llm_config.model
|
||||||
|
MetricRegistry().ttft_ms_histogram.record(ns_to_ms(ttft_ns), metric_attributes)
|
||||||
|
first_chunk = False
|
||||||
|
|
||||||
# Log LLM Trace
|
if include_return_message_types is None or chunk.message_type in include_return_message_types:
|
||||||
# TODO (cliandy): we are piecing together the streamed response here. Content here does not match the actual response schema.
|
# filter down returned data
|
||||||
await self.telemetry_manager.create_provider_trace_async(
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
actor=self.actor,
|
|
||||||
provider_trace_create=ProviderTraceCreate(
|
stream_end_time_ns = get_utc_timestamp_ns()
|
||||||
request_json=request_data,
|
|
||||||
response_json={
|
# update usage
|
||||||
"content": {
|
usage.step_count += 1
|
||||||
"tool_call": tool_call.model_dump_json(),
|
usage.completion_tokens += interface.output_tokens
|
||||||
"reasoning": [content.model_dump_json() for content in reasoning_content],
|
usage.prompt_tokens += interface.input_tokens
|
||||||
},
|
usage.total_tokens += interface.input_tokens + interface.output_tokens
|
||||||
"id": interface.message_id,
|
MetricRegistry().message_output_tokens.record(
|
||||||
"model": interface.model,
|
interface.output_tokens, dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model})
|
||||||
"role": "assistant",
|
)
|
||||||
# "stop_reason": "",
|
|
||||||
# "stop_sequence": None,
|
# log LLM request time
|
||||||
"type": "message",
|
llm_request_ms = ns_to_ms(stream_end_time_ns - provider_request_start_timestamp_ns)
|
||||||
"usage": {"input_tokens": interface.input_tokens, "output_tokens": interface.output_tokens},
|
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ms})
|
||||||
},
|
MetricRegistry().llm_execution_time_ms_histogram.record(
|
||||||
|
llm_request_ms,
|
||||||
|
dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process resulting stream content
|
||||||
|
try:
|
||||||
|
tool_call = interface.get_tool_call_object()
|
||||||
|
except ValueError as e:
|
||||||
|
stop_reason = LettaStopReason(stop_reason=StopReasonType.no_tool_call.value)
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value)
|
||||||
|
raise e
|
||||||
|
reasoning_content = interface.get_reasoning_content()
|
||||||
|
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
||||||
|
tool_call,
|
||||||
|
valid_tool_names,
|
||||||
|
agent_state,
|
||||||
|
tool_rules_solver,
|
||||||
|
UsageStatistics(
|
||||||
|
completion_tokens=interface.output_tokens,
|
||||||
|
prompt_tokens=interface.input_tokens,
|
||||||
|
total_tokens=interface.input_tokens + interface.output_tokens,
|
||||||
|
),
|
||||||
|
reasoning_content=reasoning_content,
|
||||||
|
pre_computed_assistant_message_id=interface.letta_message_id,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
organization_id=self.actor.organization_id,
|
initial_messages=initial_messages,
|
||||||
),
|
agent_step_span=agent_step_span,
|
||||||
)
|
is_final_step=(i == max_steps - 1),
|
||||||
|
)
|
||||||
|
step_progression = StepProgression.STEP_LOGGED
|
||||||
|
|
||||||
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
|
new_message_idx = len(initial_messages) if initial_messages else 0
|
||||||
if not (use_assistant_message and tool_return.name == "send_message"):
|
self.response_messages.extend(persisted_messages[new_message_idx:])
|
||||||
# Apply message type filtering if specified
|
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
||||||
if include_return_message_types is None or tool_return.message_type in include_return_message_types:
|
|
||||||
yield f"data: {tool_return.model_dump_json()}\n\n"
|
|
||||||
|
|
||||||
# TODO (cliandy): consolidate and expand with trace
|
initial_messages = None
|
||||||
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
|
|
||||||
|
# log total step time
|
||||||
|
now = get_utc_timestamp_ns()
|
||||||
|
step_ns = now - step_start
|
||||||
|
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": ns_to_ms(step_ns)})
|
||||||
|
agent_step_span.end()
|
||||||
|
|
||||||
|
# TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream
|
||||||
|
# log_event("agent.stream.llm_response.processed") # [4^]
|
||||||
|
|
||||||
|
# Log LLM Trace
|
||||||
|
# We are piecing together the streamed response here.
|
||||||
|
# Content here does not match the actual response schema as streams come in chunks.
|
||||||
|
await self.telemetry_manager.create_provider_trace_async(
|
||||||
|
actor=self.actor,
|
||||||
|
provider_trace_create=ProviderTraceCreate(
|
||||||
|
request_json=request_data,
|
||||||
|
response_json={
|
||||||
|
"content": {
|
||||||
|
"tool_call": tool_call.model_dump_json(),
|
||||||
|
"reasoning": [content.model_dump_json() for content in reasoning_content],
|
||||||
|
},
|
||||||
|
"id": interface.message_id,
|
||||||
|
"model": interface.model,
|
||||||
|
"role": "assistant",
|
||||||
|
# "stop_reason": "",
|
||||||
|
# "stop_sequence": None,
|
||||||
|
"type": "message",
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": interface.input_tokens,
|
||||||
|
"output_tokens": interface.output_tokens,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
step_id=step_id,
|
||||||
|
organization_id=self.actor.organization_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
step_progression = StepProgression.LOGGED_TRACE
|
||||||
|
|
||||||
|
# yields tool response as this is handled from Letta and not the response from the LLM provider
|
||||||
|
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
|
||||||
|
if not (use_assistant_message and tool_return.name == "send_message"):
|
||||||
|
# Apply message type filtering if specified
|
||||||
|
if include_return_message_types is None or tool_return.message_type in include_return_message_types:
|
||||||
|
yield f"data: {tool_return.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
# TODO (cliandy): consolidate and expand with trace
|
||||||
|
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
|
||||||
|
step_progression = StepProgression.FINISHED
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Handle any unexpected errors during step processing
|
||||||
|
self.logger.error(f"Error during step processing: {e}")
|
||||||
|
|
||||||
|
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
|
||||||
|
if not stop_reason:
|
||||||
|
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||||
|
elif stop_reason.stop_reason in (StopReasonType.end_turn, StopReasonType.max_steps, StopReasonType.tool_rule):
|
||||||
|
self.logger.error("Error occurred during step processing, with valid stop reason: %s", stop_reason.stop_reason)
|
||||||
|
elif stop_reason.stop_reason not in (StopReasonType.no_tool_call, StopReasonType.invalid_tool_call):
|
||||||
|
raise ValueError(f"Invalid Stop Reason: {stop_reason}")
|
||||||
|
|
||||||
|
# Send error stop reason to client and re-raise with expected response code
|
||||||
|
yield f"data: {stop_reason.model_dump_json()}\n\n", 500
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Update step if it needs to be updated
|
||||||
|
finally:
|
||||||
|
if settings.track_stop_reason:
|
||||||
|
self.logger.info("Running final update. Step Progression: %s", step_progression)
|
||||||
|
try:
|
||||||
|
if step_progression < StepProgression.STEP_LOGGED:
|
||||||
|
await self.step_manager.log_step_async(
|
||||||
|
actor=self.actor,
|
||||||
|
agent_id=agent_state.id,
|
||||||
|
provider_name=agent_state.llm_config.model_endpoint_type,
|
||||||
|
provider_category=agent_state.llm_config.provider_category or "base",
|
||||||
|
model=agent_state.llm_config.model,
|
||||||
|
model_endpoint=agent_state.llm_config.model_endpoint,
|
||||||
|
context_window_limit=agent_state.llm_config.context_window,
|
||||||
|
usage=UsageStatistics(completion_tokens=0, prompt_tokens=0, total_tokens=0),
|
||||||
|
provider_id=None,
|
||||||
|
job_id=self.current_run_id if self.current_run_id else None,
|
||||||
|
step_id=step_id,
|
||||||
|
project_id=agent_state.project_id,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
if step_progression <= StepProgression.STREAM_RECEIVED:
|
||||||
|
if first_chunk and settings.track_errored_messages:
|
||||||
|
for message in initial_messages:
|
||||||
|
message.is_err = True
|
||||||
|
message.step_id = step_id
|
||||||
|
await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor)
|
||||||
|
elif step_progression <= StepProgression.LOGGED_TRACE:
|
||||||
|
if stop_reason is None:
|
||||||
|
self.logger.error("Error in step after logging step")
|
||||||
|
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||||
|
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
|
||||||
|
elif step_progression == StepProgression.FINISHED and not should_continue:
|
||||||
|
if stop_reason is None:
|
||||||
|
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
||||||
|
await self.step_manager.update_step_stop_reason(self.actor, step_id, stop_reason.stop_reason)
|
||||||
|
else:
|
||||||
|
self.logger.error("Invalid StepProgression value")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error("Failed to update step: %s", e)
|
||||||
|
|
||||||
if not should_continue:
|
if not should_continue:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Extend the in context message ids
|
# Extend the in context message ids
|
||||||
if not agent_state.message_buffer_autoclear:
|
if not agent_state.message_buffer_autoclear:
|
||||||
await self._rebuild_context_window(
|
await self._rebuild_context_window(
|
||||||
@@ -1106,6 +1303,7 @@ class LettaAgent(BaseAgent):
|
|||||||
job_id=run_id if run_id else self.current_run_id,
|
job_id=run_id if run_id else self.current_run_id,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
project_id=agent_state.project_id,
|
project_id=agent_state.project_id,
|
||||||
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_call_messages = create_letta_messages_from_llm_response(
|
tool_call_messages = create_letta_messages_from_llm_response(
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional, Set, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -107,25 +107,20 @@ class ToolRulesSolver(BaseModel):
|
|||||||
self.tool_call_history.clear()
|
self.tool_call_history.clear()
|
||||||
|
|
||||||
def get_allowed_tool_names(
|
def get_allowed_tool_names(
|
||||||
self, available_tools: Set[str], error_on_empty: bool = False, last_function_response: Optional[str] = None
|
self, available_tools: set[str], error_on_empty: bool = True, last_function_response: str | None = None
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Get a list of tool names allowed based on the last tool called."""
|
"""Get a list of tool names allowed based on the last tool called.
|
||||||
|
|
||||||
|
The logic is as follows:
|
||||||
|
1. if there are no previous tool calls and we have InitToolRules, those are the only options for the first tool call
|
||||||
|
2. else we take the intersection of the Parent/Child/Conditional/MaxSteps as the options
|
||||||
|
3. Continue/Terminal/RequiredBeforeExit rules are applied in the agent loop flow, not to restrict tools
|
||||||
|
"""
|
||||||
# TODO: This piece of code here is quite ugly and deserves a refactor
|
# TODO: This piece of code here is quite ugly and deserves a refactor
|
||||||
# TODO: There's some weird logic encoded here:
|
|
||||||
# TODO: -> This only takes into consideration Init, and a set of Child/Conditional/MaxSteps tool rules
|
|
||||||
# TODO: -> Init tool rules outputs are treated additively, Child/Conditional/MaxSteps are intersection based
|
|
||||||
# TODO: -> Tool rules should probably be refactored to take in a set of tool names?
|
# TODO: -> Tool rules should probably be refactored to take in a set of tool names?
|
||||||
# If no tool has been called yet, return InitToolRules additively
|
if not self.tool_call_history and self.init_tool_rules:
|
||||||
if not self.tool_call_history:
|
return [rule.tool_name for rule in self.init_tool_rules]
|
||||||
if self.init_tool_rules:
|
|
||||||
# If there are init tool rules, only return those defined in the init tool rules
|
|
||||||
return [rule.tool_name for rule in self.init_tool_rules]
|
|
||||||
else:
|
|
||||||
# Otherwise, return all tools besides those constrained by parent tool rules
|
|
||||||
available_tools = available_tools - set.union(set(), *(set(rule.children) for rule in self.parent_tool_rules))
|
|
||||||
return list(available_tools)
|
|
||||||
else:
|
else:
|
||||||
# Collect valid tools from all child-based rules
|
|
||||||
valid_tool_sets = []
|
valid_tool_sets = []
|
||||||
for rule in self.child_based_tool_rules + self.parent_tool_rules:
|
for rule in self.child_based_tool_rules + self.parent_tool_rules:
|
||||||
tools = rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response)
|
tools = rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response)
|
||||||
@@ -151,11 +146,11 @@ class ToolRulesSolver(BaseModel):
|
|||||||
"""Check if the tool is defined as a continue tool in the tool rules."""
|
"""Check if the tool is defined as a continue tool in the tool rules."""
|
||||||
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
|
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
|
||||||
|
|
||||||
def has_required_tools_been_called(self, available_tools: Set[str]) -> bool:
|
def has_required_tools_been_called(self, available_tools: set[str]) -> bool:
|
||||||
"""Check if all required-before-exit tools have been called."""
|
"""Check if all required-before-exit tools have been called."""
|
||||||
return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0
|
return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0
|
||||||
|
|
||||||
def get_uncalled_required_tools(self, available_tools: Set[str]) -> List[str]:
|
def get_uncalled_required_tools(self, available_tools: set[str]) -> List[str]:
|
||||||
"""Get the list of required-before-exit tools that have not been called yet."""
|
"""Get the list of required-before-exit tools that have not been called yet."""
|
||||||
if not self.required_before_exit_tool_rules:
|
if not self.required_before_exit_tool_rules:
|
||||||
return [] # No required tools means no uncalled tools
|
return [] # No required tools means no uncalled tools
|
||||||
|
|||||||
@@ -49,6 +49,9 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
|||||||
nullable=True,
|
nullable=True,
|
||||||
doc="The id of the LLMBatchItem that this message is associated with",
|
doc="The id of the LLMBatchItem that this message is associated with",
|
||||||
)
|
)
|
||||||
|
is_err: Mapped[Optional[bool]] = mapped_column(
|
||||||
|
nullable=True, doc="Whether this message is part of an error step. Used only for debugging purposes."
|
||||||
|
)
|
||||||
|
|
||||||
# Monotonically increasing sequence for efficient/correct listing
|
# Monotonically increasing sequence for efficient/correct listing
|
||||||
sequence_id: Mapped[int] = mapped_column(
|
sequence_id: Mapped[int] = mapped_column(
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from sqlalchemy import JSON, ForeignKey, String
|
|||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||||
|
from letta.schemas.letta_stop_reason import StopReasonType
|
||||||
from letta.schemas.step import Step as PydanticStep
|
from letta.schemas.step import Step as PydanticStep
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -45,6 +46,7 @@ class Step(SqlalchemyBase):
|
|||||||
prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt")
|
prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt")
|
||||||
total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent")
|
total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent")
|
||||||
completion_tokens_details: Mapped[Optional[Dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
|
completion_tokens_details: Mapped[Optional[Dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
|
||||||
|
stop_reason: Mapped[Optional[StopReasonType]] = mapped_column(None, nullable=True, doc="The stop reason associated with this step.")
|
||||||
tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.")
|
tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.")
|
||||||
tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.")
|
tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.")
|
||||||
trace_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The trace id of the agent step.")
|
trace_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The trace id of the agent step.")
|
||||||
|
|||||||
@@ -40,15 +40,18 @@ class LettaMessage(BaseModel):
|
|||||||
message_type (MessageType): The type of the message
|
message_type (MessageType): The type of the message
|
||||||
otid (Optional[str]): The offline threading id associated with this message
|
otid (Optional[str]): The offline threading id associated with this message
|
||||||
sender_id (Optional[str]): The id of the sender of the message, can be an identity id or agent id
|
sender_id (Optional[str]): The id of the sender of the message, can be an identity id or agent id
|
||||||
|
step_id (Optional[str]): The step id associated with the message
|
||||||
|
is_err (Optional[bool]): Whether the message is an errored message or not. Used for debugging purposes only.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
date: datetime
|
date: datetime
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
message_type: MessageType = Field(..., description="The type of the message.")
|
message_type: MessageType = Field(..., description="The type of the message.")
|
||||||
otid: Optional[str] = None
|
otid: str | None = None
|
||||||
sender_id: Optional[str] = None
|
sender_id: str | None = None
|
||||||
step_id: Optional[str] = None
|
step_id: str | None = None
|
||||||
|
is_err: bool | None = None
|
||||||
|
|
||||||
@field_serializer("date")
|
@field_serializer("date")
|
||||||
def serialize_datetime(self, dt: datetime, _info):
|
def serialize_datetime(self, dt: datetime, _info):
|
||||||
@@ -60,6 +63,14 @@ class LettaMessage(BaseModel):
|
|||||||
dt = dt.replace(tzinfo=timezone.utc)
|
dt = dt.replace(tzinfo=timezone.utc)
|
||||||
return dt.isoformat(timespec="seconds")
|
return dt.isoformat(timespec="seconds")
|
||||||
|
|
||||||
|
@field_serializer("is_err", when_used="unless-none")
|
||||||
|
def serialize_is_err(self, value: bool | None, _info):
|
||||||
|
"""
|
||||||
|
Only serialize is_err field when it's True (for debugging purposes).
|
||||||
|
When is_err is None or False, this field will be excluded from the JSON output.
|
||||||
|
"""
|
||||||
|
return value if value is True else None
|
||||||
|
|
||||||
|
|
||||||
class SystemMessage(LettaMessage):
|
class SystemMessage(LettaMessage):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -172,6 +172,9 @@ class Message(BaseMessage):
|
|||||||
group_id: Optional[str] = Field(default=None, description="The multi-agent group that the message was sent in")
|
group_id: Optional[str] = Field(default=None, description="The multi-agent group that the message was sent in")
|
||||||
sender_id: Optional[str] = Field(default=None, description="The id of the sender of the message, can be an identity id or agent id")
|
sender_id: Optional[str] = Field(default=None, description="The id of the sender of the message, can be an identity id or agent id")
|
||||||
batch_item_id: Optional[str] = Field(default=None, description="The id of the LLMBatchItem that this message is associated with")
|
batch_item_id: Optional[str] = Field(default=None, description="The id of the LLMBatchItem that this message is associated with")
|
||||||
|
is_err: Optional[bool] = Field(
|
||||||
|
default=None, description="Whether this message is part of an error step. Used only for debugging purposes."
|
||||||
|
)
|
||||||
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
|
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
|
||||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||||
|
|
||||||
@@ -191,6 +194,7 @@ class Message(BaseMessage):
|
|||||||
if not is_utc_datetime(self.created_at):
|
if not is_utc_datetime(self.created_at):
|
||||||
self.created_at = self.created_at.replace(tzinfo=timezone.utc)
|
self.created_at = self.created_at.replace(tzinfo=timezone.utc)
|
||||||
json_message["created_at"] = self.created_at.isoformat()
|
json_message["created_at"] = self.created_at.isoformat()
|
||||||
|
json_message.pop("is_err", None) # make sure we don't include this debugging information
|
||||||
return json_message
|
return json_message
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -204,6 +208,7 @@ class Message(BaseMessage):
|
|||||||
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
||||||
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
||||||
reverse: bool = True,
|
reverse: bool = True,
|
||||||
|
include_err: Optional[bool] = None,
|
||||||
) -> List[LettaMessage]:
|
) -> List[LettaMessage]:
|
||||||
if use_assistant_message:
|
if use_assistant_message:
|
||||||
message_ids_to_remove = []
|
message_ids_to_remove = []
|
||||||
@@ -234,6 +239,7 @@ class Message(BaseMessage):
|
|||||||
assistant_message_tool_name=assistant_message_tool_name,
|
assistant_message_tool_name=assistant_message_tool_name,
|
||||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||||
reverse=reverse,
|
reverse=reverse,
|
||||||
|
include_err=include_err,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -243,6 +249,7 @@ class Message(BaseMessage):
|
|||||||
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
||||||
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
||||||
reverse: bool = True,
|
reverse: bool = True,
|
||||||
|
include_err: Optional[bool] = None,
|
||||||
) -> List[LettaMessage]:
|
) -> List[LettaMessage]:
|
||||||
"""Convert message object (in DB format) to the style used by the original Letta API"""
|
"""Convert message object (in DB format) to the style used by the original Letta API"""
|
||||||
messages = []
|
messages = []
|
||||||
@@ -682,14 +689,13 @@ class Message(BaseMessage):
|
|||||||
# since the only "parts" we have are for supporting various COT
|
# since the only "parts" we have are for supporting various COT
|
||||||
|
|
||||||
if self.role == "system":
|
if self.role == "system":
|
||||||
assert all([v is not None for v in [self.role]]), vars(self)
|
|
||||||
openai_message = {
|
openai_message = {
|
||||||
"content": text_content,
|
"content": text_content,
|
||||||
"role": "developer" if use_developer_message else self.role,
|
"role": "developer" if use_developer_message else self.role,
|
||||||
}
|
}
|
||||||
|
|
||||||
elif self.role == "user":
|
elif self.role == "user":
|
||||||
assert all([v is not None for v in [text_content, self.role]]), vars(self)
|
assert text_content is not None, vars(self)
|
||||||
openai_message = {
|
openai_message = {
|
||||||
"content": text_content,
|
"content": text_content,
|
||||||
"role": self.role,
|
"role": self.role,
|
||||||
@@ -720,7 +726,7 @@ class Message(BaseMessage):
|
|||||||
tool_call_dict["id"] = tool_call_dict["id"][:max_tool_id_length]
|
tool_call_dict["id"] = tool_call_dict["id"][:max_tool_id_length]
|
||||||
|
|
||||||
elif self.role == "tool":
|
elif self.role == "tool":
|
||||||
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self)
|
assert self.tool_call_id is not None, vars(self)
|
||||||
openai_message = {
|
openai_message = {
|
||||||
"content": text_content,
|
"content": text_content,
|
||||||
"role": self.role,
|
"role": self.role,
|
||||||
@@ -776,7 +782,7 @@ class Message(BaseMessage):
|
|||||||
if self.role == "system":
|
if self.role == "system":
|
||||||
# NOTE: this is not for system instructions, but instead system "events"
|
# NOTE: this is not for system instructions, but instead system "events"
|
||||||
|
|
||||||
assert all([v is not None for v in [text_content, self.role]]), vars(self)
|
assert text_content is not None, vars(self)
|
||||||
# Two options here, we would use system.package_system_message,
|
# Two options here, we would use system.package_system_message,
|
||||||
# or use a more Anthropic-specific packaging ie xml tags
|
# or use a more Anthropic-specific packaging ie xml tags
|
||||||
user_system_event = add_xml_tag(string=f"SYSTEM ALERT: {text_content}", xml_tag="event")
|
user_system_event = add_xml_tag(string=f"SYSTEM ALERT: {text_content}", xml_tag="event")
|
||||||
@@ -875,7 +881,7 @@ class Message(BaseMessage):
|
|||||||
|
|
||||||
elif self.role == "tool":
|
elif self.role == "tool":
|
||||||
# NOTE: Anthropic uses role "user" for "tool" responses
|
# NOTE: Anthropic uses role "user" for "tool" responses
|
||||||
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self)
|
assert self.tool_call_id is not None, vars(self)
|
||||||
anthropic_message = {
|
anthropic_message = {
|
||||||
"role": "user", # NOTE: diff
|
"role": "user", # NOTE: diff
|
||||||
"content": [
|
"content": [
|
||||||
@@ -988,7 +994,7 @@ class Message(BaseMessage):
|
|||||||
|
|
||||||
elif self.role == "tool":
|
elif self.role == "tool":
|
||||||
# NOTE: Significantly different tool calling format, more similar to function calling format
|
# NOTE: Significantly different tool calling format, more similar to function calling format
|
||||||
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self)
|
assert self.tool_call_id is not None, vars(self)
|
||||||
|
|
||||||
if self.name is None:
|
if self.name is None:
|
||||||
warnings.warn(f"Couldn't find function name on tool call, defaulting to tool ID instead.")
|
warnings.warn(f"Couldn't find function name on tool call, defaulting to tool ID instead.")
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
|
from enum import Enum, auto
|
||||||
from typing import Dict, List, Literal, Optional
|
from typing import Dict, List, Literal, Optional
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from letta.schemas.letta_base import LettaBase
|
from letta.schemas.letta_base import LettaBase
|
||||||
|
from letta.schemas.letta_stop_reason import StopReasonType
|
||||||
from letta.schemas.message import Message
|
from letta.schemas.message import Message
|
||||||
|
|
||||||
|
|
||||||
@@ -28,6 +30,7 @@ class Step(StepBase):
|
|||||||
prompt_tokens: Optional[int] = Field(None, description="The number of tokens in the prompt during this step.")
|
prompt_tokens: Optional[int] = Field(None, description="The number of tokens in the prompt during this step.")
|
||||||
total_tokens: Optional[int] = Field(None, description="The total number of tokens processed by the agent during this step.")
|
total_tokens: Optional[int] = Field(None, description="The total number of tokens processed by the agent during this step.")
|
||||||
completion_tokens_details: Optional[Dict] = Field(None, description="Metadata for the agent.")
|
completion_tokens_details: Optional[Dict] = Field(None, description="Metadata for the agent.")
|
||||||
|
stop_reason: Optional[StopReasonType] = Field(None, description="The stop reason associated with the step.")
|
||||||
tags: List[str] = Field([], description="Metadata tags.")
|
tags: List[str] = Field([], description="Metadata tags.")
|
||||||
tid: Optional[str] = Field(None, description="The unique identifier of the transaction that processed this step.")
|
tid: Optional[str] = Field(None, description="The unique identifier of the transaction that processed this step.")
|
||||||
trace_id: Optional[str] = Field(None, description="The trace id of the agent step.")
|
trace_id: Optional[str] = Field(None, description="The trace id of the agent step.")
|
||||||
@@ -36,3 +39,12 @@ class Step(StepBase):
|
|||||||
None, description="The feedback for this step. Must be either 'positive' or 'negative'."
|
None, description="The feedback for this step. Must be either 'positive' or 'negative'."
|
||||||
)
|
)
|
||||||
project_id: Optional[str] = Field(None, description="The project that the agent that executed this step belongs to (cloud only).")
|
project_id: Optional[str] = Field(None, description="The project that the agent that executed this step belongs to (cloud only).")
|
||||||
|
|
||||||
|
|
||||||
|
class StepProgression(int, Enum):
|
||||||
|
START = auto()
|
||||||
|
STREAM_RECEIVED = auto()
|
||||||
|
RESPONSE_RECEIVED = auto()
|
||||||
|
STEP_LOGGED = auto()
|
||||||
|
LOGGED_TRACE = auto()
|
||||||
|
FINISHED = auto()
|
||||||
|
|||||||
@@ -2,8 +2,10 @@ import importlib.util
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import sys
|
import sys
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -34,32 +36,25 @@ from letta.server.db import db_registry
|
|||||||
from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right?
|
from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right?
|
||||||
from letta.server.rest_api.interface import StreamingServerInterface
|
from letta.server.rest_api.interface import StreamingServerInterface
|
||||||
from letta.server.rest_api.routers.openai.chat_completions.chat_completions import router as openai_chat_completions_router
|
from letta.server.rest_api.routers.openai.chat_completions.chat_completions import router as openai_chat_completions_router
|
||||||
|
|
||||||
# from letta.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM
|
|
||||||
from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes
|
from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes
|
||||||
from letta.server.rest_api.routers.v1.organizations import router as organizations_router
|
from letta.server.rest_api.routers.v1.organizations import router as organizations_router
|
||||||
from letta.server.rest_api.routers.v1.users import router as users_router # TODO: decide on admin
|
from letta.server.rest_api.routers.v1.users import router as users_router # TODO: decide on admin
|
||||||
from letta.server.rest_api.static_files import mount_static_files
|
from letta.server.rest_api.static_files import mount_static_files
|
||||||
|
from letta.server.rest_api.utils import SENTRY_ENABLED
|
||||||
from letta.server.server import SyncServer
|
from letta.server.server import SyncServer
|
||||||
from letta.settings import settings
|
from letta.settings import settings
|
||||||
|
|
||||||
# TODO(ethan)
|
if SENTRY_ENABLED:
|
||||||
|
import sentry_sdk
|
||||||
|
|
||||||
|
IS_WINDOWS = platform.system() == "Windows"
|
||||||
|
|
||||||
# NOTE(charles): @ethan I had to add this to get the global as the bottom to work
|
# NOTE(charles): @ethan I had to add this to get the global as the bottom to work
|
||||||
interface: StreamingServerInterface = StreamingServerInterface
|
interface: type = StreamingServerInterface
|
||||||
server = SyncServer(default_interface_factory=lambda: interface())
|
server = SyncServer(default_interface_factory=lambda: interface())
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import platform
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
|
||||||
|
|
||||||
is_windows = platform.system() == "Windows"
|
|
||||||
|
|
||||||
log = logging.getLogger("uvicorn")
|
|
||||||
|
|
||||||
|
|
||||||
def generate_openapi_schema(app: FastAPI):
|
def generate_openapi_schema(app: FastAPI):
|
||||||
# Update the OpenAPI schema
|
# Update the OpenAPI schema
|
||||||
if not app.openapi_schema:
|
if not app.openapi_schema:
|
||||||
@@ -177,9 +172,7 @@ def create_application() -> "FastAPI":
|
|||||||
# server = SyncServer(default_interface_factory=lambda: interface())
|
# server = SyncServer(default_interface_factory=lambda: interface())
|
||||||
print(f"\n[[ Letta server // v{letta_version} ]]")
|
print(f"\n[[ Letta server // v{letta_version} ]]")
|
||||||
|
|
||||||
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
|
if SENTRY_ENABLED:
|
||||||
import sentry_sdk
|
|
||||||
|
|
||||||
sentry_sdk.init(
|
sentry_sdk.init(
|
||||||
dsn=os.getenv("SENTRY_DSN"),
|
dsn=os.getenv("SENTRY_DSN"),
|
||||||
traces_sample_rate=1.0,
|
traces_sample_rate=1.0,
|
||||||
@@ -187,6 +180,7 @@ def create_application() -> "FastAPI":
|
|||||||
"continuous_profiling_auto_start": True,
|
"continuous_profiling_auto_start": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
logger.info("Sentry enabled.")
|
||||||
|
|
||||||
debug_mode = "--debug" in sys.argv
|
debug_mode = "--debug" in sys.argv
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
@@ -199,31 +193,13 @@ def create_application() -> "FastAPI":
|
|||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.exception_handler(IncompatibleAgentType)
|
# === Exception Handlers ===
|
||||||
async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType):
|
# TODO (cliandy): move to separate file
|
||||||
return JSONResponse(
|
|
||||||
status_code=400,
|
|
||||||
content={
|
|
||||||
"detail": str(exc),
|
|
||||||
"expected_type": exc.expected_type,
|
|
||||||
"actual_type": exc.actual_type,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.exception_handler(Exception)
|
@app.exception_handler(Exception)
|
||||||
async def generic_error_handler(request: Request, exc: Exception):
|
async def generic_error_handler(request: Request, exc: Exception):
|
||||||
# Log the actual error for debugging
|
logger.error(f"Unhandled error: {str(exc)}", exc_info=True)
|
||||||
log.error(f"Unhandled error: {str(exc)}", exc_info=True)
|
if SENTRY_ENABLED:
|
||||||
print(f"Unhandled error: {str(exc)}")
|
|
||||||
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
# Print the stack trace
|
|
||||||
print(f"Stack trace: {traceback.format_exc()}")
|
|
||||||
|
|
||||||
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
|
|
||||||
import sentry_sdk
|
|
||||||
|
|
||||||
sentry_sdk.capture_exception(exc)
|
sentry_sdk.capture_exception(exc)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@@ -235,62 +211,70 @@ def create_application() -> "FastAPI":
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.exception_handler(NoResultFound)
|
async def error_handler_with_code(request: Request, exc: Exception, code: int, detail: str | None = None):
|
||||||
async def no_result_found_handler(request: Request, exc: NoResultFound):
|
logger.error(f"{type(exc).__name__}", exc_info=exc)
|
||||||
logger.error(f"NoResultFound: {exc}")
|
if SENTRY_ENABLED:
|
||||||
|
sentry_sdk.capture_exception(exc)
|
||||||
|
|
||||||
|
if not detail:
|
||||||
|
detail = str(exc)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=404,
|
status_code=code,
|
||||||
content={"detail": str(exc)},
|
content={"detail": detail},
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.exception_handler(ForeignKeyConstraintViolationError)
|
_error_handler_400 = partial(error_handler_with_code, code=400)
|
||||||
async def foreign_key_constraint_handler(request: Request, exc: ForeignKeyConstraintViolationError):
|
_error_handler_404 = partial(error_handler_with_code, code=404)
|
||||||
logger.error(f"ForeignKeyConstraintViolationError: {exc}")
|
_error_handler_404_agent = partial(_error_handler_404, detail="Agent not found")
|
||||||
|
_error_handler_404_user = partial(_error_handler_404, detail="User not found")
|
||||||
|
_error_handler_409 = partial(error_handler_with_code, code=409)
|
||||||
|
|
||||||
|
app.add_exception_handler(ValueError, _error_handler_400)
|
||||||
|
app.add_exception_handler(NoResultFound, _error_handler_404)
|
||||||
|
app.add_exception_handler(LettaAgentNotFoundError, _error_handler_404_agent)
|
||||||
|
app.add_exception_handler(LettaUserNotFoundError, _error_handler_404_user)
|
||||||
|
app.add_exception_handler(ForeignKeyConstraintViolationError, _error_handler_409)
|
||||||
|
app.add_exception_handler(UniqueConstraintViolationError, _error_handler_409)
|
||||||
|
|
||||||
|
@app.exception_handler(IncompatibleAgentType)
|
||||||
|
async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType):
|
||||||
|
logger.error("Incompatible agent types. Expected: %s, Actual: %s", exc.expected_type, exc.actual_type)
|
||||||
|
if SENTRY_ENABLED:
|
||||||
|
sentry_sdk.capture_exception(exc)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=409,
|
status_code=400,
|
||||||
content={"detail": str(exc)},
|
content={
|
||||||
)
|
"detail": str(exc),
|
||||||
|
"expected_type": exc.expected_type,
|
||||||
@app.exception_handler(UniqueConstraintViolationError)
|
"actual_type": exc.actual_type,
|
||||||
async def unique_key_constraint_handler(request: Request, exc: UniqueConstraintViolationError):
|
},
|
||||||
logger.error(f"UniqueConstraintViolationError: {exc}")
|
|
||||||
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=409,
|
|
||||||
content={"detail": str(exc)},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.exception_handler(DatabaseTimeoutError)
|
@app.exception_handler(DatabaseTimeoutError)
|
||||||
async def database_timeout_error_handler(request: Request, exc: DatabaseTimeoutError):
|
async def database_timeout_error_handler(request: Request, exc: DatabaseTimeoutError):
|
||||||
logger.error(f"Timeout occurred: {exc}. Original exception: {exc.original_exception}")
|
logger.error(f"Timeout occurred: {exc}. Original exception: {exc.original_exception}")
|
||||||
|
if SENTRY_ENABLED:
|
||||||
|
sentry_sdk.capture_exception(exc)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=503,
|
status_code=503,
|
||||||
content={"detail": "The database is temporarily unavailable. Please try again later."},
|
content={"detail": "The database is temporarily unavailable. Please try again later."},
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.exception_handler(ValueError)
|
|
||||||
async def value_error_handler(request: Request, exc: ValueError):
|
|
||||||
return JSONResponse(status_code=400, content={"detail": str(exc)})
|
|
||||||
|
|
||||||
@app.exception_handler(LettaAgentNotFoundError)
|
|
||||||
async def agent_not_found_handler(request: Request, exc: LettaAgentNotFoundError):
|
|
||||||
return JSONResponse(status_code=404, content={"detail": "Agent not found"})
|
|
||||||
|
|
||||||
@app.exception_handler(LettaUserNotFoundError)
|
|
||||||
async def user_not_found_handler(request: Request, exc: LettaUserNotFoundError):
|
|
||||||
return JSONResponse(status_code=404, content={"detail": "User not found"})
|
|
||||||
|
|
||||||
@app.exception_handler(BedrockPermissionError)
|
@app.exception_handler(BedrockPermissionError)
|
||||||
async def bedrock_permission_error_handler(request, exc: BedrockPermissionError):
|
async def bedrock_permission_error_handler(request, exc: BedrockPermissionError):
|
||||||
|
logger.error(f"Bedrock permission denied.")
|
||||||
|
if SENTRY_ENABLED:
|
||||||
|
sentry_sdk.capture_exception(exc)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
content={
|
content={
|
||||||
"error": {
|
"error": {
|
||||||
"type": "bedrock_permission_denied",
|
"type": "bedrock_permission_denied",
|
||||||
"message": "Unable to access the required AI model. Please check your Bedrock permissions or contact support.",
|
"message": "Unable to access the required AI model. Please check your Bedrock permissions or contact support.",
|
||||||
"details": {"model_arn": exc.model_arn, "reason": str(exc)},
|
"detail": {str(exc)},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -301,6 +285,9 @@ def create_application() -> "FastAPI":
|
|||||||
print(f"▶ Using secure mode with password: {random_password}")
|
print(f"▶ Using secure mode with password: {random_password}")
|
||||||
app.add_middleware(CheckPasswordMiddleware)
|
app.add_middleware(CheckPasswordMiddleware)
|
||||||
|
|
||||||
|
# Add reverse proxy middleware to handle X-Forwarded-* headers
|
||||||
|
# app.add_middleware(ReverseProxyMiddleware, base_path=settings.server_base_path)
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=settings.cors_origins,
|
allow_origins=settings.cors_origins,
|
||||||
@@ -442,7 +429,7 @@ def start_server(
|
|||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if is_windows:
|
if IS_WINDOWS:
|
||||||
# Windows doesn't those the fancy unicode characters
|
# Windows doesn't those the fancy unicode characters
|
||||||
print(f"Server running at: http://{host or 'localhost'}:{port or REST_DEFAULT_PORT}")
|
print(f"Server running at: http://{host or 'localhost'}:{port or REST_DEFAULT_PORT}")
|
||||||
print(f"View using ADE at: https://app.letta.com/development-servers/local/dashboard\n")
|
print(f"View using ADE at: https://app.letta.com/development-servers/local/dashboard\n")
|
||||||
|
|||||||
@@ -636,6 +636,9 @@ async def list_messages(
|
|||||||
use_assistant_message: bool = Query(True, description="Whether to use assistant messages"),
|
use_assistant_message: bool = Query(True, description="Whether to use assistant messages"),
|
||||||
assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."),
|
assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."),
|
||||||
assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."),
|
assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."),
|
||||||
|
include_err: bool | None = Query(
|
||||||
|
None, description="Whether to include error messages and error statuses. For debugging purposes only."
|
||||||
|
),
|
||||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -654,6 +657,7 @@ async def list_messages(
|
|||||||
use_assistant_message=use_assistant_message,
|
use_assistant_message=use_assistant_message,
|
||||||
assistant_message_tool_name=assistant_message_tool_name,
|
assistant_message_tool_name=assistant_message_tool_name,
|
||||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||||
|
include_err=include_err,
|
||||||
actor=actor,
|
actor=actor,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1156,7 +1160,7 @@ async def list_agent_groups(
|
|||||||
):
|
):
|
||||||
"""Lists the groups for an agent"""
|
"""Lists the groups for an agent"""
|
||||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||||
print("in list agents with manager_type", manager_type)
|
logger.info("in list agents with manager_type", manager_type)
|
||||||
return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor)
|
return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from starlette.types import Send
|
|||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.schemas.enums import JobStatus
|
from letta.schemas.enums import JobStatus
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
|
from letta.server.rest_api.utils import capture_sentry_exception
|
||||||
from letta.services.job_manager import JobManager
|
from letta.services.job_manager import JobManager
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -92,6 +93,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
|
|||||||
more_body = True
|
more_body = True
|
||||||
try:
|
try:
|
||||||
first_chunk = await self.body_iterator.__anext__()
|
first_chunk = await self.body_iterator.__anext__()
|
||||||
|
logger.debug("stream_response first chunk:", first_chunk)
|
||||||
if isinstance(first_chunk, tuple):
|
if isinstance(first_chunk, tuple):
|
||||||
first_chunk_content, self.status_code = first_chunk
|
first_chunk_content, self.status_code = first_chunk
|
||||||
else:
|
else:
|
||||||
@@ -130,7 +132,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
|
|||||||
"more_body": more_body,
|
"more_body": more_body,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return
|
raise Exception(f"An exception occurred mid-stream with status code {status_code}", detail={"content": content})
|
||||||
else:
|
else:
|
||||||
content = chunk
|
content = chunk
|
||||||
|
|
||||||
@@ -146,8 +148,8 @@ class StreamingResponseWithStatusCode(StreamingResponse):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# This should be handled properly upstream?
|
# This should be handled properly upstream?
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError as exc:
|
||||||
logger.info("Stream was cancelled by client or job cancellation")
|
logger.warning("Stream was cancelled by client or job cancellation")
|
||||||
# Handle cancellation gracefully
|
# Handle cancellation gracefully
|
||||||
more_body = False
|
more_body = False
|
||||||
cancellation_resp = {"error": {"message": "Stream cancelled"}}
|
cancellation_resp = {"error": {"message": "Stream cancelled"}}
|
||||||
@@ -160,6 +162,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
|
|||||||
"headers": self.raw_headers,
|
"headers": self.raw_headers,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
raise
|
||||||
await send(
|
await send(
|
||||||
{
|
{
|
||||||
"type": "http.response.body",
|
"type": "http.response.body",
|
||||||
@@ -167,13 +170,15 @@ class StreamingResponseWithStatusCode(StreamingResponse):
|
|||||||
"more_body": more_body,
|
"more_body": more_body,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
capture_sentry_exception(exc)
|
||||||
return
|
return
|
||||||
|
|
||||||
except Exception:
|
except Exception as exc:
|
||||||
logger.exception("unhandled_streaming_error")
|
logger.exception("Unhandled Streaming Error")
|
||||||
more_body = False
|
more_body = False
|
||||||
error_resp = {"error": {"message": "Internal Server Error"}}
|
error_resp = {"error": {"message": "Internal Server Error"}}
|
||||||
error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset)
|
error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset)
|
||||||
|
logger.debug("response_started:", self.response_started)
|
||||||
if not self.response_started:
|
if not self.response_started:
|
||||||
await send(
|
await send(
|
||||||
{
|
{
|
||||||
@@ -182,6 +187,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
|
|||||||
"headers": self.raw_headers,
|
"headers": self.raw_headers,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
raise
|
||||||
await send(
|
await send(
|
||||||
{
|
{
|
||||||
"type": "http.response.body",
|
"type": "http.response.body",
|
||||||
@@ -189,5 +195,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
|
|||||||
"more_body": more_body,
|
"more_body": more_body,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
capture_sentry_exception(exc)
|
||||||
|
return
|
||||||
if more_body:
|
if more_body:
|
||||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Iterable, List, Optional, Union, cast
|
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Iterable, List, Optional, Union, cast
|
||||||
|
|
||||||
@@ -34,12 +33,15 @@ from letta.schemas.message import Message, MessageCreate, ToolReturn
|
|||||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
from letta.schemas.usage import LettaUsageStatistics
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
from letta.server.rest_api.interface import StreamingServerInterface
|
|
||||||
from letta.system import get_heartbeat, package_function_response
|
from letta.system import get_heartbeat, package_function_response
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from letta.server.server import SyncServer
|
from letta.server.server import SyncServer
|
||||||
|
|
||||||
|
SENTRY_ENABLED = bool(os.getenv("SENTRY_DSN"))
|
||||||
|
|
||||||
|
if SENTRY_ENABLED:
|
||||||
|
import sentry_sdk
|
||||||
|
|
||||||
SSE_PREFIX = "data: "
|
SSE_PREFIX = "data: "
|
||||||
SSE_SUFFIX = "\n\n"
|
SSE_SUFFIX = "\n\n"
|
||||||
@@ -157,21 +159,9 @@ def get_user_id(user_id: Optional[str] = Header(None, alias="user_id")) -> Optio
|
|||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
|
|
||||||
def get_current_interface() -> StreamingServerInterface:
|
def capture_sentry_exception(e: BaseException):
|
||||||
return StreamingServerInterface
|
"""This will capture the exception in sentry, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response"""
|
||||||
|
if SENTRY_ENABLED:
|
||||||
|
|
||||||
def log_error_to_sentry(e):
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
warnings.warn(f"SSE stream generator failed: {e}")
|
|
||||||
|
|
||||||
# Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response
|
|
||||||
# Print the stack trace
|
|
||||||
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
|
|
||||||
import sentry_sdk
|
|
||||||
|
|
||||||
sentry_sdk.capture_exception(e)
|
sentry_sdk.capture_exception(e)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -105,6 +105,7 @@ from letta.services.tool_executor.tool_execution_manager import ToolExecutionMan
|
|||||||
from letta.services.tool_manager import ToolManager
|
from letta.services.tool_manager import ToolManager
|
||||||
from letta.services.user_manager import UserManager
|
from letta.services.user_manager import UserManager
|
||||||
from letta.settings import model_settings, settings, tool_settings
|
from letta.settings import model_settings, settings, tool_settings
|
||||||
|
from letta.streaming_interface import AgentChunkStreamingInterface
|
||||||
from letta.utils import get_friendly_error_msg, get_persona_text, make_key
|
from letta.utils import get_friendly_error_msg, get_persona_text, make_key
|
||||||
|
|
||||||
config = LettaConfig.load()
|
config = LettaConfig.load()
|
||||||
@@ -176,7 +177,7 @@ class SyncServer(Server):
|
|||||||
self,
|
self,
|
||||||
chaining: bool = True,
|
chaining: bool = True,
|
||||||
max_chaining_steps: Optional[int] = 100,
|
max_chaining_steps: Optional[int] = 100,
|
||||||
default_interface_factory: Callable[[], AgentInterface] = lambda: CLIInterface(),
|
default_interface_factory: Callable[[], AgentChunkStreamingInterface] = lambda: CLIInterface(),
|
||||||
init_with_default_org_and_user: bool = True,
|
init_with_default_org_and_user: bool = True,
|
||||||
# default_interface: AgentInterface = CLIInterface(),
|
# default_interface: AgentInterface = CLIInterface(),
|
||||||
# default_persistence_manager_cls: PersistenceManager = LocalStateManager,
|
# default_persistence_manager_cls: PersistenceManager = LocalStateManager,
|
||||||
@@ -1244,6 +1245,7 @@ class SyncServer(Server):
|
|||||||
use_assistant_message: bool = True,
|
use_assistant_message: bool = True,
|
||||||
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
|
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
|
||||||
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
|
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
|
||||||
|
include_err: Optional[bool] = None,
|
||||||
) -> Union[List[Message], List[LettaMessage]]:
|
) -> Union[List[Message], List[LettaMessage]]:
|
||||||
records = await self.message_manager.list_messages_for_agent_async(
|
records = await self.message_manager.list_messages_for_agent_async(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
@@ -1253,6 +1255,7 @@ class SyncServer(Server):
|
|||||||
limit=limit,
|
limit=limit,
|
||||||
ascending=not reverse,
|
ascending=not reverse,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
|
include_err=include_err,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not return_message_object:
|
if not return_message_object:
|
||||||
@@ -1262,6 +1265,7 @@ class SyncServer(Server):
|
|||||||
assistant_message_tool_name=assistant_message_tool_name,
|
assistant_message_tool_name=assistant_message_tool_name,
|
||||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||||
reverse=reverse,
|
reverse=reverse,
|
||||||
|
include_err=include_err,
|
||||||
)
|
)
|
||||||
|
|
||||||
if reverse:
|
if reverse:
|
||||||
|
|||||||
@@ -520,6 +520,7 @@ class MessageManager:
|
|||||||
limit: Optional[int] = 50,
|
limit: Optional[int] = 50,
|
||||||
ascending: bool = True,
|
ascending: bool = True,
|
||||||
group_id: Optional[str] = None,
|
group_id: Optional[str] = None,
|
||||||
|
include_err: Optional[bool] = None,
|
||||||
) -> List[PydanticMessage]:
|
) -> List[PydanticMessage]:
|
||||||
"""
|
"""
|
||||||
Most performant query to list messages for an agent by directly querying the Message table.
|
Most performant query to list messages for an agent by directly querying the Message table.
|
||||||
@@ -539,6 +540,7 @@ class MessageManager:
|
|||||||
limit: Maximum number of messages to return.
|
limit: Maximum number of messages to return.
|
||||||
ascending: If True, sort by sequence_id ascending; if False, sort descending.
|
ascending: If True, sort by sequence_id ascending; if False, sort descending.
|
||||||
group_id: Optional group ID to filter messages by group_id.
|
group_id: Optional group ID to filter messages by group_id.
|
||||||
|
include_err: Optional boolean to include errors and error statuses. Used for debugging only.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[PydanticMessage]: A list of messages (converted via .to_pydantic()).
|
List[PydanticMessage]: A list of messages (converted via .to_pydantic()).
|
||||||
@@ -558,6 +560,9 @@ class MessageManager:
|
|||||||
if group_id:
|
if group_id:
|
||||||
query = query.where(MessageModel.group_id == group_id)
|
query = query.where(MessageModel.group_id == group_id)
|
||||||
|
|
||||||
|
if not include_err:
|
||||||
|
query = query.where((MessageModel.is_err == False) | (MessageModel.is_err.is_(None)))
|
||||||
|
|
||||||
# If query_text is provided, filter messages using database-specific JSON search.
|
# If query_text is provided, filter messages using database-specific JSON search.
|
||||||
if query_text:
|
if query_text:
|
||||||
if settings.letta_pg_uri_no_default:
|
if settings.letta_pg_uri_no_default:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from letta.orm.job import Job as JobModel
|
|||||||
from letta.orm.sqlalchemy_base import AccessType
|
from letta.orm.sqlalchemy_base import AccessType
|
||||||
from letta.orm.step import Step as StepModel
|
from letta.orm.step import Step as StepModel
|
||||||
from letta.otel.tracing import get_trace_id, trace_method
|
from letta.otel.tracing import get_trace_id, trace_method
|
||||||
|
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||||
from letta.schemas.step import Step as PydanticStep
|
from letta.schemas.step import Step as PydanticStep
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
@@ -131,6 +132,7 @@ class StepManager:
|
|||||||
job_id: Optional[str] = None,
|
job_id: Optional[str] = None,
|
||||||
step_id: Optional[str] = None,
|
step_id: Optional[str] = None,
|
||||||
project_id: Optional[str] = None,
|
project_id: Optional[str] = None,
|
||||||
|
stop_reason: Optional[LettaStopReason] = None,
|
||||||
) -> PydanticStep:
|
) -> PydanticStep:
|
||||||
step_data = {
|
step_data = {
|
||||||
"origin": None,
|
"origin": None,
|
||||||
@@ -153,6 +155,8 @@ class StepManager:
|
|||||||
}
|
}
|
||||||
if step_id:
|
if step_id:
|
||||||
step_data["id"] = step_id
|
step_data["id"] = step_id
|
||||||
|
if stop_reason:
|
||||||
|
step_data["stop_reason"] = stop_reason.stop_reason
|
||||||
async with db_registry.async_session() as session:
|
async with db_registry.async_session() as session:
|
||||||
if job_id:
|
if job_id:
|
||||||
await self._verify_job_access_async(session, job_id, actor, access=["write"])
|
await self._verify_job_access_async(session, job_id, actor, access=["write"])
|
||||||
@@ -207,6 +211,33 @@ class StepManager:
|
|||||||
await session.commit()
|
await session.commit()
|
||||||
return step.to_pydantic()
|
return step.to_pydantic()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
@trace_method
|
||||||
|
async def update_step_stop_reason(self, actor: PydanticUser, step_id: str, stop_reason: StopReasonType) -> PydanticStep:
|
||||||
|
"""Update the stop reason for a step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actor: The user making the request
|
||||||
|
step_id: The ID of the step to update
|
||||||
|
stop_reason: The stop reason to set
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated step
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NoResultFound: If the step does not exist
|
||||||
|
"""
|
||||||
|
async with db_registry.async_session() as session:
|
||||||
|
step = await session.get(StepModel, step_id)
|
||||||
|
if not step:
|
||||||
|
raise NoResultFound(f"Step with id {step_id} does not exist")
|
||||||
|
if step.organization_id != actor.organization_id:
|
||||||
|
raise Exception("Unauthorized")
|
||||||
|
|
||||||
|
step.stop_reason = stop_reason
|
||||||
|
await session.commit()
|
||||||
|
return step
|
||||||
|
|
||||||
def _verify_job_access(
|
def _verify_job_access(
|
||||||
self,
|
self,
|
||||||
session: Session,
|
session: Session,
|
||||||
@@ -309,5 +340,6 @@ class NoopStepManager(StepManager):
|
|||||||
job_id: Optional[str] = None,
|
job_id: Optional[str] = None,
|
||||||
step_id: Optional[str] = None,
|
step_id: Optional[str] = None,
|
||||||
project_id: Optional[str] = None,
|
project_id: Optional[str] = None,
|
||||||
|
stop_reason: Optional[LettaStopReason] = None,
|
||||||
) -> PydanticStep:
|
) -> PydanticStep:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -220,13 +220,15 @@ class Settings(BaseSettings):
|
|||||||
multi_agent_concurrent_sends: int = 50
|
multi_agent_concurrent_sends: int = 50
|
||||||
|
|
||||||
# telemetry logging
|
# telemetry logging
|
||||||
otel_exporter_otlp_endpoint: Optional[str] = None # otel default: "http://localhost:4317"
|
otel_exporter_otlp_endpoint: str | None = None # otel default: "http://localhost:4317"
|
||||||
otel_preferred_temporality: Optional[int] = Field(
|
otel_preferred_temporality: int | None = Field(
|
||||||
default=1, ge=0, le=2, description="Exported metric temporality. {0: UNSPECIFIED, 1: DELTA, 2: CUMULATIVE}"
|
default=1, ge=0, le=2, description="Exported metric temporality. {0: UNSPECIFIED, 1: DELTA, 2: CUMULATIVE}"
|
||||||
)
|
)
|
||||||
disable_tracing: bool = Field(default=False, description="Disable OTEL Tracing")
|
disable_tracing: bool = Field(default=False, description="Disable OTEL Tracing")
|
||||||
llm_api_logging: bool = Field(default=True, description="Enable LLM API logging at each step")
|
llm_api_logging: bool = Field(default=True, description="Enable LLM API logging at each step")
|
||||||
track_last_agent_run: bool = Field(default=False, description="Update last agent run metrics")
|
track_last_agent_run: bool = Field(default=False, description="Update last agent run metrics")
|
||||||
|
track_errored_messages: bool = Field(default=True, description="Enable tracking for errored messages")
|
||||||
|
track_stop_reason: bool = Field(default=True, description="Enable tracking stop reason on steps.")
|
||||||
|
|
||||||
# uvicorn settings
|
# uvicorn settings
|
||||||
uvicorn_workers: int = 1
|
uvicorn_workers: int = 1
|
||||||
|
|||||||
@@ -752,14 +752,15 @@ def test_step_stream_agent_loop_error(
|
|||||||
"""
|
"""
|
||||||
last_message = client.agents.messages.list(agent_id=agent_state_no_tools.id, limit=1)
|
last_message = client.agents.messages.list(agent_id=agent_state_no_tools.id, limit=1)
|
||||||
agent_state_no_tools = client.agents.modify(agent_id=agent_state_no_tools.id, llm_config=llm_config)
|
agent_state_no_tools = client.agents.modify(agent_id=agent_state_no_tools.id, llm_config=llm_config)
|
||||||
|
response = client.agents.messages.create_stream(
|
||||||
|
agent_id=agent_state_no_tools.id,
|
||||||
|
messages=USER_MESSAGE_FORCE_REPLY,
|
||||||
|
)
|
||||||
with pytest.raises(Exception) as exc_info:
|
with pytest.raises(Exception) as exc_info:
|
||||||
response = client.agents.messages.create_stream(
|
for chunk in response:
|
||||||
agent_id=agent_state_no_tools.id,
|
print(chunk)
|
||||||
messages=USER_MESSAGE_FORCE_REPLY,
|
print("error info:", exc_info)
|
||||||
)
|
|
||||||
list(response)
|
|
||||||
assert type(exc_info.value) in (ApiError, ValueError)
|
assert type(exc_info.value) in (ApiError, ValueError)
|
||||||
print(exc_info.value)
|
|
||||||
messages_from_db = client.agents.messages.list(agent_id=agent_state_no_tools.id, after=last_message[0].id)
|
messages_from_db = client.agents.messages.list(agent_id=agent_state_no_tools.id, after=last_message[0].id)
|
||||||
assert len(messages_from_db) == 0
|
assert len(messages_from_db) == 0
|
||||||
|
|
||||||
|
|||||||
@@ -138,7 +138,9 @@ def test_max_count_per_step_tool_rule():
|
|||||||
assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "After first use, should still allow 'start_tool'"
|
assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "After first use, should still allow 'start_tool'"
|
||||||
|
|
||||||
solver.register_tool_call(START_TOOL)
|
solver.register_tool_call(START_TOOL)
|
||||||
assert solver.get_allowed_tool_names({START_TOOL}) == [], "After reaching max count, 'start_tool' should no longer be allowed"
|
assert (
|
||||||
|
solver.get_allowed_tool_names({START_TOOL}, error_on_empty=False) == []
|
||||||
|
), "After reaching max count, 'start_tool' should no longer be allowed"
|
||||||
|
|
||||||
|
|
||||||
def test_max_count_per_step_tool_rule_allows_usage_up_to_limit():
|
def test_max_count_per_step_tool_rule_allows_usage_up_to_limit():
|
||||||
@@ -155,7 +157,7 @@ def test_max_count_per_step_tool_rule_allows_usage_up_to_limit():
|
|||||||
assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Should still allow 'start_tool' after 2 uses"
|
assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Should still allow 'start_tool' after 2 uses"
|
||||||
|
|
||||||
solver.register_tool_call(START_TOOL)
|
solver.register_tool_call(START_TOOL)
|
||||||
assert solver.get_allowed_tool_names({START_TOOL}) == [], "Should no longer allow 'start_tool' after 3 uses"
|
assert solver.get_allowed_tool_names({START_TOOL}, error_on_empty=False) == [], "Should no longer allow 'start_tool' after 3 uses"
|
||||||
|
|
||||||
|
|
||||||
def test_max_count_per_step_tool_rule_does_not_affect_other_tools():
|
def test_max_count_per_step_tool_rule_does_not_affect_other_tools():
|
||||||
@@ -180,7 +182,7 @@ def test_max_count_per_step_tool_rule_resets_on_clear():
|
|||||||
solver.register_tool_call(START_TOOL)
|
solver.register_tool_call(START_TOOL)
|
||||||
solver.register_tool_call(START_TOOL)
|
solver.register_tool_call(START_TOOL)
|
||||||
|
|
||||||
assert solver.get_allowed_tool_names({START_TOOL}) == [], "Should not allow 'start_tool' after reaching limit"
|
assert solver.get_allowed_tool_names({START_TOOL}, error_on_empty=False) == [], "Should not allow 'start_tool' after reaching limit"
|
||||||
|
|
||||||
solver.clear_tool_history()
|
solver.clear_tool_history()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user