feat: consolidate stream methods for new agent loop (#4468)
This commit is contained in:
@@ -13,8 +13,8 @@ from letta.schemas.user import User
|
||||
|
||||
class BaseAgentV2(ABC):
|
||||
"""
|
||||
Abstract base class for the letta gent loop, handling message management,
|
||||
llm api request, tool execution, and context tracking.
|
||||
Abstract base class for the main agent execution loop for letta agents, handling
|
||||
message management, llm api request, tool execution, and context tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, agent_state: AgentState, actor: User):
|
||||
@@ -28,8 +28,8 @@ class BaseAgentV2(ABC):
|
||||
input_messages: list[MessageCreate],
|
||||
) -> dict:
|
||||
"""
|
||||
Main execution loop for the agent. This method only returns once the agent completes
|
||||
execution, returning all messages at once.
|
||||
Execute the agent loop in dry_run mode, returning just the generated request
|
||||
payload sent to the underlying llm provider.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -40,33 +40,21 @@ class BaseAgentV2(ABC):
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
) -> LettaResponse:
|
||||
"""
|
||||
Main execution loop for the agent. This method only returns once the agent completes
|
||||
execution, returning all messages at once.
|
||||
Execute the agent loop in blocking mode, returning all messages at once.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def stream_steps(
|
||||
async def stream(
|
||||
self,
|
||||
input_messages: list[MessageCreate],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
stream_tokens: bool = True,
|
||||
) -> AsyncGenerator[LettaMessage | LegacyLettaMessage | MessageStreamStatus, None]:
|
||||
"""
|
||||
Main execution loop for the agent. This method returns an async generator, streaming
|
||||
each step as it completes on the server side.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def stream_tokens(
|
||||
self,
|
||||
input_messages: list[MessageCreate],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
) -> AsyncGenerator[LettaMessage | LegacyLettaMessage | MessageStreamStatus, None]:
|
||||
"""
|
||||
Main execution loop for the agent. This method returns an async generator, streaming
|
||||
each token as it is returned from the underlying llm api. Not all llm providers offer
|
||||
native token streaming functionality; in these cases, this api streams back steps
|
||||
rather than individual tokens.
|
||||
Execute the agent loop in streaming mode, yielding chunks as they become available.
|
||||
If stream_tokens is True, individual tokens are streamed as they arrive from the LLM,
|
||||
providing the lowest latency experience, otherwise each complete step (reasoning +
|
||||
tool call + tool return) is yielded as it completes.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -192,24 +192,28 @@ class LettaAgentV2(BaseAgentV2):
|
||||
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
||||
return LettaResponse(messages=response_letta_messages, stop_reason=self.stop_reason, usage=self.usage)
|
||||
|
||||
async def stream_steps(
|
||||
async def stream(
|
||||
self,
|
||||
input_messages: list[MessageCreate],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
stream_tokens: bool = True,
|
||||
run_id: str | None = None,
|
||||
use_assistant_message: bool = True,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
) -> AsyncGenerator[str]:
|
||||
"""
|
||||
Execute the agent loop with step-level streaming.
|
||||
|
||||
Each complete step (reasoning + tool call + tool return) is returned as it completes,
|
||||
but individual tokens are not streamed.
|
||||
Execute the agent loop in streaming mode, yielding chunks as they become available.
|
||||
If stream_tokens is True, individual tokens are streamed as they arrive from the LLM,
|
||||
providing the lowest latency experience, otherwise each complete step (reasoning +
|
||||
tool call + tool return) is yielded as it completes.
|
||||
|
||||
Args:
|
||||
input_messages: List of new messages to process
|
||||
max_steps: Maximum number of agent steps to execute
|
||||
stream_tokens: Whether to stream back individual tokens. Not all llm
|
||||
providers offer native token streaming functionality; in these cases,
|
||||
this api streams back steps rather than individual tokens.
|
||||
run_id: Optional job/run ID for tracking
|
||||
use_assistant_message: Whether to use assistant message format
|
||||
include_return_message_types: Filter for which message types to return
|
||||
@@ -218,75 +222,17 @@ class LettaAgentV2(BaseAgentV2):
|
||||
Yields:
|
||||
str: JSON-formatted SSE data chunks for each completed step
|
||||
"""
|
||||
response = self.stream(
|
||||
input_messages=input_messages,
|
||||
llm_adapter=LettaLLMRequestAdapter(llm_client=self.llm_client, llm_config=self.agent_state.llm_config),
|
||||
max_steps=max_steps,
|
||||
run_id=run_id,
|
||||
use_assistant_message=use_assistant_message,
|
||||
include_return_message_types=include_return_message_types,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
)
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
|
||||
async def stream_tokens(
|
||||
self,
|
||||
input_messages: list[MessageCreate],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
run_id: str | None = None,
|
||||
use_assistant_message: bool = True,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
) -> AsyncGenerator[str]:
|
||||
"""
|
||||
Execute the agent loop with token-level streaming for minimal TTFT.
|
||||
|
||||
Individual tokens are streamed as they arrive from the LLM, providing
|
||||
the lowest latency experience. Falls back to step streaming for providers
|
||||
that don't support token streaming.
|
||||
|
||||
Args:
|
||||
input_messages: List of new messages to process
|
||||
max_steps: Maximum number of agent steps to execute
|
||||
run_id: Optional job/run ID for tracking
|
||||
use_assistant_message: Whether to use assistant message format
|
||||
include_return_message_types: Filter for which message types to return
|
||||
request_start_timestamp_ns: Start time for tracking request duration
|
||||
|
||||
Yields:
|
||||
str: JSON-formatted SSE data chunks for each token/chunk
|
||||
"""
|
||||
response = self.stream(
|
||||
input_messages=input_messages,
|
||||
llm_adapter=LettaLLMStreamAdapter(
|
||||
if stream_tokens:
|
||||
llm_adapter = LettaLLMStreamAdapter(
|
||||
llm_client=self.llm_client,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
),
|
||||
run_id=run_id,
|
||||
use_assistant_message=use_assistant_message,
|
||||
include_return_message_types=include_return_message_types,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
)
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
)
|
||||
else:
|
||||
llm_adapter = LettaLLMRequestAdapter(
|
||||
llm_client=self.llm_client,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
input_messages: list[MessageCreate],
|
||||
llm_adapter: LettaLLMAdapter,
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
run_id: str | None = None,
|
||||
use_assistant_message: bool = True,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
) -> AsyncGenerator[str]:
|
||||
"""
|
||||
Main execution loop for the agent. This method returns an async generator, streaming
|
||||
each token as it is returned from the underlying llm api. Not all llm providers offer
|
||||
native token streaming functionality; in these cases, this api streams back steps
|
||||
rather than individual tokens.
|
||||
"""
|
||||
try:
|
||||
self._initialize_state()
|
||||
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
|
||||
|
||||
Reference in New Issue
Block a user