feat: support for agent loop job cancelation (#2837)
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
@@ -34,7 +35,7 @@ from letta.otel.context import get_ctx_attributes
|
||||
from letta.otel.metric_registry import MetricRegistry
|
||||
from letta.otel.tracing import log_event, trace_method, tracer
|
||||
from letta.schemas.agent import AgentState, UpdateAgent
|
||||
from letta.schemas.enums import MessageRole, ProviderType
|
||||
from letta.schemas.enums import JobStatus, MessageRole, ProviderType
|
||||
from letta.schemas.letta_message import MessageType
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
@@ -69,7 +70,6 @@ DEFAULT_SUMMARY_BLOCK_LABEL = "conversation_summary"
|
||||
|
||||
|
||||
class LettaAgent(BaseAgent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -81,6 +81,7 @@ class LettaAgent(BaseAgent):
|
||||
actor: User,
|
||||
step_manager: StepManager = NoopStepManager(),
|
||||
telemetry_manager: TelemetryManager = NoopTelemetryManager(),
|
||||
current_run_id: str | None = None,
|
||||
summary_block_label: str = DEFAULT_SUMMARY_BLOCK_LABEL,
|
||||
message_buffer_limit: int = summarizer_settings.message_buffer_limit,
|
||||
message_buffer_min: int = summarizer_settings.message_buffer_min,
|
||||
@@ -96,7 +97,9 @@ class LettaAgent(BaseAgent):
|
||||
self.passage_manager = passage_manager
|
||||
self.step_manager = step_manager
|
||||
self.telemetry_manager = telemetry_manager
|
||||
self.response_messages: List[Message] = []
|
||||
self.job_manager = job_manager
|
||||
self.current_run_id = current_run_id
|
||||
self.response_messages: list[Message] = []
|
||||
|
||||
self.last_function_response = None
|
||||
|
||||
@@ -128,16 +131,35 @@ class LettaAgent(BaseAgent):
|
||||
message_buffer_min=message_buffer_min,
|
||||
)
|
||||
|
||||
async def _check_run_cancellation(self) -> bool:
|
||||
"""
|
||||
Check if the current run associated with this agent execution has been cancelled.
|
||||
|
||||
Returns:
|
||||
True if the run is cancelled, False otherwise (or if no run is associated)
|
||||
"""
|
||||
if not self.job_manager or not self.current_run_id:
|
||||
return False
|
||||
|
||||
try:
|
||||
job = await self.job_manager.get_job_by_id_async(job_id=self.current_run_id, actor=self.actor)
|
||||
return job.status == JobStatus.cancelled
|
||||
except Exception as e:
|
||||
# Log the error but don't fail the execution
|
||||
logger.warning(f"Failed to check job cancellation status for job {self.current_run_id}: {e}")
|
||||
return False
|
||||
|
||||
@trace_method
|
||||
async def step(
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
input_messages: list[MessageCreate],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
run_id: Optional[str] = None,
|
||||
run_id: str | None = None,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
include_return_message_types: Optional[List[MessageType]] = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
) -> LettaResponse:
|
||||
# TODO (cliandy): pass in run_id and use at send_message endpoints for all step functions
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||||
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
|
||||
)
|
||||
@@ -159,11 +181,11 @@ class LettaAgent(BaseAgent):
|
||||
@trace_method
|
||||
async def step_stream_no_tokens(
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
input_messages: list[MessageCreate],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
include_return_message_types: Optional[List[MessageType]] = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
):
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||||
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
|
||||
@@ -186,6 +208,13 @@ class LettaAgent(BaseAgent):
|
||||
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
||||
|
||||
for i in range(max_steps):
|
||||
# Check for job cancellation at the start of each step
|
||||
if await self._check_run_cancellation():
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
|
||||
logger.info(f"Agent execution cancelled for run {self.current_run_id}")
|
||||
yield f"data: {stop_reason.model_dump_json()}\n\n"
|
||||
break
|
||||
|
||||
step_id = generate_step_id()
|
||||
step_start = get_utc_timestamp_ns()
|
||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||||
@@ -317,11 +346,11 @@ class LettaAgent(BaseAgent):
|
||||
async def _step(
|
||||
self,
|
||||
agent_state: AgentState,
|
||||
input_messages: List[MessageCreate],
|
||||
input_messages: list[MessageCreate],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
run_id: Optional[str] = None,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
) -> Tuple[List[Message], List[Message], Optional[LettaStopReason], LettaUsageStatistics]:
|
||||
run_id: str | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
) -> tuple[list[Message], list[Message], LettaStopReason | None, LettaUsageStatistics]:
|
||||
"""
|
||||
Carries out an invocation of the agent loop. In each step, the agent
|
||||
1. Rebuilds its memory
|
||||
@@ -347,6 +376,12 @@ class LettaAgent(BaseAgent):
|
||||
stop_reason = None
|
||||
usage = LettaUsageStatistics()
|
||||
for i in range(max_steps):
|
||||
# Check for job cancellation at the start of each step
|
||||
if await self._check_run_cancellation():
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
|
||||
logger.info(f"Agent execution cancelled for run {self.current_run_id}")
|
||||
break
|
||||
|
||||
step_id = generate_step_id()
|
||||
step_start = get_utc_timestamp_ns()
|
||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||||
@@ -471,11 +506,11 @@ class LettaAgent(BaseAgent):
|
||||
@trace_method
|
||||
async def step_stream(
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
input_messages: list[MessageCreate],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
include_return_message_types: Optional[List[MessageType]] = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Carries out an invocation of the agent loop in a streaming fashion that yields partial tokens.
|
||||
@@ -507,6 +542,13 @@ class LettaAgent(BaseAgent):
|
||||
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
||||
|
||||
for i in range(max_steps):
|
||||
# Check for job cancellation at the start of each step
|
||||
if await self._check_run_cancellation():
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
|
||||
logger.info(f"Agent execution cancelled for run {self.current_run_id}")
|
||||
yield f"data: {stop_reason.model_dump_json()}\n\n"
|
||||
break
|
||||
|
||||
step_id = generate_step_id()
|
||||
step_start = get_utc_timestamp_ns()
|
||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||||
@@ -547,7 +589,9 @@ class LettaAgent(BaseAgent):
|
||||
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
|
||||
|
||||
async for chunk in interface.process(
|
||||
stream, ttft_span=request_span, provider_request_start_timestamp_ns=provider_request_start_timestamp_ns
|
||||
stream,
|
||||
ttft_span=request_span,
|
||||
provider_request_start_timestamp_ns=provider_request_start_timestamp_ns,
|
||||
):
|
||||
# Measure time to first token
|
||||
if first_chunk and request_span is not None:
|
||||
@@ -690,13 +734,13 @@ class LettaAgent(BaseAgent):
|
||||
# noinspection PyInconsistentReturns
|
||||
async def _build_and_request_from_llm(
|
||||
self,
|
||||
current_in_context_messages: List[Message],
|
||||
new_in_context_messages: List[Message],
|
||||
current_in_context_messages: list[Message],
|
||||
new_in_context_messages: list[Message],
|
||||
agent_state: AgentState,
|
||||
llm_client: LLMClientBase,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
agent_step_span: "Span",
|
||||
) -> Tuple[Dict, Dict, List[Message], List[Message], List[str]] | None:
|
||||
) -> tuple[dict, dict, list[Message], list[Message], list[str]] | None:
|
||||
for attempt in range(self.max_summarization_retries + 1):
|
||||
try:
|
||||
log_event("agent.stream_no_tokens.messages.refreshed")
|
||||
@@ -742,12 +786,12 @@ class LettaAgent(BaseAgent):
|
||||
first_chunk: bool,
|
||||
ttft_span: "Span",
|
||||
request_start_timestamp_ns: int,
|
||||
current_in_context_messages: List[Message],
|
||||
new_in_context_messages: List[Message],
|
||||
current_in_context_messages: list[Message],
|
||||
new_in_context_messages: list[Message],
|
||||
agent_state: AgentState,
|
||||
llm_client: LLMClientBase,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
) -> Tuple[Dict, AsyncStream[ChatCompletionChunk], List[Message], List[Message], List[str], int] | None:
|
||||
) -> tuple[dict, AsyncStream[ChatCompletionChunk], list[Message], list[Message], list[str], int] | None:
|
||||
for attempt in range(self.max_summarization_retries + 1):
|
||||
try:
|
||||
log_event("agent.stream_no_tokens.messages.refreshed")
|
||||
@@ -799,11 +843,11 @@ class LettaAgent(BaseAgent):
|
||||
self,
|
||||
e: Exception,
|
||||
llm_client: LLMClientBase,
|
||||
in_context_messages: List[Message],
|
||||
new_letta_messages: List[Message],
|
||||
in_context_messages: list[Message],
|
||||
new_letta_messages: list[Message],
|
||||
llm_config: LLMConfig,
|
||||
force: bool,
|
||||
) -> List[Message]:
|
||||
) -> list[Message]:
|
||||
if isinstance(e, ContextWindowExceededError):
|
||||
return await self._rebuild_context_window(
|
||||
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages, llm_config=llm_config, force=force
|
||||
@@ -814,12 +858,12 @@ class LettaAgent(BaseAgent):
|
||||
@trace_method
|
||||
async def _rebuild_context_window(
|
||||
self,
|
||||
in_context_messages: List[Message],
|
||||
new_letta_messages: List[Message],
|
||||
in_context_messages: list[Message],
|
||||
new_letta_messages: list[Message],
|
||||
llm_config: LLMConfig,
|
||||
total_tokens: Optional[int] = None,
|
||||
total_tokens: int | None = None,
|
||||
force: bool = False,
|
||||
) -> List[Message]:
|
||||
) -> list[Message]:
|
||||
# If total tokens is reached, we truncate down
|
||||
# TODO: This can be broken by bad configs, e.g. lower bound too high, initial messages too fat, etc.
|
||||
if force or (total_tokens and total_tokens > llm_config.context_window):
|
||||
@@ -855,10 +899,10 @@ class LettaAgent(BaseAgent):
|
||||
async def _create_llm_request_data_async(
|
||||
self,
|
||||
llm_client: LLMClientBase,
|
||||
in_context_messages: List[Message],
|
||||
in_context_messages: list[Message],
|
||||
agent_state: AgentState,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
) -> Tuple[dict, List[str]]:
|
||||
) -> tuple[dict, list[str]]:
|
||||
self.num_messages, self.num_archival_memories = await asyncio.gather(
|
||||
(
|
||||
self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
@@ -929,18 +973,18 @@ class LettaAgent(BaseAgent):
|
||||
async def _handle_ai_response(
|
||||
self,
|
||||
tool_call: ToolCall,
|
||||
valid_tool_names: List[str],
|
||||
valid_tool_names: list[str],
|
||||
agent_state: AgentState,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
usage: UsageStatistics,
|
||||
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
|
||||
pre_computed_assistant_message_id: Optional[str] = None,
|
||||
reasoning_content: list[TextContent | ReasoningContent | RedactedReasoningContent | OmittedReasoningContent] | None = None,
|
||||
pre_computed_assistant_message_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
initial_messages: Optional[List[Message]] = None,
|
||||
initial_messages: list[Message] | None = None,
|
||||
agent_step_span: Optional["Span"] = None,
|
||||
is_final_step: Optional[bool] = None,
|
||||
run_id: Optional[str] = None,
|
||||
) -> Tuple[List[Message], bool, Optional[LettaStopReason]]:
|
||||
is_final_step: bool | None = None,
|
||||
run_id: str | None = None,
|
||||
) -> tuple[list[Message], bool, LettaStopReason | None]:
|
||||
"""
|
||||
Handle the final AI response once streaming completes, execute / validate the
|
||||
tool call, decide whether we should keep stepping, and persist state.
|
||||
@@ -1016,7 +1060,7 @@ class LettaAgent(BaseAgent):
|
||||
context_window_limit=agent_state.llm_config.context_window,
|
||||
usage=usage,
|
||||
provider_id=None,
|
||||
job_id=run_id,
|
||||
job_id=run_id if run_id else self.current_run_id,
|
||||
step_id=step_id,
|
||||
project_id=agent_state.project_id,
|
||||
)
|
||||
@@ -1155,7 +1199,7 @@ class LettaAgent(BaseAgent):
|
||||
name="tool_execution_completed",
|
||||
attributes={
|
||||
"tool_name": target_tool.name,
|
||||
"duration_ms": ns_to_ms((end_time - start_time)),
|
||||
"duration_ms": ns_to_ms(end_time - start_time),
|
||||
"success": tool_execution_result.success_flag,
|
||||
"tool_type": target_tool.tool_type,
|
||||
"tool_id": target_tool.id,
|
||||
@@ -1165,7 +1209,7 @@ class LettaAgent(BaseAgent):
|
||||
return tool_execution_result
|
||||
|
||||
@trace_method
|
||||
def _load_last_function_response(self, in_context_messages: List[Message]):
|
||||
def _load_last_function_response(self, in_context_messages: list[Message]):
|
||||
"""Load the last function response from message history"""
|
||||
for msg in reversed(in_context_messages):
|
||||
if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent):
|
||||
|
||||
@@ -358,6 +358,7 @@ REDIS_INCLUDE = "include"
|
||||
REDIS_EXCLUDE = "exclude"
|
||||
REDIS_SET_DEFAULT_VAL = "None"
|
||||
REDIS_DEFAULT_CACHE_PREFIX = "letta_cache"
|
||||
REDIS_RUN_ID_PREFIX = "agent:send_message:run_id"
|
||||
|
||||
# TODO: This is temporary, eventually use token-based eviction
|
||||
MAX_FILES_OPEN = 5
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
@@ -39,7 +39,8 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
actor: User,
|
||||
step_manager: StepManager = NoopStepManager(),
|
||||
telemetry_manager: TelemetryManager = NoopTelemetryManager(),
|
||||
group: Optional[Group] = None,
|
||||
group: Group | None = None,
|
||||
current_run_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
agent_id=agent_id,
|
||||
@@ -54,6 +55,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
self.job_manager = job_manager
|
||||
self.step_manager = step_manager
|
||||
self.telemetry_manager = telemetry_manager
|
||||
self.current_run_id = current_run_id
|
||||
# Group settings
|
||||
assert group.manager_type == ManagerType.sleeptime, f"Expected group manager type to be 'sleeptime', got {group.manager_type}"
|
||||
self.group = group
|
||||
@@ -61,12 +63,12 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
@trace_method
|
||||
async def step(
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
input_messages: list[MessageCreate],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
run_id: Optional[str] = None,
|
||||
run_id: str | None = None,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
include_return_message_types: Optional[List[MessageType]] = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
) -> LettaResponse:
|
||||
run_ids = []
|
||||
|
||||
@@ -89,6 +91,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
actor=self.actor,
|
||||
step_manager=self.step_manager,
|
||||
telemetry_manager=self.telemetry_manager,
|
||||
current_run_id=self.current_run_id,
|
||||
)
|
||||
# Perform foreground agent step
|
||||
response = await foreground_agent.step(
|
||||
@@ -125,7 +128,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
|
||||
except Exception as e:
|
||||
# Individual task failures
|
||||
print(f"Agent processing failed: {str(e)}")
|
||||
print(f"Agent processing failed: {e!s}")
|
||||
raise e
|
||||
|
||||
response.usage.run_ids = run_ids
|
||||
@@ -134,11 +137,11 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
@trace_method
|
||||
async def step_stream_no_tokens(
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
input_messages: list[MessageCreate],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
include_return_message_types: Optional[List[MessageType]] = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
):
|
||||
response = await self.step(
|
||||
input_messages=input_messages,
|
||||
@@ -157,11 +160,11 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
@trace_method
|
||||
async def step_stream(
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
input_messages: list[MessageCreate],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
include_return_message_types: Optional[List[MessageType]] = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
# Prepare new messages
|
||||
new_messages = []
|
||||
@@ -182,6 +185,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
actor=self.actor,
|
||||
step_manager=self.step_manager,
|
||||
telemetry_manager=self.telemetry_manager,
|
||||
current_run_id=self.current_run_id,
|
||||
)
|
||||
# Perform foreground agent step
|
||||
async for chunk in foreground_agent.step_stream(
|
||||
@@ -218,7 +222,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
async def _issue_background_task(
|
||||
self,
|
||||
sleeptime_agent_id: str,
|
||||
response_messages: List[Message],
|
||||
response_messages: list[Message],
|
||||
last_processed_message_id: str,
|
||||
use_assistant_message: bool = True,
|
||||
) -> str:
|
||||
@@ -248,7 +252,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
self,
|
||||
foreground_agent_id: str,
|
||||
sleeptime_agent_id: str,
|
||||
response_messages: List[Message],
|
||||
response_messages: list[Message],
|
||||
last_processed_message_id: str,
|
||||
run_id: str,
|
||||
use_assistant_message: bool = True,
|
||||
@@ -296,6 +300,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
actor=self.actor,
|
||||
step_manager=self.step_manager,
|
||||
telemetry_manager=self.telemetry_manager,
|
||||
current_run_id=self.current_run_id,
|
||||
message_buffer_limit=20, # TODO: Make this configurable
|
||||
message_buffer_min=8, # TODO: Make this configurable
|
||||
enable_summarization=False, # TODO: Make this configurable
|
||||
|
||||
@@ -81,7 +81,7 @@ class CLIInterface(AgentInterface):
|
||||
@staticmethod
|
||||
def internal_monologue(msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
||||
# ANSI escape code for italic is '\x1B[3m'
|
||||
fstr = f"\x1B[3m{Fore.LIGHTBLACK_EX}{INNER_THOUGHTS_CLI_SYMBOL} {{msg}}{Style.RESET_ALL}"
|
||||
fstr = f"\x1b[3m{Fore.LIGHTBLACK_EX}{INNER_THOUGHTS_CLI_SYMBOL} {{msg}}{Style.RESET_ALL}"
|
||||
if STRIP_UI:
|
||||
fstr = "{msg}"
|
||||
print(fstr.format(msg=msg))
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from anthropic import AsyncStream
|
||||
from anthropic.types.beta import (
|
||||
@@ -131,14 +133,16 @@ class AnthropicStreamingInterface:
|
||||
self,
|
||||
stream: AsyncStream[BetaRawMessageStreamEvent],
|
||||
ttft_span: Optional["Span"] = None,
|
||||
provider_request_start_timestamp_ns: Optional[int] = None,
|
||||
) -> AsyncGenerator[LettaMessage, None]:
|
||||
provider_request_start_timestamp_ns: int | None = None,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
prev_message_type = None
|
||||
message_index = 0
|
||||
first_chunk = True
|
||||
try:
|
||||
async with stream:
|
||||
async for event in stream:
|
||||
# TODO (cliandy): reconsider in stream cancellations
|
||||
# await cancellation_token.check_and_raise_if_cancelled()
|
||||
if first_chunk and ttft_span is not None and provider_request_start_timestamp_ns is not None:
|
||||
now = get_utc_timestamp_ns()
|
||||
ttft_ns = now - provider_request_start_timestamp_ns
|
||||
@@ -384,18 +388,21 @@ class AnthropicStreamingInterface:
|
||||
self.tool_call_buffer = []
|
||||
|
||||
self.anthropic_mode = None
|
||||
except asyncio.CancelledError as e:
|
||||
logger.info("Cancelled stream %s", e)
|
||||
yield LettaStopReason(stop_reason=StopReasonType.cancelled)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error processing stream: %s", e)
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||
yield stop_reason
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise
|
||||
finally:
|
||||
logger.info("AnthropicStreamingInterface: Stream processing complete.")
|
||||
|
||||
def get_reasoning_content(self) -> List[Union[TextContent, ReasoningContent, RedactedReasoningContent]]:
|
||||
def get_reasoning_content(self) -> list[TextContent | ReasoningContent | RedactedReasoningContent]:
|
||||
def _process_group(
|
||||
group: List[Union[ReasoningMessage, HiddenReasoningMessage]], group_type: str
|
||||
) -> Union[TextContent, ReasoningContent, RedactedReasoningContent]:
|
||||
group: list[ReasoningMessage | HiddenReasoningMessage], group_type: str
|
||||
) -> TextContent | ReasoningContent | RedactedReasoningContent:
|
||||
if group_type == "reasoning":
|
||||
reasoning_text = "".join(chunk.reasoning for chunk in group).strip()
|
||||
is_native = any(chunk.source == "reasoner_model" for chunk in group)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta
|
||||
@@ -19,14 +20,14 @@ class OpenAIChatCompletionsStreamingInterface:
|
||||
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
|
||||
self.stream_pre_execution_message: bool = stream_pre_execution_message
|
||||
|
||||
self.current_parsed_json_result: Dict[str, Any] = {}
|
||||
self.content_buffer: List[str] = []
|
||||
self.current_parsed_json_result: dict[str, Any] = {}
|
||||
self.content_buffer: list[str] = []
|
||||
self.tool_call_happened: bool = False
|
||||
self.finish_reason_stop: bool = False
|
||||
|
||||
self.tool_call_name: Optional[str] = None
|
||||
self.tool_call_name: str | None = None
|
||||
self.tool_call_args_str: str = ""
|
||||
self.tool_call_id: Optional[str] = None
|
||||
self.tool_call_id: str | None = None
|
||||
|
||||
async def process(self, stream: AsyncStream[ChatCompletionChunk]) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
@@ -35,6 +36,8 @@ class OpenAIChatCompletionsStreamingInterface:
|
||||
"""
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
# TODO (cliandy): reconsider in stream cancellations
|
||||
# await cancellation_token.check_and_raise_if_cancelled()
|
||||
if chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
@@ -103,7 +106,7 @@ class OpenAIChatCompletionsStreamingInterface:
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_finish_reason(self, finish_reason: Optional[str]) -> bool:
|
||||
def _handle_finish_reason(self, finish_reason: str | None) -> bool:
|
||||
"""Handles the finish reason and determines if streaming should stop."""
|
||||
if finish_reason == "tool_calls":
|
||||
self.tool_call_happened = True
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
@@ -55,12 +57,12 @@ class OpenAIStreamingInterface:
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
|
||||
self.content_buffer: List[str] = []
|
||||
self.tool_call_name: Optional[str] = None
|
||||
self.tool_call_id: Optional[str] = None
|
||||
self.content_buffer: list[str] = []
|
||||
self.tool_call_name: str | None = None
|
||||
self.tool_call_id: str | None = None
|
||||
self.reasoning_messages = []
|
||||
|
||||
def get_reasoning_content(self) -> List[TextContent]:
|
||||
def get_reasoning_content(self) -> list[TextContent | OmittedReasoningContent]:
|
||||
content = "".join(self.reasoning_messages).strip()
|
||||
|
||||
# Right now we assume that all models omit reasoning content for OAI,
|
||||
@@ -87,8 +89,8 @@ class OpenAIStreamingInterface:
|
||||
self,
|
||||
stream: AsyncStream[ChatCompletionChunk],
|
||||
ttft_span: Optional["Span"] = None,
|
||||
provider_request_start_timestamp_ns: Optional[int] = None,
|
||||
) -> AsyncGenerator[LettaMessage, None]:
|
||||
provider_request_start_timestamp_ns: int | None = None,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
"""
|
||||
Iterates over the OpenAI stream, yielding SSE events.
|
||||
It also collects tokens and detects if a tool call is triggered.
|
||||
@@ -99,6 +101,8 @@ class OpenAIStreamingInterface:
|
||||
prev_message_type = None
|
||||
message_index = 0
|
||||
async for chunk in stream:
|
||||
# TODO (cliandy): reconsider in stream cancellations
|
||||
# await cancellation_token.check_and_raise_if_cancelled()
|
||||
if first_chunk and ttft_span is not None and provider_request_start_timestamp_ns is not None:
|
||||
now = get_utc_timestamp_ns()
|
||||
ttft_ns = now - provider_request_start_timestamp_ns
|
||||
@@ -224,8 +228,7 @@ class OpenAIStreamingInterface:
|
||||
# If there was nothing in the name buffer, we can proceed to
|
||||
# output the arguments chunk as a ToolCallMessage
|
||||
else:
|
||||
|
||||
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
||||
# use_assistant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
||||
if self.use_assistant_message and (
|
||||
self.last_flushed_function_name is not None
|
||||
and self.last_flushed_function_name == self.assistant_message_tool_name
|
||||
@@ -349,10 +352,13 @@ class OpenAIStreamingInterface:
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
self.function_id_buffer = None
|
||||
except asyncio.CancelledError as e:
|
||||
logger.info("Cancelled stream %s", e)
|
||||
yield LettaStopReason(stop_reason=StopReasonType.cancelled)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error processing stream: %s", e)
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||
yield stop_reason
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise
|
||||
finally:
|
||||
logger.info("OpenAIStreamingInterface: Stream processing complete.")
|
||||
|
||||
@@ -54,6 +54,10 @@ class JobStatus(str, Enum):
|
||||
cancelled = "cancelled"
|
||||
expired = "expired"
|
||||
|
||||
@property
|
||||
def is_terminal(self):
|
||||
return self in (JobStatus.completed, JobStatus.failed, JobStatus.cancelled, JobStatus.expired)
|
||||
|
||||
|
||||
class AgentStepStatus(str, Enum):
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.enums import JobStatus
|
||||
|
||||
|
||||
class StopReasonType(str, Enum):
|
||||
end_turn = "end_turn"
|
||||
@@ -11,6 +13,22 @@ class StopReasonType(str, Enum):
|
||||
max_steps = "max_steps"
|
||||
no_tool_call = "no_tool_call"
|
||||
tool_rule = "tool_rule"
|
||||
cancelled = "cancelled"
|
||||
|
||||
@property
|
||||
def run_status(self) -> JobStatus:
|
||||
if self in (
|
||||
StopReasonType.end_turn,
|
||||
StopReasonType.max_steps,
|
||||
StopReasonType.tool_rule,
|
||||
):
|
||||
return JobStatus.completed
|
||||
elif self in (StopReasonType.error, StopReasonType.invalid_tool_call, StopReasonType.no_tool_call):
|
||||
return JobStatus.failed
|
||||
elif self == StopReasonType.cancelled:
|
||||
return JobStatus.cancelled
|
||||
else:
|
||||
raise ValueError("Unknown StopReasonType")
|
||||
|
||||
|
||||
class LettaStopReason(BaseModel):
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Any, List, Optional
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, File, Header, HTTPException, Query, Request, UploadFile, status
|
||||
from fastapi.responses import JSONResponse
|
||||
@@ -13,7 +13,8 @@ from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, LETTA_MODEL_ENDPOINT
|
||||
from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, LETTA_MODEL_ENDPOINT, REDIS_RUN_ID_PREFIX
|
||||
from letta.data_sources.redis_client import get_redis_client
|
||||
from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
|
||||
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
||||
from letta.log import get_logger
|
||||
@@ -49,26 +50,26 @@ router = APIRouter(prefix="/agents", tags=["agents"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.get("/", response_model=List[AgentState], operation_id="list_agents")
|
||||
@router.get("/", response_model=list[AgentState], operation_id="list_agents")
|
||||
async def list_agents(
|
||||
name: Optional[str] = Query(None, description="Name of the agent"),
|
||||
tags: Optional[List[str]] = Query(None, description="List of tags to filter agents by"),
|
||||
name: str | None = Query(None, description="Name of the agent"),
|
||||
tags: list[str] | None = Query(None, description="List of tags to filter agents by"),
|
||||
match_all_tags: bool = Query(
|
||||
False,
|
||||
description="If True, only returns agents that match ALL given tags. Otherwise, return agents that have ANY of the passed-in tags.",
|
||||
),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
before: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
after: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
limit: Optional[int] = Query(50, description="Limit for pagination"),
|
||||
query_text: Optional[str] = Query(None, description="Search agents by name"),
|
||||
project_id: Optional[str] = Query(None, description="Search agents by project ID"),
|
||||
template_id: Optional[str] = Query(None, description="Search agents by template ID"),
|
||||
base_template_id: Optional[str] = Query(None, description="Search agents by base template ID"),
|
||||
identity_id: Optional[str] = Query(None, description="Search agents by identity ID"),
|
||||
identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"),
|
||||
include_relationships: Optional[List[str]] = Query(
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
before: str | None = Query(None, description="Cursor for pagination"),
|
||||
after: str | None = Query(None, description="Cursor for pagination"),
|
||||
limit: int | None = Query(50, description="Limit for pagination"),
|
||||
query_text: str | None = Query(None, description="Search agents by name"),
|
||||
project_id: str | None = Query(None, description="Search agents by project ID"),
|
||||
template_id: str | None = Query(None, description="Search agents by template ID"),
|
||||
base_template_id: str | None = Query(None, description="Search agents by base template ID"),
|
||||
identity_id: str | None = Query(None, description="Search agents by identity ID"),
|
||||
identifier_keys: list[str] | None = Query(None, description="Search agents by identifier keys"),
|
||||
include_relationships: list[str] | None = Query(
|
||||
None,
|
||||
description=(
|
||||
"Specify which relational fields (e.g., 'tools', 'sources', 'memory') to include in the response. "
|
||||
@@ -80,7 +81,7 @@ async def list_agents(
|
||||
False,
|
||||
description="Whether to sort agents oldest to newest (True) or newest to oldest (False, default)",
|
||||
),
|
||||
sort_by: Optional[str] = Query(
|
||||
sort_by: str | None = Query(
|
||||
"created_at",
|
||||
description="Field to sort by. Options: 'created_at' (default), 'last_run_completion'",
|
||||
),
|
||||
@@ -119,7 +120,7 @@ async def list_agents(
|
||||
@router.get("/count", response_model=int, operation_id="count_agents")
|
||||
async def count_agents(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Get the count of all agents associated with a given user.
|
||||
@@ -139,10 +140,10 @@ class IndentedORJSONResponse(Response):
|
||||
def export_agent_serialized(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
# do not remove, used to autogeneration of spec
|
||||
# TODO: Think of a better way to export AgentSchema
|
||||
spec: Optional[AgentSchema] = None,
|
||||
spec: AgentSchema | None = None,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Export the serialized JSON representation of an agent, formatted with indentation.
|
||||
@@ -160,13 +161,13 @@ def export_agent_serialized(
|
||||
def import_agent_serialized(
|
||||
file: UploadFile = File(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
append_copy_suffix: bool = Query(True, description='If set to True, appends "_copy" to the end of the agent name.'),
|
||||
override_existing_tools: bool = Query(
|
||||
True,
|
||||
description="If set to True, existing tools can get their source code overwritten by the uploaded tool definitions. Note that Letta core tools can never be updated externally.",
|
||||
),
|
||||
project_id: Optional[str] = Query(None, description="The project ID to associate the uploaded agent with."),
|
||||
project_id: str | None = Query(None, description="The project ID to associate the uploaded agent with."),
|
||||
strip_messages: bool = Query(
|
||||
False,
|
||||
description="If set to True, strips all messages from the agent before importing.",
|
||||
@@ -198,24 +199,24 @@ def import_agent_serialized(
|
||||
raise HTTPException(status_code=400, detail="Corrupted agent file format.")
|
||||
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=422, detail=f"Invalid agent schema: {str(e)}")
|
||||
raise HTTPException(status_code=422, detail=f"Invalid agent schema: {e!s}")
|
||||
|
||||
except IntegrityError as e:
|
||||
raise HTTPException(status_code=409, detail=f"Database integrity error: {str(e)}")
|
||||
raise HTTPException(status_code=409, detail=f"Database integrity error: {e!s}")
|
||||
|
||||
except OperationalError as e:
|
||||
raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {str(e)}")
|
||||
raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {e!s}")
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred while uploading the agent: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred while uploading the agent: {e!s}")
|
||||
|
||||
|
||||
@router.get("/{agent_id}/context", response_model=ContextWindowOverview, operation_id="retrieve_agent_context_window")
|
||||
async def retrieve_agent_context_window(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the context window of a specific agent.
|
||||
@@ -234,15 +235,15 @@ class CreateAgentRequest(CreateAgent):
|
||||
"""
|
||||
|
||||
# Override the user_id field to exclude it from the request body validation
|
||||
actor_id: Optional[str] = Field(None, exclude=True)
|
||||
actor_id: str | None = Field(None, exclude=True)
|
||||
|
||||
|
||||
@router.post("/", response_model=AgentState, operation_id="create_agent")
|
||||
async def create_agent(
|
||||
agent: CreateAgentRequest = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
x_project: Optional[str] = Header(
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
x_project: str | None = Header(
|
||||
None, alias="X-Project", description="The project slug to associate with the agent (cloud only)."
|
||||
), # Only handled by next js middleware
|
||||
):
|
||||
@@ -262,18 +263,18 @@ async def modify_agent(
|
||||
agent_id: str,
|
||||
update_agent: UpdateAgent = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Update an existing agent"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.update_agent_async(agent_id=agent_id, request=update_agent, actor=actor)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/tools", response_model=List[Tool], operation_id="list_agent_tools")
|
||||
@router.get("/{agent_id}/tools", response_model=list[Tool], operation_id="list_agent_tools")
|
||||
def list_agent_tools(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Get tools from an existing agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
@@ -285,7 +286,7 @@ async def attach_tool(
|
||||
agent_id: str,
|
||||
tool_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Attach a tool to an agent.
|
||||
@@ -299,7 +300,7 @@ async def detach_tool(
|
||||
agent_id: str,
|
||||
tool_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Detach a tool from an agent.
|
||||
@@ -313,7 +314,7 @@ async def attach_source(
|
||||
agent_id: str,
|
||||
source_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Attach a source to an agent.
|
||||
@@ -341,7 +342,7 @@ async def detach_source(
|
||||
agent_id: str,
|
||||
source_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Detach a source from an agent.
|
||||
@@ -386,7 +387,7 @@ async def close_all_open_files(
|
||||
@router.get("/{agent_id}", response_model=AgentState, operation_id="retrieve_agent")
|
||||
async def retrieve_agent(
|
||||
agent_id: str,
|
||||
include_relationships: Optional[List[str]] = Query(
|
||||
include_relationships: list[str] | None = Query(
|
||||
None,
|
||||
description=(
|
||||
"Specify which relational fields (e.g., 'tools', 'sources', 'memory') to include in the response. "
|
||||
@@ -395,7 +396,7 @@ async def retrieve_agent(
|
||||
),
|
||||
),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get the state of the agent.
|
||||
@@ -412,7 +413,7 @@ async def retrieve_agent(
|
||||
async def delete_agent(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete an agent.
|
||||
@@ -425,11 +426,11 @@ async def delete_agent(
|
||||
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found for user_id={actor.id}.")
|
||||
|
||||
|
||||
@router.get("/{agent_id}/sources", response_model=List[Source], operation_id="list_agent_sources")
|
||||
@router.get("/{agent_id}/sources", response_model=list[Source], operation_id="list_agent_sources")
|
||||
async def list_agent_sources(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get the sources associated with an agent.
|
||||
@@ -443,7 +444,7 @@ async def list_agent_sources(
|
||||
async def retrieve_agent_memory(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the memory state of a specific agent.
|
||||
@@ -459,7 +460,7 @@ async def retrieve_block(
|
||||
agent_id: str,
|
||||
block_label: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve a core memory block from an agent.
|
||||
@@ -472,11 +473,11 @@ async def retrieve_block(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{agent_id}/core-memory/blocks", response_model=List[Block], operation_id="list_core_memory_blocks")
|
||||
@router.get("/{agent_id}/core-memory/blocks", response_model=list[Block], operation_id="list_core_memory_blocks")
|
||||
async def list_blocks(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the core memory blocks of a specific agent.
|
||||
@@ -495,7 +496,7 @@ async def modify_block(
|
||||
block_label: str,
|
||||
block_update: BlockUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Updates a core memory block of an agent.
|
||||
@@ -517,7 +518,7 @@ async def attach_block(
|
||||
agent_id: str,
|
||||
block_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Attach a core memoryblock to an agent.
|
||||
@@ -531,7 +532,7 @@ async def detach_block(
|
||||
agent_id: str,
|
||||
block_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Detach a core memory block from an agent.
|
||||
@@ -540,18 +541,18 @@ async def detach_block(
|
||||
return await server.agent_manager.detach_block_async(agent_id=agent_id, block_id=block_id, actor=actor)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/archival-memory", response_model=List[Passage], operation_id="list_passages")
|
||||
@router.get("/{agent_id}/archival-memory", response_model=list[Passage], operation_id="list_passages")
|
||||
async def list_passages(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
after: Optional[str] = Query(None, description="Unique ID of the memory to start the query range at."),
|
||||
before: Optional[str] = Query(None, description="Unique ID of the memory to end the query range at."),
|
||||
limit: Optional[int] = Query(None, description="How many results to include in the response."),
|
||||
search: Optional[str] = Query(None, description="Search passages by text"),
|
||||
ascending: Optional[bool] = Query(
|
||||
after: str | None = Query(None, description="Unique ID of the memory to start the query range at."),
|
||||
before: str | None = Query(None, description="Unique ID of the memory to end the query range at."),
|
||||
limit: int | None = Query(None, description="How many results to include in the response."),
|
||||
search: str | None = Query(None, description="Search passages by text"),
|
||||
ascending: bool | None = Query(
|
||||
True, description="Whether to sort passages oldest to newest (True, default) or newest to oldest (False)"
|
||||
),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the memories in an agent's archival memory store (paginated query).
|
||||
@@ -569,12 +570,12 @@ async def list_passages(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{agent_id}/archival-memory", response_model=List[Passage], operation_id="create_passage")
|
||||
@router.post("/{agent_id}/archival-memory", response_model=list[Passage], operation_id="create_passage")
|
||||
async def create_passage(
|
||||
agent_id: str,
|
||||
request: CreateArchivalMemory = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Insert a memory into an agent's archival memory store.
|
||||
@@ -584,13 +585,13 @@ async def create_passage(
|
||||
return await server.insert_archival_memory_async(agent_id=agent_id, memory_contents=request.text, actor=actor)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/archival-memory/{memory_id}", response_model=List[Passage], operation_id="modify_passage")
|
||||
@router.patch("/{agent_id}/archival-memory/{memory_id}", response_model=list[Passage], operation_id="modify_passage")
|
||||
def modify_passage(
|
||||
agent_id: str,
|
||||
memory_id: str,
|
||||
passage: PassageUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Modify a memory in the agent's archival memory store.
|
||||
@@ -607,7 +608,7 @@ async def delete_passage(
|
||||
memory_id: str,
|
||||
# memory_id: str = Query(..., description="Unique ID of the memory to be deleted."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete a memory from an agent's archival memory store.
|
||||
@@ -619,7 +620,7 @@ async def delete_passage(
|
||||
|
||||
|
||||
AgentMessagesResponse = Annotated[
|
||||
List[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}})
|
||||
list[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}})
|
||||
]
|
||||
|
||||
|
||||
@@ -627,14 +628,14 @@ AgentMessagesResponse = Annotated[
|
||||
async def list_messages(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."),
|
||||
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
|
||||
after: str | None = Query(None, description="Message after which to retrieve the returned messages."),
|
||||
before: str | None = Query(None, description="Message before which to retrieve the returned messages."),
|
||||
limit: int = Query(10, description="Maximum number of messages to retrieve."),
|
||||
group_id: Optional[str] = Query(None, description="Group ID to filter messages by."),
|
||||
group_id: str | None = Query(None, description="Group ID to filter messages by."),
|
||||
use_assistant_message: bool = Query(True, description="Whether to use assistant messages"),
|
||||
assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."),
|
||||
assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve message history for an agent.
|
||||
@@ -662,7 +663,7 @@ def modify_message(
|
||||
message_id: str,
|
||||
request: LettaMessageUpdateUnion = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Update the details of a message associated with an agent.
|
||||
@@ -672,6 +673,7 @@ def modify_message(
|
||||
return server.message_manager.update_message_by_letta_message(message_id=message_id, letta_message_update=request, actor=actor)
|
||||
|
||||
|
||||
# noinspection PyInconsistentReturns
|
||||
@router.post(
|
||||
"/{agent_id}/messages",
|
||||
response_model=LettaResponse,
|
||||
@@ -682,7 +684,7 @@ async def send_message(
|
||||
request_obj: Request, # FastAPI Request
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: LettaRequest = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Process a user message and return the agent's response.
|
||||
@@ -697,55 +699,95 @@ async def send_message(
|
||||
agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"]
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"]
|
||||
|
||||
if agent_eligible and model_compatible:
|
||||
if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent:
|
||||
agent_loop = SleeptimeMultiAgentV2(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
group_manager=server.group_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=actor,
|
||||
group=agent.multi_agent_group,
|
||||
# Create a new run for execution tracking
|
||||
job_status = JobStatus.created
|
||||
run = await server.job_manager.create_job_async(
|
||||
pydantic_job=Run(
|
||||
user_id=actor.id,
|
||||
status=job_status,
|
||||
metadata={
|
||||
"job_type": "send_message",
|
||||
"agent_id": agent_id,
|
||||
},
|
||||
request_config=LettaRequestConfig(
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
),
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
job_update_metadata = None
|
||||
# TODO (cliandy): clean this up
|
||||
redis_client = await get_redis_client()
|
||||
await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id)
|
||||
|
||||
try:
|
||||
if agent_eligible and model_compatible:
|
||||
if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent:
|
||||
agent_loop = SleeptimeMultiAgentV2(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
group_manager=server.group_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=actor,
|
||||
group=agent.multi_agent_group,
|
||||
current_run_id=run.id,
|
||||
)
|
||||
else:
|
||||
agent_loop = LettaAgent(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
job_manager=server.job_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
current_run_id=run.id,
|
||||
)
|
||||
|
||||
result = await agent_loop.step(
|
||||
request.messages,
|
||||
max_steps=request.max_steps,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
)
|
||||
else:
|
||||
agent_loop = LettaAgent(
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
job_manager=server.job_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
input_messages=request.messages,
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
# Support for AssistantMessage
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
)
|
||||
|
||||
result = await agent_loop.step(
|
||||
request.messages,
|
||||
max_steps=request.max_steps,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
)
|
||||
else:
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
job_status = result.stop_reason.stop_reason.run_status
|
||||
return result
|
||||
except Exception as e:
|
||||
job_update_metadata = {"error": str(e)}
|
||||
job_status = JobStatus.failed
|
||||
raise
|
||||
finally:
|
||||
await server.job_manager.safe_update_job_status_async(
|
||||
job_id=run.id,
|
||||
new_status=job_status,
|
||||
actor=actor,
|
||||
input_messages=request.messages,
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
# Support for AssistantMessage
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
metadata=job_update_metadata,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# noinspection PyInconsistentReturns
|
||||
@router.post(
|
||||
"/{agent_id}/messages/stream",
|
||||
response_model=None,
|
||||
@@ -764,7 +806,7 @@ async def send_message_streaming(
|
||||
request_obj: Request, # FastAPI Request
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: LettaStreamingRequest = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
) -> StreamingResponse | LettaResponse:
|
||||
"""
|
||||
Process a user message and return the agent's response.
|
||||
@@ -780,88 +822,160 @@ async def send_message_streaming(
|
||||
agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"]
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"]
|
||||
model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"]
|
||||
not_letta_endpoint = LETTA_MODEL_ENDPOINT != agent.llm_config.model_endpoint
|
||||
not_letta_endpoint = agent.llm_config.model_endpoint != LETTA_MODEL_ENDPOINT
|
||||
|
||||
if agent_eligible and model_compatible:
|
||||
if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent:
|
||||
agent_loop = SleeptimeMultiAgentV2(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
group_manager=server.group_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
group=agent.multi_agent_group,
|
||||
)
|
||||
else:
|
||||
agent_loop = LettaAgent(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
job_manager=server.job_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
)
|
||||
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
|
||||
# Create a new job for execution tracking
|
||||
job_status = JobStatus.created
|
||||
run = await server.job_manager.create_job_async(
|
||||
pydantic_job=Run(
|
||||
user_id=actor.id,
|
||||
status=job_status,
|
||||
metadata={
|
||||
"job_type": "send_message_streaming",
|
||||
"agent_id": agent_id,
|
||||
},
|
||||
request_config=LettaRequestConfig(
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
),
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
if request.stream_tokens and model_compatible_token_streaming and not_letta_endpoint:
|
||||
result = StreamingResponseWithStatusCode(
|
||||
agent_loop.step_stream(
|
||||
input_messages=request.messages,
|
||||
max_steps=request.max_steps,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
job_update_metadata = None
|
||||
# TODO (cliandy): clean this up
|
||||
redis_client = await get_redis_client()
|
||||
await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id)
|
||||
|
||||
try:
|
||||
if agent_eligible and model_compatible:
|
||||
if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent:
|
||||
agent_loop = SleeptimeMultiAgentV2(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
group_manager=server.group_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
group=agent.multi_agent_group,
|
||||
current_run_id=run.id,
|
||||
)
|
||||
else:
|
||||
agent_loop = LettaAgent(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
job_manager=server.job_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
current_run_id=run.id,
|
||||
)
|
||||
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
|
||||
|
||||
if request.stream_tokens and model_compatible_token_streaming and not_letta_endpoint:
|
||||
result = StreamingResponseWithStatusCode(
|
||||
agent_loop.step_stream(
|
||||
input_messages=request.messages,
|
||||
max_steps=request.max_steps,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
result = StreamingResponseWithStatusCode(
|
||||
agent_loop.step_stream_no_tokens(
|
||||
request.messages,
|
||||
max_steps=request.max_steps,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
result = StreamingResponseWithStatusCode(
|
||||
agent_loop.step_stream_no_tokens(
|
||||
request.messages,
|
||||
max_steps=request.max_steps,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
input_messages=request.messages,
|
||||
stream_steps=True,
|
||||
stream_tokens=request.stream_tokens,
|
||||
# Support for AssistantMessage
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
)
|
||||
else:
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
job_status = JobStatus.running
|
||||
return result
|
||||
except Exception as e:
|
||||
job_update_metadata = {"error": str(e)}
|
||||
job_status = JobStatus.failed
|
||||
raise
|
||||
finally:
|
||||
await server.job_manager.safe_update_job_status_async(
|
||||
job_id=run.id,
|
||||
new_status=job_status,
|
||||
actor=actor,
|
||||
input_messages=request.messages,
|
||||
stream_steps=True,
|
||||
stream_tokens=request.stream_tokens,
|
||||
# Support for AssistantMessage
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
metadata=job_update_metadata,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@router.post("/{agent_id}/messages/cancel", operation_id="cancel_agent_run")
|
||||
async def cancel_agent_run(
|
||||
agent_id: str,
|
||||
run_ids: list[str] | None = None,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
) -> dict:
|
||||
"""
|
||||
Cancel runs associated with an agent. If run_ids are passed in, cancel those in particular.
|
||||
|
||||
Note to cancel active runs associated with an agent, redis is required.
|
||||
"""
|
||||
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
if not run_ids:
|
||||
redis_client = await get_redis_client()
|
||||
run_id = await redis_client.get(f"{REDIS_RUN_ID_PREFIX}:{agent_id}")
|
||||
if run_id is None:
|
||||
logger.warning("Cannot find run associated with agent to cancel.")
|
||||
return {}
|
||||
run_ids = [run_id]
|
||||
|
||||
results = {}
|
||||
for run_id in run_ids:
|
||||
success = await server.job_manager.safe_update_job_status_async(
|
||||
job_id=run_id,
|
||||
new_status=JobStatus.cancelled,
|
||||
actor=actor,
|
||||
)
|
||||
results[run_id] = "cancelled" if success else "failed"
|
||||
return results
|
||||
|
||||
|
||||
async def process_message_background(
|
||||
job_id: str,
|
||||
async def _process_message_background(
|
||||
run_id: str,
|
||||
server: SyncServer,
|
||||
actor: User,
|
||||
agent_id: str,
|
||||
messages: List[MessageCreate],
|
||||
messages: list[MessageCreate],
|
||||
use_assistant_message: bool,
|
||||
assistant_message_tool_name: str,
|
||||
assistant_message_tool_kwarg: str,
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
include_return_message_types: Optional[List[MessageType]] = None,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
) -> None:
|
||||
"""Background task to process the message and update job status."""
|
||||
request_start_timestamp_ns = get_utc_timestamp_ns()
|
||||
@@ -905,7 +1019,7 @@ async def process_message_background(
|
||||
result = await agent_loop.step(
|
||||
messages,
|
||||
max_steps=max_steps,
|
||||
run_id=job_id,
|
||||
run_id=run_id,
|
||||
use_assistant_message=use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=include_return_message_types,
|
||||
@@ -917,7 +1031,7 @@ async def process_message_background(
|
||||
input_messages=messages,
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
metadata={"job_id": job_id},
|
||||
metadata={"job_id": run_id},
|
||||
# Support for AssistantMessage
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
@@ -930,7 +1044,7 @@ async def process_message_background(
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
metadata={"result": result.model_dump(mode="json")},
|
||||
)
|
||||
await server.job_manager.update_job_by_id_async(job_id=job_id, job_update=job_update, actor=actor)
|
||||
await server.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=actor)
|
||||
|
||||
except Exception as e:
|
||||
# Update job status to failed
|
||||
@@ -951,11 +1065,14 @@ async def send_message_async(
|
||||
agent_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: LettaAsyncRequest = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Asynchronously process a user message and return a run object.
|
||||
The actual processing happens in the background, and the status can be checked using the run ID.
|
||||
|
||||
This is "asynchronous" in the sense that it's a background job and explicitly must be fetched by the run ID.
|
||||
This is more like `send_message_job`
|
||||
"""
|
||||
MetricRegistry().user_message_counter.add(1, get_ctx_attributes())
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
@@ -980,8 +1097,8 @@ async def send_message_async(
|
||||
|
||||
# Create asyncio task for background processing
|
||||
asyncio.create_task(
|
||||
process_message_background(
|
||||
job_id=run.id,
|
||||
_process_message_background(
|
||||
run_id=run.id,
|
||||
server=server,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
@@ -1002,7 +1119,7 @@ async def reset_messages(
|
||||
agent_id: str,
|
||||
add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Resets the messages for an agent"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
@@ -1011,12 +1128,12 @@ async def reset_messages(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/groups", response_model=List[Group], operation_id="list_agent_groups")
|
||||
@router.get("/{agent_id}/groups", response_model=list[Group], operation_id="list_agent_groups")
|
||||
async def list_agent_groups(
|
||||
agent_id: str,
|
||||
manager_type: Optional[str] = Query(None, description="Manager type to filter groups by"),
|
||||
manager_type: str | None = Query(None, description="Manager type to filter groups by"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Lists the groups for an agent"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
@@ -1030,7 +1147,7 @@ async def summarize_agent_conversation(
|
||||
request_obj: Request, # FastAPI Request
|
||||
max_message_length: int = Query(..., description="Maximum number of messages to retain after summarization."),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Summarize an agent's conversation history to a target message length.
|
||||
|
||||
@@ -15,10 +15,15 @@ router = APIRouter(prefix="/jobs", tags=["jobs"])
|
||||
async def list_jobs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."),
|
||||
before: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
after: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
limit: Optional[int] = Query(50, description="Limit for pagination"),
|
||||
ascending: bool = Query(True, description="Whether to sort jobs oldest to newest (True, default) or newest to oldest (False)"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all jobs.
|
||||
TODO (cliandy): implementation for pagination
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
@@ -26,6 +31,10 @@ async def list_jobs(
|
||||
return await server.job_manager.list_jobs_async(
|
||||
actor=actor,
|
||||
source_id=source_id,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
ascending=ascending,
|
||||
)
|
||||
|
||||
|
||||
@@ -34,12 +43,24 @@ async def list_active_jobs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."),
|
||||
before: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
after: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
limit: Optional[int] = Query(50, description="Limit for pagination"),
|
||||
ascending: bool = Query(True, description="Whether to sort jobs oldest to newest (True, default) or newest to oldest (False)"),
|
||||
):
|
||||
"""
|
||||
List all active jobs.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.job_manager.list_jobs_async(actor=actor, statuses=[JobStatus.created, JobStatus.running], source_id=source_id)
|
||||
return await server.job_manager.list_jobs_async(
|
||||
actor=actor,
|
||||
statuses=[JobStatus.created, JobStatus.running],
|
||||
source_id=source_id,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
ascending=ascending,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{job_id}", response_model=Job, operation_id="retrieve_job")
|
||||
@@ -59,6 +80,33 @@ async def retrieve_job(
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
|
||||
@router.patch("/{job_id}/cancel", response_model=Job, operation_id="cancel_job")
|
||||
async def cancel_job(
|
||||
job_id: str,
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Cancel a job by its job_id.
|
||||
|
||||
This endpoint marks a job as cancelled, which will cause any associated
|
||||
agent execution to terminate as soon as possible.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
try:
|
||||
# First check if the job exists and is in a cancellable state
|
||||
existing_job = await server.job_manager.get_job_by_id_async(job_id=job_id, actor=actor)
|
||||
|
||||
if existing_job.status.is_terminal:
|
||||
return False
|
||||
|
||||
return await server.job_manager.safe_update_job_status_async(job_id=job_id, new_status=JobStatus.cancelled, actor=actor)
|
||||
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
|
||||
@router.delete("/{job_id}", response_model=Job, operation_id="delete_job")
|
||||
async def delete_job(
|
||||
job_id: str,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# stremaing HTTP trailers, as we cannot set codes after the initial response.
|
||||
# Taken from: https://github.com/fastapi/fastapi/discussions/10138#discussioncomment-10377361
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
@@ -9,10 +10,73 @@ from fastapi.responses import StreamingResponse
|
||||
from starlette.types import Send
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.user import User
|
||||
from letta.services.job_manager import JobManager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO (cliandy) wrap this and handle types
|
||||
async def cancellation_aware_stream_wrapper(
|
||||
stream_generator: AsyncIterator[str | bytes],
|
||||
job_manager: JobManager,
|
||||
job_id: str,
|
||||
actor: User,
|
||||
cancellation_check_interval: float = 0.5,
|
||||
) -> AsyncIterator[str | bytes]:
|
||||
"""
|
||||
Wraps a stream generator to provide real-time job cancellation checking.
|
||||
|
||||
This wrapper periodically checks for job cancellation while streaming and
|
||||
can interrupt the stream at any point, not just at step boundaries.
|
||||
|
||||
Args:
|
||||
stream_generator: The original stream generator to wrap
|
||||
job_manager: Job manager instance for checking job status
|
||||
job_id: ID of the job to monitor for cancellation
|
||||
actor: User/actor making the request
|
||||
cancellation_check_interval: How often to check for cancellation (seconds)
|
||||
|
||||
Yields:
|
||||
Stream chunks from the original generator until cancelled
|
||||
|
||||
Raises:
|
||||
asyncio.CancelledError: If the job is cancelled during streaming
|
||||
"""
|
||||
last_cancellation_check = asyncio.get_event_loop().time()
|
||||
|
||||
try:
|
||||
async for chunk in stream_generator:
|
||||
# Check for cancellation periodically (not on every chunk for performance)
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
if current_time - last_cancellation_check >= cancellation_check_interval:
|
||||
try:
|
||||
job = await job_manager.get_job_by_id_async(job_id=job_id, actor=actor)
|
||||
if job.status == JobStatus.cancelled:
|
||||
logger.info(f"Stream cancelled for job {job_id}, interrupting stream")
|
||||
# Send cancellation event to client
|
||||
cancellation_event = {"message_type": "stop_reason", "stop_reason": "cancelled"}
|
||||
yield f"data: {json.dumps(cancellation_event)}\n\n"
|
||||
# Raise CancelledError to interrupt the stream
|
||||
raise asyncio.CancelledError(f"Job {job_id} was cancelled")
|
||||
except Exception as e:
|
||||
# Log warning but don't fail the stream if cancellation check fails
|
||||
logger.warning(f"Failed to check job cancellation for job {job_id}: {e}")
|
||||
|
||||
last_cancellation_check = current_time
|
||||
|
||||
yield chunk
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Re-raise CancelledError to ensure proper cleanup
|
||||
logger.info(f"Stream for job {job_id} was cancelled and cleaned up")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cancellation-aware stream wrapper for job {job_id}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class StreamingResponseWithStatusCode(StreamingResponse):
|
||||
"""
|
||||
Variation of StreamingResponse that can dynamically decide the HTTP status code,
|
||||
@@ -81,6 +145,30 @@ class StreamingResponseWithStatusCode(StreamingResponse):
|
||||
}
|
||||
)
|
||||
|
||||
# This should be handled properly upstream?
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Stream was cancelled by client or job cancellation")
|
||||
# Handle cancellation gracefully
|
||||
more_body = False
|
||||
cancellation_resp = {"error": {"message": "Stream cancelled"}}
|
||||
cancellation_event = f"event: cancelled\ndata: {json.dumps(cancellation_resp)}\n\n".encode(self.charset)
|
||||
if not self.response_started:
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 200, # Use 200 for graceful cancellation
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": cancellation_event,
|
||||
"more_body": more_body,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
except Exception:
|
||||
logger.exception("unhandled_streaming_error")
|
||||
more_body = False
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from functools import reduce
|
||||
from functools import partial, reduce
|
||||
from operator import add
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
@@ -125,6 +125,46 @@ class JobManager:
|
||||
|
||||
return job.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def safe_update_job_status_async(
|
||||
self, job_id: str, new_status: JobStatus, actor: PydanticUser, metadata: Optional[dict] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Safely update job status with state transition guards.
|
||||
Created -> Pending -> Running --> <Terminal>
|
||||
|
||||
Returns:
|
||||
True if update was successful, False if update was skipped due to invalid transition
|
||||
"""
|
||||
try:
|
||||
# Get current job state
|
||||
current_job = await self.get_job_by_id_async(job_id=job_id, actor=actor)
|
||||
|
||||
current_status = current_job.status
|
||||
if not any(
|
||||
(
|
||||
new_status.is_terminal and not current_status.is_terminal,
|
||||
current_status == JobStatus.created and new_status != JobStatus.created,
|
||||
current_status == JobStatus.pending and new_status == JobStatus.running,
|
||||
)
|
||||
):
|
||||
logger.warning(f"Invalid job status transition from {current_job.status} to {new_status} for job {job_id}")
|
||||
return False
|
||||
|
||||
job_update_builder = partial(JobUpdate, status=new_status)
|
||||
if metadata:
|
||||
job_update_builder = partial(job_update_builder, metadata=metadata)
|
||||
if new_status.is_terminal:
|
||||
job_update_builder = partial(job_update_builder, completed_at=get_utc_time())
|
||||
|
||||
await self.update_job_by_id_async(job_id=job_id, job_update=job_update_builder(), actor=actor)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to safely update job status for job {job_id}: {e}")
|
||||
return False
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def get_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob:
|
||||
@@ -656,7 +696,7 @@ class JobManager:
|
||||
job.callback_status_code = resp.status_code
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"Failed to dispatch callback for job {job.id} to {job.callback_url}: {str(e)}"
|
||||
error_message = f"Failed to dispatch callback for job {job.id} to {job.callback_url}: {e!s}"
|
||||
logger.error(error_message)
|
||||
# Record the failed attempt
|
||||
job.callback_sent_at = get_utc_time().replace(tzinfo=None)
|
||||
@@ -686,7 +726,7 @@ class JobManager:
|
||||
job.callback_sent_at = get_utc_time().replace(tzinfo=None)
|
||||
job.callback_status_code = resp.status_code
|
||||
except Exception as e:
|
||||
error_message = f"Failed to dispatch callback for job {job.id} to {job.callback_url}: {str(e)}"
|
||||
error_message = f"Failed to dispatch callback for job {job.id} to {job.callback_url}: {e!s}"
|
||||
logger.error(error_message)
|
||||
# Record the failed attempt
|
||||
job.callback_sent_at = get_utc_time().replace(tzinfo=None)
|
||||
|
||||
@@ -12,11 +12,12 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import Coroutine
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from functools import wraps
|
||||
from logging import Logger
|
||||
from typing import Any, Coroutine, List, Union, _GenericAlias, get_args, get_origin, get_type_hints
|
||||
from typing import Any, Coroutine, Union, _GenericAlias, get_args, get_origin, get_type_hints
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import demjson3 as demjson
|
||||
@@ -519,7 +520,7 @@ def enforce_types(func):
|
||||
arg_names = inspect.getfullargspec(func).args
|
||||
|
||||
# Pair each argument with its corresponding type hint
|
||||
args_with_hints = dict(zip(arg_names[1:], args[1:])) # Skipping 'self'
|
||||
args_with_hints = dict(zip(arg_names[1:], args[1:], strict=False)) # Skipping 'self'
|
||||
|
||||
# Function to check if a value matches a given type hint
|
||||
def matches_type(value, hint):
|
||||
@@ -557,7 +558,7 @@ def enforce_types(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def annotate_message_json_list_with_tool_calls(messages: List[dict], allow_tool_roles: bool = False):
|
||||
def annotate_message_json_list_with_tool_calls(messages: list[dict], allow_tool_roles: bool = False):
|
||||
"""Add in missing tool_call_id fields to a list of messages using function call style
|
||||
|
||||
Walk through the list forwards:
|
||||
@@ -946,7 +947,7 @@ def get_human_text(name: str, enforce_limit=True):
|
||||
for file_path in list_human_files():
|
||||
file = os.path.basename(file_path)
|
||||
if f"{name}.txt" == file or name == file:
|
||||
human_text = open(file_path, "r", encoding="utf-8").read().strip()
|
||||
human_text = open(file_path, encoding="utf-8").read().strip()
|
||||
if enforce_limit and len(human_text) > CORE_MEMORY_HUMAN_CHAR_LIMIT:
|
||||
raise ValueError(f"Contents of {name}.txt is over the character limit ({len(human_text)} > {CORE_MEMORY_HUMAN_CHAR_LIMIT})")
|
||||
return human_text
|
||||
@@ -958,7 +959,7 @@ def get_persona_text(name: str, enforce_limit=True):
|
||||
for file_path in list_persona_files():
|
||||
file = os.path.basename(file_path)
|
||||
if f"{name}.txt" == file or name == file:
|
||||
persona_text = open(file_path, "r", encoding="utf-8").read().strip()
|
||||
persona_text = open(file_path, encoding="utf-8").read().strip()
|
||||
if enforce_limit and len(persona_text) > CORE_MEMORY_PERSONA_CHAR_LIMIT:
|
||||
raise ValueError(
|
||||
f"Contents of {name}.txt is over the character limit ({len(persona_text)} > {CORE_MEMORY_PERSONA_CHAR_LIMIT})"
|
||||
@@ -1109,3 +1110,75 @@ def safe_create_task(coro, logger: Logger, label: str = "background task"):
|
||||
logger.exception(f"{label} failed with {type(e).__name__}: {e}")
|
||||
|
||||
return asyncio.create_task(wrapper())
|
||||
|
||||
|
||||
class CancellationSignal:
|
||||
"""
|
||||
A signal that can be checked for cancellation during streaming operations.
|
||||
|
||||
This provides a lightweight way to check if an operation should be cancelled
|
||||
without having to pass job managers and other dependencies through every method.
|
||||
"""
|
||||
|
||||
def __init__(self, job_manager=None, job_id=None, actor=None):
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.user import User
|
||||
from letta.services.job_manager import JobManager
|
||||
|
||||
self.job_manager: JobManager | None = job_manager
|
||||
self.job_id: str | None = job_id
|
||||
self.actor: User | None = actor
|
||||
self._is_cancelled = False
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
async def is_cancelled(self) -> bool:
|
||||
"""
|
||||
Check if the operation has been cancelled.
|
||||
|
||||
Returns:
|
||||
True if cancelled, False otherwise
|
||||
"""
|
||||
from letta.schemas.enums import JobStatus
|
||||
|
||||
if self._is_cancelled:
|
||||
return True
|
||||
|
||||
if not self.job_manager or not self.job_id or not self.actor:
|
||||
return False
|
||||
|
||||
try:
|
||||
job = await self.job_manager.get_job_by_id_async(job_id=self.job_id, actor=self.actor)
|
||||
self._is_cancelled = job.status == JobStatus.cancelled
|
||||
return self._is_cancelled
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to check cancellation status for job {self.job_id}: {e}")
|
||||
return False
|
||||
|
||||
def cancel(self):
|
||||
"""Mark this signal as cancelled locally (for testing or direct cancellation)."""
|
||||
self._is_cancelled = True
|
||||
|
||||
async def check_and_raise_if_cancelled(self):
|
||||
"""
|
||||
Check for cancellation and raise CancelledError if cancelled.
|
||||
|
||||
Raises:
|
||||
asyncio.CancelledError: If the operation has been cancelled
|
||||
"""
|
||||
if await self.is_cancelled():
|
||||
self.logger.info(f"Operation cancelled for job {self.job_id}")
|
||||
raise asyncio.CancelledError(f"Job {self.job_id} was cancelled")
|
||||
|
||||
|
||||
class NullCancellationSignal(CancellationSignal):
|
||||
"""A null cancellation signal that is never cancelled."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
async def is_cancelled(self) -> bool:
|
||||
return False
|
||||
|
||||
async def check_and_raise_if_cancelled(self):
|
||||
pass
|
||||
|
||||
@@ -1251,3 +1251,207 @@ def test_auto_summarize(disable_e2b_api_key: Any, client: Letta, llm_config: LLM
|
||||
prev_length = current_length
|
||||
else:
|
||||
raise AssertionError("Summarization was not triggered after 10 messages")
|
||||
|
||||
|
||||
# ============================
|
||||
# Job Cancellation Tests
|
||||
# ============================
|
||||
|
||||
|
||||
def wait_for_run_status(client: Letta, run_id: str, target_status: str, timeout: float = 30.0, interval: float = 0.1) -> Run:
|
||||
"""Wait for a run to reach a specific status"""
|
||||
start = time.time()
|
||||
while True:
|
||||
run = client.runs.retrieve(run_id)
|
||||
if run.status == target_status:
|
||||
return run
|
||||
if time.time() - start > timeout:
|
||||
raise TimeoutError(f"Run {run_id} did not reach status '{target_status}' within {timeout} seconds (last status: {run.status})")
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
def test_job_creation_for_send_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test that send_message endpoint creates a job and the job completes successfully.
|
||||
"""
|
||||
previous_runs = client.runs.list(agent_ids=[agent_state.id])
|
||||
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
# Send a simple message and verify a job was created
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
|
||||
# The response should be successful
|
||||
assert response.messages is not None
|
||||
assert len(response.messages) > 0
|
||||
|
||||
runs = client.runs.list(agent_ids=[agent_state.id])
|
||||
new_runs = set(r.id for r in runs) - set(r.id for r in previous_runs)
|
||||
assert len(new_runs) == 1
|
||||
|
||||
for run in runs:
|
||||
if run.id == list(new_runs)[0]:
|
||||
assert run.status == "completed"
|
||||
|
||||
|
||||
# TODO (cliandy): MERGE BACK IN POST
|
||||
# @pytest.mark.parametrize(
|
||||
# "llm_config",
|
||||
# TESTED_LLM_CONFIGS,
|
||||
# ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
# )
|
||||
# def test_async_job_cancellation(
|
||||
# disable_e2b_api_key: Any,
|
||||
# client: Letta,
|
||||
# agent_state: AgentState,
|
||||
# llm_config: LLMConfig,
|
||||
# ) -> None:
|
||||
# """
|
||||
# Test that an async job can be cancelled and the cancellation is reflected in the job status.
|
||||
# """
|
||||
# client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
#
|
||||
# # client.runs.cancel
|
||||
# # Start an async job
|
||||
# run = client.agents.messages.create_async(
|
||||
# agent_id=agent_state.id,
|
||||
# messages=USER_MESSAGE_FORCE_REPLY,
|
||||
# )
|
||||
#
|
||||
# # Verify the job was created
|
||||
# assert run.id is not None
|
||||
# assert run.status in ["created", "running"]
|
||||
#
|
||||
# # Cancel the job quickly (before it potentially completes)
|
||||
# cancelled_run = client.jobs.cancel(run.id)
|
||||
#
|
||||
# # Verify the job was cancelled
|
||||
# assert cancelled_run.status == "cancelled"
|
||||
#
|
||||
# # Wait a bit and verify it stays cancelled (no invalid state transitions)
|
||||
# time.sleep(1)
|
||||
# final_run = client.runs.retrieve(run.id)
|
||||
# assert final_run.status == "cancelled"
|
||||
#
|
||||
# # Verify the job metadata indicates cancellation
|
||||
# if final_run.metadata:
|
||||
# assert final_run.metadata.get("cancelled") is True or "stop_reason" in final_run.metadata
|
||||
#
|
||||
#
|
||||
# def test_job_cancellation_endpoint_validation(
|
||||
# disable_e2b_api_key: Any,
|
||||
# client: Letta,
|
||||
# agent_state: AgentState,
|
||||
# ) -> None:
|
||||
# """
|
||||
# Test job cancellation endpoint validation (trying to cancel completed/failed jobs).
|
||||
# """
|
||||
# # Test cancelling a non-existent job
|
||||
# with pytest.raises(ApiError) as exc_info:
|
||||
# client.jobs.cancel("non-existent-job-id")
|
||||
# assert exc_info.value.status_code == 404
|
||||
#
|
||||
#
|
||||
# @pytest.mark.parametrize(
|
||||
# "llm_config",
|
||||
# TESTED_LLM_CONFIGS,
|
||||
# ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
# )
|
||||
# def test_completed_job_cannot_be_cancelled(
|
||||
# disable_e2b_api_key: Any,
|
||||
# client: Letta,
|
||||
# agent_state: AgentState,
|
||||
# llm_config: LLMConfig,
|
||||
# ) -> None:
|
||||
# """
|
||||
# Test that completed jobs cannot be cancelled.
|
||||
# """
|
||||
# client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
#
|
||||
# # Start an async job and wait for it to complete
|
||||
# run = client.agents.messages.create_async(
|
||||
# agent_id=agent_state.id,
|
||||
# messages=USER_MESSAGE_FORCE_REPLY,
|
||||
# )
|
||||
#
|
||||
# # Wait for completion
|
||||
# completed_run = wait_for_run_completion(client, run.id)
|
||||
# assert completed_run.status == "completed"
|
||||
#
|
||||
# # Try to cancel the completed job - should fail
|
||||
# with pytest.raises(ApiError) as exc_info:
|
||||
# client.jobs.cancel(run.id)
|
||||
# assert exc_info.value.status_code == 400
|
||||
# assert "Cannot cancel job with status 'completed'" in str(exc_info.value)
|
||||
#
|
||||
#
|
||||
# @pytest.mark.parametrize(
|
||||
# "llm_config",
|
||||
# TESTED_LLM_CONFIGS,
|
||||
# ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
# )
|
||||
# def test_streaming_job_independence_from_client_disconnect(
|
||||
# disable_e2b_api_key: Any,
|
||||
# client: Letta,
|
||||
# agent_state: AgentState,
|
||||
# llm_config: LLMConfig,
|
||||
# ) -> None:
|
||||
# """
|
||||
# Test that streaming jobs are independent of client connection state.
|
||||
# This verifies that jobs continue even if the client "disconnects" (simulated by not consuming the stream).
|
||||
# """
|
||||
# client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
#
|
||||
# # Create a streaming request
|
||||
# import threading
|
||||
#
|
||||
# import httpx
|
||||
#
|
||||
# # Get the base URL and create a raw HTTP request to simulate partial consumption
|
||||
# base_url = client._client_wrapper._base_url
|
||||
#
|
||||
# def start_stream_and_abandon():
|
||||
# """Start a streaming request but abandon it (simulating client disconnect)"""
|
||||
# try:
|
||||
# response = httpx.post(
|
||||
# f"{base_url}/agents/{agent_state.id}/messages/stream",
|
||||
# json={"messages": [{"role": "user", "text": "Hello, how are you?"}], "stream_tokens": False},
|
||||
# headers={"user_id": "test-user"},
|
||||
# timeout=30.0,
|
||||
# )
|
||||
#
|
||||
# # Read just a few chunks then "disconnect" by not reading the rest
|
||||
# chunk_count = 0
|
||||
# for chunk in response.iter_lines():
|
||||
# chunk_count += 1
|
||||
# if chunk_count > 3: # Read a few chunks then stop
|
||||
# break
|
||||
# # Connection is now "abandoned" but the job should continue
|
||||
#
|
||||
# except Exception:
|
||||
# pass # Ignore connection errors
|
||||
#
|
||||
# # Start the stream in a separate thread to simulate abandonment
|
||||
# thread = threading.Thread(target=start_stream_and_abandon)
|
||||
# thread.start()
|
||||
# thread.join(timeout=5.0) # Wait up to 5 seconds for the "disconnect"
|
||||
#
|
||||
# # The important thing is that this test validates our architecture:
|
||||
# # 1. Jobs are created before streaming starts (verified by our other tests)
|
||||
# # 2. Jobs track execution independent of client connection (handled by our wrapper)
|
||||
# # 3. Only explicit cancellation terminates jobs (tested by other tests)
|
||||
#
|
||||
# # This test primarily validates that the implementation doesn't break under simulated disconnection
|
||||
# assert True # If we get here without errors, the architecture is sound
|
||||
|
||||
@@ -4551,7 +4551,7 @@ async def test_create_and_upsert_identity(server: SyncServer, default_user, even
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
identity_create.properties = [(IdentityProperty(key="age", value=29, type=IdentityPropertyType.number))]
|
||||
identity_create.properties = [IdentityProperty(key="age", value=29, type=IdentityPropertyType.number)]
|
||||
|
||||
identity = await server.identity_manager.upsert_identity_async(
|
||||
identity=IdentityUpsert(**identity_create.model_dump()), actor=default_user
|
||||
|
||||
@@ -17,6 +17,7 @@ from letta.schemas.message import MessageCreate
|
||||
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.step_manager import StepManager
|
||||
|
||||
Reference in New Issue
Block a user