feat: add gating to provider trace persistence in db (#4223)
* feat: make provider trace fetch result nullable * feat: add flag for persisting provider trace to db
This commit is contained in:
@@ -349,16 +349,17 @@ class LettaAgent(BaseAgent):
|
||||
agent_step_span.end()
|
||||
|
||||
# Log LLM Trace
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id, # Use original step_id for telemetry
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
if settings.track_provider_trace:
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id, # Use original step_id for telemetry
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
|
||||
# stream step
|
||||
# TODO: improve TTFT
|
||||
@@ -646,17 +647,18 @@ class LettaAgent(BaseAgent):
|
||||
agent_step_span.end()
|
||||
|
||||
# Log LLM Trace
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id, # Use original step_id for telemetry
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
if settings.track_provider_trace:
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id, # Use original step_id for telemetry
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
|
||||
step_progression = StepProgression.FINISHED
|
||||
|
||||
@@ -1007,31 +1009,32 @@ class LettaAgent(BaseAgent):
|
||||
# Log LLM Trace
|
||||
# We are piecing together the streamed response here.
|
||||
# Content here does not match the actual response schema as streams come in chunks.
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json={
|
||||
"content": {
|
||||
"tool_call": tool_call.model_dump_json(),
|
||||
"reasoning": [content.model_dump_json() for content in reasoning_content],
|
||||
if settings.track_provider_trace:
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json={
|
||||
"content": {
|
||||
"tool_call": tool_call.model_dump_json(),
|
||||
"reasoning": [content.model_dump_json() for content in reasoning_content],
|
||||
},
|
||||
"id": interface.message_id,
|
||||
"model": interface.model,
|
||||
"role": "assistant",
|
||||
# "stop_reason": "",
|
||||
# "stop_sequence": None,
|
||||
"type": "message",
|
||||
"usage": {
|
||||
"input_tokens": usage.prompt_tokens,
|
||||
"output_tokens": usage.completion_tokens,
|
||||
},
|
||||
},
|
||||
"id": interface.message_id,
|
||||
"model": interface.model,
|
||||
"role": "assistant",
|
||||
# "stop_reason": "",
|
||||
# "stop_sequence": None,
|
||||
"type": "message",
|
||||
"usage": {
|
||||
"input_tokens": usage.prompt_tokens,
|
||||
"output_tokens": usage.completion_tokens,
|
||||
},
|
||||
},
|
||||
step_id=step_id, # Use original step_id for telemetry
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
step_id=step_id, # Use original step_id for telemetry
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
|
||||
# yields tool response as this is handled from Letta and not the response from the LLM provider
|
||||
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
|
||||
|
||||
@@ -15,6 +15,7 @@ from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.schemas.provider_trace import ProviderTraceCreate
|
||||
from letta.services.telemetry_manager import TelemetryManager
|
||||
from letta.settings import settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm import User
|
||||
@@ -90,15 +91,16 @@ class LLMClientBase:
|
||||
try:
|
||||
log_event(name="llm_request_sent", attributes=request_data)
|
||||
response_data = await self.request_async(request_data, llm_config)
|
||||
await telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
if settings.track_provider_trace and telemetry_manager:
|
||||
await telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
|
||||
log_event(name="llm_response_received", attributes=response_data)
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user