feat: set request heartbeat for max steps (#2739)

This commit is contained in:
cthomas
2025-06-10 15:26:07 -07:00
committed by GitHub
parent 484a6f1d37
commit 883050e761
10 changed files with 53 additions and 36 deletions

View File

@@ -3,6 +3,7 @@ from typing import Any, AsyncGenerator, List, Optional, Union
import openai
from letta.constants import DEFAULT_MAX_STEPS
from letta.helpers.datetime_helpers import get_utc_time
from letta.log import get_logger
from letta.schemas.agent import AgentState
@@ -43,7 +44,7 @@ class BaseAgent(ABC):
self.logger = get_logger(agent_id)
@abstractmethod
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
async def step(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> LettaResponse:
"""
Main execution loop for the agent.
"""
@@ -51,7 +52,7 @@ class BaseAgent(ABC):
@abstractmethod
async def step_stream(
self, input_messages: List[MessageCreate], max_steps: int = 10
self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS
) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
"""
Main streaming execution loop for the agent.

View File

@@ -4,6 +4,7 @@ from typing import AsyncGenerator, Dict, List
from openai import AsyncOpenAI
from letta.agents.base_agent import BaseAgent
from letta.constants import DEFAULT_MAX_STEPS
from letta.orm.errors import NoResultFound
from letta.schemas.block import Block, BlockUpdate
from letta.schemas.enums import MessageRole
@@ -42,7 +43,7 @@ class EphemeralSummaryAgent(BaseAgent):
self.target_block_label = target_block_label
self.block_manager = block_manager
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> List[Message]:
async def step(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> List[Message]:
if len(input_messages) > 1:
raise ValueError("Can only invoke EphemeralSummaryAgent with a single summarization message.")
@@ -100,5 +101,5 @@ class EphemeralSummaryAgent(BaseAgent):
)
return openai_request
async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = 10) -> AsyncGenerator[str, None]:
async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> AsyncGenerator[str, None]:
raise NotImplementedError("EphemeralAgent does not support async step.")

View File

@@ -9,6 +9,7 @@ from openai.types.chat import ChatCompletionChunk
from letta.agents.base_agent import BaseAgent
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_no_persist_async, generate_step_id
from letta.constants import DEFAULT_MAX_STEPS
from letta.errors import ContextWindowExceededError
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import AsyncTimer, get_utc_timestamp_ns, ns_to_ms
@@ -114,7 +115,7 @@ class LettaAgent(BaseAgent):
async def step(
self,
input_messages: List[MessageCreate],
max_steps: int = 10,
max_steps: int = DEFAULT_MAX_STEPS,
use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None,
include_return_message_types: Optional[List[MessageType]] = None,
@@ -139,7 +140,7 @@ class LettaAgent(BaseAgent):
async def step_stream_no_tokens(
self,
input_messages: List[MessageCreate],
max_steps: int = 10,
max_steps: int = DEFAULT_MAX_STEPS,
use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None,
include_return_message_types: Optional[List[MessageType]] = None,
@@ -163,7 +164,7 @@ class LettaAgent(BaseAgent):
request_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns)
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
for _ in range(max_steps):
for i in range(max_steps):
step_id = generate_step_id()
step_start = get_utc_timestamp_ns()
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
@@ -225,6 +226,7 @@ class LettaAgent(BaseAgent):
reasoning_content=reasoning,
initial_messages=initial_messages,
agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1),
)
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
@@ -289,7 +291,7 @@ class LettaAgent(BaseAgent):
self,
agent_state: AgentState,
input_messages: List[MessageCreate],
max_steps: int = 10,
max_steps: int = DEFAULT_MAX_STEPS,
request_start_timestamp_ns: Optional[int] = None,
) -> Tuple[List[Message], List[Message], LettaUsageStatistics]:
"""
@@ -315,7 +317,7 @@ class LettaAgent(BaseAgent):
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
usage = LettaUsageStatistics()
for _ in range(max_steps):
for i in range(max_steps):
step_id = generate_step_id()
step_start = get_utc_timestamp_ns()
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
@@ -368,6 +370,7 @@ class LettaAgent(BaseAgent):
step_id=step_id,
initial_messages=initial_messages,
agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1),
)
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
@@ -417,7 +420,7 @@ class LettaAgent(BaseAgent):
async def step_stream(
self,
input_messages: List[MessageCreate],
max_steps: int = 10,
max_steps: int = DEFAULT_MAX_STEPS,
use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None,
include_return_message_types: Optional[List[MessageType]] = None,
@@ -451,7 +454,7 @@ class LettaAgent(BaseAgent):
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
provider_request_start_timestamp_ns = None
for _ in range(max_steps):
for i in range(max_steps):
step_id = generate_step_id()
step_start = get_utc_timestamp_ns()
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
@@ -532,6 +535,7 @@ class LettaAgent(BaseAgent):
step_id=step_id,
initial_messages=initial_messages,
agent_step_span=agent_step_span,
is_final_step=(i == max_steps - 1),
)
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
@@ -838,6 +842,7 @@ class LettaAgent(BaseAgent):
step_id: str | None = None,
initial_messages: Optional[List[Message]] = None,
agent_step_span: Optional["Span"] = None,
is_final_step: Optional[bool] = None,
) -> Tuple[List[Message], bool]:
"""
Now that streaming is done, handle the final AI response.
@@ -858,17 +863,21 @@ class LettaAgent(BaseAgent):
except AssertionError:
tool_args = json.loads(tool_args)
# Get request heartbeats and coerce to bool
request_heartbeat = tool_args.pop("request_heartbeat", False)
# Pre-emptively pop out inner_thoughts
tool_args.pop(INNER_THOUGHTS_KWARG, "")
if is_final_step:
logger.info("Agent has reached max steps.")
request_heartbeat = False
else:
# Get request heartbeats and coerce to bool
request_heartbeat = tool_args.pop("request_heartbeat", False)
# Pre-emptively pop out inner_thoughts
tool_args.pop(INNER_THOUGHTS_KWARG, "")
# So this is necessary, because sometimes non-structured outputs makes mistakes
if not isinstance(request_heartbeat, bool):
if isinstance(request_heartbeat, str):
request_heartbeat = request_heartbeat.lower() == "true"
else:
request_heartbeat = bool(request_heartbeat)
# So this is necessary, because sometimes non-structured outputs makes mistakes
if not isinstance(request_heartbeat, bool):
if isinstance(request_heartbeat, str):
request_heartbeat = request_heartbeat.lower() == "true"
else:
request_heartbeat = bool(request_heartbeat)
tool_call_id = tool_call.id or f"call_{uuid.uuid4().hex[:8]}"

View File

@@ -8,6 +8,7 @@ from anthropic.types.beta.messages import BetaMessageBatchCanceledResult, BetaMe
from letta.agents.base_agent import BaseAgent
from letta.agents.helpers import _prepare_in_context_messages_async
from letta.constants import DEFAULT_MAX_STEPS
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.tool_execution_helper import enable_strict_mode
@@ -110,7 +111,7 @@ class LettaAgentBatch(BaseAgent):
sandbox_config_manager: SandboxConfigManager,
job_manager: JobManager,
actor: User,
max_steps: int = 10,
max_steps: int = DEFAULT_MAX_STEPS,
):
self.message_manager = message_manager
self.agent_manager = agent_manager
@@ -619,10 +620,10 @@ class LettaAgentBatch(BaseAgent):
return in_context_messages
# Not used in batch.
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
async def step(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> LettaResponse:
raise NotImplementedError
async def step_stream(
self, input_messages: List[MessageCreate], max_steps: int = 10
self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS
) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
raise NotImplementedError

View File

@@ -9,7 +9,7 @@ import openai
from letta.agents.base_agent import BaseAgent
from letta.agents.exceptions import IncompatibleAgentType
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
from letta.constants import NON_USER_MSG_PREFIX
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.tool_execution_helper import (
add_pre_execution_message,
@@ -111,10 +111,10 @@ class VoiceAgent(BaseAgent):
return summarizer
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
async def step(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> LettaResponse:
raise NotImplementedError("VoiceAgent does not have a synchronous step implemented currently.")
async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = 10) -> AsyncGenerator[str, None]:
async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> AsyncGenerator[str, None]:
"""
Main streaming loop that yields partial tokens.
Whenever we detect a tool call, we yield from _handle_ai_response as well.

