* auto fixes * auto fix pt2 and transitive deps and undefined var checking locals() * manual fixes (ignored or letta-code fixed) * fix circular import
188 lines
7.7 KiB
Python
188 lines
7.7 KiB
Python
"""ClickHouse provider trace backend.
|
|
|
|
Writes and reads from the llm_traces table with denormalized columns for cost analytics.
|
|
"""
|
|
|
|
import json
|
|
import uuid
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
from letta.log import get_logger
|
|
from letta.schemas.provider_trace import ProviderTrace
|
|
from letta.schemas.user import User
|
|
from letta.services.clickhouse_provider_traces import ClickhouseProviderTraceReader
|
|
from letta.services.provider_trace_backends.base import ProviderTraceBackendClient
|
|
from letta.settings import settings
|
|
|
|
if TYPE_CHECKING:
|
|
from letta.schemas.llm_trace import LLMTrace
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class ClickhouseProviderTraceBackend(ProviderTraceBackendClient):
|
|
"""ClickHouse backend for provider traces (reads and writes from llm_traces table)."""
|
|
|
|
def __init__(self):
|
|
self._reader = ClickhouseProviderTraceReader()
|
|
|
|
async def create_async(
|
|
self,
|
|
actor: User,
|
|
provider_trace: ProviderTrace,
|
|
) -> ProviderTrace | None:
|
|
"""Write provider trace to ClickHouse llm_traces table."""
|
|
if not settings.store_llm_traces:
|
|
# Return minimal trace for consistency if writes disabled
|
|
return ProviderTrace(
|
|
id=provider_trace.id,
|
|
step_id=provider_trace.step_id,
|
|
request_json=provider_trace.request_json or {},
|
|
response_json=provider_trace.response_json or {},
|
|
)
|
|
|
|
try:
|
|
from letta.services.llm_trace_writer import get_llm_trace_writer
|
|
|
|
trace = self._convert_to_trace(actor, provider_trace)
|
|
if trace:
|
|
writer = get_llm_trace_writer()
|
|
await writer.write_async(trace)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Failed to write trace to ClickHouse: {e}")
|
|
|
|
return ProviderTrace(
|
|
id=provider_trace.id,
|
|
step_id=provider_trace.step_id,
|
|
request_json=provider_trace.request_json or {},
|
|
response_json=provider_trace.response_json or {},
|
|
)
|
|
|
|
async def get_by_step_id_async(
|
|
self,
|
|
step_id: str,
|
|
actor: User,
|
|
) -> ProviderTrace | None:
|
|
"""Read provider trace from llm_traces table by step_id."""
|
|
return await self._reader.get_provider_trace_by_step_id_async(
|
|
step_id=step_id,
|
|
organization_id=actor.organization_id,
|
|
)
|
|
|
|
def _convert_to_trace(
|
|
self,
|
|
actor: User,
|
|
provider_trace: ProviderTrace,
|
|
) -> Optional["LLMTrace"]:
|
|
"""Convert ProviderTrace to LLMTrace for analytics storage."""
|
|
from letta.schemas.llm_trace import LLMTrace
|
|
|
|
# Serialize JSON fields
|
|
request_json_str = json.dumps(provider_trace.request_json, default=str)
|
|
response_json_str = json.dumps(provider_trace.response_json, default=str)
|
|
llm_config_json_str = json.dumps(provider_trace.llm_config, default=str) if provider_trace.llm_config else "{}"
|
|
|
|
# Extract provider and model from llm_config
|
|
llm_config = provider_trace.llm_config or {}
|
|
provider = llm_config.get("model_endpoint_type", "unknown")
|
|
model = llm_config.get("model", "unknown")
|
|
is_byok = llm_config.get("provider_category") == "byok"
|
|
|
|
# Extract usage from response (generic parsing for common formats)
|
|
usage = self._extract_usage(provider_trace.response_json, provider)
|
|
|
|
# Check for error in response - must have actual error content, not just null
|
|
# OpenAI Responses API returns {"error": null} on success
|
|
error_data = provider_trace.response_json.get("error")
|
|
error_type = provider_trace.response_json.get("error_type")
|
|
error_message = None
|
|
is_error = bool(error_data) or bool(error_type)
|
|
if is_error:
|
|
if isinstance(error_data, dict):
|
|
error_type = error_type or error_data.get("type")
|
|
error_message = error_data.get("message", str(error_data))[:1000]
|
|
elif error_data:
|
|
error_message = str(error_data)[:1000]
|
|
|
|
# Extract UUID from provider_trace.id (strip "provider_trace-" prefix)
|
|
trace_id = provider_trace.id
|
|
if not trace_id:
|
|
logger.warning("ProviderTrace missing id - trace correlation across backends will fail")
|
|
trace_id = str(uuid.uuid4())
|
|
elif trace_id.startswith("provider_trace-"):
|
|
trace_id = trace_id[len("provider_trace-") :]
|
|
|
|
return LLMTrace(
|
|
id=trace_id,
|
|
organization_id=provider_trace.org_id or actor.organization_id,
|
|
project_id=None,
|
|
agent_id=provider_trace.agent_id,
|
|
agent_tags=provider_trace.agent_tags or [],
|
|
run_id=provider_trace.run_id,
|
|
step_id=provider_trace.step_id,
|
|
trace_id=None,
|
|
call_type=provider_trace.call_type or "unknown",
|
|
provider=provider,
|
|
model=model,
|
|
is_byok=is_byok,
|
|
request_size_bytes=len(request_json_str.encode("utf-8")),
|
|
response_size_bytes=len(response_json_str.encode("utf-8")),
|
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
|
completion_tokens=usage.get("completion_tokens", 0),
|
|
total_tokens=usage.get("total_tokens", 0),
|
|
cached_input_tokens=usage.get("cached_input_tokens"),
|
|
cache_write_tokens=usage.get("cache_write_tokens"),
|
|
reasoning_tokens=usage.get("reasoning_tokens"),
|
|
latency_ms=0, # Not available in ProviderTrace
|
|
is_error=is_error,
|
|
error_type=error_type,
|
|
error_message=error_message,
|
|
request_json=request_json_str,
|
|
response_json=response_json_str,
|
|
llm_config_json=llm_config_json_str,
|
|
)
|
|
|
|
def _extract_usage(self, response_json: dict, provider: str) -> dict:
|
|
"""Extract usage statistics from response JSON.
|
|
|
|
Handles common formats from OpenAI, Anthropic, and other providers.
|
|
"""
|
|
usage = {}
|
|
|
|
# OpenAI format: response.usage
|
|
if "usage" in response_json:
|
|
u = response_json["usage"]
|
|
usage["prompt_tokens"] = u.get("prompt_tokens", 0)
|
|
usage["completion_tokens"] = u.get("completion_tokens", 0)
|
|
usage["total_tokens"] = u.get("total_tokens", 0)
|
|
|
|
# OpenAI reasoning tokens
|
|
if "completion_tokens_details" in u:
|
|
details = u["completion_tokens_details"]
|
|
usage["reasoning_tokens"] = details.get("reasoning_tokens")
|
|
|
|
# OpenAI cached tokens
|
|
if "prompt_tokens_details" in u:
|
|
details = u["prompt_tokens_details"]
|
|
usage["cached_input_tokens"] = details.get("cached_tokens")
|
|
|
|
# Anthropic format: response.usage with cache fields
|
|
if provider == "anthropic" and "usage" in response_json:
|
|
u = response_json["usage"]
|
|
# input_tokens can be 0 when all tokens come from cache
|
|
input_tokens = u.get("input_tokens", 0)
|
|
cache_read = u.get("cache_read_input_tokens", 0)
|
|
cache_write = u.get("cache_creation_input_tokens", 0)
|
|
# Total prompt = input + cached (for cost analytics)
|
|
usage["prompt_tokens"] = input_tokens + cache_read + cache_write
|
|
usage["completion_tokens"] = u.get("output_tokens", usage.get("completion_tokens", 0))
|
|
usage["cached_input_tokens"] = cache_read if cache_read else None
|
|
usage["cache_write_tokens"] = cache_write if cache_write else None
|
|
|
|
# Recalculate total if not present
|
|
if "total_tokens" not in usage or usage["total_tokens"] == 0:
|
|
usage["total_tokens"] = usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0)
|
|
|
|
return usage
|