feat: add performance tracing to new agent loop (#2534)

This commit is contained in:
Sarah Wooders
2025-05-30 12:49:46 -07:00
committed by GitHub
parent e19955db6f
commit fe829a71bf
3 changed files with 168 additions and 20 deletions

View File

@@ -49,7 +49,7 @@ from letta.services.tool_executor.tool_execution_manager import ToolExecutionMan
from letta.settings import model_settings
from letta.system import package_function_response
from letta.tracing import log_event, trace_method, tracer
from letta.utils import validate_function_response
from letta.utils import log_telemetry, validate_function_response
logger = get_logger(__name__)
@@ -112,17 +112,34 @@ class LettaAgent(BaseAgent):
)
@trace_method
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True) -> LettaResponse:
async def step(
self,
input_messages: List[MessageCreate],
max_steps: int = 10,
use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None,
) -> LettaResponse:
agent_state = await self.agent_manager.get_agent_by_id_async(
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
)
_, new_in_context_messages, usage = await self._step(agent_state=agent_state, input_messages=input_messages, max_steps=max_steps)
_, new_in_context_messages, usage = await self._step(
agent_state=agent_state,
input_messages=input_messages,
max_steps=max_steps,
request_start_timestamp_ns=request_start_timestamp_ns,
)
return _create_letta_response(
new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage
)
@trace_method
async def step_stream_no_tokens(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True):
async def step_stream_no_tokens(
self,
input_messages: List[MessageCreate],
max_steps: int = 10,
use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None,
):
agent_state = await self.agent_manager.get_agent_by_id_async(
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
)
@@ -136,8 +153,16 @@ class LettaAgent(BaseAgent):
actor=self.actor,
)
usage = LettaUsageStatistics()
# span for request
request_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns)
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
for _ in range(max_steps):
step_id = generate_step_id()
step_start = get_utc_timestamp_ns()
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
agent_step_span.set_attributes({"step_id": step_id})
request_data, response_data, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm(
current_in_context_messages,
@@ -150,6 +175,11 @@ class LettaAgent(BaseAgent):
log_event("agent.stream_no_tokens.llm_response.received") # [3^]
# log llm request time
now = get_utc_timestamp_ns()
llm_request_ns = now - step_start
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ns // 1_000_000})
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
# update usage
@@ -177,17 +207,29 @@ class LettaAgent(BaseAgent):
logger.info("No reasoning content found.")
reasoning = None
# log LLM request time
now = get_utc_timestamp_ns()
llm_request_ns = now - step_start
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ns // 1_000_000})
persisted_messages, should_continue = await self._handle_ai_response(
tool_call,
agent_state,
tool_rules_solver,
response.usage,
reasoning_content=reasoning,
agent_step_span=agent_step_span,
)
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
log_event("agent.stream_no_tokens.llm_response.processed") # [4^]
# log step time
now = get_utc_timestamp_ns()
step_ns = now - step_start
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": step_ns // 1_000_000})
agent_step_span.end()
# Log LLM Trace
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
@@ -221,12 +263,23 @@ class LettaAgent(BaseAgent):
force=False,
)
# log request time
if request_start_timestamp_ns:
now = get_utc_timestamp_ns()
request_ns = now - request_start_timestamp_ns
request_span.add_event(name="letta_request_ms", attributes={"duration_ms": request_ns // 1_000_000})
request_span.end()
# Return back usage
yield f"data: {usage.model_dump_json()}\n\n"
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
async def _step(
self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10
self,
agent_state: AgentState,
input_messages: List[MessageCreate],
max_steps: int = 10,
request_start_timestamp_ns: Optional[int] = None,
) -> Tuple[List[Message], List[Message], LettaUsageStatistics]:
"""
Carries out an invocation of the agent loop. In each step, the agent
@@ -244,9 +297,18 @@ class LettaAgent(BaseAgent):
put_inner_thoughts_first=True,
actor=self.actor,
)
# span for request
request_span = tracer.start_span("time_to_first_token")
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
usage = LettaUsageStatistics()
for _ in range(max_steps):
step_id = generate_step_id()
step_start = get_utc_timestamp_ns()
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
agent_step_span.set_attributes({"step_id": step_id})
request_data, response_data, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm(
current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver
)
@@ -256,6 +318,11 @@ class LettaAgent(BaseAgent):
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
# log LLM request time
now = get_utc_timestamp_ns()
llm_request_ns = now - step_start
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ns // 1_000_000})
# TODO: add run_id
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
@@ -287,11 +354,18 @@ class LettaAgent(BaseAgent):
response.usage,
reasoning_content=reasoning,
step_id=step_id,
agent_step_span=agent_step_span,
)
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
log_event("agent.step.llm_response.processed") # [4^]
# log step time
now = get_utc_timestamp_ns()
step_ns = now - step_start
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": step_ns // 1_000_000})
agent_step_span.end()
# Log LLM Trace
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
@@ -306,6 +380,13 @@ class LettaAgent(BaseAgent):
if not should_continue:
break
# log request time
if request_start_timestamp_ns:
now = get_utc_timestamp_ns()
request_ns = now - request_start_timestamp_ns
request_span.add_event(name="request_ms", attributes={"duration_ms": request_ns // 1_000_000})
request_span.end()
# Extend the in context message ids
if not agent_state.message_buffer_autoclear:
await self._rebuild_context_window(
@@ -353,17 +434,21 @@ class LettaAgent(BaseAgent):
actor=self.actor,
)
usage = LettaUsageStatistics()
first_chunk, ttft_span = True, None
if request_start_timestamp_ns is not None:
ttft_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns)
ttft_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
first_chunk, request_span = True, None
if request_start_timestamp_ns:
request_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns)
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
provider_request_start_timestamp_ns = None
for _ in range(max_steps):
step_id = generate_step_id()
step_start = get_utc_timestamp_ns()
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
agent_step_span.set_attributes({"step_id": step_id})
request_data, stream, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm_streaming(
first_chunk,
ttft_span,
agent_step_span,
request_start_timestamp_ns,
current_in_context_messages,
new_in_context_messages,
@@ -389,14 +474,13 @@ class LettaAgent(BaseAgent):
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
async for chunk in interface.process(
stream, ttft_span=ttft_span, provider_request_start_timestamp_ns=provider_request_start_timestamp_ns
stream, ttft_span=request_span, provider_request_start_timestamp_ns=provider_request_start_timestamp_ns
):
# Measure time to first token
if first_chunk and ttft_span is not None:
if first_chunk and request_span is not None:
now = get_utc_timestamp_ns()
ttft_ns = now - request_start_timestamp_ns
ttft_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ttft_ns // 1_000_000})
ttft_span.end()
request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ttft_ns // 1_000_000})
first_chunk = False
yield f"data: {chunk.model_dump_json()}\n\n"
@@ -413,6 +497,11 @@ class LettaAgent(BaseAgent):
await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor)
persisted_input_messages = True
# log LLM request time
now = get_utc_timestamp_ns()
llm_request_ns = now - step_start
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ns // 1_000_000})
# Process resulting stream content
tool_call = interface.get_tool_call_object()
reasoning_content = interface.get_reasoning_content()
@@ -429,10 +518,17 @@ class LettaAgent(BaseAgent):
pre_computed_assistant_message_id=interface.letta_assistant_message_id,
pre_computed_tool_message_id=interface.letta_tool_message_id,
step_id=step_id,
agent_step_span=agent_step_span,
)
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
# log total step time
now = get_utc_timestamp_ns()
step_ns = now - step_start
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": step_ns // 1_000_000})
agent_step_span.end()
# TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream
# log_event("agent.stream.llm_response.processed") # [4^]
@@ -477,6 +573,13 @@ class LettaAgent(BaseAgent):
force=False,
)
# log time of entire request
if request_start_timestamp_ns:
now = get_utc_timestamp_ns()
request_ns = now - request_start_timestamp_ns
request_span.add_event(name="letta_request_ms", attributes={"duration_ms": request_ns // 1_000_000})
request_span.end()
# TODO: Also yield out a letta usage stats SSE
yield f"data: {usage.model_dump_json()}\n\n"
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
@@ -710,6 +813,7 @@ class LettaAgent(BaseAgent):
pre_computed_tool_message_id: Optional[str] = None,
step_id: str | None = None,
new_in_context_messages: Optional[List[Message]] = None,
agent_step_span: Optional["Span"] = None,
) -> Tuple[List[Message], bool]:
"""
Now that streaming is done, handle the final AI response.
@@ -741,10 +845,23 @@ class LettaAgent(BaseAgent):
tool_call_id = tool_call.id or f"call_{uuid.uuid4().hex[:8]}"
log_telemetry(
self.logger,
"_handle_ai_response execute tool start",
tool_name=tool_call_name,
tool_args=tool_args,
tool_call_id=tool_call_id,
request_heartbeat=request_heartbeat,
)
tool_execution_result = await self._execute_tool(
tool_name=tool_call_name,
tool_args=tool_args,
agent_state=agent_state,
agent_step_span=agent_step_span,
)
log_telemetry(
self.logger, "_handle_ai_response execute tool finish", tool_execution_result=tool_execution_result, tool_call_id=tool_call_id
)
if tool_call_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]:
@@ -819,7 +936,9 @@ class LettaAgent(BaseAgent):
return persisted_messages, continue_stepping
@trace_method
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> "ToolExecutionResult":
async def _execute_tool(
self, tool_name: str, tool_args: dict, agent_state: AgentState, agent_step_span: Optional["Span"] = None
) -> "ToolExecutionResult":
"""
Executes a tool and returns (result, success_flag).
"""
@@ -835,6 +954,11 @@ class LettaAgent(BaseAgent):
)
# TODO: This temp. Move this logic and code to executors
if agent_step_span:
start_time = get_utc_timestamp_ns()
agent_step_span.add_event(name="tool_execution_started")
sandbox_env_vars = {var.key: var.value for var in agent_state.tool_exec_environment_variables}
tool_execution_manager = ToolExecutionManager(
agent_state=agent_state,
@@ -850,7 +974,19 @@ class LettaAgent(BaseAgent):
tool_execution_result = await tool_execution_manager.execute_tool_async(
function_name=tool_name, function_args=tool_args, tool=target_tool
)
log_event(name=f"finish_{tool_name}_execution", attributes=tool_args)
if agent_step_span:
end_time = get_utc_timestamp_ns()
agent_step_span.add_event(
name="tool_execution_completed",
attributes={
"tool_name": target_tool.name,
"duration_ms": (end_time - start_time) // 1_000_000,
"success": tool_execution_result.success_flag,
"tool_type": target_tool.tool_type,
"tool_id": target_tool.id,
},
)
log_event(name=f"finish_{tool_name}_execution", attributes=tool_execution_result.model_dump())
return tool_execution_result
@trace_method

View File

@@ -1,4 +1,4 @@
from typing import AsyncGenerator, List, Tuple, Union
from typing import AsyncGenerator, List, Optional, Tuple, Union
from letta.agents.helpers import _create_letta_response, serialize_message_history
from letta.agents.letta_agent import LettaAgent
@@ -89,7 +89,7 @@ class VoiceSleeptimeAgent(LettaAgent):
)
@trace_method
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState):
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState, agent_step_span: Optional["Span"] = None):
"""
Executes a tool and returns (result, success_flag).
"""

