From 883050e761b14863ed5e4516557e7a7f9a3f673c Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 10 Jun 2025 15:26:07 -0700 Subject: [PATCH] feat: set request heartbeat for max steps (#2739) --- letta/agents/base_agent.py | 5 ++- letta/agents/ephemeral_summary_agent.py | 5 ++- letta/agents/letta_agent.py | 43 +++++++++++++--------- letta/agents/letta_agent_batch.py | 7 ++-- letta/agents/voice_agent.py | 6 +-- letta/agents/voice_sleeptime_agent.py | 5 ++- letta/constants.py | 3 ++ letta/groups/sleeptime_multi_agent_v2.py | 7 ++-- letta/schemas/letta_request.py | 4 +- letta/server/rest_api/routers/v1/agents.py | 4 +- 10 files changed, 53 insertions(+), 36 deletions(-) diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index 95cb00df..e275f7b6 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -3,6 +3,7 @@ from typing import Any, AsyncGenerator, List, Optional, Union import openai +from letta.constants import DEFAULT_MAX_STEPS from letta.helpers.datetime_helpers import get_utc_time from letta.log import get_logger from letta.schemas.agent import AgentState @@ -43,7 +44,7 @@ class BaseAgent(ABC): self.logger = get_logger(agent_id) @abstractmethod - async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse: + async def step(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> LettaResponse: """ Main execution loop for the agent. """ @@ -51,7 +52,7 @@ class BaseAgent(ABC): @abstractmethod async def step_stream( - self, input_messages: List[MessageCreate], max_steps: int = 10 + self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS ) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]: """ Main streaming execution loop for the agent. diff --git a/letta/agents/ephemeral_summary_agent.py b/letta/agents/ephemeral_summary_agent.py index 17adaf3c..572e3c78 100644 --- a/letta/agents/ephemeral_summary_agent.py +++ b/letta/agents/ephemeral_summary_agent.py @@ -4,6 +4,7 @@ from typing import AsyncGenerator, Dict, List from openai import AsyncOpenAI from letta.agents.base_agent import BaseAgent +from letta.constants import DEFAULT_MAX_STEPS from letta.orm.errors import NoResultFound from letta.schemas.block import Block, BlockUpdate from letta.schemas.enums import MessageRole @@ -42,7 +43,7 @@ class EphemeralSummaryAgent(BaseAgent): self.target_block_label = target_block_label self.block_manager = block_manager - async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> List[Message]: + async def step(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> List[Message]: if len(input_messages) > 1: raise ValueError("Can only invoke EphemeralSummaryAgent with a single summarization message.") @@ -100,5 +101,5 @@ class EphemeralSummaryAgent(BaseAgent): ) return openai_request - async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = 10) -> AsyncGenerator[str, None]: + async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> AsyncGenerator[str, None]: raise NotImplementedError("EphemeralAgent does not support async step.") diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index ddcc6596..7a0e0336 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -9,6 +9,7 @@ from openai.types.chat import ChatCompletionChunk from letta.agents.base_agent import BaseAgent from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_no_persist_async, generate_step_id +from letta.constants import DEFAULT_MAX_STEPS from letta.errors import ContextWindowExceededError from letta.helpers import ToolRulesSolver from letta.helpers.datetime_helpers import AsyncTimer, get_utc_timestamp_ns, ns_to_ms @@ -114,7 +115,7 @@ class LettaAgent(BaseAgent): async def step( self, input_messages: List[MessageCreate], - max_steps: int = 10, + max_steps: int = DEFAULT_MAX_STEPS, use_assistant_message: bool = True, request_start_timestamp_ns: Optional[int] = None, include_return_message_types: Optional[List[MessageType]] = None, @@ -139,7 +140,7 @@ class LettaAgent(BaseAgent): async def step_stream_no_tokens( self, input_messages: List[MessageCreate], - max_steps: int = 10, + max_steps: int = DEFAULT_MAX_STEPS, use_assistant_message: bool = True, request_start_timestamp_ns: Optional[int] = None, include_return_message_types: Optional[List[MessageType]] = None, @@ -163,7 +164,7 @@ class LettaAgent(BaseAgent): request_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns) request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None}) - for _ in range(max_steps): + for i in range(max_steps): step_id = generate_step_id() step_start = get_utc_timestamp_ns() agent_step_span = tracer.start_span("agent_step", start_time=step_start) @@ -225,6 +226,7 @@ class LettaAgent(BaseAgent): reasoning_content=reasoning, initial_messages=initial_messages, agent_step_span=agent_step_span, + is_final_step=(i == max_steps - 1), ) self.response_messages.extend(persisted_messages) new_in_context_messages.extend(persisted_messages) @@ -289,7 +291,7 @@ class LettaAgent(BaseAgent): self, agent_state: AgentState, input_messages: List[MessageCreate], - max_steps: int = 10, + max_steps: int = DEFAULT_MAX_STEPS, request_start_timestamp_ns: Optional[int] = None, ) -> Tuple[List[Message], List[Message], LettaUsageStatistics]: """ @@ -315,7 +317,7 @@ class LettaAgent(BaseAgent): request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None}) usage = LettaUsageStatistics() - for _ in range(max_steps): + for i in range(max_steps): step_id = generate_step_id() step_start = get_utc_timestamp_ns() agent_step_span = tracer.start_span("agent_step", start_time=step_start) @@ -368,6 +370,7 @@ class LettaAgent(BaseAgent): step_id=step_id, initial_messages=initial_messages, agent_step_span=agent_step_span, + is_final_step=(i == max_steps - 1), ) self.response_messages.extend(persisted_messages) new_in_context_messages.extend(persisted_messages) @@ -417,7 +420,7 @@ class LettaAgent(BaseAgent): async def step_stream( self, input_messages: List[MessageCreate], - max_steps: int = 10, + max_steps: int = DEFAULT_MAX_STEPS, use_assistant_message: bool = True, request_start_timestamp_ns: Optional[int] = None, include_return_message_types: Optional[List[MessageType]] = None, @@ -451,7 +454,7 @@ class LettaAgent(BaseAgent): request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None}) provider_request_start_timestamp_ns = None - for _ in range(max_steps): + for i in range(max_steps): step_id = generate_step_id() step_start = get_utc_timestamp_ns() agent_step_span = tracer.start_span("agent_step", start_time=step_start) @@ -532,6 +535,7 @@ class LettaAgent(BaseAgent): step_id=step_id, initial_messages=initial_messages, agent_step_span=agent_step_span, + is_final_step=(i == max_steps - 1), ) self.response_messages.extend(persisted_messages) new_in_context_messages.extend(persisted_messages) @@ -838,6 +842,7 @@ class LettaAgent(BaseAgent): step_id: str | None = None, initial_messages: Optional[List[Message]] = None, agent_step_span: Optional["Span"] = None, + is_final_step: Optional[bool] = None, ) -> Tuple[List[Message], bool]: """ Now that streaming is done, handle the final AI response. @@ -858,17 +863,21 @@ class LettaAgent(BaseAgent): except AssertionError: tool_args = json.loads(tool_args) - # Get request heartbeats and coerce to bool - request_heartbeat = tool_args.pop("request_heartbeat", False) - # Pre-emptively pop out inner_thoughts - tool_args.pop(INNER_THOUGHTS_KWARG, "") + if is_final_step: + logger.info("Agent has reached max steps.") + request_heartbeat = False + else: + # Get request heartbeats and coerce to bool + request_heartbeat = tool_args.pop("request_heartbeat", False) + # Pre-emptively pop out inner_thoughts + tool_args.pop(INNER_THOUGHTS_KWARG, "") - # So this is necessary, because sometimes non-structured outputs makes mistakes - if not isinstance(request_heartbeat, bool): - if isinstance(request_heartbeat, str): - request_heartbeat = request_heartbeat.lower() == "true" - else: - request_heartbeat = bool(request_heartbeat) + # So this is necessary, because sometimes non-structured outputs makes mistakes + if not isinstance(request_heartbeat, bool): + if isinstance(request_heartbeat, str): + request_heartbeat = request_heartbeat.lower() == "true" + else: + request_heartbeat = bool(request_heartbeat) tool_call_id = tool_call.id or f"call_{uuid.uuid4().hex[:8]}" diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index f0810533..ce4ee643 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -8,6 +8,7 @@ from anthropic.types.beta.messages import BetaMessageBatchCanceledResult, BetaMe from letta.agents.base_agent import BaseAgent from letta.agents.helpers import _prepare_in_context_messages_async +from letta.constants import DEFAULT_MAX_STEPS from letta.helpers import ToolRulesSolver from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.tool_execution_helper import enable_strict_mode @@ -110,7 +111,7 @@ class LettaAgentBatch(BaseAgent): sandbox_config_manager: SandboxConfigManager, job_manager: JobManager, actor: User, - max_steps: int = 10, + max_steps: int = DEFAULT_MAX_STEPS, ): self.message_manager = message_manager self.agent_manager = agent_manager @@ -619,10 +620,10 @@ class LettaAgentBatch(BaseAgent): return in_context_messages # Not used in batch. - async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse: + async def step(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> LettaResponse: raise NotImplementedError async def step_stream( - self, input_messages: List[MessageCreate], max_steps: int = 10 + self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS ) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]: raise NotImplementedError diff --git a/letta/agents/voice_agent.py b/letta/agents/voice_agent.py index 4f904350..e35e5385 100644 --- a/letta/agents/voice_agent.py +++ b/letta/agents/voice_agent.py @@ -9,7 +9,7 @@ import openai from letta.agents.base_agent import BaseAgent from letta.agents.exceptions import IncompatibleAgentType from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent -from letta.constants import NON_USER_MSG_PREFIX +from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.tool_execution_helper import ( add_pre_execution_message, @@ -111,10 +111,10 @@ class VoiceAgent(BaseAgent): return summarizer - async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse: + async def step(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> LettaResponse: raise NotImplementedError("VoiceAgent does not have a synchronous step implemented currently.") - async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = 10) -> AsyncGenerator[str, None]: + async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> AsyncGenerator[str, None]: """ Main streaming loop that yields partial tokens. Whenever we detect a tool call, we yield from _handle_ai_response as well. diff --git a/letta/agents/voice_sleeptime_agent.py b/letta/agents/voice_sleeptime_agent.py index 1d5abfde..8a9c61c6 100644 --- a/letta/agents/voice_sleeptime_agent.py +++ b/letta/agents/voice_sleeptime_agent.py @@ -2,6 +2,7 @@ from typing import AsyncGenerator, List, Optional, Tuple, Union from letta.agents.helpers import _create_letta_response, serialize_message_history from letta.agents.letta_agent import LettaAgent +from letta.constants import DEFAULT_MAX_STEPS from letta.orm.enums import ToolType from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState @@ -62,7 +63,7 @@ class VoiceSleeptimeAgent(LettaAgent): async def step( self, input_messages: List[MessageCreate], - max_steps: int = 20, + max_steps: int = DEFAULT_MAX_STEPS, use_assistant_message: bool = True, include_return_message_types: Optional[List[MessageType]] = None, ) -> LettaResponse: @@ -170,7 +171,7 @@ class VoiceSleeptimeAgent(LettaAgent): return f"Failed to store memory given start_index {start_index} and end_index {end_index}: {e}", False async def step_stream( - self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True + self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, use_assistant_message: bool = True ) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]: """ This agent is synchronous-only. If called in an async context, raise an error. diff --git a/letta/constants.py b/letta/constants.py index fa0993e6..8c4ce4ac 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -46,6 +46,9 @@ IN_CONTEXT_MEMORY_KEYWORD = "CORE_MEMORY" # OpenAI error message: Invalid 'messages[1].tool_calls[0].id': string too long. Expected a string with maximum length 29, but got a string with length 36 instead. TOOL_CALL_ID_MAX_LEN = 29 +# Max steps for agent loop +DEFAULT_MAX_STEPS = 50 + # minimum context window size MIN_CONTEXT_WINDOW = 4096 diff --git a/letta/groups/sleeptime_multi_agent_v2.py b/letta/groups/sleeptime_multi_agent_v2.py index c88a9977..f0875941 100644 --- a/letta/groups/sleeptime_multi_agent_v2.py +++ b/letta/groups/sleeptime_multi_agent_v2.py @@ -4,6 +4,7 @@ from typing import AsyncGenerator, List, Optional from letta.agents.base_agent import BaseAgent from letta.agents.letta_agent import LettaAgent +from letta.constants import DEFAULT_MAX_STEPS from letta.groups.helpers import stringify_message from letta.otel.tracing import trace_method from letta.schemas.enums import JobStatus @@ -61,7 +62,7 @@ class SleeptimeMultiAgentV2(BaseAgent): async def step( self, input_messages: List[MessageCreate], - max_steps: int = 10, + max_steps: int = DEFAULT_MAX_STEPS, use_assistant_message: bool = True, request_start_timestamp_ns: Optional[int] = None, include_return_message_types: Optional[List[MessageType]] = None, @@ -131,7 +132,7 @@ class SleeptimeMultiAgentV2(BaseAgent): async def step_stream_no_tokens( self, input_messages: List[MessageCreate], - max_steps: int = 10, + max_steps: int = DEFAULT_MAX_STEPS, use_assistant_message: bool = True, request_start_timestamp_ns: Optional[int] = None, include_return_message_types: Optional[List[MessageType]] = None, @@ -149,7 +150,7 @@ class SleeptimeMultiAgentV2(BaseAgent): async def step_stream( self, input_messages: List[MessageCreate], - max_steps: int = 10, + max_steps: int = DEFAULT_MAX_STEPS, use_assistant_message: bool = True, request_start_timestamp_ns: Optional[int] = None, include_return_message_types: Optional[List[MessageType]] = None, diff --git a/letta/schemas/letta_request.py b/letta/schemas/letta_request.py index ec0bba94..222de433 100644 --- a/letta/schemas/letta_request.py +++ b/letta/schemas/letta_request.py @@ -2,7 +2,7 @@ from typing import List, Optional from pydantic import BaseModel, Field, HttpUrl -from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG +from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.letta_message import MessageType from letta.schemas.message import MessageCreate @@ -10,7 +10,7 @@ from letta.schemas.message import MessageCreate class LettaRequest(BaseModel): messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.") max_steps: int = Field( - default=10, + default=DEFAULT_MAX_STEPS, description="Maximum number of steps the agent should take to process the request.", ) use_assistant_message: bool = Field( diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index dc3ff068..0399266d 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -12,7 +12,7 @@ from sqlalchemy.exc import IntegrityError, OperationalError from starlette.responses import Response, StreamingResponse from letta.agents.letta_agent import LettaAgent -from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG +from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2 from letta.helpers.datetime_helpers import get_utc_timestamp_ns from letta.log import get_logger @@ -843,7 +843,7 @@ async def process_message_background( use_assistant_message: bool, assistant_message_tool_name: str, assistant_message_tool_kwarg: str, - max_steps: int = 10, + max_steps: int = DEFAULT_MAX_STEPS, include_return_message_types: Optional[List[MessageType]] = None, ) -> None: """Background task to process the message and update job status."""