chore: add tracing to new agent loop

This commit is contained in:
Caren Thomas
2025-09-09 13:14:30 -07:00
parent a5b34bca78
commit b95c68414b
2 changed files with 26 additions and 0 deletions

View File

@@ -119,6 +119,7 @@ class LettaAgentV2(BaseAgentV2):
agent_id=self.agent_state.id,
)
@trace_method
async def build_request(self, input_messages: list[MessageCreate]) -> dict:
"""
Build the request data for an LLM call without actually executing it.
@@ -146,6 +147,7 @@ class LettaAgentV2(BaseAgentV2):
return request
@trace_method
async def step(
self,
input_messages: list[MessageCreate],
@@ -210,6 +212,7 @@ class LettaAgentV2(BaseAgentV2):
self._request_checkpoint_finish(request_span=request_span, request_start_timestamp_ns=request_start_timestamp_ns)
return LettaResponse(messages=response_letta_messages, stop_reason=self.stop_reason, usage=self.usage)
@trace_method
async def stream(
self,
input_messages: list[MessageCreate],
@@ -298,6 +301,7 @@ class LettaAgentV2(BaseAgentV2):
for finish_chunk in self.get_finish_chunks_for_stream(self.usage, self.stop_reason):
yield f"data: {finish_chunk}\n\n"
@trace_method
async def _step(
self,
messages: list[Message],
@@ -588,6 +592,7 @@ class LettaAgentV2(BaseAgentV2):
return maybe_approval_request, maybe_approval_response
return None, None
@trace_method
async def _check_run_cancellation(self, run_id) -> bool:
try:
job = await self.job_manager.get_job_by_id_async(job_id=run_id, actor=self.actor)
@@ -597,6 +602,7 @@ class LettaAgentV2(BaseAgentV2):
self.logger.warning(f"Failed to check job cancellation status for job {run_id}: {e}")
return False
@trace_method
async def _refresh_messages(self, in_context_messages: list[Message]):
num_messages = await self.message_manager.size_async(
agent_id=self.agent_state.id,
@@ -614,6 +620,7 @@ class LettaAgentV2(BaseAgentV2):
in_context_messages = scrub_inner_thoughts_from_messages(in_context_messages, self.agent_state.llm_config)
return in_context_messages
@trace_method
async def _rebuild_memory(
self,
in_context_messages: list[Message],
@@ -702,6 +709,7 @@ class LettaAgentV2(BaseAgentV2):
else:
return in_context_messages
@trace_method
async def _get_valid_tools(self, in_context_messages: list[Message]):
tools = self.agent_state.tools
self.last_function_response = self._load_last_function_response(in_context_messages)
@@ -720,6 +728,7 @@ class LettaAgentV2(BaseAgentV2):
)
return allowed_tools
@trace_method
def _load_last_function_response(self, in_context_messages: list[Message]):
"""Load the last function response from message history"""
for msg in reversed(in_context_messages):
@@ -733,6 +742,7 @@ class LettaAgentV2(BaseAgentV2):
raise ValueError(f"Invalid JSON format in message: {text_content}")
return None
@trace_method
def _request_checkpoint_start(self, request_start_timestamp_ns: int | None) -> Span | None:
if request_start_timestamp_ns is not None:
request_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns)
@@ -742,6 +752,7 @@ class LettaAgentV2(BaseAgentV2):
return request_span
return None
@trace_method
def _request_checkpoint_ttft(self, request_span: Span | None, request_start_timestamp_ns: int | None) -> Span | None:
if request_span:
ttft_ns = get_utc_timestamp_ns() - request_start_timestamp_ns
@@ -749,6 +760,7 @@ class LettaAgentV2(BaseAgentV2):
return request_span
return None
@trace_method
def _request_checkpoint_finish(self, request_span: Span | None, request_start_timestamp_ns: int | None) -> None:
if request_span is not None:
duration_ns = get_utc_timestamp_ns() - request_start_timestamp_ns
@@ -756,6 +768,7 @@ class LettaAgentV2(BaseAgentV2):
request_span.end()
return None
@trace_method
def _step_checkpoint_start(self, step_id: str) -> Tuple[StepProgression, StepMetrics, Span]:
step_start_ns = get_utc_timestamp_ns()
step_metrics = StepMetrics(id=step_id, step_start_ns=step_start_ns)
@@ -763,6 +776,7 @@ class LettaAgentV2(BaseAgentV2):
agent_step_span.set_attributes({"step_id": step_id})
return StepProgression.START, step_metrics, agent_step_span
@trace_method
def _step_checkpoint_llm_request_start(self, step_metrics: StepMetrics, agent_step_span: Span) -> Tuple[StepProgression, StepMetrics]:
llm_request_start_ns = get_utc_timestamp_ns()
step_metrics.llm_request_start_ns = llm_request_start_ns
@@ -772,6 +786,7 @@ class LettaAgentV2(BaseAgentV2):
)
return StepProgression.START, step_metrics
@trace_method
def _step_checkpoint_llm_request_finish(
self, step_metrics: StepMetrics, agent_step_span: Span, llm_request_finish_timestamp_ns: int
) -> Tuple[StepProgression, StepMetrics]:
@@ -780,6 +795,7 @@ class LettaAgentV2(BaseAgentV2):
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": ns_to_ms(llm_request_ns)})
return StepProgression.RESPONSE_RECEIVED, step_metrics
@trace_method
def _step_checkpoint_finish(
self, step_metrics: StepMetrics, agent_step_span: Span | None, run_id: str | None
) -> Tuple[StepProgression, StepMetrics]:
@@ -798,6 +814,7 @@ class LettaAgentV2(BaseAgentV2):
self.usage.prompt_tokens += step_usage_stats.prompt_tokens
self.usage.total_tokens += step_usage_stats.total_tokens
@trace_method
async def _handle_ai_response(
self,
tool_call: ToolCall,
@@ -973,6 +990,7 @@ class LettaAgentV2(BaseAgentV2):
return persisted_messages, continue_stepping, stop_reason
@trace_method
def _decide_continuation(
self,
agent_state: AgentState,
@@ -1145,6 +1163,7 @@ class LettaAgentV2(BaseAgentV2):
)
return task
@trace_method
async def _log_request(
self,
request_start_timestamp_ns: int,
@@ -1170,6 +1189,7 @@ class LettaAgentV2(BaseAgentV2):
if request_span:
request_span.end()
@trace_method
async def _update_agent_last_run_metrics(self, completion_time: datetime, duration_ms: float) -> None:
if not settings.track_last_agent_run:
return

View File

@@ -5,6 +5,7 @@ from datetime import datetime, timezone
from letta.agents.letta_agent_v2 import LettaAgentV2
from letta.constants import DEFAULT_MAX_STEPS
from letta.groups.helpers import stringify_message
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
from letta.schemas.enums import JobStatus
from letta.schemas.group import Group, ManagerType
@@ -33,6 +34,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
# Additional manager classes
self.group_manager = GroupManager()
@trace_method
async def step(
self,
input_messages: list[MessageCreate],
@@ -61,6 +63,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
response.usage.run_ids = self.run_ids
return response
@trace_method
async def stream(
self,
input_messages: list[MessageCreate],
@@ -90,6 +93,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
await self.run_sleeptime_agents(use_assistant_message=use_assistant_message)
@trace_method
async def run_sleeptime_agents(self, use_assistant_message: bool = True):
# Get response messages
last_response_messages = self.response_messages
@@ -120,6 +124,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
print(f"Sleeptime agent processing failed: {e!s}")
raise e
@trace_method
async def _issue_background_task(
self,
sleeptime_agent_id: str,
@@ -149,6 +154,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
)
return run.id
@trace_method
async def _participant_agent_step(
self,
foreground_agent_id: str,