feat: add gating to provider trace persistence in db (#4223)

* feat: make provider trace fetch result nullable

* feat: add flag for persisting provider trace to db
This commit is contained in:
cthomas
2025-08-26 15:58:26 -07:00
committed by GitHub
parent 4b6e52b0e5
commit 3d62f14bac
2 changed files with 58 additions and 53 deletions

View File

@@ -349,16 +349,17 @@ class LettaAgent(BaseAgent):
agent_step_span.end()
# Log LLM Trace
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id, # Use original step_id for telemetry
organization_id=self.actor.organization_id,
),
)
step_progression = StepProgression.LOGGED_TRACE
if settings.track_provider_trace:
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id, # Use original step_id for telemetry
organization_id=self.actor.organization_id,
),
)
step_progression = StepProgression.LOGGED_TRACE
# stream step
# TODO: improve TTFT
@@ -646,17 +647,18 @@ class LettaAgent(BaseAgent):
agent_step_span.end()
# Log LLM Trace
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id, # Use original step_id for telemetry
organization_id=self.actor.organization_id,
),
)
if settings.track_provider_trace:
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id, # Use original step_id for telemetry
organization_id=self.actor.organization_id,
),
)
step_progression = StepProgression.LOGGED_TRACE
step_progression = StepProgression.LOGGED_TRACE
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
step_progression = StepProgression.FINISHED
@@ -1007,31 +1009,32 @@ class LettaAgent(BaseAgent):
# Log LLM Trace
# We are piecing together the streamed response here.
# Content here does not match the actual response schema as streams come in chunks.
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json={
"content": {
"tool_call": tool_call.model_dump_json(),
"reasoning": [content.model_dump_json() for content in reasoning_content],
if settings.track_provider_trace:
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json={
"content": {
"tool_call": tool_call.model_dump_json(),
"reasoning": [content.model_dump_json() for content in reasoning_content],
},
"id": interface.message_id,
"model": interface.model,
"role": "assistant",
# "stop_reason": "",
# "stop_sequence": None,
"type": "message",
"usage": {
"input_tokens": usage.prompt_tokens,
"output_tokens": usage.completion_tokens,
},
},
"id": interface.message_id,
"model": interface.model,
"role": "assistant",
# "stop_reason": "",
# "stop_sequence": None,
"type": "message",
"usage": {
"input_tokens": usage.prompt_tokens,
"output_tokens": usage.completion_tokens,
},
},
step_id=step_id, # Use original step_id for telemetry
organization_id=self.actor.organization_id,
),
)
step_progression = StepProgression.LOGGED_TRACE
step_id=step_id, # Use original step_id for telemetry
organization_id=self.actor.organization_id,
),
)
step_progression = StepProgression.LOGGED_TRACE
# yields tool response as this is handled from Letta and not the response from the LLM provider
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]

View File

@@ -15,6 +15,7 @@ from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.provider_trace import ProviderTraceCreate
from letta.services.telemetry_manager import TelemetryManager
from letta.settings import settings
if TYPE_CHECKING:
from letta.orm import User
@@ -90,15 +91,16 @@ class LLMClientBase:
try:
log_event(name="llm_request_sent", attributes=request_data)
response_data = await self.request_async(request_data, llm_config)
await telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
if settings.track_provider_trace and telemetry_manager:
await telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
log_event(name="llm_response_received", attributes=response_data)
except Exception as e: