feat: support for agent loop job cancelation (#2837)

This commit is contained in:
Andy Li
2025-07-02 14:31:16 -07:00
committed by GitHub
parent e9f7601892
commit f9bb757a98
17 changed files with 940 additions and 281 deletions

View File

@@ -1,8 +1,9 @@
import asyncio
import json
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
from typing import Optional
from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk
@@ -34,7 +35,7 @@ from letta.otel.context import get_ctx_attributes
from letta.otel.metric_registry import MetricRegistry
from letta.otel.tracing import log_event, trace_method, tracer
from letta.schemas.agent import AgentState, UpdateAgent
from letta.schemas.enums import MessageRole, ProviderType
from letta.schemas.enums import JobStatus, MessageRole, ProviderType
from letta.schemas.letta_message import MessageType
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_response import LettaResponse
@@ -69,7 +70,6 @@ DEFAULT_SUMMARY_BLOCK_LABEL = "conversation_summary"
class LettaAgent(BaseAgent):
def __init__(
self,
agent_id: str,
@@ -81,6 +81,7 @@ class LettaAgent(BaseAgent):
actor: User,
step_manager: StepManager = NoopStepManager(),
telemetry_manager: TelemetryManager = NoopTelemetryManager(),
current_run_id: str | None = None,
summary_block_label: str = DEFAULT_SUMMARY_BLOCK_LABEL,
message_buffer_limit: int = summarizer_settings.message_buffer_limit,
message_buffer_min: int = summarizer_settings.message_buffer_min,
@@ -96,7 +97,9 @@ class LettaAgent(BaseAgent):
self.passage_manager = passage_manager
self.step_manager = step_manager
self.telemetry_manager = telemetry_manager
self.response_messages: List[Message] = []
self.job_manager = job_manager
self.current_run_id = current_run_id
self.response_messages: list[Message] = []
self.last_function_response = None
@@ -128,16 +131,35 @@ class LettaAgent(BaseAgent):
message_buffer_min=message_buffer_min,
)
async def _check_run_cancellation(self) -> bool:
"""
Check if the current run associated with this agent execution has been cancelled.
Returns:
True if the run is cancelled, False otherwise (or if no run is associated)
"""
if not self.job_manager or not self.current_run_id:
return False
try:
job = await self.job_manager.get_job_by_id_async(job_id=self.current_run_id, actor=self.actor)
return job.status == JobStatus.cancelled
except Exception as e:
# Log the error but don't fail the execution
logger.warning(f"Failed to check job cancellation status for job {self.current_run_id}: {e}")
return False
@trace_method
async def step(
self,
input_messages: List[MessageCreate],
input_messages: list[MessageCreate],
max_steps: int = DEFAULT_MAX_STEPS,
run_id: Optional[str] = None,
run_id: str | None = None,
use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None,
include_return_message_types: Optional[List[MessageType]] = None,
request_start_timestamp_ns: int | None = None,
include_return_message_types: list[MessageType] | None = None,
) -> LettaResponse:
# TODO (cliandy): pass in run_id and use at send_message endpoints for all step functions
agent_state = await self.agent_manager.get_agent_by_id_async(
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
)
@@ -159,11 +181,11 @@ class LettaAgent(BaseAgent):
@trace_method
async def step_stream_no_tokens(
self,
input_messages: List[MessageCreate],
input_messages: list[MessageCreate],
max_steps: int = DEFAULT_MAX_STEPS,
use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None,
include_return_message_types: Optional[List[MessageType]] = None,
request_start_timestamp_ns: int | None = None,
include_return_message_types: list[MessageType] | None = None,
):
agent_state = await self.agent_manager.get_agent_by_id_async(
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
@@ -186,6 +208,13 @@ class LettaAgent(BaseAgent):
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
for i in range(max_steps):
# Check for job cancellation at the start of each step
if await self._check_run_cancellation():
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
logger.info(f"Agent execution cancelled for run {self.current_run_id}")
yield f"data: {stop_reason.model_dump_json()}\n\n"
break
step_id = generate_step_id()
step_start = get_utc_timestamp_ns()
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
@@ -317,11 +346,11 @@ class LettaAgent(BaseAgent):
async def _step(
self,
agent_state: AgentState,
input_messages: List[MessageCreate],
input_messages: list[MessageCreate],
max_steps: int = DEFAULT_MAX_STEPS,
run_id: Optional[str] = None,
request_start_timestamp_ns: Optional[int] = None,
) -> Tuple[List[Message], List[Message], Optional[LettaStopReason], LettaUsageStatistics]:
run_id: str | None = None,
request_start_timestamp_ns: int | None = None,
) -> tuple[list[Message], list[Message], LettaStopReason | None, LettaUsageStatistics]:
"""
Carries out an invocation of the agent loop. In each step, the agent
1. Rebuilds its memory
@@ -347,6 +376,12 @@ class LettaAgent(BaseAgent):
stop_reason = None
usage = LettaUsageStatistics()
for i in range(max_steps):
# Check for job cancellation at the start of each step
if await self._check_run_cancellation():
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
logger.info(f"Agent execution cancelled for run {self.current_run_id}")
break
step_id = generate_step_id()
step_start = get_utc_timestamp_ns()
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
@@ -471,11 +506,11 @@ class LettaAgent(BaseAgent):
@trace_method
async def step_stream(
self,
input_messages: List[MessageCreate],
input_messages: list[MessageCreate],
max_steps: int = DEFAULT_MAX_STEPS,
use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None,
include_return_message_types: Optional[List[MessageType]] = None,
request_start_timestamp_ns: int | None = None,
include_return_message_types: list[MessageType] | None = None,
) -> AsyncGenerator[str, None]:
"""
Carries out an invocation of the agent loop in a streaming fashion that yields partial tokens.
@@ -507,6 +542,13 @@ class LettaAgent(BaseAgent):
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
for i in range(max_steps):
# Check for job cancellation at the start of each step
if await self._check_run_cancellation():
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
logger.info(f"Agent execution cancelled for run {self.current_run_id}")
yield f"data: {stop_reason.model_dump_json()}\n\n"
break
step_id = generate_step_id()
step_start = get_utc_timestamp_ns()
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
@@ -547,7 +589,9 @@ class LettaAgent(BaseAgent):
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
async for chunk in interface.process(
stream, ttft_span=request_span, provider_request_start_timestamp_ns=provider_request_start_timestamp_ns
stream,
ttft_span=request_span,
provider_request_start_timestamp_ns=provider_request_start_timestamp_ns,
):
# Measure time to first token
if first_chunk and request_span is not None:
@@ -690,13 +734,13 @@ class LettaAgent(BaseAgent):
# noinspection PyInconsistentReturns
async def _build_and_request_from_llm(
self,
current_in_context_messages: List[Message],
new_in_context_messages: List[Message],
current_in_context_messages: list[Message],
new_in_context_messages: list[Message],
agent_state: AgentState,
llm_client: LLMClientBase,
tool_rules_solver: ToolRulesSolver,
agent_step_span: "Span",
) -> Tuple[Dict, Dict, List[Message], List[Message], List[str]] | None:
) -> tuple[dict, dict, list[Message], list[Message], list[str]] | None:
for attempt in range(self.max_summarization_retries + 1):
try:
log_event("agent.stream_no_tokens.messages.refreshed")
@@ -742,12 +786,12 @@ class LettaAgent(BaseAgent):
first_chunk: bool,
ttft_span: "Span",
request_start_timestamp_ns: int,
current_in_context_messages: List[Message],
new_in_context_messages: List[Message],
current_in_context_messages: list[Message],
new_in_context_messages: list[Message],
agent_state: AgentState,
llm_client: LLMClientBase,
tool_rules_solver: ToolRulesSolver,
) -> Tuple[Dict, AsyncStream[ChatCompletionChunk], List[Message], List[Message], List[str], int] | None:
) -> tuple[dict, AsyncStream[ChatCompletionChunk], list[Message], list[Message], list[str], int] | None:
for attempt in range(self.max_summarization_retries + 1):
try:
log_event("agent.stream_no_tokens.messages.refreshed")
@@ -799,11 +843,11 @@ class LettaAgent(BaseAgent):
self,
e: Exception,
llm_client: LLMClientBase,
in_context_messages: List[Message],
new_letta_messages: List[Message],
in_context_messages: list[Message],
new_letta_messages: list[Message],
llm_config: LLMConfig,
force: bool,
) -> List[Message]:
) -> list[Message]:
if isinstance(e, ContextWindowExceededError):
return await self._rebuild_context_window(
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages, llm_config=llm_config, force=force
@@ -814,12 +858,12 @@ class LettaAgent(BaseAgent):
@trace_method
async def _rebuild_context_window(
self,
in_context_messages: List[Message],
new_letta_messages: List[Message],
in_context_messages: list[Message],
new_letta_messages: list[Message],
llm_config: LLMConfig,
total_tokens: Optional[int] = None,
total_tokens: int | None = None,
force: bool = False,
) -> List[Message]:
) -> list[Message]:
# If total tokens is reached, we truncate down
# TODO: This can be broken by bad configs, e.g. lower bound too high, initial messages too fat, etc.
if force or (total_tokens and total_tokens > llm_config.context_window):
@@ -855,10 +899,10 @@ class LettaAgent(BaseAgent):
async def _create_llm_request_data_async(
self,
llm_client: LLMClientBase,
in_context_messages: List[Message],
in_context_messages: list[Message],
agent_state: AgentState,
tool_rules_solver: ToolRulesSolver,
) -> Tuple[dict, List[str]]:
) -> tuple[dict, list[str]]:
self.num_messages, self.num_archival_memories = await asyncio.gather(
(
self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
@@ -929,18 +973,18 @@ class LettaAgent(BaseAgent):
async def _handle_ai_response(
self,
tool_call: ToolCall,
valid_tool_names: List[str],
valid_tool_names: list[str],
agent_state: AgentState,
tool_rules_solver: ToolRulesSolver,
usage: UsageStatistics,
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
pre_computed_assistant_message_id: Optional[str] = None,
reasoning_content: list[TextContent | ReasoningContent | RedactedReasoningContent | OmittedReasoningContent] | None = None,
pre_computed_assistant_message_id: str | None = None,
step_id: str | None = None,
initial_messages: Optional[List[Message]] = None,
initial_messages: list[Message] | None = None,
agent_step_span: Optional["Span"] = None,
is_final_step: Optional[bool] = None,
run_id: Optional[str] = None,
) -> Tuple[List[Message], bool, Optional[LettaStopReason]]:
is_final_step: bool | None = None,
run_id: str | None = None,
) -> tuple[list[Message], bool, LettaStopReason | None]:
"""
Handle the final AI response once streaming completes, execute / validate the
tool call, decide whether we should keep stepping, and persist state.
@@ -1016,7 +1060,7 @@ class LettaAgent(BaseAgent):
context_window_limit=agent_state.llm_config.context_window,
usage=usage,
provider_id=None,
job_id=run_id,
job_id=run_id if run_id else self.current_run_id,
step_id=step_id,
project_id=agent_state.project_id,
)
@@ -1155,7 +1199,7 @@ class LettaAgent(BaseAgent):
name="tool_execution_completed",
attributes={
"tool_name": target_tool.name,
"duration_ms": ns_to_ms((end_time - start_time)),
"duration_ms": ns_to_ms(end_time - start_time),
"success": tool_execution_result.success_flag,
"tool_type": target_tool.tool_type,
"tool_id": target_tool.id,
@@ -1165,7 +1209,7 @@ class LettaAgent(BaseAgent):
return tool_execution_result
@trace_method
def _load_last_function_response(self, in_context_messages: List[Message]):
def _load_last_function_response(self, in_context_messages: list[Message]):
"""Load the last function response from message history"""
for msg in reversed(in_context_messages):
if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent):

View File

@@ -358,6 +358,7 @@ REDIS_INCLUDE = "include"
REDIS_EXCLUDE = "exclude"
REDIS_SET_DEFAULT_VAL = "None"
REDIS_DEFAULT_CACHE_PREFIX = "letta_cache"
REDIS_RUN_ID_PREFIX = "agent:send_message:run_id"
# TODO: This is temporary, eventually use token-based eviction
MAX_FILES_OPEN = 5

View File

@@ -1,6 +1,6 @@
import asyncio
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from typing import AsyncGenerator, List, Optional
from letta.agents.base_agent import BaseAgent
from letta.agents.letta_agent import LettaAgent
@@ -39,7 +39,8 @@ class SleeptimeMultiAgentV2(BaseAgent):
actor: User,
step_manager: StepManager = NoopStepManager(),
telemetry_manager: TelemetryManager = NoopTelemetryManager(),
group: Optional[Group] = None,
group: Group | None = None,
current_run_id: str | None = None,
):
super().__init__(
agent_id=agent_id,
@@ -54,6 +55,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
self.job_manager = job_manager
self.step_manager = step_manager
self.telemetry_manager = telemetry_manager
self.current_run_id = current_run_id
# Group settings
assert group.manager_type == ManagerType.sleeptime, f"Expected group manager type to be 'sleeptime', got {group.manager_type}"
self.group = group
@@ -61,12 +63,12 @@ class SleeptimeMultiAgentV2(BaseAgent):
@trace_method
async def step(
self,
input_messages: List[MessageCreate],
input_messages: list[MessageCreate],
max_steps: int = DEFAULT_MAX_STEPS,
run_id: Optional[str] = None,
run_id: str | None = None,
use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None,
include_return_message_types: Optional[List[MessageType]] = None,
request_start_timestamp_ns: int | None = None,
include_return_message_types: list[MessageType] | None = None,
) -> LettaResponse:
run_ids = []
@@ -89,6 +91,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
actor=self.actor,
step_manager=self.step_manager,
telemetry_manager=self.telemetry_manager,
current_run_id=self.current_run_id,
)
# Perform foreground agent step
response = await foreground_agent.step(
@@ -125,7 +128,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
except Exception as e:
# Individual task failures
print(f"Agent processing failed: {str(e)}")
print(f"Agent processing failed: {e!s}")
raise e
response.usage.run_ids = run_ids
@@ -134,11 +137,11 @@ class SleeptimeMultiAgentV2(BaseAgent):
@trace_method
async def step_stream_no_tokens(
self,
input_messages: List[MessageCreate],
input_messages: list[MessageCreate],
max_steps: int = DEFAULT_MAX_STEPS,
use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None,
include_return_message_types: Optional[List[MessageType]] = None,
request_start_timestamp_ns: int | None = None,
include_return_message_types: list[MessageType] | None = None,
):
response = await self.step(
input_messages=input_messages,
@@ -157,11 +160,11 @@ class SleeptimeMultiAgentV2(BaseAgent):
@trace_method
async def step_stream(
self,
input_messages: List[MessageCreate],
input_messages: list[MessageCreate],
max_steps: int = DEFAULT_MAX_STEPS,
use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None,
include_return_message_types: Optional[List[MessageType]] = None,
request_start_timestamp_ns: int | None = None,
include_return_message_types: list[MessageType] | None = None,
) -> AsyncGenerator[str, None]:
# Prepare new messages
new_messages = []
@@ -182,6 +185,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
actor=self.actor,
step_manager=self.step_manager,
telemetry_manager=self.telemetry_manager,
current_run_id=self.current_run_id,
)
# Perform foreground agent step
async for chunk in foreground_agent.step_stream(
@@ -218,7 +222,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
async def _issue_background_task(
self,
sleeptime_agent_id: str,
response_messages: List[Message],
response_messages: list[Message],
last_processed_message_id: str,
use_assistant_message: bool = True,
) -> str:
@@ -248,7 +252,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
self,
foreground_agent_id: str,
sleeptime_agent_id: str,
response_messages: List[Message],
response_messages: list[Message],
last_processed_message_id: str,
run_id: str,
use_assistant_message: bool = True,
@@ -296,6 +300,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
actor=self.actor,
step_manager=self.step_manager,
telemetry_manager=self.telemetry_manager,
current_run_id=self.current_run_id,
message_buffer_limit=20, # TODO: Make this configurable
message_buffer_min=8, # TODO: Make this configurable
enable_summarization=False, # TODO: Make this configurable

View File

@@ -81,7 +81,7 @@ class CLIInterface(AgentInterface):
@staticmethod
def internal_monologue(msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
# ANSI escape code for italic is '\x1B[3m'
fstr = f"\x1B[3m{Fore.LIGHTBLACK_EX}{INNER_THOUGHTS_CLI_SYMBOL} {{msg}}{Style.RESET_ALL}"
fstr = f"\x1b[3m{Fore.LIGHTBLACK_EX}{INNER_THOUGHTS_CLI_SYMBOL} {{msg}}{Style.RESET_ALL}"
if STRIP_UI:
fstr = "{msg}"
print(fstr.format(msg=msg))

View File

@@ -1,7 +1,9 @@
import asyncio
import json
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from enum import Enum
from typing import AsyncGenerator, List, Optional, Union
from typing import Optional
from anthropic import AsyncStream
from anthropic.types.beta import (
@@ -131,14 +133,16 @@ class AnthropicStreamingInterface:
self,
stream: AsyncStream[BetaRawMessageStreamEvent],
ttft_span: Optional["Span"] = None,
provider_request_start_timestamp_ns: Optional[int] = None,
) -> AsyncGenerator[LettaMessage, None]:
provider_request_start_timestamp_ns: int | None = None,
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
prev_message_type = None
message_index = 0
first_chunk = True
try:
async with stream:
async for event in stream:
# TODO (cliandy): reconsider in stream cancellations
# await cancellation_token.check_and_raise_if_cancelled()
if first_chunk and ttft_span is not None and provider_request_start_timestamp_ns is not None:
now = get_utc_timestamp_ns()
ttft_ns = now - provider_request_start_timestamp_ns
@@ -384,18 +388,21 @@ class AnthropicStreamingInterface:
self.tool_call_buffer = []
self.anthropic_mode = None
except asyncio.CancelledError as e:
logger.info("Cancelled stream %s", e)
yield LettaStopReason(stop_reason=StopReasonType.cancelled)
raise
except Exception as e:
logger.error("Error processing stream: %s", e)
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
yield stop_reason
yield LettaStopReason(stop_reason=StopReasonType.error)
raise
finally:
logger.info("AnthropicStreamingInterface: Stream processing complete.")
def get_reasoning_content(self) -> List[Union[TextContent, ReasoningContent, RedactedReasoningContent]]:
def get_reasoning_content(self) -> list[TextContent | ReasoningContent | RedactedReasoningContent]:
def _process_group(
group: List[Union[ReasoningMessage, HiddenReasoningMessage]], group_type: str
) -> Union[TextContent, ReasoningContent, RedactedReasoningContent]:
group: list[ReasoningMessage | HiddenReasoningMessage], group_type: str
) -> TextContent | ReasoningContent | RedactedReasoningContent:
if group_type == "reasoning":
reasoning_text = "".join(chunk.reasoning for chunk in group).strip()
is_native = any(chunk.source == "reasoner_model" for chunk in group)

View File

@@ -1,4 +1,5 @@
from typing import Any, AsyncGenerator, Dict, List, Optional
from collections.abc import AsyncGenerator
from typing import Any
from openai import AsyncStream
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta
@@ -19,14 +20,14 @@ class OpenAIChatCompletionsStreamingInterface:
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
self.stream_pre_execution_message: bool = stream_pre_execution_message
self.current_parsed_json_result: Dict[str, Any] = {}
self.content_buffer: List[str] = []
self.current_parsed_json_result: dict[str, Any] = {}
self.content_buffer: list[str] = []
self.tool_call_happened: bool = False
self.finish_reason_stop: bool = False
self.tool_call_name: Optional[str] = None
self.tool_call_name: str | None = None
self.tool_call_args_str: str = ""
self.tool_call_id: Optional[str] = None
self.tool_call_id: str | None = None
async def process(self, stream: AsyncStream[ChatCompletionChunk]) -> AsyncGenerator[str, None]:
"""
@@ -35,6 +36,8 @@ class OpenAIChatCompletionsStreamingInterface:
"""
async with stream:
async for chunk in stream:
# TODO (cliandy): reconsider in stream cancellations
# await cancellation_token.check_and_raise_if_cancelled()
if chunk.choices:
choice = chunk.choices[0]
delta = choice.delta
@@ -103,7 +106,7 @@ class OpenAIChatCompletionsStreamingInterface:
)
)
def _handle_finish_reason(self, finish_reason: Optional[str]) -> bool:
def _handle_finish_reason(self, finish_reason: str | None) -> bool:
"""Handles the finish reason and determines if streaming should stop."""
if finish_reason == "tool_calls":
self.tool_call_happened = True

View File

@@ -1,5 +1,7 @@
import asyncio
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from typing import AsyncGenerator, List, Optional
from typing import Optional
from openai import AsyncStream
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
@@ -55,12 +57,12 @@ class OpenAIStreamingInterface:
self.input_tokens = 0
self.output_tokens = 0
self.content_buffer: List[str] = []
self.tool_call_name: Optional[str] = None
self.tool_call_id: Optional[str] = None
self.content_buffer: list[str] = []
self.tool_call_name: str | None = None
self.tool_call_id: str | None = None
self.reasoning_messages = []
def get_reasoning_content(self) -> List[TextContent]:
def get_reasoning_content(self) -> list[TextContent | OmittedReasoningContent]:
content = "".join(self.reasoning_messages).strip()
# Right now we assume that all models omit reasoning content for OAI,
@@ -87,8 +89,8 @@ class OpenAIStreamingInterface:
self,
stream: AsyncStream[ChatCompletionChunk],
ttft_span: Optional["Span"] = None,
provider_request_start_timestamp_ns: Optional[int] = None,
) -> AsyncGenerator[LettaMessage, None]:
provider_request_start_timestamp_ns: int | None = None,
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
"""
Iterates over the OpenAI stream, yielding SSE events.
It also collects tokens and detects if a tool call is triggered.
@@ -99,6 +101,8 @@ class OpenAIStreamingInterface:
prev_message_type = None
message_index = 0
async for chunk in stream:
# TODO (cliandy): reconsider in stream cancellations
# await cancellation_token.check_and_raise_if_cancelled()
if first_chunk and ttft_span is not None and provider_request_start_timestamp_ns is not None:
now = get_utc_timestamp_ns()
ttft_ns = now - provider_request_start_timestamp_ns
@@ -224,8 +228,7 @@ class OpenAIStreamingInterface:
# If there was nothing in the name buffer, we can proceed to
# output the arguments chunk as a ToolCallMessage
else:
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
# use_assistant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
if self.use_assistant_message and (
self.last_flushed_function_name is not None
and self.last_flushed_function_name == self.assistant_message_tool_name
@@ -349,10 +352,13 @@ class OpenAIStreamingInterface:
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
self.function_id_buffer = None
except asyncio.CancelledError as e:
logger.info("Cancelled stream %s", e)
yield LettaStopReason(stop_reason=StopReasonType.cancelled)
raise
except Exception as e:
logger.error("Error processing stream: %s", e)
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
yield stop_reason
yield LettaStopReason(stop_reason=StopReasonType.error)
raise
finally:
logger.info("OpenAIStreamingInterface: Stream processing complete.")

View File

@@ -54,6 +54,10 @@ class JobStatus(str, Enum):
cancelled = "cancelled"
expired = "expired"
@property
def is_terminal(self):
return self in (JobStatus.completed, JobStatus.failed, JobStatus.cancelled, JobStatus.expired)
class AgentStepStatus(str, Enum):
"""

View File

@@ -3,6 +3,8 @@ from typing import Literal
from pydantic import BaseModel, Field
from letta.schemas.enums import JobStatus
class StopReasonType(str, Enum):
end_turn = "end_turn"
@@ -11,6 +13,22 @@ class StopReasonType(str, Enum):
max_steps = "max_steps"
no_tool_call = "no_tool_call"
tool_rule = "tool_rule"
cancelled = "cancelled"
@property
def run_status(self) -> JobStatus:
if self in (
StopReasonType.end_turn,
StopReasonType.max_steps,
StopReasonType.tool_rule,
):
return JobStatus.completed
elif self in (StopReasonType.error, StopReasonType.invalid_tool_call, StopReasonType.no_tool_call):
return JobStatus.failed
elif self == StopReasonType.cancelled:
return JobStatus.cancelled
else:
raise ValueError("Unknown StopReasonType")
class LettaStopReason(BaseModel):

View File

@@ -2,7 +2,7 @@ import asyncio
import json
import traceback
from datetime import datetime, timezone
from typing import Annotated, Any, List, Optional
from typing import Annotated, Any
from fastapi import APIRouter, Body, Depends, File, Header, HTTPException, Query, Request, UploadFile, status
from fastapi.responses import JSONResponse
@@ -13,7 +13,8 @@ from sqlalchemy.exc import IntegrityError, OperationalError
from starlette.responses import Response, StreamingResponse
from letta.agents.letta_agent import LettaAgent
from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, LETTA_MODEL_ENDPOINT
from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, LETTA_MODEL_ENDPOINT, REDIS_RUN_ID_PREFIX
from letta.data_sources.redis_client import get_redis_client
from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
from letta.log import get_logger
@@ -49,26 +50,26 @@ router = APIRouter(prefix="/agents", tags=["agents"])
logger = get_logger(__name__)
@router.get("/", response_model=List[AgentState], operation_id="list_agents")
@router.get("/", response_model=list[AgentState], operation_id="list_agents")
async def list_agents(
name: Optional[str] = Query(None, description="Name of the agent"),
tags: Optional[List[str]] = Query(None, description="List of tags to filter agents by"),
name: str | None = Query(None, description="Name of the agent"),
tags: list[str] | None = Query(None, description="List of tags to filter agents by"),
match_all_tags: bool = Query(
False,
description="If True, only returns agents that match ALL given tags. Otherwise, return agents that have ANY of the passed-in tags.",
),
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
before: Optional[str] = Query(None, description="Cursor for pagination"),
after: Optional[str] = Query(None, description="Cursor for pagination"),
limit: Optional[int] = Query(50, description="Limit for pagination"),
query_text: Optional[str] = Query(None, description="Search agents by name"),
project_id: Optional[str] = Query(None, description="Search agents by project ID"),
template_id: Optional[str] = Query(None, description="Search agents by template ID"),
base_template_id: Optional[str] = Query(None, description="Search agents by base template ID"),
identity_id: Optional[str] = Query(None, description="Search agents by identity ID"),
identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"),
include_relationships: Optional[List[str]] = Query(
actor_id: str | None = Header(None, alias="user_id"),
before: str | None = Query(None, description="Cursor for pagination"),
after: str | None = Query(None, description="Cursor for pagination"),
limit: int | None = Query(50, description="Limit for pagination"),
query_text: str | None = Query(None, description="Search agents by name"),
project_id: str | None = Query(None, description="Search agents by project ID"),
template_id: str | None = Query(None, description="Search agents by template ID"),
base_template_id: str | None = Query(None, description="Search agents by base template ID"),
identity_id: str | None = Query(None, description="Search agents by identity ID"),
identifier_keys: list[str] | None = Query(None, description="Search agents by identifier keys"),
include_relationships: list[str] | None = Query(
None,
description=(
"Specify which relational fields (e.g., 'tools', 'sources', 'memory') to include in the response. "
@@ -80,7 +81,7 @@ async def list_agents(
False,
description="Whether to sort agents oldest to newest (True) or newest to oldest (False, default)",
),
sort_by: Optional[str] = Query(
sort_by: str | None = Query(
"created_at",
description="Field to sort by. Options: 'created_at' (default), 'last_run_completion'",
),
@@ -119,7 +120,7 @@ async def list_agents(
@router.get("/count", response_model=int, operation_id="count_agents")
async def count_agents(
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
actor_id: str | None = Header(None, alias="user_id"),
):
"""
Get the count of all agents associated with a given user.
@@ -139,10 +140,10 @@ class IndentedORJSONResponse(Response):
def export_agent_serialized(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
actor_id: str | None = Header(None, alias="user_id"),
# do not remove, used to autogeneration of spec
# TODO: Think of a better way to export AgentSchema
spec: Optional[AgentSchema] = None,
spec: AgentSchema | None = None,
) -> JSONResponse:
"""
Export the serialized JSON representation of an agent, formatted with indentation.
@@ -160,13 +161,13 @@ def export_agent_serialized(
def import_agent_serialized(
file: UploadFile = File(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
actor_id: str | None = Header(None, alias="user_id"),
append_copy_suffix: bool = Query(True, description='If set to True, appends "_copy" to the end of the agent name.'),
override_existing_tools: bool = Query(
True,
description="If set to True, existing tools can get their source code overwritten by the uploaded tool definitions. Note that Letta core tools can never be updated externally.",
),
project_id: Optional[str] = Query(None, description="The project ID to associate the uploaded agent with."),
project_id: str | None = Query(None, description="The project ID to associate the uploaded agent with."),
strip_messages: bool = Query(
False,
description="If set to True, strips all messages from the agent before importing.",
@@ -198,24 +199,24 @@ def import_agent_serialized(
raise HTTPException(status_code=400, detail="Corrupted agent file format.")
except ValidationError as e:
raise HTTPException(status_code=422, detail=f"Invalid agent schema: {str(e)}")
raise HTTPException(status_code=422, detail=f"Invalid agent schema: {e!s}")
except IntegrityError as e:
raise HTTPException(status_code=409, detail=f"Database integrity error: {str(e)}")
raise HTTPException(status_code=409, detail=f"Database integrity error: {e!s}")
except OperationalError as e:
raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {str(e)}")
raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {e!s}")
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"An unexpected error occurred while uploading the agent: {str(e)}")
raise HTTPException(status_code=500, detail=f"An unexpected error occurred while uploading the agent: {e!s}")
@router.get("/{agent_id}/context", response_model=ContextWindowOverview, operation_id="retrieve_agent_context_window")
async def retrieve_agent_context_window(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve the context window of a specific agent.
@@ -234,15 +235,15 @@ class CreateAgentRequest(CreateAgent):
"""
# Override the user_id field to exclude it from the request body validation
actor_id: Optional[str] = Field(None, exclude=True)
actor_id: str | None = Field(None, exclude=True)
@router.post("/", response_model=AgentState, operation_id="create_agent")
async def create_agent(
agent: CreateAgentRequest = Body(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
x_project: Optional[str] = Header(
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
x_project: str | None = Header(
None, alias="X-Project", description="The project slug to associate with the agent (cloud only)."
), # Only handled by next js middleware
):
@@ -262,18 +263,18 @@ async def modify_agent(
agent_id: str,
update_agent: UpdateAgent = Body(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Update an existing agent"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.update_agent_async(agent_id=agent_id, request=update_agent, actor=actor)
@router.get("/{agent_id}/tools", response_model=List[Tool], operation_id="list_agent_tools")
@router.get("/{agent_id}/tools", response_model=list[Tool], operation_id="list_agent_tools")
def list_agent_tools(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Get tools from an existing agent"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
@@ -285,7 +286,7 @@ async def attach_tool(
agent_id: str,
tool_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
actor_id: str | None = Header(None, alias="user_id"),
):
"""
Attach a tool to an agent.
@@ -299,7 +300,7 @@ async def detach_tool(
agent_id: str,
tool_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
actor_id: str | None = Header(None, alias="user_id"),
):
"""
Detach a tool from an agent.
@@ -313,7 +314,7 @@ async def attach_source(
agent_id: str,
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
actor_id: str | None = Header(None, alias="user_id"),
):
"""
Attach a source to an agent.
@@ -341,7 +342,7 @@ async def detach_source(
agent_id: str,
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
actor_id: str | None = Header(None, alias="user_id"),
):
"""
Detach a source from an agent.
@@ -386,7 +387,7 @@ async def close_all_open_files(
@router.get("/{agent_id}", response_model=AgentState, operation_id="retrieve_agent")
async def retrieve_agent(
agent_id: str,
include_relationships: Optional[List[str]] = Query(
include_relationships: list[str] | None = Query(
None,
description=(
"Specify which relational fields (e.g., 'tools', 'sources', 'memory') to include in the response. "
@@ -395,7 +396,7 @@ async def retrieve_agent(
),
),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get the state of the agent.
@@ -412,7 +413,7 @@ async def retrieve_agent(
async def delete_agent(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Delete an agent.
@@ -425,11 +426,11 @@ async def delete_agent(
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found for user_id={actor.id}.")
@router.get("/{agent_id}/sources", response_model=List[Source], operation_id="list_agent_sources")
@router.get("/{agent_id}/sources", response_model=list[Source], operation_id="list_agent_sources")
async def list_agent_sources(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get the sources associated with an agent.
@@ -443,7 +444,7 @@ async def list_agent_sources(
async def retrieve_agent_memory(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve the memory state of a specific agent.
@@ -459,7 +460,7 @@ async def retrieve_block(
agent_id: str,
block_label: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve a core memory block from an agent.
@@ -472,11 +473,11 @@ async def retrieve_block(
raise HTTPException(status_code=404, detail=str(e))
@router.get("/{agent_id}/core-memory/blocks", response_model=List[Block], operation_id="list_core_memory_blocks")
@router.get("/{agent_id}/core-memory/blocks", response_model=list[Block], operation_id="list_core_memory_blocks")
async def list_blocks(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve the core memory blocks of a specific agent.
@@ -495,7 +496,7 @@ async def modify_block(
block_label: str,
block_update: BlockUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Updates a core memory block of an agent.
@@ -517,7 +518,7 @@ async def attach_block(
agent_id: str,
block_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
actor_id: str | None = Header(None, alias="user_id"),
):
"""
Attach a core memoryblock to an agent.
@@ -531,7 +532,7 @@ async def detach_block(
agent_id: str,
block_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
actor_id: str | None = Header(None, alias="user_id"),
):
"""
Detach a core memory block from an agent.
@@ -540,18 +541,18 @@ async def detach_block(
return await server.agent_manager.detach_block_async(agent_id=agent_id, block_id=block_id, actor=actor)
@router.get("/{agent_id}/archival-memory", response_model=List[Passage], operation_id="list_passages")
@router.get("/{agent_id}/archival-memory", response_model=list[Passage], operation_id="list_passages")
async def list_passages(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
after: Optional[str] = Query(None, description="Unique ID of the memory to start the query range at."),
before: Optional[str] = Query(None, description="Unique ID of the memory to end the query range at."),
limit: Optional[int] = Query(None, description="How many results to include in the response."),
search: Optional[str] = Query(None, description="Search passages by text"),
ascending: Optional[bool] = Query(
after: str | None = Query(None, description="Unique ID of the memory to start the query range at."),
before: str | None = Query(None, description="Unique ID of the memory to end the query range at."),
limit: int | None = Query(None, description="How many results to include in the response."),
search: str | None = Query(None, description="Search passages by text"),
ascending: bool | None = Query(
True, description="Whether to sort passages oldest to newest (True, default) or newest to oldest (False)"
),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve the memories in an agent's archival memory store (paginated query).
@@ -569,12 +570,12 @@ async def list_passages(
)
@router.post("/{agent_id}/archival-memory", response_model=List[Passage], operation_id="create_passage")
@router.post("/{agent_id}/archival-memory", response_model=list[Passage], operation_id="create_passage")
async def create_passage(
agent_id: str,
request: CreateArchivalMemory = Body(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
actor_id: str | None = Header(None, alias="user_id"),
):
"""
Insert a memory into an agent's archival memory store.
@@ -584,13 +585,13 @@ async def create_passage(
return await server.insert_archival_memory_async(agent_id=agent_id, memory_contents=request.text, actor=actor)
@router.patch("/{agent_id}/archival-memory/{memory_id}", response_model=List[Passage], operation_id="modify_passage")
@router.patch("/{agent_id}/archival-memory/{memory_id}", response_model=list[Passage], operation_id="modify_passage")
def modify_passage(
agent_id: str,
memory_id: str,
passage: PassageUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Modify a memory in the agent's archival memory store.
@@ -607,7 +608,7 @@ async def delete_passage(
memory_id: str,
# memory_id: str = Query(..., description="Unique ID of the memory to be deleted."),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Delete a memory from an agent's archival memory store.
@@ -619,7 +620,7 @@ async def delete_passage(
AgentMessagesResponse = Annotated[
List[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}})
list[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}})
]
@@ -627,14 +628,14 @@ AgentMessagesResponse = Annotated[
async def list_messages(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."),
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
after: str | None = Query(None, description="Message after which to retrieve the returned messages."),
before: str | None = Query(None, description="Message before which to retrieve the returned messages."),
limit: int = Query(10, description="Maximum number of messages to retrieve."),
group_id: Optional[str] = Query(None, description="Group ID to filter messages by."),
group_id: str | None = Query(None, description="Group ID to filter messages by."),
use_assistant_message: bool = Query(True, description="Whether to use assistant messages"),
assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."),
assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve message history for an agent.
@@ -662,7 +663,7 @@ def modify_message(
message_id: str,
request: LettaMessageUpdateUnion = Body(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Update the details of a message associated with an agent.
@@ -672,6 +673,7 @@ def modify_message(
return server.message_manager.update_message_by_letta_message(message_id=message_id, letta_message_update=request, actor=actor)
# noinspection PyInconsistentReturns
@router.post(
"/{agent_id}/messages",
response_model=LettaResponse,
@@ -682,7 +684,7 @@ async def send_message(
request_obj: Request, # FastAPI Request
server: SyncServer = Depends(get_letta_server),
request: LettaRequest = Body(...),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Process a user message and return the agent's response.
@@ -697,55 +699,95 @@ async def send_message(
agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"]
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"]
if agent_eligible and model_compatible:
if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent:
agent_loop = SleeptimeMultiAgentV2(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
group_manager=server.group_manager,
job_manager=server.job_manager,
actor=actor,
group=agent.multi_agent_group,
# Create a new run for execution tracking
job_status = JobStatus.created
run = await server.job_manager.create_job_async(
pydantic_job=Run(
user_id=actor.id,
status=job_status,
metadata={
"job_type": "send_message",
"agent_id": agent_id,
},
request_config=LettaRequestConfig(
use_assistant_message=request.use_assistant_message,
assistant_message_tool_name=request.assistant_message_tool_name,
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
include_return_message_types=request.include_return_message_types,
),
),
actor=actor,
)
job_update_metadata = None
# TODO (cliandy): clean this up
redis_client = await get_redis_client()
await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id)
try:
if agent_eligible and model_compatible:
if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent:
agent_loop = SleeptimeMultiAgentV2(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
group_manager=server.group_manager,
job_manager=server.job_manager,
actor=actor,
group=agent.multi_agent_group,
current_run_id=run.id,
)
else:
agent_loop = LettaAgent(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
job_manager=server.job_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
current_run_id=run.id,
)
result = await agent_loop.step(
request.messages,
max_steps=request.max_steps,
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=request.include_return_message_types,
)
else:
agent_loop = LettaAgent(
result = await server.send_message_to_agent(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
job_manager=server.job_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
input_messages=request.messages,
stream_steps=False,
stream_tokens=False,
# Support for AssistantMessage
use_assistant_message=request.use_assistant_message,
assistant_message_tool_name=request.assistant_message_tool_name,
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
include_return_message_types=request.include_return_message_types,
)
result = await agent_loop.step(
request.messages,
max_steps=request.max_steps,
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=request.include_return_message_types,
)
else:
result = await server.send_message_to_agent(
agent_id=agent_id,
job_status = result.stop_reason.stop_reason.run_status
return result
except Exception as e:
job_update_metadata = {"error": str(e)}
job_status = JobStatus.failed
raise
finally:
await server.job_manager.safe_update_job_status_async(
job_id=run.id,
new_status=job_status,
actor=actor,
input_messages=request.messages,
stream_steps=False,
stream_tokens=False,
# Support for AssistantMessage
use_assistant_message=request.use_assistant_message,
assistant_message_tool_name=request.assistant_message_tool_name,
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
include_return_message_types=request.include_return_message_types,
metadata=job_update_metadata,
)
return result
# noinspection PyInconsistentReturns
@router.post(
"/{agent_id}/messages/stream",
response_model=None,
@@ -764,7 +806,7 @@ async def send_message_streaming(
request_obj: Request, # FastAPI Request
server: SyncServer = Depends(get_letta_server),
request: LettaStreamingRequest = Body(...),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
) -> StreamingResponse | LettaResponse:
"""
Process a user message and return the agent's response.
@@ -780,88 +822,160 @@ async def send_message_streaming(
agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"]
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"]
model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"]
not_letta_endpoint = LETTA_MODEL_ENDPOINT != agent.llm_config.model_endpoint
not_letta_endpoint = agent.llm_config.model_endpoint != LETTA_MODEL_ENDPOINT
if agent_eligible and model_compatible:
if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent:
agent_loop = SleeptimeMultiAgentV2(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
group_manager=server.group_manager,
job_manager=server.job_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
group=agent.multi_agent_group,
)
else:
agent_loop = LettaAgent(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
job_manager=server.job_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
)
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
# Create a new job for execution tracking
job_status = JobStatus.created
run = await server.job_manager.create_job_async(
pydantic_job=Run(
user_id=actor.id,
status=job_status,
metadata={
"job_type": "send_message_streaming",
"agent_id": agent_id,
},
request_config=LettaRequestConfig(
use_assistant_message=request.use_assistant_message,
assistant_message_tool_name=request.assistant_message_tool_name,
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
include_return_message_types=request.include_return_message_types,
),
),
actor=actor,
)
if request.stream_tokens and model_compatible_token_streaming and not_letta_endpoint:
result = StreamingResponseWithStatusCode(
agent_loop.step_stream(
input_messages=request.messages,
max_steps=request.max_steps,
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=request.include_return_message_types,
),
media_type="text/event-stream",
)
job_update_metadata = None
# TODO (cliandy): clean this up
redis_client = await get_redis_client()
await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id)
try:
if agent_eligible and model_compatible:
if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent:
agent_loop = SleeptimeMultiAgentV2(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
group_manager=server.group_manager,
job_manager=server.job_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
group=agent.multi_agent_group,
current_run_id=run.id,
)
else:
agent_loop = LettaAgent(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
job_manager=server.job_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
current_run_id=run.id,
)
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
if request.stream_tokens and model_compatible_token_streaming and not_letta_endpoint:
result = StreamingResponseWithStatusCode(
agent_loop.step_stream(
input_messages=request.messages,
max_steps=request.max_steps,
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=request.include_return_message_types,
),
media_type="text/event-stream",
)
else:
result = StreamingResponseWithStatusCode(
agent_loop.step_stream_no_tokens(
request.messages,
max_steps=request.max_steps,
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=request.include_return_message_types,
),
media_type="text/event-stream",
)
else:
result = StreamingResponseWithStatusCode(
agent_loop.step_stream_no_tokens(
request.messages,
max_steps=request.max_steps,
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=request.include_return_message_types,
),
media_type="text/event-stream",
result = await server.send_message_to_agent(
agent_id=agent_id,
actor=actor,
input_messages=request.messages,
stream_steps=True,
stream_tokens=request.stream_tokens,
# Support for AssistantMessage
use_assistant_message=request.use_assistant_message,
assistant_message_tool_name=request.assistant_message_tool_name,
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=request.include_return_message_types,
)
else:
result = await server.send_message_to_agent(
agent_id=agent_id,
job_status = JobStatus.running
return result
except Exception as e:
job_update_metadata = {"error": str(e)}
job_status = JobStatus.failed
raise
finally:
await server.job_manager.safe_update_job_status_async(
job_id=run.id,
new_status=job_status,
actor=actor,
input_messages=request.messages,
stream_steps=True,
stream_tokens=request.stream_tokens,
# Support for AssistantMessage
use_assistant_message=request.use_assistant_message,
assistant_message_tool_name=request.assistant_message_tool_name,
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=request.include_return_message_types,
metadata=job_update_metadata,
)
return result
@router.post("/{agent_id}/messages/cancel", operation_id="cancel_agent_run")
async def cancel_agent_run(
agent_id: str,
run_ids: list[str] | None = None,
server: SyncServer = Depends(get_letta_server),
actor_id: str | None = Header(None, alias="user_id"),
) -> dict:
"""
Cancel runs associated with an agent. If run_ids are passed in, cancel those in particular.
Note to cancel active runs associated with an agent, redis is required.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
if not run_ids:
redis_client = await get_redis_client()
run_id = await redis_client.get(f"{REDIS_RUN_ID_PREFIX}:{agent_id}")
if run_id is None:
logger.warning("Cannot find run associated with agent to cancel.")
return {}
run_ids = [run_id]
results = {}
for run_id in run_ids:
success = await server.job_manager.safe_update_job_status_async(
job_id=run_id,
new_status=JobStatus.cancelled,
actor=actor,
)
results[run_id] = "cancelled" if success else "failed"
return results
async def process_message_background(
job_id: str,
async def _process_message_background(
run_id: str,
server: SyncServer,
actor: User,
agent_id: str,
messages: List[MessageCreate],
messages: list[MessageCreate],
use_assistant_message: bool,
assistant_message_tool_name: str,
assistant_message_tool_kwarg: str,
max_steps: int = DEFAULT_MAX_STEPS,
include_return_message_types: Optional[List[MessageType]] = None,
include_return_message_types: list[MessageType] | None = None,
) -> None:
"""Background task to process the message and update job status."""
request_start_timestamp_ns = get_utc_timestamp_ns()
@@ -905,7 +1019,7 @@ async def process_message_background(
result = await agent_loop.step(
messages,
max_steps=max_steps,
run_id=job_id,
run_id=run_id,
use_assistant_message=use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=include_return_message_types,
@@ -917,7 +1031,7 @@ async def process_message_background(
input_messages=messages,
stream_steps=False,
stream_tokens=False,
metadata={"job_id": job_id},
metadata={"job_id": run_id},
# Support for AssistantMessage
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
@@ -930,7 +1044,7 @@ async def process_message_background(
completed_at=datetime.now(timezone.utc),
metadata={"result": result.model_dump(mode="json")},
)
await server.job_manager.update_job_by_id_async(job_id=job_id, job_update=job_update, actor=actor)
await server.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=actor)
except Exception as e:
# Update job status to failed
@@ -951,11 +1065,14 @@ async def send_message_async(
agent_id: str,
server: SyncServer = Depends(get_letta_server),
request: LettaAsyncRequest = Body(...),
actor_id: Optional[str] = Header(None, alias="user_id"),
actor_id: str | None = Header(None, alias="user_id"),
):
"""
Asynchronously process a user message and return a run object.
The actual processing happens in the background, and the status can be checked using the run ID.
This is "asynchronous" in the sense that it's a background job and explicitly must be fetched by the run ID.
This is more like `send_message_job`
"""
MetricRegistry().user_message_counter.add(1, get_ctx_attributes())
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
@@ -980,8 +1097,8 @@ async def send_message_async(
# Create asyncio task for background processing
asyncio.create_task(
process_message_background(
job_id=run.id,
_process_message_background(
run_id=run.id,
server=server,
actor=actor,
agent_id=agent_id,
@@ -1002,7 +1119,7 @@ async def reset_messages(
agent_id: str,
add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Resets the messages for an agent"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
@@ -1011,12 +1128,12 @@ async def reset_messages(
)
@router.get("/{agent_id}/groups", response_model=List[Group], operation_id="list_agent_groups")
@router.get("/{agent_id}/groups", response_model=list[Group], operation_id="list_agent_groups")
async def list_agent_groups(
agent_id: str,
manager_type: Optional[str] = Query(None, description="Manager type to filter groups by"),
manager_type: str | None = Query(None, description="Manager type to filter groups by"),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Lists the groups for an agent"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
@@ -1030,7 +1147,7 @@ async def summarize_agent_conversation(
request_obj: Request, # FastAPI Request
max_message_length: int = Query(..., description="Maximum number of messages to retain after summarization."),
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
actor_id: str | None = Header(None, alias="user_id"),
):
"""
Summarize an agent's conversation history to a target message length.

View File

@@ -15,10 +15,15 @@ router = APIRouter(prefix="/jobs", tags=["jobs"])
async def list_jobs(
server: "SyncServer" = Depends(get_letta_server),
source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."),
before: Optional[str] = Query(None, description="Cursor for pagination"),
after: Optional[str] = Query(None, description="Cursor for pagination"),
limit: Optional[int] = Query(50, description="Limit for pagination"),
ascending: bool = Query(True, description="Whether to sort jobs oldest to newest (True, default) or newest to oldest (False)"),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all jobs.
TODO (cliandy): implementation for pagination
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
@@ -26,6 +31,10 @@ async def list_jobs(
return await server.job_manager.list_jobs_async(
actor=actor,
source_id=source_id,
before=before,
after=after,
limit=limit,
ascending=ascending,
)
@@ -34,12 +43,24 @@ async def list_active_jobs(
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."),
before: Optional[str] = Query(None, description="Cursor for pagination"),
after: Optional[str] = Query(None, description="Cursor for pagination"),
limit: Optional[int] = Query(50, description="Limit for pagination"),
ascending: bool = Query(True, description="Whether to sort jobs oldest to newest (True, default) or newest to oldest (False)"),
):
"""
List all active jobs.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.job_manager.list_jobs_async(actor=actor, statuses=[JobStatus.created, JobStatus.running], source_id=source_id)
return await server.job_manager.list_jobs_async(
actor=actor,
statuses=[JobStatus.created, JobStatus.running],
source_id=source_id,
before=before,
after=after,
limit=limit,
ascending=ascending,
)
@router.get("/{job_id}", response_model=Job, operation_id="retrieve_job")
@@ -59,6 +80,33 @@ async def retrieve_job(
raise HTTPException(status_code=404, detail="Job not found")
@router.patch("/{job_id}/cancel", response_model=Job, operation_id="cancel_job")
async def cancel_job(
job_id: str,
actor_id: Optional[str] = Header(None, alias="user_id"),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Cancel a job by its job_id.
This endpoint marks a job as cancelled, which will cause any associated
agent execution to terminate as soon as possible.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
# First check if the job exists and is in a cancellable state
existing_job = await server.job_manager.get_job_by_id_async(job_id=job_id, actor=actor)
if existing_job.status.is_terminal:
return False
return await server.job_manager.safe_update_job_status_async(job_id=job_id, new_status=JobStatus.cancelled, actor=actor)
except NoResultFound:
raise HTTPException(status_code=404, detail="Job not found")
@router.delete("/{job_id}", response_model=Job, operation_id="delete_job")
async def delete_job(
job_id: str,

View File

@@ -2,6 +2,7 @@
# stremaing HTTP trailers, as we cannot set codes after the initial response.
# Taken from: https://github.com/fastapi/fastapi/discussions/10138#discussioncomment-10377361
import asyncio
import json
from collections.abc import AsyncIterator
@@ -9,10 +10,73 @@ from fastapi.responses import StreamingResponse
from starlette.types import Send
from letta.log import get_logger
from letta.schemas.enums import JobStatus
from letta.schemas.user import User
from letta.services.job_manager import JobManager
logger = get_logger(__name__)
# TODO (cliandy) wrap this and handle types
async def cancellation_aware_stream_wrapper(
stream_generator: AsyncIterator[str | bytes],
job_manager: JobManager,
job_id: str,
actor: User,
cancellation_check_interval: float = 0.5,
) -> AsyncIterator[str | bytes]:
"""
Wraps a stream generator to provide real-time job cancellation checking.
This wrapper periodically checks for job cancellation while streaming and
can interrupt the stream at any point, not just at step boundaries.
Args:
stream_generator: The original stream generator to wrap
job_manager: Job manager instance for checking job status
job_id: ID of the job to monitor for cancellation
actor: User/actor making the request
cancellation_check_interval: How often to check for cancellation (seconds)
Yields:
Stream chunks from the original generator until cancelled
Raises:
asyncio.CancelledError: If the job is cancelled during streaming
"""
last_cancellation_check = asyncio.get_event_loop().time()
try:
async for chunk in stream_generator:
# Check for cancellation periodically (not on every chunk for performance)
current_time = asyncio.get_event_loop().time()
if current_time - last_cancellation_check >= cancellation_check_interval:
try:
job = await job_manager.get_job_by_id_async(job_id=job_id, actor=actor)
if job.status == JobStatus.cancelled:
logger.info(f"Stream cancelled for job {job_id}, interrupting stream")
# Send cancellation event to client
cancellation_event = {"message_type": "stop_reason", "stop_reason": "cancelled"}
yield f"data: {json.dumps(cancellation_event)}\n\n"
# Raise CancelledError to interrupt the stream
raise asyncio.CancelledError(f"Job {job_id} was cancelled")
except Exception as e:
# Log warning but don't fail the stream if cancellation check fails
logger.warning(f"Failed to check job cancellation for job {job_id}: {e}")
last_cancellation_check = current_time
yield chunk
except asyncio.CancelledError:
# Re-raise CancelledError to ensure proper cleanup
logger.info(f"Stream for job {job_id} was cancelled and cleaned up")
raise
except Exception as e:
logger.error(f"Error in cancellation-aware stream wrapper for job {job_id}: {e}")
raise
class StreamingResponseWithStatusCode(StreamingResponse):
"""
Variation of StreamingResponse that can dynamically decide the HTTP status code,
@@ -81,6 +145,30 @@ class StreamingResponseWithStatusCode(StreamingResponse):
}
)
# This should be handled properly upstream?
except asyncio.CancelledError:
logger.info("Stream was cancelled by client or job cancellation")
# Handle cancellation gracefully
more_body = False
cancellation_resp = {"error": {"message": "Stream cancelled"}}
cancellation_event = f"event: cancelled\ndata: {json.dumps(cancellation_resp)}\n\n".encode(self.charset)
if not self.response_started:
await send(
{
"type": "http.response.start",
"status": 200, # Use 200 for graceful cancellation
"headers": self.raw_headers,
}
)
await send(
{
"type": "http.response.body",
"body": cancellation_event,
"more_body": more_body,
}
)
return
except Exception:
logger.exception("unhandled_streaming_error")
more_body = False

View File

@@ -1,4 +1,4 @@
from functools import reduce
from functools import partial, reduce
from operator import add
from typing import List, Literal, Optional, Union
@@ -125,6 +125,46 @@ class JobManager:
return job.to_pydantic()
@enforce_types
@trace_method
async def safe_update_job_status_async(
self, job_id: str, new_status: JobStatus, actor: PydanticUser, metadata: Optional[dict] = None
) -> bool:
"""
Safely update job status with state transition guards.
Created -> Pending -> Running --> <Terminal>
Returns:
True if update was successful, False if update was skipped due to invalid transition
"""
try:
# Get current job state
current_job = await self.get_job_by_id_async(job_id=job_id, actor=actor)
current_status = current_job.status
if not any(
(
new_status.is_terminal and not current_status.is_terminal,
current_status == JobStatus.created and new_status != JobStatus.created,
current_status == JobStatus.pending and new_status == JobStatus.running,
)
):
logger.warning(f"Invalid job status transition from {current_job.status} to {new_status} for job {job_id}")
return False
job_update_builder = partial(JobUpdate, status=new_status)
if metadata:
job_update_builder = partial(job_update_builder, metadata=metadata)
if new_status.is_terminal:
job_update_builder = partial(job_update_builder, completed_at=get_utc_time())
await self.update_job_by_id_async(job_id=job_id, job_update=job_update_builder(), actor=actor)
return True
except Exception as e:
logger.error(f"Failed to safely update job status for job {job_id}: {e}")
return False
@enforce_types
@trace_method
def get_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob:
@@ -656,7 +696,7 @@ class JobManager:
job.callback_status_code = resp.status_code
except Exception as e:
error_message = f"Failed to dispatch callback for job {job.id} to {job.callback_url}: {str(e)}"
error_message = f"Failed to dispatch callback for job {job.id} to {job.callback_url}: {e!s}"
logger.error(error_message)
# Record the failed attempt
job.callback_sent_at = get_utc_time().replace(tzinfo=None)
@@ -686,7 +726,7 @@ class JobManager:
job.callback_sent_at = get_utc_time().replace(tzinfo=None)
job.callback_status_code = resp.status_code
except Exception as e:
error_message = f"Failed to dispatch callback for job {job.id} to {job.callback_url}: {str(e)}"
error_message = f"Failed to dispatch callback for job {job.id} to {job.callback_url}: {e!s}"
logger.error(error_message)
# Record the failed attempt
job.callback_sent_at = get_utc_time().replace(tzinfo=None)

View File

@@ -12,11 +12,12 @@ import re
import subprocess
import sys
import uuid
from collections.abc import Coroutine
from contextlib import contextmanager
from datetime import datetime, timezone
from functools import wraps
from logging import Logger
from typing import Any, Coroutine, List, Union, _GenericAlias, get_args, get_origin, get_type_hints
from typing import Any, Coroutine, Union, _GenericAlias, get_args, get_origin, get_type_hints
from urllib.parse import urljoin, urlparse
import demjson3 as demjson
@@ -519,7 +520,7 @@ def enforce_types(func):
arg_names = inspect.getfullargspec(func).args
# Pair each argument with its corresponding type hint
args_with_hints = dict(zip(arg_names[1:], args[1:])) # Skipping 'self'
args_with_hints = dict(zip(arg_names[1:], args[1:], strict=False)) # Skipping 'self'
# Function to check if a value matches a given type hint
def matches_type(value, hint):
@@ -557,7 +558,7 @@ def enforce_types(func):
return wrapper
def annotate_message_json_list_with_tool_calls(messages: List[dict], allow_tool_roles: bool = False):
def annotate_message_json_list_with_tool_calls(messages: list[dict], allow_tool_roles: bool = False):
"""Add in missing tool_call_id fields to a list of messages using function call style
Walk through the list forwards:
@@ -946,7 +947,7 @@ def get_human_text(name: str, enforce_limit=True):
for file_path in list_human_files():
file = os.path.basename(file_path)
if f"{name}.txt" == file or name == file:
human_text = open(file_path, "r", encoding="utf-8").read().strip()
human_text = open(file_path, encoding="utf-8").read().strip()
if enforce_limit and len(human_text) > CORE_MEMORY_HUMAN_CHAR_LIMIT:
raise ValueError(f"Contents of {name}.txt is over the character limit ({len(human_text)} > {CORE_MEMORY_HUMAN_CHAR_LIMIT})")
return human_text
@@ -958,7 +959,7 @@ def get_persona_text(name: str, enforce_limit=True):
for file_path in list_persona_files():
file = os.path.basename(file_path)
if f"{name}.txt" == file or name == file:
persona_text = open(file_path, "r", encoding="utf-8").read().strip()
persona_text = open(file_path, encoding="utf-8").read().strip()
if enforce_limit and len(persona_text) > CORE_MEMORY_PERSONA_CHAR_LIMIT:
raise ValueError(
f"Contents of {name}.txt is over the character limit ({len(persona_text)} > {CORE_MEMORY_PERSONA_CHAR_LIMIT})"
@@ -1109,3 +1110,75 @@ def safe_create_task(coro, logger: Logger, label: str = "background task"):
logger.exception(f"{label} failed with {type(e).__name__}: {e}")
return asyncio.create_task(wrapper())
class CancellationSignal:
"""
A signal that can be checked for cancellation during streaming operations.
This provides a lightweight way to check if an operation should be cancelled
without having to pass job managers and other dependencies through every method.
"""
def __init__(self, job_manager=None, job_id=None, actor=None):
from letta.log import get_logger
from letta.schemas.user import User
from letta.services.job_manager import JobManager
self.job_manager: JobManager | None = job_manager
self.job_id: str | None = job_id
self.actor: User | None = actor
self._is_cancelled = False
self.logger = get_logger(__name__)
async def is_cancelled(self) -> bool:
"""
Check if the operation has been cancelled.
Returns:
True if cancelled, False otherwise
"""
from letta.schemas.enums import JobStatus
if self._is_cancelled:
return True
if not self.job_manager or not self.job_id or not self.actor:
return False
try:
job = await self.job_manager.get_job_by_id_async(job_id=self.job_id, actor=self.actor)
self._is_cancelled = job.status == JobStatus.cancelled
return self._is_cancelled
except Exception as e:
self.logger.warning(f"Failed to check cancellation status for job {self.job_id}: {e}")
return False
def cancel(self):
"""Mark this signal as cancelled locally (for testing or direct cancellation)."""
self._is_cancelled = True
async def check_and_raise_if_cancelled(self):
"""
Check for cancellation and raise CancelledError if cancelled.
Raises:
asyncio.CancelledError: If the operation has been cancelled
"""
if await self.is_cancelled():
self.logger.info(f"Operation cancelled for job {self.job_id}")
raise asyncio.CancelledError(f"Job {self.job_id} was cancelled")
class NullCancellationSignal(CancellationSignal):
"""A null cancellation signal that is never cancelled."""
def __init__(self):
super().__init__()
async def is_cancelled(self) -> bool:
return False
async def check_and_raise_if_cancelled(self):
pass

View File

@@ -1251,3 +1251,207 @@ def test_auto_summarize(disable_e2b_api_key: Any, client: Letta, llm_config: LLM
prev_length = current_length
else:
raise AssertionError("Summarization was not triggered after 10 messages")
# ============================
# Job Cancellation Tests
# ============================
def wait_for_run_status(client: Letta, run_id: str, target_status: str, timeout: float = 30.0, interval: float = 0.1) -> Run:
"""Wait for a run to reach a specific status"""
start = time.time()
while True:
run = client.runs.retrieve(run_id)
if run.status == target_status:
return run
if time.time() - start > timeout:
raise TimeoutError(f"Run {run_id} did not reach status '{target_status}' within {timeout} seconds (last status: {run.status})")
time.sleep(interval)
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
def test_job_creation_for_send_message(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
llm_config: LLMConfig,
) -> None:
"""
Test that send_message endpoint creates a job and the job completes successfully.
"""
previous_runs = client.runs.list(agent_ids=[agent_state.id])
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
# Send a simple message and verify a job was created
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
# The response should be successful
assert response.messages is not None
assert len(response.messages) > 0
runs = client.runs.list(agent_ids=[agent_state.id])
new_runs = set(r.id for r in runs) - set(r.id for r in previous_runs)
assert len(new_runs) == 1
for run in runs:
if run.id == list(new_runs)[0]:
assert run.status == "completed"
# TODO (cliandy): MERGE BACK IN POST
# @pytest.mark.parametrize(
# "llm_config",
# TESTED_LLM_CONFIGS,
# ids=[c.model for c in TESTED_LLM_CONFIGS],
# )
# def test_async_job_cancellation(
# disable_e2b_api_key: Any,
# client: Letta,
# agent_state: AgentState,
# llm_config: LLMConfig,
# ) -> None:
# """
# Test that an async job can be cancelled and the cancellation is reflected in the job status.
# """
# client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
#
# # client.runs.cancel
# # Start an async job
# run = client.agents.messages.create_async(
# agent_id=agent_state.id,
# messages=USER_MESSAGE_FORCE_REPLY,
# )
#
# # Verify the job was created
# assert run.id is not None
# assert run.status in ["created", "running"]
#
# # Cancel the job quickly (before it potentially completes)
# cancelled_run = client.jobs.cancel(run.id)
#
# # Verify the job was cancelled
# assert cancelled_run.status == "cancelled"
#
# # Wait a bit and verify it stays cancelled (no invalid state transitions)
# time.sleep(1)
# final_run = client.runs.retrieve(run.id)
# assert final_run.status == "cancelled"
#
# # Verify the job metadata indicates cancellation
# if final_run.metadata:
# assert final_run.metadata.get("cancelled") is True or "stop_reason" in final_run.metadata
#
#
# def test_job_cancellation_endpoint_validation(
# disable_e2b_api_key: Any,
# client: Letta,
# agent_state: AgentState,
# ) -> None:
# """
# Test job cancellation endpoint validation (trying to cancel completed/failed jobs).
# """
# # Test cancelling a non-existent job
# with pytest.raises(ApiError) as exc_info:
# client.jobs.cancel("non-existent-job-id")
# assert exc_info.value.status_code == 404
#
#
# @pytest.mark.parametrize(
# "llm_config",
# TESTED_LLM_CONFIGS,
# ids=[c.model for c in TESTED_LLM_CONFIGS],
# )
# def test_completed_job_cannot_be_cancelled(
# disable_e2b_api_key: Any,
# client: Letta,
# agent_state: AgentState,
# llm_config: LLMConfig,
# ) -> None:
# """
# Test that completed jobs cannot be cancelled.
# """
# client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
#
# # Start an async job and wait for it to complete
# run = client.agents.messages.create_async(
# agent_id=agent_state.id,
# messages=USER_MESSAGE_FORCE_REPLY,
# )
#
# # Wait for completion
# completed_run = wait_for_run_completion(client, run.id)
# assert completed_run.status == "completed"
#
# # Try to cancel the completed job - should fail
# with pytest.raises(ApiError) as exc_info:
# client.jobs.cancel(run.id)
# assert exc_info.value.status_code == 400
# assert "Cannot cancel job with status 'completed'" in str(exc_info.value)
#
#
# @pytest.mark.parametrize(
# "llm_config",
# TESTED_LLM_CONFIGS,
# ids=[c.model for c in TESTED_LLM_CONFIGS],
# )
# def test_streaming_job_independence_from_client_disconnect(
# disable_e2b_api_key: Any,
# client: Letta,
# agent_state: AgentState,
# llm_config: LLMConfig,
# ) -> None:
# """
# Test that streaming jobs are independent of client connection state.
# This verifies that jobs continue even if the client "disconnects" (simulated by not consuming the stream).
# """
# client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
#
# # Create a streaming request
# import threading
#
# import httpx
#
# # Get the base URL and create a raw HTTP request to simulate partial consumption
# base_url = client._client_wrapper._base_url
#
# def start_stream_and_abandon():
# """Start a streaming request but abandon it (simulating client disconnect)"""
# try:
# response = httpx.post(
# f"{base_url}/agents/{agent_state.id}/messages/stream",
# json={"messages": [{"role": "user", "text": "Hello, how are you?"}], "stream_tokens": False},
# headers={"user_id": "test-user"},
# timeout=30.0,
# )
#
# # Read just a few chunks then "disconnect" by not reading the rest
# chunk_count = 0
# for chunk in response.iter_lines():
# chunk_count += 1
# if chunk_count > 3: # Read a few chunks then stop
# break
# # Connection is now "abandoned" but the job should continue
#
# except Exception:
# pass # Ignore connection errors
#
# # Start the stream in a separate thread to simulate abandonment
# thread = threading.Thread(target=start_stream_and_abandon)
# thread.start()
# thread.join(timeout=5.0) # Wait up to 5 seconds for the "disconnect"
#
# # The important thing is that this test validates our architecture:
# # 1. Jobs are created before streaming starts (verified by our other tests)
# # 2. Jobs track execution independent of client connection (handled by our wrapper)
# # 3. Only explicit cancellation terminates jobs (tested by other tests)
#
# # This test primarily validates that the implementation doesn't break under simulated disconnection
# assert True # If we get here without errors, the architecture is sound

View File

@@ -4551,7 +4551,7 @@ async def test_create_and_upsert_identity(server: SyncServer, default_user, even
actor=default_user,
)
identity_create.properties = [(IdentityProperty(key="age", value=29, type=IdentityPropertyType.number))]
identity_create.properties = [IdentityProperty(key="age", value=29, type=IdentityPropertyType.number)]
identity = await server.identity_manager.upsert_identity_async(
identity=IdentityUpsert(**identity_create.model_dump()), actor=default_user

View File

@@ -17,6 +17,7 @@ from letta.schemas.message import MessageCreate
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.job_manager import JobManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.step_manager import StepManager