View File

@@ -2,6 +2,7 @@ from typing import AsyncGenerator, List, Optional, Tuple, Union
from letta.agents.helpers import _create_letta_response, serialize_message_history
from letta.agents.letta_agent import LettaAgent
from letta.constants import DEFAULT_MAX_STEPS
from letta.orm.enums import ToolType
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
@@ -62,7 +63,7 @@ class VoiceSleeptimeAgent(LettaAgent):
async def step(
self,
input_messages: List[MessageCreate],
max_steps: int = 20,
max_steps: int = DEFAULT_MAX_STEPS,
use_assistant_message: bool = True,
include_return_message_types: Optional[List[MessageType]] = None,
) -> LettaResponse:
@@ -170,7 +171,7 @@ class VoiceSleeptimeAgent(LettaAgent):
return f"Failed to store memory given start_index {start_index} and end_index {end_index}: {e}", False
async def step_stream(
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True
self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, use_assistant_message: bool = True
) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
"""
This agent is synchronous-only. If called in an async context, raise an error.

View File

@@ -46,6 +46,9 @@ IN_CONTEXT_MEMORY_KEYWORD = "CORE_MEMORY"
# OpenAI error message: Invalid 'messages[1].tool_calls[0].id': string too long. Expected a string with maximum length 29, but got a string with length 36 instead.
TOOL_CALL_ID_MAX_LEN = 29
# Max steps for agent loop
DEFAULT_MAX_STEPS = 50
# minimum context window size
MIN_CONTEXT_WINDOW = 4096

View File

@@ -4,6 +4,7 @@ from typing import AsyncGenerator, List, Optional
from letta.agents.base_agent import BaseAgent
from letta.agents.letta_agent import LettaAgent
from letta.constants import DEFAULT_MAX_STEPS
from letta.groups.helpers import stringify_message
from letta.otel.tracing import trace_method
from letta.schemas.enums import JobStatus
@@ -61,7 +62,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
async def step(
self,
input_messages: List[MessageCreate],
max_steps: int = 10,
max_steps: int = DEFAULT_MAX_STEPS,
use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None,
include_return_message_types: Optional[List[MessageType]] = None,
@@ -131,7 +132,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
async def step_stream_no_tokens(
self,
input_messages: List[MessageCreate],
max_steps: int = 10,
max_steps: int = DEFAULT_MAX_STEPS,
use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None,
include_return_message_types: Optional[List[MessageType]] = None,
@@ -149,7 +150,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
async def step_stream(
self,
input_messages: List[MessageCreate],
max_steps: int = 10,
max_steps: int = DEFAULT_MAX_STEPS,
use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None,
include_return_message_types: Optional[List[MessageType]] = None,

View File

@@ -2,7 +2,7 @@ from typing import List, Optional
from pydantic import BaseModel, Field, HttpUrl
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.schemas.letta_message import MessageType
from letta.schemas.message import MessageCreate
@@ -10,7 +10,7 @@ from letta.schemas.message import MessageCreate
class LettaRequest(BaseModel):
messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.")
max_steps: int = Field(
default=10,
default=DEFAULT_MAX_STEPS,
description="Maximum number of steps the agent should take to process the request.",
)
use_assistant_message: bool = Field(

View File

@@ -12,7 +12,7 @@ from sqlalchemy.exc import IntegrityError, OperationalError
from starlette.responses import Response, StreamingResponse
from letta.agents.letta_agent import LettaAgent
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
from letta.log import get_logger
@@ -843,7 +843,7 @@ async def process_message_background(
use_assistant_message: bool,
assistant_message_tool_name: str,
assistant_message_tool_kwarg: str,
max_steps: int = 10,
max_steps: int = DEFAULT_MAX_STEPS,
include_return_message_types: Optional[List[MessageType]] = None,
) -> None:
"""Background task to process the message and update job status."""