feat: add performance tracing to new agent loop (#2534)
This commit is contained in:
@@ -49,7 +49,7 @@ from letta.services.tool_executor.tool_execution_manager import ToolExecutionMan
|
||||
from letta.settings import model_settings
|
||||
from letta.system import package_function_response
|
||||
from letta.tracing import log_event, trace_method, tracer
|
||||
from letta.utils import validate_function_response
|
||||
from letta.utils import log_telemetry, validate_function_response
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -112,17 +112,34 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True) -> LettaResponse:
|
||||
async def step(
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
max_steps: int = 10,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
) -> LettaResponse:
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||||
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
|
||||
)
|
||||
_, new_in_context_messages, usage = await self._step(agent_state=agent_state, input_messages=input_messages, max_steps=max_steps)
|
||||
_, new_in_context_messages, usage = await self._step(
|
||||
agent_state=agent_state,
|
||||
input_messages=input_messages,
|
||||
max_steps=max_steps,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
)
|
||||
return _create_letta_response(
|
||||
new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def step_stream_no_tokens(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True):
|
||||
async def step_stream_no_tokens(
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
max_steps: int = 10,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
):
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||||
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
|
||||
)
|
||||
@@ -136,8 +153,16 @@ class LettaAgent(BaseAgent):
|
||||
actor=self.actor,
|
||||
)
|
||||
usage = LettaUsageStatistics()
|
||||
|
||||
# span for request
|
||||
request_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns)
|
||||
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
||||
|
||||
for _ in range(max_steps):
|
||||
step_id = generate_step_id()
|
||||
step_start = get_utc_timestamp_ns()
|
||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||||
agent_step_span.set_attributes({"step_id": step_id})
|
||||
|
||||
request_data, response_data, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm(
|
||||
current_in_context_messages,
|
||||
@@ -150,6 +175,11 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
log_event("agent.stream_no_tokens.llm_response.received") # [3^]
|
||||
|
||||
# log llm request time
|
||||
now = get_utc_timestamp_ns()
|
||||
llm_request_ns = now - step_start
|
||||
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ns // 1_000_000})
|
||||
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
|
||||
|
||||
# update usage
|
||||
@@ -177,17 +207,29 @@ class LettaAgent(BaseAgent):
|
||||
logger.info("No reasoning content found.")
|
||||
reasoning = None
|
||||
|
||||
# log LLM request time
|
||||
now = get_utc_timestamp_ns()
|
||||
llm_request_ns = now - step_start
|
||||
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ns // 1_000_000})
|
||||
|
||||
persisted_messages, should_continue = await self._handle_ai_response(
|
||||
tool_call,
|
||||
agent_state,
|
||||
tool_rules_solver,
|
||||
response.usage,
|
||||
reasoning_content=reasoning,
|
||||
agent_step_span=agent_step_span,
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
log_event("agent.stream_no_tokens.llm_response.processed") # [4^]
|
||||
|
||||
# log step time
|
||||
now = get_utc_timestamp_ns()
|
||||
step_ns = now - step_start
|
||||
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": step_ns // 1_000_000})
|
||||
agent_step_span.end()
|
||||
|
||||
# Log LLM Trace
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
@@ -221,12 +263,23 @@ class LettaAgent(BaseAgent):
|
||||
force=False,
|
||||
)
|
||||
|
||||
# log request time
|
||||
if request_start_timestamp_ns:
|
||||
now = get_utc_timestamp_ns()
|
||||
request_ns = now - request_start_timestamp_ns
|
||||
request_span.add_event(name="letta_request_ms", attributes={"duration_ms": request_ns // 1_000_000})
|
||||
request_span.end()
|
||||
|
||||
# Return back usage
|
||||
yield f"data: {usage.model_dump_json()}\n\n"
|
||||
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
|
||||
|
||||
async def _step(
|
||||
self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10
|
||||
self,
|
||||
agent_state: AgentState,
|
||||
input_messages: List[MessageCreate],
|
||||
max_steps: int = 10,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
) -> Tuple[List[Message], List[Message], LettaUsageStatistics]:
|
||||
"""
|
||||
Carries out an invocation of the agent loop. In each step, the agent
|
||||
@@ -244,9 +297,18 @@ class LettaAgent(BaseAgent):
|
||||
put_inner_thoughts_first=True,
|
||||
actor=self.actor,
|
||||
)
|
||||
|
||||
# span for request
|
||||
request_span = tracer.start_span("time_to_first_token")
|
||||
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
||||
|
||||
usage = LettaUsageStatistics()
|
||||
for _ in range(max_steps):
|
||||
step_id = generate_step_id()
|
||||
step_start = get_utc_timestamp_ns()
|
||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||||
agent_step_span.set_attributes({"step_id": step_id})
|
||||
|
||||
request_data, response_data, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm(
|
||||
current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver
|
||||
)
|
||||
@@ -256,6 +318,11 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
|
||||
|
||||
# log LLM request time
|
||||
now = get_utc_timestamp_ns()
|
||||
llm_request_ns = now - step_start
|
||||
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ns // 1_000_000})
|
||||
|
||||
# TODO: add run_id
|
||||
usage.step_count += 1
|
||||
usage.completion_tokens += response.usage.completion_tokens
|
||||
@@ -287,11 +354,18 @@ class LettaAgent(BaseAgent):
|
||||
response.usage,
|
||||
reasoning_content=reasoning,
|
||||
step_id=step_id,
|
||||
agent_step_span=agent_step_span,
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
log_event("agent.step.llm_response.processed") # [4^]
|
||||
|
||||
# log step time
|
||||
now = get_utc_timestamp_ns()
|
||||
step_ns = now - step_start
|
||||
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": step_ns // 1_000_000})
|
||||
agent_step_span.end()
|
||||
|
||||
# Log LLM Trace
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
@@ -306,6 +380,13 @@ class LettaAgent(BaseAgent):
|
||||
if not should_continue:
|
||||
break
|
||||
|
||||
# log request time
|
||||
if request_start_timestamp_ns:
|
||||
now = get_utc_timestamp_ns()
|
||||
request_ns = now - request_start_timestamp_ns
|
||||
request_span.add_event(name="request_ms", attributes={"duration_ms": request_ns // 1_000_000})
|
||||
request_span.end()
|
||||
|
||||
# Extend the in context message ids
|
||||
if not agent_state.message_buffer_autoclear:
|
||||
await self._rebuild_context_window(
|
||||
@@ -353,17 +434,21 @@ class LettaAgent(BaseAgent):
|
||||
actor=self.actor,
|
||||
)
|
||||
usage = LettaUsageStatistics()
|
||||
first_chunk, ttft_span = True, None
|
||||
if request_start_timestamp_ns is not None:
|
||||
ttft_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns)
|
||||
ttft_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
||||
first_chunk, request_span = True, None
|
||||
if request_start_timestamp_ns:
|
||||
request_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns)
|
||||
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
||||
|
||||
provider_request_start_timestamp_ns = None
|
||||
for _ in range(max_steps):
|
||||
step_id = generate_step_id()
|
||||
step_start = get_utc_timestamp_ns()
|
||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||||
agent_step_span.set_attributes({"step_id": step_id})
|
||||
|
||||
request_data, stream, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm_streaming(
|
||||
first_chunk,
|
||||
ttft_span,
|
||||
agent_step_span,
|
||||
request_start_timestamp_ns,
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
@@ -389,14 +474,13 @@ class LettaAgent(BaseAgent):
|
||||
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
|
||||
|
||||
async for chunk in interface.process(
|
||||
stream, ttft_span=ttft_span, provider_request_start_timestamp_ns=provider_request_start_timestamp_ns
|
||||
stream, ttft_span=request_span, provider_request_start_timestamp_ns=provider_request_start_timestamp_ns
|
||||
):
|
||||
# Measure time to first token
|
||||
if first_chunk and ttft_span is not None:
|
||||
if first_chunk and request_span is not None:
|
||||
now = get_utc_timestamp_ns()
|
||||
ttft_ns = now - request_start_timestamp_ns
|
||||
ttft_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ttft_ns // 1_000_000})
|
||||
ttft_span.end()
|
||||
request_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ttft_ns // 1_000_000})
|
||||
first_chunk = False
|
||||
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
@@ -413,6 +497,11 @@ class LettaAgent(BaseAgent):
|
||||
await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor)
|
||||
persisted_input_messages = True
|
||||
|
||||
# log LLM request time
|
||||
now = get_utc_timestamp_ns()
|
||||
llm_request_ns = now - step_start
|
||||
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ns // 1_000_000})
|
||||
|
||||
# Process resulting stream content
|
||||
tool_call = interface.get_tool_call_object()
|
||||
reasoning_content = interface.get_reasoning_content()
|
||||
@@ -429,10 +518,17 @@ class LettaAgent(BaseAgent):
|
||||
pre_computed_assistant_message_id=interface.letta_assistant_message_id,
|
||||
pre_computed_tool_message_id=interface.letta_tool_message_id,
|
||||
step_id=step_id,
|
||||
agent_step_span=agent_step_span,
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
|
||||
# log total step time
|
||||
now = get_utc_timestamp_ns()
|
||||
step_ns = now - step_start
|
||||
agent_step_span.add_event(name="step_ms", attributes={"duration_ms": step_ns // 1_000_000})
|
||||
agent_step_span.end()
|
||||
|
||||
# TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream
|
||||
# log_event("agent.stream.llm_response.processed") # [4^]
|
||||
|
||||
@@ -477,6 +573,13 @@ class LettaAgent(BaseAgent):
|
||||
force=False,
|
||||
)
|
||||
|
||||
# log time of entire request
|
||||
if request_start_timestamp_ns:
|
||||
now = get_utc_timestamp_ns()
|
||||
request_ns = now - request_start_timestamp_ns
|
||||
request_span.add_event(name="letta_request_ms", attributes={"duration_ms": request_ns // 1_000_000})
|
||||
request_span.end()
|
||||
|
||||
# TODO: Also yield out a letta usage stats SSE
|
||||
yield f"data: {usage.model_dump_json()}\n\n"
|
||||
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
|
||||
@@ -710,6 +813,7 @@ class LettaAgent(BaseAgent):
|
||||
pre_computed_tool_message_id: Optional[str] = None,
|
||||
step_id: str | None = None,
|
||||
new_in_context_messages: Optional[List[Message]] = None,
|
||||
agent_step_span: Optional["Span"] = None,
|
||||
) -> Tuple[List[Message], bool]:
|
||||
"""
|
||||
Now that streaming is done, handle the final AI response.
|
||||
@@ -741,10 +845,23 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
tool_call_id = tool_call.id or f"call_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
log_telemetry(
|
||||
self.logger,
|
||||
"_handle_ai_response execute tool start",
|
||||
tool_name=tool_call_name,
|
||||
tool_args=tool_args,
|
||||
tool_call_id=tool_call_id,
|
||||
request_heartbeat=request_heartbeat,
|
||||
)
|
||||
|
||||
tool_execution_result = await self._execute_tool(
|
||||
tool_name=tool_call_name,
|
||||
tool_args=tool_args,
|
||||
agent_state=agent_state,
|
||||
agent_step_span=agent_step_span,
|
||||
)
|
||||
log_telemetry(
|
||||
self.logger, "_handle_ai_response execute tool finish", tool_execution_result=tool_execution_result, tool_call_id=tool_call_id
|
||||
)
|
||||
|
||||
if tool_call_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]:
|
||||
@@ -819,7 +936,9 @@ class LettaAgent(BaseAgent):
|
||||
return persisted_messages, continue_stepping
|
||||
|
||||
@trace_method
|
||||
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> "ToolExecutionResult":
|
||||
async def _execute_tool(
|
||||
self, tool_name: str, tool_args: dict, agent_state: AgentState, agent_step_span: Optional["Span"] = None
|
||||
) -> "ToolExecutionResult":
|
||||
"""
|
||||
Executes a tool and returns (result, success_flag).
|
||||
"""
|
||||
@@ -835,6 +954,11 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
|
||||
# TODO: This temp. Move this logic and code to executors
|
||||
|
||||
if agent_step_span:
|
||||
start_time = get_utc_timestamp_ns()
|
||||
agent_step_span.add_event(name="tool_execution_started")
|
||||
|
||||
sandbox_env_vars = {var.key: var.value for var in agent_state.tool_exec_environment_variables}
|
||||
tool_execution_manager = ToolExecutionManager(
|
||||
agent_state=agent_state,
|
||||
@@ -850,7 +974,19 @@ class LettaAgent(BaseAgent):
|
||||
tool_execution_result = await tool_execution_manager.execute_tool_async(
|
||||
function_name=tool_name, function_args=tool_args, tool=target_tool
|
||||
)
|
||||
log_event(name=f"finish_{tool_name}_execution", attributes=tool_args)
|
||||
if agent_step_span:
|
||||
end_time = get_utc_timestamp_ns()
|
||||
agent_step_span.add_event(
|
||||
name="tool_execution_completed",
|
||||
attributes={
|
||||
"tool_name": target_tool.name,
|
||||
"duration_ms": (end_time - start_time) // 1_000_000,
|
||||
"success": tool_execution_result.success_flag,
|
||||
"tool_type": target_tool.tool_type,
|
||||
"tool_id": target_tool.id,
|
||||
},
|
||||
)
|
||||
log_event(name=f"finish_{tool_name}_execution", attributes=tool_execution_result.model_dump())
|
||||
return tool_execution_result
|
||||
|
||||
@trace_method
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import AsyncGenerator, List, Tuple, Union
|
||||
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||
|
||||
from letta.agents.helpers import _create_letta_response, serialize_message_history
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
@@ -89,7 +89,7 @@ class VoiceSleeptimeAgent(LettaAgent):
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState):
|
||||
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState, agent_step_span: Optional["Span"] = None):
|
||||
"""
|
||||
Executes a tool and returns (result, success_flag).
|
||||
"""
|
||||
|
||||
@@ -655,6 +655,7 @@ async def send_message(
|
||||
This endpoint accepts a message from a user and processes it through the agent.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
request_start_timestamp_ns = get_utc_timestamp_ns()
|
||||
# user_eligible = actor.organization_id not in ["org-4a3af5dd-4c6a-48cb-ac13-3f73ecaaa4bf", "org-4ab3f6e8-9a44-4bee-aeb6-c681cbbc7bf6"]
|
||||
# TODO: This is redundant, remove soon
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"])
|
||||
@@ -688,7 +689,12 @@ async def send_message(
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
)
|
||||
|
||||
result = await experimental_agent.step(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message)
|
||||
result = await experimental_agent.step(
|
||||
request.messages,
|
||||
max_steps=10,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
)
|
||||
else:
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
@@ -740,6 +746,7 @@ async def send_message_streaming(
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"]
|
||||
model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai"]
|
||||
not_letta_endpoint = not ("inference.letta.com" in agent.llm_config.model_endpoint)
|
||||
request_start_timestamp_ns = get_utc_timestamp_ns()
|
||||
|
||||
if agent_eligible and feature_enabled and model_compatible:
|
||||
if agent.enable_sleeptime:
|
||||
@@ -782,7 +789,10 @@ async def send_message_streaming(
|
||||
else:
|
||||
result = StreamingResponseWithStatusCode(
|
||||
experimental_agent.step_stream_no_tokens(
|
||||
request.messages, max_steps=10, use_assistant_message=request.use_assistant_message
|
||||
request.messages,
|
||||
max_steps=10,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
@@ -815,6 +825,7 @@ async def process_message_background(
|
||||
) -> None:
|
||||
"""Background task to process the message and update job status."""
|
||||
try:
|
||||
request_start_timestamp_ns = get_utc_timestamp_ns()
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
@@ -825,6 +836,7 @@ async def process_message_background(
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
metadata={"job_id": job_id}, # Pass job_id through metadata
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
)
|
||||
|
||||
# Update job status to completed
|
||||
|
||||
Reference in New Issue
Block a user