feat: set request heartbeat for max steps (#2739)
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Any, AsyncGenerator, List, Optional, Union
|
||||
|
||||
import openai
|
||||
|
||||
from letta.constants import DEFAULT_MAX_STEPS
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
@@ -43,7 +44,7 @@ class BaseAgent(ABC):
|
||||
self.logger = get_logger(agent_id)
|
||||
|
||||
@abstractmethod
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> LettaResponse:
|
||||
"""
|
||||
Main execution loop for the agent.
|
||||
"""
|
||||
@@ -51,7 +52,7 @@ class BaseAgent(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def step_stream(
|
||||
self, input_messages: List[MessageCreate], max_steps: int = 10
|
||||
self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS
|
||||
) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
|
||||
"""
|
||||
Main streaming execution loop for the agent.
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import AsyncGenerator, Dict, List
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.constants import DEFAULT_MAX_STEPS
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.block import Block, BlockUpdate
|
||||
from letta.schemas.enums import MessageRole
|
||||
@@ -42,7 +43,7 @@ class EphemeralSummaryAgent(BaseAgent):
|
||||
self.target_block_label = target_block_label
|
||||
self.block_manager = block_manager
|
||||
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> List[Message]:
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> List[Message]:
|
||||
if len(input_messages) > 1:
|
||||
raise ValueError("Can only invoke EphemeralSummaryAgent with a single summarization message.")
|
||||
|
||||
@@ -100,5 +101,5 @@ class EphemeralSummaryAgent(BaseAgent):
|
||||
)
|
||||
return openai_request
|
||||
|
||||
async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = 10) -> AsyncGenerator[str, None]:
|
||||
async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> AsyncGenerator[str, None]:
|
||||
raise NotImplementedError("EphemeralAgent does not support async step.")
|
||||
|
||||
@@ -9,6 +9,7 @@ from openai.types.chat import ChatCompletionChunk
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
|
||||
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_no_persist_async, generate_step_id
|
||||
from letta.constants import DEFAULT_MAX_STEPS
|
||||
from letta.errors import ContextWindowExceededError
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import AsyncTimer, get_utc_timestamp_ns, ns_to_ms
|
||||
@@ -114,7 +115,7 @@ class LettaAgent(BaseAgent):
|
||||
async def step(
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
max_steps: int = 10,
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
include_return_message_types: Optional[List[MessageType]] = None,
|
||||
@@ -139,7 +140,7 @@ class LettaAgent(BaseAgent):
|
||||
async def step_stream_no_tokens(
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
max_steps: int = 10,
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
include_return_message_types: Optional[List[MessageType]] = None,
|
||||
@@ -163,7 +164,7 @@ class LettaAgent(BaseAgent):
|
||||
request_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns)
|
||||
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
||||
|
||||
for _ in range(max_steps):
|
||||
for i in range(max_steps):
|
||||
step_id = generate_step_id()
|
||||
step_start = get_utc_timestamp_ns()
|
||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||||
@@ -225,6 +226,7 @@ class LettaAgent(BaseAgent):
|
||||
reasoning_content=reasoning,
|
||||
initial_messages=initial_messages,
|
||||
agent_step_span=agent_step_span,
|
||||
is_final_step=(i == max_steps - 1),
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
@@ -289,7 +291,7 @@ class LettaAgent(BaseAgent):
|
||||
self,
|
||||
agent_state: AgentState,
|
||||
input_messages: List[MessageCreate],
|
||||
max_steps: int = 10,
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
) -> Tuple[List[Message], List[Message], LettaUsageStatistics]:
|
||||
"""
|
||||
@@ -315,7 +317,7 @@ class LettaAgent(BaseAgent):
|
||||
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
||||
|
||||
usage = LettaUsageStatistics()
|
||||
for _ in range(max_steps):
|
||||
for i in range(max_steps):
|
||||
step_id = generate_step_id()
|
||||
step_start = get_utc_timestamp_ns()
|
||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||||
@@ -368,6 +370,7 @@ class LettaAgent(BaseAgent):
|
||||
step_id=step_id,
|
||||
initial_messages=initial_messages,
|
||||
agent_step_span=agent_step_span,
|
||||
is_final_step=(i == max_steps - 1),
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
@@ -417,7 +420,7 @@ class LettaAgent(BaseAgent):
|
||||
async def step_stream(
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
max_steps: int = 10,
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
include_return_message_types: Optional[List[MessageType]] = None,
|
||||
@@ -451,7 +454,7 @@ class LettaAgent(BaseAgent):
|
||||
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
||||
|
||||
provider_request_start_timestamp_ns = None
|
||||
for _ in range(max_steps):
|
||||
for i in range(max_steps):
|
||||
step_id = generate_step_id()
|
||||
step_start = get_utc_timestamp_ns()
|
||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||||
@@ -532,6 +535,7 @@ class LettaAgent(BaseAgent):
|
||||
step_id=step_id,
|
||||
initial_messages=initial_messages,
|
||||
agent_step_span=agent_step_span,
|
||||
is_final_step=(i == max_steps - 1),
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
@@ -838,6 +842,7 @@ class LettaAgent(BaseAgent):
|
||||
step_id: str | None = None,
|
||||
initial_messages: Optional[List[Message]] = None,
|
||||
agent_step_span: Optional["Span"] = None,
|
||||
is_final_step: Optional[bool] = None,
|
||||
) -> Tuple[List[Message], bool]:
|
||||
"""
|
||||
Now that streaming is done, handle the final AI response.
|
||||
@@ -858,17 +863,21 @@ class LettaAgent(BaseAgent):
|
||||
except AssertionError:
|
||||
tool_args = json.loads(tool_args)
|
||||
|
||||
# Get request heartbeats and coerce to bool
|
||||
request_heartbeat = tool_args.pop("request_heartbeat", False)
|
||||
# Pre-emptively pop out inner_thoughts
|
||||
tool_args.pop(INNER_THOUGHTS_KWARG, "")
|
||||
if is_final_step:
|
||||
logger.info("Agent has reached max steps.")
|
||||
request_heartbeat = False
|
||||
else:
|
||||
# Get request heartbeats and coerce to bool
|
||||
request_heartbeat = tool_args.pop("request_heartbeat", False)
|
||||
# Pre-emptively pop out inner_thoughts
|
||||
tool_args.pop(INNER_THOUGHTS_KWARG, "")
|
||||
|
||||
# So this is necessary, because sometimes non-structured outputs makes mistakes
|
||||
if not isinstance(request_heartbeat, bool):
|
||||
if isinstance(request_heartbeat, str):
|
||||
request_heartbeat = request_heartbeat.lower() == "true"
|
||||
else:
|
||||
request_heartbeat = bool(request_heartbeat)
|
||||
# So this is necessary, because sometimes non-structured outputs makes mistakes
|
||||
if not isinstance(request_heartbeat, bool):
|
||||
if isinstance(request_heartbeat, str):
|
||||
request_heartbeat = request_heartbeat.lower() == "true"
|
||||
else:
|
||||
request_heartbeat = bool(request_heartbeat)
|
||||
|
||||
tool_call_id = tool_call.id or f"call_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from anthropic.types.beta.messages import BetaMessageBatchCanceledResult, BetaMe
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.helpers import _prepare_in_context_messages_async
|
||||
from letta.constants import DEFAULT_MAX_STEPS
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.tool_execution_helper import enable_strict_mode
|
||||
@@ -110,7 +111,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
sandbox_config_manager: SandboxConfigManager,
|
||||
job_manager: JobManager,
|
||||
actor: User,
|
||||
max_steps: int = 10,
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
):
|
||||
self.message_manager = message_manager
|
||||
self.agent_manager = agent_manager
|
||||
@@ -619,10 +620,10 @@ class LettaAgentBatch(BaseAgent):
|
||||
return in_context_messages
|
||||
|
||||
# Not used in batch.
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> LettaResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def step_stream(
|
||||
self, input_messages: List[MessageCreate], max_steps: int = 10
|
||||
self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS
|
||||
) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -9,7 +9,7 @@ import openai
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.exceptions import IncompatibleAgentType
|
||||
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
|
||||
from letta.constants import NON_USER_MSG_PREFIX
|
||||
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.tool_execution_helper import (
|
||||
add_pre_execution_message,
|
||||
@@ -111,10 +111,10 @@ class VoiceAgent(BaseAgent):
|
||||
|
||||
return summarizer
|
||||
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> LettaResponse:
|
||||
raise NotImplementedError("VoiceAgent does not have a synchronous step implemented currently.")
|
||||
|
||||
async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = 10) -> AsyncGenerator[str, None]:
|
||||
async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Main streaming loop that yields partial tokens.
|
||||
Whenever we detect a tool call, we yield from _handle_ai_response as well.
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||
|
||||
from letta.agents.helpers import _create_letta_response, serialize_message_history
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.constants import DEFAULT_MAX_STEPS
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
@@ -62,7 +63,7 @@ class VoiceSleeptimeAgent(LettaAgent):
|
||||
async def step(
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
max_steps: int = 20,
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
use_assistant_message: bool = True,
|
||||
include_return_message_types: Optional[List[MessageType]] = None,
|
||||
) -> LettaResponse:
|
||||
@@ -170,7 +171,7 @@ class VoiceSleeptimeAgent(LettaAgent):
|
||||
return f"Failed to store memory given start_index {start_index} and end_index {end_index}: {e}", False
|
||||
|
||||
async def step_stream(
|
||||
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True
|
||||
self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, use_assistant_message: bool = True
|
||||
) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
|
||||
"""
|
||||
This agent is synchronous-only. If called in an async context, raise an error.
|
||||
|
||||
@@ -46,6 +46,9 @@ IN_CONTEXT_MEMORY_KEYWORD = "CORE_MEMORY"
|
||||
# OpenAI error message: Invalid 'messages[1].tool_calls[0].id': string too long. Expected a string with maximum length 29, but got a string with length 36 instead.
|
||||
TOOL_CALL_ID_MAX_LEN = 29
|
||||
|
||||
# Max steps for agent loop
|
||||
DEFAULT_MAX_STEPS = 50
|
||||
|
||||
# minimum context window size
|
||||
MIN_CONTEXT_WINDOW = 4096
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.constants import DEFAULT_MAX_STEPS
|
||||
from letta.groups.helpers import stringify_message
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import JobStatus
|
||||
@@ -61,7 +62,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
async def step(
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
max_steps: int = 10,
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
include_return_message_types: Optional[List[MessageType]] = None,
|
||||
@@ -131,7 +132,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
async def step_stream_no_tokens(
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
max_steps: int = 10,
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
include_return_message_types: Optional[List[MessageType]] = None,
|
||||
@@ -149,7 +150,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
async def step_stream(
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
max_steps: int = 10,
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
include_return_message_types: Optional[List[MessageType]] = None,
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.schemas.letta_message import MessageType
|
||||
from letta.schemas.message import MessageCreate
|
||||
|
||||
@@ -10,7 +10,7 @@ from letta.schemas.message import MessageCreate
|
||||
class LettaRequest(BaseModel):
|
||||
messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.")
|
||||
max_steps: int = Field(
|
||||
default=10,
|
||||
default=DEFAULT_MAX_STEPS,
|
||||
description="Maximum number of steps the agent should take to process the request.",
|
||||
)
|
||||
use_assistant_message: bool = Field(
|
||||
|
||||
@@ -12,7 +12,7 @@ from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
|
||||
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
||||
from letta.log import get_logger
|
||||
@@ -843,7 +843,7 @@ async def process_message_background(
|
||||
use_assistant_message: bool,
|
||||
assistant_message_tool_name: str,
|
||||
assistant_message_tool_kwarg: str,
|
||||
max_steps: int = 10,
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
include_return_message_types: Optional[List[MessageType]] = None,
|
||||
) -> None:
|
||||
"""Background task to process the message and update job status."""
|
||||
|
||||
Reference in New Issue
Block a user