feat: track run state in stream

This commit is contained in:
Andy Li
2025-08-08 16:31:15 -07:00
committed by GitHub
parent 7645310570
commit c6002744e6
2 changed files with 26 additions and 11 deletions

View File

@@ -220,6 +220,7 @@ class LettaAgent(BaseAgent):
actor=self.actor,
)
stop_reason = None
job_update_metadata = None
usage = LettaUsageStatistics()
# span for request
@@ -367,6 +368,7 @@ class LettaAgent(BaseAgent):
except Exception as e:
# Handle any unexpected errors during step processing
self.logger.error(f"Error during step processing: {e}")
job_update_metadata = {"error": str(e)}
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
if not stop_reason:
@@ -429,7 +431,7 @@ class LettaAgent(BaseAgent):
self.logger.error("Invalid StepProgression value")
if settings.track_stop_reason:
await self._log_request(request_start_timestamp_ns, request_span)
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True)
except Exception as e:
self.logger.error("Failed to update step: %s", e)
@@ -447,7 +449,7 @@ class LettaAgent(BaseAgent):
force=False,
)
await self._log_request(request_start_timestamp_ns, request_span)
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=False)
# Return back usage
for finish_chunk in self.get_finish_chunks_for_stream(usage, stop_reason):
@@ -485,6 +487,7 @@ class LettaAgent(BaseAgent):
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
stop_reason = None
job_update_metadata = None
usage = LettaUsageStatistics()
for i in range(max_steps):
# If dry run, build request data and return it without making LLM call
@@ -622,6 +625,7 @@ class LettaAgent(BaseAgent):
except Exception as e:
# Handle any unexpected errors during step processing
self.logger.error(f"Error during step processing: {e}")
job_update_metadata = {"error": str(e)}
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
if not stop_reason:
@@ -680,7 +684,7 @@ class LettaAgent(BaseAgent):
self.logger.error("Invalid StepProgression value")
if settings.track_stop_reason:
await self._log_request(request_start_timestamp_ns, request_span)
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True)
except Exception as e:
self.logger.error("Failed to update step: %s", e)
@@ -698,7 +702,7 @@ class LettaAgent(BaseAgent):
force=False,
)
await self._log_request(request_start_timestamp_ns, request_span)
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=False)
return current_in_context_messages, new_in_context_messages, stop_reason, usage
@@ -748,6 +752,7 @@ class LettaAgent(BaseAgent):
actor=self.actor,
)
stop_reason = None
job_update_metadata = None
usage = LettaUsageStatistics()
first_chunk, request_span = True, None
if request_start_timestamp_ns:
@@ -977,6 +982,7 @@ class LettaAgent(BaseAgent):
except Exception as e:
# Handle any unexpected errors during step processing
self.logger.error(f"Error during step processing: {e}")
job_update_metadata = {"error": str(e)}
# This indicates we failed after we decided to stop stepping, which indicates a bug with our flow.
if not stop_reason:
@@ -1039,7 +1045,7 @@ class LettaAgent(BaseAgent):
# Do tracking for failure cases. Can consolidate with success conditions later.
if settings.track_stop_reason:
await self._log_request(request_start_timestamp_ns, request_span)
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=True)
except Exception as e:
self.logger.error("Failed to update step: %s", e)
@@ -1056,20 +1062,28 @@ class LettaAgent(BaseAgent):
force=False,
)
await self._log_request(request_start_timestamp_ns, request_span)
await self._log_request(request_start_timestamp_ns, request_span, job_update_metadata, is_error=False)
for finish_chunk in self.get_finish_chunks_for_stream(usage, stop_reason):
yield f"data: {finish_chunk}\n\n"
async def _log_request(self, request_start_timestamp_ns: int, request_span: "Span | None"):
async def _log_request(
self, request_start_timestamp_ns: int, request_span: "Span | None", job_update_metadata: dict | None, is_error: bool
):
if request_start_timestamp_ns:
now_ns, now = get_utc_timestamp_ns(), get_utc_time()
duration_ns = now_ns - request_start_timestamp_ns
if request_span:
request_span.add_event(name="letta_request_ms", attributes={"duration_ms": ns_to_ms(duration_ns)})
await self._update_agent_last_run_metrics(now, ns_to_ms(duration_ns))
if self.current_run_id:
if settings.track_agent_run and self.current_run_id:
await self.job_manager.record_response_duration(self.current_run_id, duration_ns, self.actor)
await self.job_manager.safe_update_job_status_async(
job_id=self.current_run_id,
new_status=JobStatus.failed if is_error else JobStatus.completed,
actor=self.actor,
metadata=job_update_metadata,
)
if request_span:
request_span.end()

View File

@@ -1,4 +1,4 @@
from enum import Enum
from enum import Enum, StrEnum
class ProviderType(str, Enum):
@@ -42,7 +42,7 @@ class OptionState(str, Enum):
DEFAULT = "default"
class JobStatus(str, Enum):
class JobStatus(StrEnum):
"""
Status of the job.
"""
@@ -63,7 +63,8 @@ class JobStatus(str, Enum):
class AgentStepStatus(str, Enum):
"""
Status of the job.
Status of agent step.
TODO (cliandy): consolidate this with job status
"""
paused = "paused"