View File

@@ -655,6 +655,7 @@ async def send_message(
This endpoint accepts a message from a user and processes it through the agent.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
request_start_timestamp_ns = get_utc_timestamp_ns()
# user_eligible = actor.organization_id not in ["org-4a3af5dd-4c6a-48cb-ac13-3f73ecaaa4bf", "org-4ab3f6e8-9a44-4bee-aeb6-c681cbbc7bf6"]
# TODO: This is redundant, remove soon
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"])
@@ -688,7 +689,12 @@ async def send_message(
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
)
result = await experimental_agent.step(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message)
result = await experimental_agent.step(
request.messages,
max_steps=10,
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
)
else:
result = await server.send_message_to_agent(
agent_id=agent_id,
@@ -740,6 +746,7 @@ async def send_message_streaming(
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"]
model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai"]
not_letta_endpoint = not ("inference.letta.com" in agent.llm_config.model_endpoint)
request_start_timestamp_ns = get_utc_timestamp_ns()
if agent_eligible and feature_enabled and model_compatible:
if agent.enable_sleeptime:
@@ -782,7 +789,10 @@ async def send_message_streaming(
else:
result = StreamingResponseWithStatusCode(
experimental_agent.step_stream_no_tokens(
request.messages, max_steps=10, use_assistant_message=request.use_assistant_message
request.messages,
max_steps=10,
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
),
media_type="text/event-stream",
)
@@ -815,6 +825,7 @@ async def process_message_background(
) -> None:
"""Background task to process the message and update job status."""
try:
request_start_timestamp_ns = get_utc_timestamp_ns()
result = await server.send_message_to_agent(
agent_id=agent_id,
actor=actor,
@@ -825,6 +836,7 @@ async def process_message_background(
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
metadata={"job_id": job_id}, # Pass job_id through metadata
request_start_timestamp_ns=request_start_timestamp_ns,
)
# Update job status to completed