134 lines
4.0 KiB
Python
134 lines
4.0 KiB
Python
import asyncio
|
|
import json
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
from urllib.parse import urlparse
|
|
|
|
from letta.helpers.singleton import singleton
|
|
from letta.schemas.provider_trace import ProviderTrace
|
|
from letta.settings import settings
|
|
|
|
|
|
def _parse_json_maybe(value: str | None) -> dict[str, Any]:
|
|
if not value:
|
|
return {}
|
|
try:
|
|
parsed = json.loads(value)
|
|
return parsed if isinstance(parsed, dict) else {"_value": parsed}
|
|
except Exception:
|
|
# Preserve the raw payload if parsing fails (e.g. non-JSON string)
|
|
return {"_raw": value}
|
|
|
|
|
|
def _parse_clickhouse_endpoint(endpoint: str) -> tuple[str, int, bool]:
|
|
"""Return (host, port, secure) for clickhouse_connect.get_client."""
|
|
parsed = urlparse(endpoint)
|
|
|
|
if parsed.scheme in ("http", "https"):
|
|
host = parsed.hostname or ""
|
|
port = parsed.port or (8443 if parsed.scheme == "https" else 8123)
|
|
secure = parsed.scheme == "https"
|
|
return host, port, secure
|
|
|
|
# Fallback: accept raw hostname (possibly with :port)
|
|
if ":" in endpoint:
|
|
host, port_str = endpoint.rsplit(":", 1)
|
|
return host, int(port_str), True
|
|
|
|
return endpoint, 8443, True
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ClickhouseProviderTraceRow:
|
|
created_at: Any
|
|
id: str
|
|
step_id: str
|
|
request_json: str | None
|
|
response_json: str | None
|
|
|
|
|
|
@singleton
|
|
class ClickhouseProviderTraceReader:
|
|
def __init__(self):
|
|
self._client = None
|
|
|
|
def _get_client(self):
|
|
if self._client is not None:
|
|
return self._client
|
|
|
|
# Import lazily so OSS users who never enable this flag don't pay import cost.
|
|
import clickhouse_connect
|
|
|
|
if not settings.clickhouse_endpoint:
|
|
raise ValueError("CLICKHOUSE_ENDPOINT is required")
|
|
|
|
host, port, secure = _parse_clickhouse_endpoint(settings.clickhouse_endpoint)
|
|
if not host:
|
|
raise ValueError("Invalid CLICKHOUSE_ENDPOINT")
|
|
|
|
database = settings.clickhouse_database or "otel"
|
|
username = settings.clickhouse_username or "default"
|
|
password = settings.clickhouse_password
|
|
if not password:
|
|
raise ValueError("CLICKHOUSE_PASSWORD is required")
|
|
|
|
self._client = clickhouse_connect.get_client(
|
|
host=host,
|
|
port=port,
|
|
username=username,
|
|
password=password,
|
|
database=database,
|
|
secure=secure,
|
|
verify=True,
|
|
)
|
|
return self._client
|
|
|
|
def _query_latest_row_for_step_id_sync(self, step_id: str, organization_id: str) -> ClickhouseProviderTraceRow | None:
|
|
client = self._get_client()
|
|
query = """
|
|
SELECT
|
|
created_at,
|
|
id,
|
|
step_id,
|
|
request_json,
|
|
response_json
|
|
FROM llm_traces
|
|
WHERE step_id = %(step_id)s
|
|
AND organization_id = %(organization_id)s
|
|
ORDER BY created_at DESC
|
|
LIMIT 1
|
|
"""
|
|
|
|
result = client.query(
|
|
query,
|
|
parameters={
|
|
"step_id": step_id,
|
|
"organization_id": organization_id,
|
|
},
|
|
)
|
|
|
|
if not result or not result.result_rows:
|
|
return None
|
|
|
|
row = result.result_rows[0]
|
|
return ClickhouseProviderTraceRow(
|
|
created_at=row[0],
|
|
id=row[1],
|
|
step_id=row[2],
|
|
request_json=row[3],
|
|
response_json=row[4],
|
|
)
|
|
|
|
async def get_provider_trace_by_step_id_async(self, *, step_id: str, organization_id: str) -> ProviderTrace | None:
|
|
row = await asyncio.to_thread(self._query_latest_row_for_step_id_sync, step_id, organization_id)
|
|
if row is None:
|
|
return None
|
|
|
|
return ProviderTrace(
|
|
id=f"provider_trace-{row.id}",
|
|
step_id=row.step_id,
|
|
request_json=_parse_json_maybe(row.request_json),
|
|
response_json=_parse_json_maybe(row.response_json),
|
|
created_at=row.created_at,
|
|
)
|