Files
letta-server/letta/services/provider_trace_backends/clickhouse.py
Kian Jones 25d54dd896 chore: enable F821, F401, W293 (#9503)
* auto fixes

* auto fix pt2 and transitive deps and undefined var checking locals()

* manual fixes (ignored or letta-code fixed)

* fix circular import
2026-02-24 10:55:08 -08:00

188 lines
7.7 KiB
Python

"""ClickHouse provider trace backend.
Writes and reads from the llm_traces table with denormalized columns for cost analytics.
"""
import json
import uuid
from typing import TYPE_CHECKING, Optional
from letta.log import get_logger
from letta.schemas.provider_trace import ProviderTrace
from letta.schemas.user import User
from letta.services.clickhouse_provider_traces import ClickhouseProviderTraceReader
from letta.services.provider_trace_backends.base import ProviderTraceBackendClient
from letta.settings import settings
if TYPE_CHECKING:
from letta.schemas.llm_trace import LLMTrace
logger = get_logger(__name__)
class ClickhouseProviderTraceBackend(ProviderTraceBackendClient):
"""ClickHouse backend for provider traces (reads and writes from llm_traces table)."""
def __init__(self):
self._reader = ClickhouseProviderTraceReader()
async def create_async(
self,
actor: User,
provider_trace: ProviderTrace,
) -> ProviderTrace | None:
"""Write provider trace to ClickHouse llm_traces table."""
if not settings.store_llm_traces:
# Return minimal trace for consistency if writes disabled
return ProviderTrace(
id=provider_trace.id,
step_id=provider_trace.step_id,
request_json=provider_trace.request_json or {},
response_json=provider_trace.response_json or {},
)
try:
from letta.services.llm_trace_writer import get_llm_trace_writer
trace = self._convert_to_trace(actor, provider_trace)
if trace:
writer = get_llm_trace_writer()
await writer.write_async(trace)
except Exception as e:
logger.debug(f"Failed to write trace to ClickHouse: {e}")
return ProviderTrace(
id=provider_trace.id,
step_id=provider_trace.step_id,
request_json=provider_trace.request_json or {},
response_json=provider_trace.response_json or {},
)
async def get_by_step_id_async(
self,
step_id: str,
actor: User,
) -> ProviderTrace | None:
"""Read provider trace from llm_traces table by step_id."""
return await self._reader.get_provider_trace_by_step_id_async(
step_id=step_id,
organization_id=actor.organization_id,
)
def _convert_to_trace(
self,
actor: User,
provider_trace: ProviderTrace,
) -> Optional["LLMTrace"]:
"""Convert ProviderTrace to LLMTrace for analytics storage."""
from letta.schemas.llm_trace import LLMTrace
# Serialize JSON fields
request_json_str = json.dumps(provider_trace.request_json, default=str)
response_json_str = json.dumps(provider_trace.response_json, default=str)
llm_config_json_str = json.dumps(provider_trace.llm_config, default=str) if provider_trace.llm_config else "{}"
# Extract provider and model from llm_config
llm_config = provider_trace.llm_config or {}
provider = llm_config.get("model_endpoint_type", "unknown")
model = llm_config.get("model", "unknown")
is_byok = llm_config.get("provider_category") == "byok"
# Extract usage from response (generic parsing for common formats)
usage = self._extract_usage(provider_trace.response_json, provider)
# Check for error in response - must have actual error content, not just null
# OpenAI Responses API returns {"error": null} on success
error_data = provider_trace.response_json.get("error")
error_type = provider_trace.response_json.get("error_type")
error_message = None
is_error = bool(error_data) or bool(error_type)
if is_error:
if isinstance(error_data, dict):
error_type = error_type or error_data.get("type")
error_message = error_data.get("message", str(error_data))[:1000]
elif error_data:
error_message = str(error_data)[:1000]
# Extract UUID from provider_trace.id (strip "provider_trace-" prefix)
trace_id = provider_trace.id
if not trace_id:
logger.warning("ProviderTrace missing id - trace correlation across backends will fail")
trace_id = str(uuid.uuid4())
elif trace_id.startswith("provider_trace-"):
trace_id = trace_id[len("provider_trace-") :]
return LLMTrace(
id=trace_id,
organization_id=provider_trace.org_id or actor.organization_id,
project_id=None,
agent_id=provider_trace.agent_id,
agent_tags=provider_trace.agent_tags or [],
run_id=provider_trace.run_id,
step_id=provider_trace.step_id,
trace_id=None,
call_type=provider_trace.call_type or "unknown",
provider=provider,
model=model,
is_byok=is_byok,
request_size_bytes=len(request_json_str.encode("utf-8")),
response_size_bytes=len(response_json_str.encode("utf-8")),
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
cached_input_tokens=usage.get("cached_input_tokens"),
cache_write_tokens=usage.get("cache_write_tokens"),
reasoning_tokens=usage.get("reasoning_tokens"),
latency_ms=0, # Not available in ProviderTrace
is_error=is_error,
error_type=error_type,
error_message=error_message,
request_json=request_json_str,
response_json=response_json_str,
llm_config_json=llm_config_json_str,
)
def _extract_usage(self, response_json: dict, provider: str) -> dict:
"""Extract usage statistics from response JSON.
Handles common formats from OpenAI, Anthropic, and other providers.
"""
usage = {}
# OpenAI format: response.usage
if "usage" in response_json:
u = response_json["usage"]
usage["prompt_tokens"] = u.get("prompt_tokens", 0)
usage["completion_tokens"] = u.get("completion_tokens", 0)
usage["total_tokens"] = u.get("total_tokens", 0)
# OpenAI reasoning tokens
if "completion_tokens_details" in u:
details = u["completion_tokens_details"]
usage["reasoning_tokens"] = details.get("reasoning_tokens")
# OpenAI cached tokens
if "prompt_tokens_details" in u:
details = u["prompt_tokens_details"]
usage["cached_input_tokens"] = details.get("cached_tokens")
# Anthropic format: response.usage with cache fields
if provider == "anthropic" and "usage" in response_json:
u = response_json["usage"]
# input_tokens can be 0 when all tokens come from cache
input_tokens = u.get("input_tokens", 0)
cache_read = u.get("cache_read_input_tokens", 0)
cache_write = u.get("cache_creation_input_tokens", 0)
# Total prompt = input + cached (for cost analytics)
usage["prompt_tokens"] = input_tokens + cache_read + cache_write
usage["completion_tokens"] = u.get("output_tokens", usage.get("completion_tokens", 0))
usage["cached_input_tokens"] = cache_read if cache_read else None
usage["cache_write_tokens"] = cache_write if cache_write else None
# Recalculate total if not present
if "total_tokens" not in usage or usage["total_tokens"] == 0:
usage["total_tokens"] = usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0)
return usage