1005 lines
45 KiB
Python
1005 lines
45 KiB
Python
import asyncio
|
|
import json
|
|
import uuid
|
|
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
|
|
|
|
from openai import AsyncStream
|
|
from openai.types.chat import ChatCompletionChunk
|
|
|
|
from letta.agents.base_agent import BaseAgent
|
|
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
|
|
from letta.agents.helpers import (
|
|
_create_letta_response,
|
|
_prepare_in_context_messages_async,
|
|
_prepare_in_context_messages_no_persist_async,
|
|
generate_step_id,
|
|
)
|
|
from letta.errors import LLMContextWindowExceededError
|
|
from letta.helpers import ToolRulesSolver
|
|
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
|
from letta.helpers.tool_execution_helper import enable_strict_mode
|
|
from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface
|
|
from letta.interfaces.openai_streaming_interface import OpenAIStreamingInterface
|
|
from letta.llm_api.llm_client import LLMClient
|
|
from letta.llm_api.llm_client_base import LLMClientBase
|
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
|
from letta.log import get_logger
|
|
from letta.orm.enums import ToolType
|
|
from letta.schemas.agent import AgentState
|
|
from letta.schemas.enums import MessageRole, MessageStreamStatus
|
|
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
|
from letta.schemas.letta_response import LettaResponse
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.message import Message, MessageCreate
|
|
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
|
from letta.schemas.provider_trace import ProviderTraceCreate
|
|
from letta.schemas.usage import LettaUsageStatistics
|
|
from letta.schemas.user import User
|
|
from letta.server.rest_api.utils import create_letta_messages_from_llm_response
|
|
from letta.services.agent_manager import AgentManager
|
|
from letta.services.block_manager import BlockManager
|
|
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
|
|
from letta.services.message_manager import MessageManager
|
|
from letta.services.passage_manager import PassageManager
|
|
from letta.services.step_manager import NoopStepManager, StepManager
|
|
from letta.services.summarizer.enums import SummarizationMode
|
|
from letta.services.summarizer.summarizer import Summarizer
|
|
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
|
|
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
|
from letta.settings import model_settings
|
|
from letta.system import package_function_response
|
|
from letta.tracing import log_event, trace_method, tracer
|
|
from letta.utils import log_telemetry, validate_function_response
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class LettaAgent(BaseAgent):
|
|
|
|
def __init__(
|
|
self,
|
|
agent_id: str,
|
|
message_manager: MessageManager,
|
|
agent_manager: AgentManager,
|
|
block_manager: BlockManager,
|
|
passage_manager: PassageManager,
|
|
actor: User,
|
|
step_manager: StepManager = NoopStepManager(),
|
|
telemetry_manager: TelemetryManager = NoopTelemetryManager(),
|
|
summary_block_label: str = "conversation_summary",
|
|
message_buffer_limit: int = 60, # TODO: Make this configurable
|
|
message_buffer_min: int = 15, # TODO: Make this configurable
|
|
enable_summarization: bool = True, # TODO: Make this configurable
|
|
max_summarization_retries: int = 3, # TODO: Make this configurable
|
|
):
|
|
super().__init__(agent_id=agent_id, openai_client=None, message_manager=message_manager, agent_manager=agent_manager, actor=actor)
|
|
|
|
# TODO: Make this more general, factorable
|
|
# Summarizer settings
|
|
self.block_manager = block_manager
|
|
self.passage_manager = passage_manager
|
|
self.step_manager = step_manager
|
|
self.telemetry_manager = telemetry_manager
|
|
self.response_messages: List[Message] = []
|
|
|
|
self.last_function_response = None
|
|
|
|
# Cached archival memory/message size
|
|
self.num_messages = None
|
|
self.num_archival_memories = None
|
|
|
|
self.summarization_agent = None
|
|
self.summary_block_label = summary_block_label
|
|
self.max_summarization_retries = max_summarization_retries
|
|
|
|
# TODO: Expand to more
|
|
if enable_summarization and model_settings.openai_api_key:
|
|
self.summarization_agent = EphemeralSummaryAgent(
|
|
target_block_label=self.summary_block_label,
|
|
agent_id=agent_id,
|
|
block_manager=self.block_manager,
|
|
message_manager=self.message_manager,
|
|
agent_manager=self.agent_manager,
|
|
actor=self.actor,
|
|
)
|
|
|
|
self.summarizer = Summarizer(
|
|
mode=SummarizationMode.STATIC_MESSAGE_BUFFER,
|
|
summarizer_agent=self.summarization_agent,
|
|
# TODO: Make this configurable
|
|
message_buffer_limit=message_buffer_limit,
|
|
message_buffer_min=message_buffer_min,
|
|
)
|
|
|
|
@trace_method
|
|
async def step(
|
|
self,
|
|
input_messages: List[MessageCreate],
|
|
max_steps: int = 10,
|
|
use_assistant_message: bool = True,
|
|
request_start_timestamp_ns: Optional[int] = None,
|
|
) -> LettaResponse:
|
|
agent_state = await self.agent_manager.get_agent_by_id_async(
|
|
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
|
|
)
|
|
_, new_in_context_messages, usage = await self._step(
|
|
agent_state=agent_state,
|
|
input_messages=input_messages,
|
|
max_steps=max_steps,
|
|
request_start_timestamp_ns=request_start_timestamp_ns,
|
|
)
|
|
return _create_letta_response(
|
|
new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage
|
|
)
|
|
|
|
@trace_method
|
|
async def step_stream_no_tokens(
|
|
self,
|
|
input_messages: List[MessageCreate],
|
|
max_steps: int = 10,
|
|
use_assistant_message: bool = True,
|
|
request_start_timestamp_ns: Optional[int] = None,
|
|
):
|
|
agent_state = await self.agent_manager.get_agent_by_id_async(
|
|
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
|
|
)
|
|
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
|
|
input_messages, agent_state, self.message_manager, self.actor
|
|
)
|
|
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
|
llm_client = LLMClient.create(
|
|
provider_type=agent_state.llm_config.model_endpoint_type,
|
|
put_inner_thoughts_first=True,
|
|
actor=self.actor,
|
|
)
|
|
usage = LettaUsageStatistics()
|
|
|
|
# span for request
|
|
request_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns)
|
|
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
|
|
|
for _ in range(max_steps):
|
|
step_id = generate_step_id()
|
|
step_start = get_utc_timestamp_ns()
|
|
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
|
agent_step_span.set_attributes({"step_id": step_id})
|
|
|
|
request_data, response_data, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm(
|
|
current_in_context_messages,
|
|
new_in_context_messages,
|
|
agent_state,
|
|
llm_client,
|
|
tool_rules_solver,
|
|
)
|
|
in_context_messages = current_in_context_messages + new_in_context_messages
|
|
|
|
log_event("agent.stream_no_tokens.llm_response.received") # [3^]
|
|
|
|
# log llm request time
|
|
now = get_utc_timestamp_ns()
|
|
llm_request_ns = now - step_start
|
|
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ns // 1_000_000})
|
|
|
|
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
|
|
|
|
# update usage
|
|
# TODO: add run_id
|
|
usage.step_count += 1
|
|
usage.completion_tokens += response.usage.completion_tokens
|
|
usage.prompt_tokens += response.usage.prompt_tokens
|
|
usage.total_tokens += response.usage.total_tokens
|
|
|
|
if not response.choices[0].message.tool_calls:
|
|
# TODO: make into a real error
|
|
raise ValueError("No tool calls found in response, model must make a tool call")
|
|
tool_call = response.choices[0].message.tool_calls[0]
|
|
if response.choices[0].message.reasoning_content:
|
|
reasoning = [
|
|
ReasoningContent(
|
|
reasoning=response.choices[0].message.reasoning_content,
|
|
is_native=True,
|
|
signature=response.choices[0].message.reasoning_content_signature,
|
|
)
|
|
]
|
|
elif response.choices[0].message.content:
|
|
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
|
|
else:
|
|
logger.info("No reasoning content found.")
|
|
reasoning = None
|
|
|
|
# log LLM request time
|
|
now = get_utc_timestamp_ns()
|
|
llm_request_ns = now - step_start
|
|
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ns // 1_000_000})
|
|
|
|
persisted_messages, should_continue = await self._handle_ai_response(
|
|
tool_call,
|
|
agent_state,
|
|
tool_rules_solver,
|
|
response.usage,
|
|
reasoning_content=reasoning,
|
|
agent_step_span=agent_step_span,
|
|
)
|
|
self.response_messages.extend(persisted_messages)
|
|
new_in_context_messages.extend(persisted_messages)
|
|
log_event("agent.stream_no_tokens.llm_response.processed") # [4^]
|
|
|
|
# log step time
|
|
now = get_utc_timestamp_ns()
|
|
step_ns = now - step_start
|
|
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": step_ns // 1_000_000})
|
|
agent_step_span.end()
|
|
|
|
# Log LLM Trace
|
|
await self.telemetry_manager.create_provider_trace_async(
|
|
actor=self.actor,
|
|
provider_trace_create=ProviderTraceCreate(
|
|
request_json=request_data,
|
|
response_json=response_data,
|
|
step_id=step_id,
|
|
organization_id=self.actor.organization_id,
|
|
),
|
|
)
|
|
|
|
# stream step
|
|
# TODO: improve TTFT
|
|
filter_user_messages = [m for m in persisted_messages if m.role != "user"]
|
|
letta_messages = Message.to_letta_messages_from_list(
|
|
filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
|
|
)
|
|
for message in letta_messages:
|
|
yield f"data: {message.model_dump_json()}\n\n"
|
|
|
|
if not should_continue:
|
|
break
|
|
|
|
# Extend the in context message ids
|
|
if not agent_state.message_buffer_autoclear:
|
|
await self._rebuild_context_window(
|
|
in_context_messages=current_in_context_messages,
|
|
new_letta_messages=new_in_context_messages,
|
|
llm_config=agent_state.llm_config,
|
|
total_tokens=usage.total_tokens,
|
|
force=False,
|
|
)
|
|
|
|
# log request time
|
|
if request_start_timestamp_ns:
|
|
now = get_utc_timestamp_ns()
|
|
request_ns = now - request_start_timestamp_ns
|
|
request_span.add_event(name="letta_request_ms", attributes={"duration_ms": request_ns // 1_000_000})
|
|
request_span.end()
|
|
|
|
# Return back usage
|
|
yield f"data: {usage.model_dump_json()}\n\n"
|
|
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
|
|
|
|
async def _step(
|
|
self,
|
|
agent_state: AgentState,
|
|
input_messages: List[MessageCreate],
|
|
max_steps: int = 10,
|
|
request_start_timestamp_ns: Optional[int] = None,
|
|
) -> Tuple[List[Message], List[Message], LettaUsageStatistics]:
|
|
"""
|
|
Carries out an invocation of the agent loop. In each step, the agent
|
|
1. Rebuilds its memory
|
|
2. Generates a request for the LLM
|
|
3. Fetches a response from the LLM
|
|
4. Processes the response
|
|
"""
|
|
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
|
|
input_messages, agent_state, self.message_manager, self.actor
|
|
)
|
|
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
|
llm_client = LLMClient.create(
|
|
provider_type=agent_state.llm_config.model_endpoint_type,
|
|
put_inner_thoughts_first=True,
|
|
actor=self.actor,
|
|
)
|
|
|
|
# span for request
|
|
request_span = tracer.start_span("time_to_first_token")
|
|
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
|
|
|
usage = LettaUsageStatistics()
|
|
for _ in range(max_steps):
|
|
step_id = generate_step_id()
|
|
step_start = get_utc_timestamp_ns()
|
|
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
|
agent_step_span.set_attributes({"step_id": step_id})
|
|
|
|
request_data, response_data, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm(
|
|
current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver
|
|
)
|
|
in_context_messages = current_in_context_messages + new_in_context_messages
|
|
|
|
log_event("agent.step.llm_response.received") # [3^]
|
|
|
|
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
|
|
|
|
# log LLM request time
|
|
now = get_utc_timestamp_ns()
|
|
llm_request_ns = now - step_start
|
|
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ns // 1_000_000})
|
|
|
|
# TODO: add run_id
|
|
usage.step_count += 1
|
|
usage.completion_tokens += response.usage.completion_tokens
|
|
usage.prompt_tokens += response.usage.prompt_tokens
|
|
usage.total_tokens += response.usage.total_tokens
|
|
|
|
if not response.choices[0].message.tool_calls:
|
|
# TODO: make into a real error
|
|
raise ValueError("No tool calls found in response, model must make a tool call")
|
|
tool_call = response.choices[0].message.tool_calls[0]
|
|
if response.choices[0].message.reasoning_content:
|
|
reasoning = [
|
|
ReasoningContent(
|
|
reasoning=response.choices[0].message.reasoning_content,
|
|
is_native=True,
|
|
signature=response.choices[0].message.reasoning_content_signature,
|
|
)
|
|
]
|
|
elif response.choices[0].message.content:
|
|
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
|
|
else:
|
|
logger.info("No reasoning content found.")
|
|
reasoning = None
|
|
|
|
persisted_messages, should_continue = await self._handle_ai_response(
|
|
tool_call,
|
|
agent_state,
|
|
tool_rules_solver,
|
|
response.usage,
|
|
reasoning_content=reasoning,
|
|
step_id=step_id,
|
|
agent_step_span=agent_step_span,
|
|
)
|
|
self.response_messages.extend(persisted_messages)
|
|
new_in_context_messages.extend(persisted_messages)
|
|
log_event("agent.step.llm_response.processed") # [4^]
|
|
|
|
# log step time
|
|
now = get_utc_timestamp_ns()
|
|
step_ns = now - step_start
|
|
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": step_ns // 1_000_000})
|
|
agent_step_span.end()
|
|
|
|
# Log LLM Trace
|
|
await self.telemetry_manager.create_provider_trace_async(
|
|
actor=self.actor,
|
|
provider_trace_create=ProviderTraceCreate(
|
|
request_json=request_data,
|
|
response_json=response_data,
|
|
step_id=step_id,
|
|
organization_id=self.actor.organization_id,
|
|
),
|
|
)
|
|
|
|
if not should_continue:
|
|
break
|
|
|
|
# log request time
|
|
if request_start_timestamp_ns:
|
|
now = get_utc_timestamp_ns()
|
|
request_ns = now - request_start_timestamp_ns
|
|
request_span.add_event(name="request_ms", attributes={"duration_ms": request_ns // 1_000_000})
|
|
request_span.end()
|
|
|
|
# Extend the in context message ids
|
|
if not agent_state.message_buffer_autoclear:
|
|
await self._rebuild_context_window(
|
|
in_context_messages=current_in_context_messages,
|
|
new_letta_messages=new_in_context_messages,
|
|
llm_config=agent_state.llm_config,
|
|
total_tokens=usage.total_tokens,
|
|
force=False,
|
|
)
|
|
|
|
return current_in_context_messages, new_in_context_messages, usage
|
|
|
|
@trace_method
|
|
async def step_stream(
|
|
self,
|
|
input_messages: List[MessageCreate],
|
|
max_steps: int = 10,
|
|
use_assistant_message: bool = True,
|
|
request_start_timestamp_ns: Optional[int] = None,
|
|
) -> AsyncGenerator[str, None]:
|
|
"""
|
|
Carries out an invocation of the agent loop in a streaming fashion that yields partial tokens.
|
|
Whenever we detect a tool call, we yield from _handle_ai_response as well. At each step, the agent
|
|
1. Rebuilds its memory
|
|
2. Generates a request for the LLM
|
|
3. Fetches a response from the LLM
|
|
4. Processes the response
|
|
"""
|
|
agent_state = await self.agent_manager.get_agent_by_id_async(
|
|
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
|
|
)
|
|
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_no_persist_async(
|
|
input_messages, agent_state, self.message_manager, self.actor
|
|
)
|
|
|
|
# Special strategy to lower TTFT
|
|
# Delay persistence of the initial input message as much as possible
|
|
persisted_input_messages = False
|
|
initial_messages = new_in_context_messages
|
|
|
|
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
|
llm_client = LLMClient.create(
|
|
provider_type=agent_state.llm_config.model_endpoint_type,
|
|
put_inner_thoughts_first=True,
|
|
actor=self.actor,
|
|
)
|
|
usage = LettaUsageStatistics()
|
|
first_chunk, request_span = True, None
|
|
if request_start_timestamp_ns:
|
|
request_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns)
|
|
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
|
|
|
provider_request_start_timestamp_ns = None
|
|
for _ in range(max_steps):
|
|
step_id = generate_step_id()
|
|
step_start = get_utc_timestamp_ns()
|
|
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
|
agent_step_span.set_attributes({"step_id": step_id})
|
|
|
|
request_data, stream, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm_streaming(
|
|
first_chunk,
|
|
agent_step_span,
|
|
request_start_timestamp_ns,
|
|
current_in_context_messages,
|
|
new_in_context_messages,
|
|
agent_state,
|
|
llm_client,
|
|
tool_rules_solver,
|
|
)
|
|
log_event("agent.stream.llm_response.received") # [3^]
|
|
|
|
# TODO: THIS IS INCREDIBLY UGLY
|
|
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
|
|
if agent_state.llm_config.model_endpoint_type == "anthropic":
|
|
interface = AnthropicStreamingInterface(
|
|
use_assistant_message=use_assistant_message,
|
|
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
|
|
)
|
|
elif agent_state.llm_config.model_endpoint_type == "openai":
|
|
interface = OpenAIStreamingInterface(
|
|
use_assistant_message=use_assistant_message,
|
|
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
|
|
)
|
|
else:
|
|
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
|
|
|
|
async for chunk in interface.process(
|
|
stream, ttft_span=request_span, provider_request_start_timestamp_ns=provider_request_start_timestamp_ns
|
|
):
|
|
# Measure time to first token
|
|
if first_chunk and request_span is not None:
|
|
now = get_utc_timestamp_ns()
|
|
ttft_ns = now - request_start_timestamp_ns
|
|
request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ttft_ns // 1_000_000})
|
|
first_chunk = False
|
|
|
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
|
|
|
# update usage
|
|
usage.step_count += 1
|
|
usage.completion_tokens += interface.output_tokens
|
|
usage.prompt_tokens += interface.input_tokens
|
|
usage.total_tokens += interface.input_tokens + interface.output_tokens
|
|
|
|
# Persist input messages if not already
|
|
# Special strategy to lower TTFT
|
|
if not persisted_input_messages:
|
|
await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor)
|
|
persisted_input_messages = True
|
|
|
|
# log LLM request time
|
|
now = get_utc_timestamp_ns()
|
|
llm_request_ns = now - step_start
|
|
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ns // 1_000_000})
|
|
|
|
# Process resulting stream content
|
|
tool_call = interface.get_tool_call_object()
|
|
reasoning_content = interface.get_reasoning_content()
|
|
persisted_messages, should_continue = await self._handle_ai_response(
|
|
tool_call,
|
|
agent_state,
|
|
tool_rules_solver,
|
|
UsageStatistics(
|
|
completion_tokens=interface.output_tokens,
|
|
prompt_tokens=interface.input_tokens,
|
|
total_tokens=interface.input_tokens + interface.output_tokens,
|
|
),
|
|
reasoning_content=reasoning_content,
|
|
pre_computed_assistant_message_id=interface.letta_assistant_message_id,
|
|
pre_computed_tool_message_id=interface.letta_tool_message_id,
|
|
step_id=step_id,
|
|
agent_step_span=agent_step_span,
|
|
)
|
|
self.response_messages.extend(persisted_messages)
|
|
new_in_context_messages.extend(persisted_messages)
|
|
|
|
# log total step time
|
|
now = get_utc_timestamp_ns()
|
|
step_ns = now - step_start
|
|
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": step_ns // 1_000_000})
|
|
agent_step_span.end()
|
|
|
|
# TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream
|
|
# log_event("agent.stream.llm_response.processed") # [4^]
|
|
|
|
# Log LLM Trace
|
|
# TODO (cliandy): we are piecing together the streamed response here. Content here does not match the actual response schema.
|
|
await self.telemetry_manager.create_provider_trace_async(
|
|
actor=self.actor,
|
|
provider_trace_create=ProviderTraceCreate(
|
|
request_json=request_data,
|
|
response_json={
|
|
"content": {
|
|
"tool_call": tool_call.model_dump_json(),
|
|
"reasoning": [content.model_dump_json() for content in reasoning_content],
|
|
},
|
|
"id": interface.message_id,
|
|
"model": interface.model,
|
|
"role": "assistant",
|
|
# "stop_reason": "",
|
|
# "stop_sequence": None,
|
|
"type": "message",
|
|
"usage": {"input_tokens": interface.input_tokens, "output_tokens": interface.output_tokens},
|
|
},
|
|
step_id=step_id,
|
|
organization_id=self.actor.organization_id,
|
|
),
|
|
)
|
|
|
|
if not use_assistant_message or should_continue:
|
|
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
|
|
yield f"data: {tool_return.model_dump_json()}\n\n"
|
|
|
|
if not should_continue:
|
|
break
|
|
|
|
# Extend the in context message ids
|
|
if not agent_state.message_buffer_autoclear:
|
|
await self._rebuild_context_window(
|
|
in_context_messages=current_in_context_messages,
|
|
new_letta_messages=new_in_context_messages,
|
|
llm_config=agent_state.llm_config,
|
|
total_tokens=usage.total_tokens,
|
|
force=False,
|
|
)
|
|
|
|
# log time of entire request
|
|
if request_start_timestamp_ns:
|
|
now = get_utc_timestamp_ns()
|
|
request_ns = now - request_start_timestamp_ns
|
|
request_span.add_event(name="letta_request_ms", attributes={"duration_ms": request_ns // 1_000_000})
|
|
request_span.end()
|
|
|
|
# TODO: Also yield out a letta usage stats SSE
|
|
yield f"data: {usage.model_dump_json()}\n\n"
|
|
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
|
|
|
|
async def _build_and_request_from_llm(
|
|
self,
|
|
current_in_context_messages: List[Message],
|
|
new_in_context_messages: List[Message],
|
|
agent_state: AgentState,
|
|
llm_client: LLMClientBase,
|
|
tool_rules_solver: ToolRulesSolver,
|
|
) -> Tuple[Dict, Dict, List[Message], List[Message]]:
|
|
for attempt in range(self.max_summarization_retries + 1):
|
|
try:
|
|
log_event("agent.stream_no_tokens.messages.refreshed")
|
|
# Create LLM request data
|
|
request_data = await self._create_llm_request_data_async(
|
|
llm_client=llm_client,
|
|
in_context_messages=current_in_context_messages + new_in_context_messages,
|
|
agent_state=agent_state,
|
|
tool_rules_solver=tool_rules_solver,
|
|
)
|
|
log_event("agent.stream_no_tokens.llm_request.created")
|
|
|
|
# Attempt LLM request
|
|
return (
|
|
request_data,
|
|
await llm_client.request_async(request_data, agent_state.llm_config),
|
|
current_in_context_messages,
|
|
new_in_context_messages,
|
|
)
|
|
|
|
except Exception as e:
|
|
if attempt == self.max_summarization_retries:
|
|
raise e
|
|
|
|
# Handle the error and prepare for retry
|
|
current_in_context_messages = await self._handle_llm_error(
|
|
e,
|
|
llm_client=llm_client,
|
|
in_context_messages=current_in_context_messages,
|
|
new_letta_messages=new_in_context_messages,
|
|
llm_config=agent_state.llm_config,
|
|
force=True,
|
|
)
|
|
new_in_context_messages = []
|
|
log_event(f"agent.stream_no_tokens.retry_attempt.{attempt + 1}")
|
|
|
|
async def _build_and_request_from_llm_streaming(
|
|
self,
|
|
first_chunk: bool,
|
|
ttft_span: "Span",
|
|
request_start_timestamp_ns: int,
|
|
current_in_context_messages: List[Message],
|
|
new_in_context_messages: List[Message],
|
|
agent_state: AgentState,
|
|
llm_client: LLMClientBase,
|
|
tool_rules_solver: ToolRulesSolver,
|
|
) -> Tuple[Dict, AsyncStream[ChatCompletionChunk], List[Message], List[Message]]:
|
|
for attempt in range(self.max_summarization_retries + 1):
|
|
try:
|
|
log_event("agent.stream_no_tokens.messages.refreshed")
|
|
# Create LLM request data
|
|
request_data = await self._create_llm_request_data_async(
|
|
llm_client=llm_client,
|
|
in_context_messages=current_in_context_messages + new_in_context_messages,
|
|
agent_state=agent_state,
|
|
tool_rules_solver=tool_rules_solver,
|
|
)
|
|
log_event("agent.stream.llm_request.created") # [2^]
|
|
|
|
if first_chunk and ttft_span is not None:
|
|
provider_request_start_timestamp_ns = get_utc_timestamp_ns()
|
|
provider_req_start_ns = provider_request_start_timestamp_ns - request_start_timestamp_ns
|
|
ttft_span.add_event(
|
|
name="provider_req_start_ns", attributes={"provider_req_start_ms": provider_req_start_ns // 1_000_000}
|
|
)
|
|
|
|
# Attempt LLM request
|
|
return (
|
|
request_data,
|
|
await llm_client.stream_async(request_data, agent_state.llm_config),
|
|
current_in_context_messages,
|
|
new_in_context_messages,
|
|
)
|
|
|
|
except Exception as e:
|
|
if attempt == self.max_summarization_retries:
|
|
raise e
|
|
|
|
# Handle the error and prepare for retry
|
|
current_in_context_messages = await self._handle_llm_error(
|
|
e,
|
|
llm_client=llm_client,
|
|
in_context_messages=current_in_context_messages,
|
|
new_letta_messages=new_in_context_messages,
|
|
llm_config=agent_state.llm_config,
|
|
force=True,
|
|
)
|
|
new_in_context_messages = []
|
|
log_event(f"agent.stream_no_tokens.retry_attempt.{attempt + 1}")
|
|
|
|
@trace_method
|
|
async def _handle_llm_error(
|
|
self,
|
|
e: Exception,
|
|
llm_client: LLMClientBase,
|
|
in_context_messages: List[Message],
|
|
new_letta_messages: List[Message],
|
|
llm_config: LLMConfig,
|
|
force: bool,
|
|
) -> List[Message]:
|
|
if isinstance(e, LLMContextWindowExceededError):
|
|
return await self._rebuild_context_window(
|
|
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages, llm_config=llm_config, force=force
|
|
)
|
|
else:
|
|
raise llm_client.handle_llm_error(e)
|
|
|
|
@trace_method
|
|
async def _rebuild_context_window(
|
|
self,
|
|
in_context_messages: List[Message],
|
|
new_letta_messages: List[Message],
|
|
llm_config: LLMConfig,
|
|
total_tokens: Optional[int] = None,
|
|
force: bool = False,
|
|
) -> List[Message]:
|
|
# If total tokens is reached, we truncate down
|
|
# TODO: This can be broken by bad configs, e.g. lower bound too high, initial messages too fat, etc.
|
|
if force or (total_tokens and total_tokens > llm_config.context_window):
|
|
self.logger.warning(
|
|
f"Total tokens {total_tokens} exceeds configured max tokens {llm_config.context_window}, forcefully clearing message history."
|
|
)
|
|
new_in_context_messages, updated = self.summarizer.summarize(
|
|
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages, force=True, clear=True
|
|
)
|
|
else:
|
|
new_in_context_messages, updated = self.summarizer.summarize(
|
|
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages
|
|
)
|
|
await self.agent_manager.set_in_context_messages_async(
|
|
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
|
|
)
|
|
|
|
return new_in_context_messages
|
|
|
|
@trace_method
|
|
async def summarize_conversation_history(self) -> AgentState:
|
|
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=self.agent_id, actor=self.actor)
|
|
message_ids = agent_state.message_ids
|
|
in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=message_ids, actor=self.actor)
|
|
new_in_context_messages, updated = self.summarizer.summarize(
|
|
in_context_messages=in_context_messages, new_letta_messages=[], force=True
|
|
)
|
|
return await self.agent_manager.set_in_context_messages_async(
|
|
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
|
|
)
|
|
|
|
@trace_method
|
|
async def _create_llm_request_data_async(
|
|
self,
|
|
llm_client: LLMClientBase,
|
|
in_context_messages: List[Message],
|
|
agent_state: AgentState,
|
|
tool_rules_solver: ToolRulesSolver,
|
|
) -> dict:
|
|
self.num_messages, self.num_archival_memories = await asyncio.gather(
|
|
(
|
|
self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
|
if self.num_messages is None
|
|
else asyncio.sleep(0, result=self.num_messages)
|
|
),
|
|
(
|
|
self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
|
if self.num_archival_memories is None
|
|
else asyncio.sleep(0, result=self.num_archival_memories)
|
|
),
|
|
)
|
|
in_context_messages = await self._rebuild_memory_async(
|
|
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
|
)
|
|
|
|
tools = [
|
|
t
|
|
for t in agent_state.tools
|
|
if t.tool_type
|
|
in {
|
|
ToolType.CUSTOM,
|
|
ToolType.LETTA_CORE,
|
|
ToolType.LETTA_MEMORY_CORE,
|
|
ToolType.LETTA_MULTI_AGENT_CORE,
|
|
ToolType.LETTA_SLEEPTIME_CORE,
|
|
ToolType.LETTA_VOICE_SLEEPTIME_CORE,
|
|
ToolType.LETTA_BUILTIN,
|
|
ToolType.EXTERNAL_COMPOSIO,
|
|
ToolType.EXTERNAL_MCP,
|
|
}
|
|
]
|
|
|
|
# Mirror the sync agent loop: get allowed tools or allow all if none are allowed
|
|
if self.last_function_response is None:
|
|
self.last_function_response = self._load_last_function_response(in_context_messages)
|
|
valid_tool_names = tool_rules_solver.get_allowed_tool_names(
|
|
available_tools=set([t.name for t in tools]),
|
|
last_function_response=self.last_function_response,
|
|
) or list(set(t.name for t in tools))
|
|
|
|
# TODO: Copied from legacy agent loop, so please be cautious
|
|
# Set force tool
|
|
force_tool_call = None
|
|
if len(valid_tool_names) == 1:
|
|
force_tool_call = valid_tool_names[0]
|
|
|
|
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
|
|
allowed_tools = runtime_override_tool_json_schema(
|
|
tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True
|
|
)
|
|
|
|
return llm_client.build_request_data(in_context_messages, agent_state.llm_config, allowed_tools, force_tool_call)
|
|
|
|
@trace_method
|
|
async def _handle_ai_response(
|
|
self,
|
|
tool_call: ToolCall,
|
|
agent_state: AgentState,
|
|
tool_rules_solver: ToolRulesSolver,
|
|
usage: UsageStatistics,
|
|
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
|
|
pre_computed_assistant_message_id: Optional[str] = None,
|
|
pre_computed_tool_message_id: Optional[str] = None,
|
|
step_id: str | None = None,
|
|
new_in_context_messages: Optional[List[Message]] = None,
|
|
agent_step_span: Optional["Span"] = None,
|
|
) -> Tuple[List[Message], bool]:
|
|
"""
|
|
Now that streaming is done, handle the final AI response.
|
|
This might yield additional SSE tokens if we do stalling.
|
|
At the end, set self._continue_execution accordingly.
|
|
"""
|
|
tool_call_name = tool_call.function.name
|
|
tool_call_args_str = tool_call.function.arguments
|
|
|
|
try:
|
|
tool_args = json.loads(tool_call_args_str)
|
|
assert isinstance(tool_args, dict), "tool_args must be a dict"
|
|
except json.JSONDecodeError:
|
|
tool_args = {}
|
|
except AssertionError:
|
|
tool_args = json.loads(tool_args)
|
|
|
|
# Get request heartbeats and coerce to bool
|
|
request_heartbeat = tool_args.pop("request_heartbeat", False)
|
|
# Pre-emptively pop out inner_thoughts
|
|
tool_args.pop(INNER_THOUGHTS_KWARG, "")
|
|
|
|
# So this is necessary, because sometimes non-structured outputs makes mistakes
|
|
if not isinstance(request_heartbeat, bool):
|
|
if isinstance(request_heartbeat, str):
|
|
request_heartbeat = request_heartbeat.lower() == "true"
|
|
else:
|
|
request_heartbeat = bool(request_heartbeat)
|
|
|
|
tool_call_id = tool_call.id or f"call_{uuid.uuid4().hex[:8]}"
|
|
|
|
log_telemetry(
|
|
self.logger,
|
|
"_handle_ai_response execute tool start",
|
|
tool_name=tool_call_name,
|
|
tool_args=tool_args,
|
|
tool_call_id=tool_call_id,
|
|
request_heartbeat=request_heartbeat,
|
|
)
|
|
|
|
tool_execution_result = await self._execute_tool(
|
|
tool_name=tool_call_name,
|
|
tool_args=tool_args,
|
|
agent_state=agent_state,
|
|
agent_step_span=agent_step_span,
|
|
)
|
|
log_telemetry(
|
|
self.logger, "_handle_ai_response execute tool finish", tool_execution_result=tool_execution_result, tool_call_id=tool_call_id
|
|
)
|
|
|
|
if tool_call_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]:
|
|
# with certain functions we rely on the paging mechanism to handle overflow
|
|
truncate = False
|
|
else:
|
|
# but by default, we add a truncation safeguard to prevent bad functions from
|
|
# overflow the agent context window
|
|
truncate = True
|
|
|
|
# get the function response limit
|
|
target_tool = next((x for x in agent_state.tools if x.name == tool_call_name), None)
|
|
return_char_limit = target_tool.return_char_limit
|
|
function_response_string = validate_function_response(
|
|
tool_execution_result.func_return, return_char_limit=return_char_limit, truncate=truncate
|
|
)
|
|
function_response = package_function_response(
|
|
was_success=tool_execution_result.success_flag,
|
|
response_string=function_response_string,
|
|
)
|
|
|
|
# 4. Register tool call with tool rule solver
|
|
# Resolve whether or not to continue stepping
|
|
continue_stepping = request_heartbeat
|
|
tool_rules_solver.register_tool_call(tool_name=tool_call_name)
|
|
if tool_rules_solver.is_terminal_tool(tool_name=tool_call_name):
|
|
continue_stepping = False
|
|
elif tool_rules_solver.has_children_tools(tool_name=tool_call_name):
|
|
continue_stepping = True
|
|
elif tool_rules_solver.is_continue_tool(tool_name=tool_call_name):
|
|
continue_stepping = True
|
|
|
|
# 5a. Persist Steps to DB
|
|
# Following agent loop to persist this before messages
|
|
# TODO (cliandy): determine what should match old loop w/provider_id, job_id
|
|
# TODO (cliandy): UsageStatistics and LettaUsageStatistics are used in many places, but are not the same.
|
|
logged_step = await self.step_manager.log_step_async(
|
|
actor=self.actor,
|
|
agent_id=agent_state.id,
|
|
provider_name=agent_state.llm_config.model_endpoint_type,
|
|
provider_category=agent_state.llm_config.provider_category or "base",
|
|
model=agent_state.llm_config.model,
|
|
model_endpoint=agent_state.llm_config.model_endpoint,
|
|
context_window_limit=agent_state.llm_config.context_window,
|
|
usage=usage,
|
|
provider_id=None,
|
|
job_id=None,
|
|
step_id=step_id,
|
|
)
|
|
|
|
# 5b. Persist Messages to DB
|
|
tool_call_messages = create_letta_messages_from_llm_response(
|
|
agent_id=agent_state.id,
|
|
model=agent_state.llm_config.model,
|
|
function_name=tool_call_name,
|
|
function_arguments=tool_args,
|
|
tool_execution_result=tool_execution_result,
|
|
tool_call_id=tool_call_id,
|
|
function_call_success=tool_execution_result.success_flag,
|
|
function_response=function_response_string,
|
|
actor=self.actor,
|
|
add_heartbeat_request_system_message=continue_stepping,
|
|
reasoning_content=reasoning_content,
|
|
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
|
|
pre_computed_tool_message_id=pre_computed_tool_message_id,
|
|
step_id=logged_step.id if logged_step else None, # TODO (cliandy): eventually move over other agent loops
|
|
)
|
|
|
|
persisted_messages = await self.message_manager.create_many_messages_async(tool_call_messages, actor=self.actor)
|
|
self.last_function_response = function_response
|
|
|
|
return persisted_messages, continue_stepping
|
|
|
|
@trace_method
|
|
async def _execute_tool(
|
|
self, tool_name: str, tool_args: dict, agent_state: AgentState, agent_step_span: Optional["Span"] = None
|
|
) -> "ToolExecutionResult":
|
|
"""
|
|
Executes a tool and returns (result, success_flag).
|
|
"""
|
|
from letta.schemas.tool_execution_result import ToolExecutionResult
|
|
|
|
# Special memory case
|
|
target_tool = next((x for x in agent_state.tools if x.name == tool_name), None)
|
|
if not target_tool:
|
|
# TODO: fix this error message
|
|
return ToolExecutionResult(
|
|
func_return=f"Tool {tool_name} not found",
|
|
status="error",
|
|
)
|
|
|
|
# TODO: This temp. Move this logic and code to executors
|
|
|
|
if agent_step_span:
|
|
start_time = get_utc_timestamp_ns()
|
|
agent_step_span.add_event(name="tool_execution_started")
|
|
|
|
sandbox_env_vars = {var.key: var.value for var in agent_state.tool_exec_environment_variables}
|
|
tool_execution_manager = ToolExecutionManager(
|
|
agent_state=agent_state,
|
|
message_manager=self.message_manager,
|
|
agent_manager=self.agent_manager,
|
|
block_manager=self.block_manager,
|
|
passage_manager=self.passage_manager,
|
|
sandbox_env_vars=sandbox_env_vars,
|
|
actor=self.actor,
|
|
)
|
|
# TODO: Integrate sandbox result
|
|
log_event(name=f"start_{tool_name}_execution", attributes=tool_args)
|
|
tool_execution_result = await tool_execution_manager.execute_tool_async(
|
|
function_name=tool_name, function_args=tool_args, tool=target_tool
|
|
)
|
|
if agent_step_span:
|
|
end_time = get_utc_timestamp_ns()
|
|
agent_step_span.add_event(
|
|
name="tool_execution_completed",
|
|
attributes={
|
|
"tool_name": target_tool.name,
|
|
"duration_ms": (end_time - start_time) // 1_000_000,
|
|
"success": tool_execution_result.success_flag,
|
|
"tool_type": target_tool.tool_type,
|
|
"tool_id": target_tool.id,
|
|
},
|
|
)
|
|
log_event(name=f"finish_{tool_name}_execution", attributes=tool_execution_result.model_dump())
|
|
return tool_execution_result
|
|
|
|
@trace_method
|
|
def _load_last_function_response(self, in_context_messages: List[Message]):
|
|
"""Load the last function response from message history"""
|
|
for msg in reversed(in_context_messages):
|
|
if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent):
|
|
text_content = msg.content[0].text
|
|
try:
|
|
response_json = json.loads(text_content)
|
|
if response_json.get("message"):
|
|
return response_json["message"]
|
|
except (json.JSONDecodeError, KeyError):
|
|
raise ValueError(f"Invalid JSON format in message: {text_content}")
|
|
return None
|