diff --git a/letta/adapters/letta_llm_stream_adapter.py b/letta/adapters/letta_llm_stream_adapter.py index 4ae64e91..986d5244 100644 --- a/letta/adapters/letta_llm_stream_adapter.py +++ b/letta/adapters/letta_llm_stream_adapter.py @@ -88,10 +88,22 @@ class LettaLLMStreamAdapter(LettaLLMAdapter): # Extract optional parameters # ttft_span = kwargs.get('ttft_span', None) + request_start_ns = get_utc_timestamp_ns() + # Start the streaming request (map provider errors to common LLMError types) try: stream = await self.llm_client.stream_async(request_data, self.llm_config) except Exception as e: + self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns() + latency_ms = int((self.llm_request_finish_timestamp_ns - request_start_ns) / 1_000_000) + await self.llm_client.log_provider_trace_async( + request_data=request_data, + response_json=None, + llm_config=self.llm_config, + latency_ms=latency_ms, + error_msg=str(e), + error_type=type(e).__name__, + ) raise self.llm_client.handle_llm_error(e) # Process the stream and yield chunks immediately for TTFT @@ -101,6 +113,16 @@ class LettaLLMStreamAdapter(LettaLLMAdapter): # Yield each chunk immediately as it arrives yield chunk except Exception as e: + self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns() + latency_ms = int((self.llm_request_finish_timestamp_ns - request_start_ns) / 1_000_000) + await self.llm_client.log_provider_trace_async( + request_data=request_data, + response_json=None, + llm_config=self.llm_config, + latency_ms=latency_ms, + error_msg=str(e), + error_type=type(e).__name__, + ) raise self.llm_client.handle_llm_error(e) # After streaming completes, extract the accumulated data diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 3b359c72..ae0d49bd 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -1152,6 +1152,8 @@ class LettaAgent(BaseAgent): "output_tokens": interface.output_tokens, }, }, + llm_config=agent_state.llm_config, + latency_ms=int(llm_request_ms), ) persisted_messages, should_continue, stop_reason = await self._handle_ai_response( tool_call, diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index b5bac794..6f8dca74 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -670,6 +670,7 @@ class GoogleVertexClient(LLMClientBase): # "candidatesTokenCount": 27, # "totalTokenCount": 36 # } + usage = None if response.usage_metadata: # Extract usage via centralized method from letta.schemas.enums import ProviderType diff --git a/letta/llm_api/llm_client_base.py b/letta/llm_api/llm_client_base.py index 754a19e8..8b506c05 100644 --- a/letta/llm_api/llm_client_base.py +++ b/letta/llm_api/llm_client_base.py @@ -82,6 +82,10 @@ class LLMClientBase: """Wrapper around request_async that logs telemetry for all requests including errors. Call set_telemetry_context() first to set agent_id, run_id, etc. + + Telemetry is logged via TelemetryManager which supports multiple backends + (postgres, clickhouse, socket, etc.) configured via + LETTA_TELEMETRY_PROVIDER_TRACE_BACKEND. """ from letta.log import get_logger @@ -97,6 +101,7 @@ class LLMClientBase: error_type = type(e).__name__ raise finally: + # Log telemetry via configured backends if self._telemetry_manager and settings.track_provider_trace: if self.actor is None: logger.warning(f"Skipping telemetry: actor is None (call_type={self._telemetry_call_type})") @@ -116,7 +121,7 @@ class LLMClientBase: org_id=self._telemetry_org_id, user_id=self._telemetry_user_id, compaction_settings=self._telemetry_compaction_settings, - llm_config=self._telemetry_llm_config, + llm_config=llm_config.model_dump() if llm_config else self._telemetry_llm_config, ), ) except Exception as e: @@ -130,10 +135,27 @@ class LLMClientBase: """ return await self.stream_async(request_data, llm_config) - async def log_provider_trace_async(self, request_data: dict, response_json: dict) -> None: + async def log_provider_trace_async( + self, + request_data: dict, + response_json: Optional[dict], + llm_config: Optional[LLMConfig] = None, + latency_ms: Optional[int] = None, + error_msg: Optional[str] = None, + error_type: Optional[str] = None, + ) -> None: """Log provider trace telemetry. Call after processing LLM response. Uses telemetry context set via set_telemetry_context(). + Telemetry is logged via TelemetryManager which supports multiple backends. + + Args: + request_data: The request payload sent to the LLM + response_json: The response payload from the LLM + llm_config: LLMConfig for extracting provider/model info + latency_ms: Latency in milliseconds (not used currently, kept for API compatibility) + error_msg: Error message if request failed (not used currently) + error_type: Error type if request failed (not used currently) """ from letta.log import get_logger @@ -146,6 +168,9 @@ class LLMClientBase: logger.warning(f"Skipping telemetry: actor is None (call_type={self._telemetry_call_type})") return + if response_json is None: + return + try: pydantic_actor = self.actor.to_pydantic() if hasattr(self.actor, "to_pydantic") else self.actor await self._telemetry_manager.create_provider_trace_async( @@ -161,7 +186,7 @@ class LLMClientBase: org_id=self._telemetry_org_id, user_id=self._telemetry_user_id, compaction_settings=self._telemetry_compaction_settings, - llm_config=self._telemetry_llm_config, + llm_config=llm_config.model_dump() if llm_config else self._telemetry_llm_config, ), ) except Exception as e: diff --git a/letta/schemas/llm_trace.py b/letta/schemas/llm_trace.py new file mode 100644 index 00000000..537148b5 --- /dev/null +++ b/letta/schemas/llm_trace.py @@ -0,0 +1,167 @@ +"""Schema for LLM request/response traces stored in ClickHouse for analytics.""" + +from __future__ import annotations + +import uuid +from datetime import datetime +from typing import Optional + +from pydantic import Field + +from letta.helpers.datetime_helpers import get_utc_time +from letta.schemas.letta_base import LettaBase + + +class LLMTrace(LettaBase): + """ + LLM request/response trace for ClickHouse analytics. + + Stores LLM request/response payloads with denormalized columns for + fast cost analytics queries (token usage by org/agent/model). + + Attributes: + id (str): Unique trace identifier (UUID). + organization_id (str): The organization this trace belongs to. + project_id (str): The project this trace belongs to. + agent_id (str): ID of the agent that made the request. + run_id (str): ID of the run this trace is associated with. + step_id (str): ID of the step that generated this trace. + trace_id (str): OTEL trace ID for correlation. + + call_type (str): Type of LLM call ('agent_step', 'summarization', 'embedding'). + provider (str): LLM provider name ('openai', 'anthropic', etc.). + model (str): Model name/identifier used. + + request_size_bytes (int): Size of request_json in bytes. + response_size_bytes (int): Size of response_json in bytes. + prompt_tokens (int): Number of prompt tokens used. + completion_tokens (int): Number of completion tokens generated. + total_tokens (int): Total tokens (prompt + completion). + latency_ms (int): Request latency in milliseconds. + + is_error (bool): Whether the request resulted in an error. + error_type (str): Exception class name if error occurred. + error_message (str): Error message if error occurred. + + request_json (str): Full request payload as JSON string. + response_json (str): Full response payload as JSON string. + + created_at (datetime): Timestamp when the trace was created. + """ + + __id_prefix__ = "llm_trace" + + # Primary identifier (UUID portion of ProviderTrace.id, prefix stripped for ClickHouse) + id: str = Field(..., description="Trace UUID (strip 'provider_trace-' prefix to correlate)") + + # Context identifiers + organization_id: str = Field(..., description="Organization this trace belongs to") + project_id: Optional[str] = Field(default=None, description="Project this trace belongs to") + agent_id: Optional[str] = Field(default=None, description="Agent that made the request") + agent_tags: list[str] = Field(default_factory=list, description="Tags associated with the agent") + run_id: Optional[str] = Field(default=None, description="Run this trace is associated with") + step_id: Optional[str] = Field(default=None, description="Step that generated this trace") + trace_id: Optional[str] = Field(default=None, description="OTEL trace ID for correlation") + + # Request metadata (queryable) + call_type: str = Field(..., description="Type of LLM call: 'agent_step', 'summarization', 'embedding'") + provider: str = Field(..., description="LLM provider: 'openai', 'anthropic', 'google_ai', etc.") + model: str = Field(..., description="Model name/identifier") + is_byok: bool = Field(default=False, description="Whether this request used BYOK (Bring Your Own Key)") + + # Size metrics + request_size_bytes: int = Field(default=0, description="Size of request_json in bytes") + response_size_bytes: int = Field(default=0, description="Size of response_json in bytes") + + # Token usage + prompt_tokens: int = Field(default=0, description="Number of prompt tokens") + completion_tokens: int = Field(default=0, description="Number of completion tokens") + total_tokens: int = Field(default=0, description="Total tokens (prompt + completion)") + + # Cache and reasoning tokens (from LettaUsageStatistics) + cached_input_tokens: Optional[int] = Field(default=None, description="Number of input tokens served from cache") + cache_write_tokens: Optional[int] = Field(default=None, description="Number of tokens written to cache (Anthropic)") + reasoning_tokens: Optional[int] = Field(default=None, description="Number of reasoning/thinking tokens generated") + + # Latency + latency_ms: int = Field(default=0, description="Request latency in milliseconds") + + # Error tracking + is_error: bool = Field(default=False, description="Whether the request resulted in an error") + error_type: Optional[str] = Field(default=None, description="Exception class name if error") + error_message: Optional[str] = Field(default=None, description="Error message if error") + + # Raw payloads (JSON strings) + request_json: str = Field(..., description="Full request payload as JSON string") + response_json: str = Field(..., description="Full response payload as JSON string") + llm_config_json: str = Field(default="", description="LLM config as JSON string") + + # Timestamp + created_at: datetime = Field(default_factory=get_utc_time, description="When the trace was created") + + def to_clickhouse_row(self) -> tuple: + """Convert to a tuple for ClickHouse insertion.""" + return ( + self.id, + self.organization_id, + self.project_id or "", + self.agent_id or "", + self.agent_tags, + self.run_id or "", + self.step_id or "", + self.trace_id or "", + self.call_type, + self.provider, + self.model, + 1 if self.is_byok else 0, + self.request_size_bytes, + self.response_size_bytes, + self.prompt_tokens, + self.completion_tokens, + self.total_tokens, + self.cached_input_tokens, + self.cache_write_tokens, + self.reasoning_tokens, + self.latency_ms, + 1 if self.is_error else 0, + self.error_type or "", + self.error_message or "", + self.request_json, + self.response_json, + self.llm_config_json, + self.created_at, + ) + + @classmethod + def clickhouse_columns(cls) -> list[str]: + """Return column names for ClickHouse insertion.""" + return [ + "id", + "organization_id", + "project_id", + "agent_id", + "agent_tags", + "run_id", + "step_id", + "trace_id", + "call_type", + "provider", + "model", + "is_byok", + "request_size_bytes", + "response_size_bytes", + "prompt_tokens", + "completion_tokens", + "total_tokens", + "cached_input_tokens", + "cache_write_tokens", + "reasoning_tokens", + "latency_ms", + "is_error", + "error_type", + "error_message", + "request_json", + "response_json", + "llm_config_json", + "created_at", + ] diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index 19bdb656..60f1b123 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -221,6 +221,17 @@ async def lifespan(app_: FastAPI): except Exception as e: logger.warning(f"[Worker {worker_id}] SQLAlchemy instrumentation shutdown failed: {e}") + # Shutdown LLM raw trace writer (closes ClickHouse connection) + if settings.store_llm_traces: + try: + from letta.services.llm_trace_writer import get_llm_trace_writer + + writer = get_llm_trace_writer() + await writer.shutdown_async() + logger.info(f"[Worker {worker_id}] LLM raw trace writer shutdown completed") + except Exception as e: + logger.warning(f"[Worker {worker_id}] LLM raw trace writer shutdown failed: {e}") + logger.info(f"[Worker {worker_id}] Lifespan shutdown completed") diff --git a/letta/services/llm_trace_reader.py b/letta/services/llm_trace_reader.py new file mode 100644 index 00000000..10105e90 --- /dev/null +++ b/letta/services/llm_trace_reader.py @@ -0,0 +1,462 @@ +"""ClickHouse reader for LLM analytics traces. + +Reads LLM traces from ClickHouse for debugging, analytics, and auditing. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from datetime import datetime +from typing import Any, List, Optional +from urllib.parse import urlparse + +from letta.helpers.singleton import singleton +from letta.log import get_logger +from letta.schemas.llm_trace import LLMTrace +from letta.settings import settings + +logger = get_logger(__name__) + + +def _parse_clickhouse_endpoint(endpoint: str) -> tuple[str, int, bool]: + """Return (host, port, secure) for clickhouse_connect.get_client. + + Supports: + - http://host:port -> (host, port, False) + - https://host:port -> (host, port, True) + - host:port -> (host, port, False) # Default to insecure for local dev + - host -> (host, 8123, False) # Default HTTP port, insecure + """ + parsed = urlparse(endpoint) + + if parsed.scheme in ("http", "https"): + host = parsed.hostname or "" + port = parsed.port or (8443 if parsed.scheme == "https" else 8123) + secure = parsed.scheme == "https" + return host, port, secure + + # Fallback: accept raw hostname (possibly with :port) + # Default to insecure (HTTP) for local development + if ":" in endpoint: + host, port_str = endpoint.rsplit(":", 1) + return host, int(port_str), False + + return endpoint, 8123, False + + +@dataclass(frozen=True) +class LLMTraceRow: + """Raw row from ClickHouse query.""" + + id: str + organization_id: str + project_id: str + agent_id: str + agent_tags: List[str] + run_id: str + step_id: str + trace_id: str + call_type: str + provider: str + model: str + is_byok: bool + request_size_bytes: int + response_size_bytes: int + prompt_tokens: int + completion_tokens: int + total_tokens: int + cached_input_tokens: Optional[int] + cache_write_tokens: Optional[int] + reasoning_tokens: Optional[int] + latency_ms: int + is_error: bool + error_type: str + error_message: str + request_json: str + response_json: str + llm_config_json: str + created_at: datetime + + +@singleton +class LLMTraceReader: + """ + ClickHouse reader for raw LLM traces. + + Provides query methods for debugging, analytics, and auditing. + + Usage: + reader = LLMTraceReader() + trace = await reader.get_by_step_id_async(step_id="step-xxx", organization_id="org-xxx") + traces = await reader.list_by_agent_async(agent_id="agent-xxx", organization_id="org-xxx") + """ + + def __init__(self): + self._client = None + + def _get_client(self): + """Initialize ClickHouse client on first use (lazy loading).""" + if self._client is not None: + return self._client + + import clickhouse_connect + + if not settings.clickhouse_endpoint: + raise ValueError("CLICKHOUSE_ENDPOINT is required") + + host, port, secure = _parse_clickhouse_endpoint(settings.clickhouse_endpoint) + if not host: + raise ValueError("Invalid CLICKHOUSE_ENDPOINT") + + database = settings.clickhouse_database or "otel" + username = settings.clickhouse_username or "default" + password = settings.clickhouse_password + if not password: + raise ValueError("CLICKHOUSE_PASSWORD is required") + + self._client = clickhouse_connect.get_client( + host=host, + port=port, + username=username, + password=password, + database=database, + secure=secure, + verify=True, + ) + return self._client + + def _row_to_trace(self, row: tuple) -> LLMTrace: + """Convert a ClickHouse row tuple to LLMTrace.""" + return LLMTrace( + id=row[0], + organization_id=row[1], + project_id=row[2] or None, + agent_id=row[3] or None, + agent_tags=list(row[4]) if row[4] else [], + run_id=row[5] or None, + step_id=row[6] or None, + trace_id=row[7] or None, + call_type=row[8], + provider=row[9], + model=row[10], + is_byok=bool(row[11]), + request_size_bytes=row[12], + response_size_bytes=row[13], + prompt_tokens=row[14], + completion_tokens=row[15], + total_tokens=row[16], + cached_input_tokens=row[17], + cache_write_tokens=row[18], + reasoning_tokens=row[19], + latency_ms=row[20], + is_error=bool(row[21]), + error_type=row[22] or None, + error_message=row[23] or None, + request_json=row[24], + response_json=row[25], + llm_config_json=row[26] or "", + created_at=row[27], + ) + + def _query_sync(self, query: str, parameters: dict[str, Any]) -> List[tuple]: + """Execute a query synchronously.""" + client = self._get_client() + result = client.query(query, parameters=parameters) + return result.result_rows if result else [] + + # ------------------------------------------------------------------------- + # Query Methods + # ------------------------------------------------------------------------- + + async def get_by_step_id_async( + self, + step_id: str, + organization_id: str, + ) -> Optional[LLMTrace]: + """ + Get the most recent trace for a step. + + Args: + step_id: The step ID to look up + organization_id: Organization ID for access control + + Returns: + LLMTrace if found, None otherwise + """ + query = """ + SELECT + id, organization_id, project_id, agent_id, agent_tags, run_id, step_id, trace_id, + call_type, provider, model, is_byok, + request_size_bytes, response_size_bytes, + prompt_tokens, completion_tokens, total_tokens, + cached_input_tokens, cache_write_tokens, reasoning_tokens, + latency_ms, + is_error, error_type, error_message, + request_json, response_json, llm_config_json, + created_at + FROM llm_traces + WHERE step_id = %(step_id)s + AND organization_id = %(organization_id)s + ORDER BY created_at DESC + LIMIT 1 + """ + + rows = await asyncio.to_thread( + self._query_sync, + query, + {"step_id": step_id, "organization_id": organization_id}, + ) + + if not rows: + return None + + return self._row_to_trace(rows[0]) + + async def get_by_id_async( + self, + trace_id: str, + organization_id: str, + ) -> Optional[LLMTrace]: + """ + Get a trace by its ID. + + Args: + trace_id: The trace ID (UUID) + organization_id: Organization ID for access control + + Returns: + LLMTrace if found, None otherwise + """ + query = """ + SELECT + id, organization_id, project_id, agent_id, agent_tags, run_id, step_id, trace_id, + call_type, provider, model, is_byok, + request_size_bytes, response_size_bytes, + prompt_tokens, completion_tokens, total_tokens, + cached_input_tokens, cache_write_tokens, reasoning_tokens, + latency_ms, + is_error, error_type, error_message, + request_json, response_json, llm_config_json, + created_at + FROM llm_traces + WHERE id = %(trace_id)s + AND organization_id = %(organization_id)s + LIMIT 1 + """ + + rows = await asyncio.to_thread( + self._query_sync, + query, + {"trace_id": trace_id, "organization_id": organization_id}, + ) + + if not rows: + return None + + return self._row_to_trace(rows[0]) + + async def list_by_agent_async( + self, + agent_id: str, + organization_id: str, + limit: int = 100, + offset: int = 0, + call_type: Optional[str] = None, + is_error: Optional[bool] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + ) -> List[LLMTrace]: + """ + List traces for an agent with optional filters. + + Args: + agent_id: Agent ID to filter by + organization_id: Organization ID for access control + limit: Maximum number of results (default 100) + offset: Offset for pagination + call_type: Filter by call type ('agent_step', 'summarization') + is_error: Filter by error status + start_date: Filter by created_at >= start_date + end_date: Filter by created_at <= end_date + + Returns: + List of LLMTrace objects + """ + conditions = [ + "agent_id = %(agent_id)s", + "organization_id = %(organization_id)s", + ] + params: dict[str, Any] = { + "agent_id": agent_id, + "organization_id": organization_id, + "limit": limit, + "offset": offset, + } + + if call_type: + conditions.append("call_type = %(call_type)s") + params["call_type"] = call_type + + if is_error is not None: + conditions.append("is_error = %(is_error)s") + params["is_error"] = 1 if is_error else 0 + + if start_date: + conditions.append("created_at >= %(start_date)s") + params["start_date"] = start_date + + if end_date: + conditions.append("created_at <= %(end_date)s") + params["end_date"] = end_date + + where_clause = " AND ".join(conditions) + + query = f""" + SELECT + id, organization_id, project_id, agent_id, agent_tags, run_id, step_id, trace_id, + call_type, provider, model, is_byok, + request_size_bytes, response_size_bytes, + prompt_tokens, completion_tokens, total_tokens, + cached_input_tokens, cache_write_tokens, reasoning_tokens, + latency_ms, + is_error, error_type, error_message, + request_json, response_json, llm_config_json, + created_at + FROM llm_traces + WHERE {where_clause} + ORDER BY created_at DESC + LIMIT %(limit)s OFFSET %(offset)s + """ + + rows = await asyncio.to_thread(self._query_sync, query, params) + return [self._row_to_trace(row) for row in rows] + + async def get_usage_stats_async( + self, + organization_id: str, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + group_by: str = "model", # 'model', 'agent_id', 'call_type' + ) -> List[dict[str, Any]]: + """ + Get aggregated usage statistics. + + Args: + organization_id: Organization ID for access control + start_date: Filter by created_at >= start_date + end_date: Filter by created_at <= end_date + group_by: Field to group by ('model', 'agent_id', 'call_type') + + Returns: + List of aggregated stats dicts + """ + valid_group_by = {"model", "agent_id", "call_type", "provider"} + if group_by not in valid_group_by: + raise ValueError(f"group_by must be one of {valid_group_by}") + + conditions = ["organization_id = %(organization_id)s"] + params: dict[str, Any] = {"organization_id": organization_id} + + if start_date: + conditions.append("created_at >= %(start_date)s") + params["start_date"] = start_date + + if end_date: + conditions.append("created_at <= %(end_date)s") + params["end_date"] = end_date + + where_clause = " AND ".join(conditions) + + query = f""" + SELECT + {group_by}, + count() as request_count, + sum(total_tokens) as total_tokens, + sum(prompt_tokens) as prompt_tokens, + sum(completion_tokens) as completion_tokens, + avg(latency_ms) as avg_latency_ms, + sum(request_size_bytes) as total_request_bytes, + sum(response_size_bytes) as total_response_bytes, + countIf(is_error = 1) as error_count + FROM llm_traces + WHERE {where_clause} + GROUP BY {group_by} + ORDER BY total_tokens DESC + """ + + rows = await asyncio.to_thread(self._query_sync, query, params) + + return [ + { + group_by: row[0], + "request_count": row[1], + "total_tokens": row[2], + "prompt_tokens": row[3], + "completion_tokens": row[4], + "avg_latency_ms": row[5], + "total_request_bytes": row[6], + "total_response_bytes": row[7], + "error_count": row[8], + } + for row in rows + ] + + async def find_large_requests_async( + self, + organization_id: str, + min_size_bytes: int = 1_000_000, # 1MB default + limit: int = 100, + ) -> List[LLMTrace]: + """ + Find traces with large request payloads (for debugging). + + Args: + organization_id: Organization ID for access control + min_size_bytes: Minimum request size in bytes (default 1MB) + limit: Maximum number of results + + Returns: + List of LLMTrace objects with large requests + """ + query = """ + SELECT + id, organization_id, project_id, agent_id, agent_tags, run_id, step_id, trace_id, + call_type, provider, model, is_byok, + request_size_bytes, response_size_bytes, + prompt_tokens, completion_tokens, total_tokens, + cached_input_tokens, cache_write_tokens, reasoning_tokens, + latency_ms, + is_error, error_type, error_message, + request_json, response_json, llm_config_json, + created_at + FROM llm_traces + WHERE organization_id = %(organization_id)s + AND request_size_bytes >= %(min_size_bytes)s + ORDER BY request_size_bytes DESC + LIMIT %(limit)s + """ + + rows = await asyncio.to_thread( + self._query_sync, + query, + { + "organization_id": organization_id, + "min_size_bytes": min_size_bytes, + "limit": limit, + }, + ) + + return [self._row_to_trace(row) for row in rows] + + +# Module-level instance for easy access +_reader_instance: Optional[LLMTraceReader] = None + + +def get_llm_trace_reader() -> LLMTraceReader: + """Get the singleton LLMTraceReader instance.""" + global _reader_instance + if _reader_instance is None: + _reader_instance = LLMTraceReader() + return _reader_instance diff --git a/letta/services/llm_trace_writer.py b/letta/services/llm_trace_writer.py new file mode 100644 index 00000000..169100a2 --- /dev/null +++ b/letta/services/llm_trace_writer.py @@ -0,0 +1,205 @@ +"""ClickHouse writer for LLM analytics traces. + +Writes LLM traces to ClickHouse with denormalized columns for cost analytics. +Uses ClickHouse's async_insert feature for server-side batching. +""" + +from __future__ import annotations + +import asyncio +import atexit +from typing import TYPE_CHECKING, Optional +from urllib.parse import urlparse + +from letta.helpers.singleton import singleton +from letta.log import get_logger +from letta.settings import settings + +if TYPE_CHECKING: + from letta.schemas.llm_trace import LLMTrace + +logger = get_logger(__name__) + +# Retry configuration +MAX_RETRIES = 3 +INITIAL_BACKOFF_SECONDS = 1.0 + + +def _parse_clickhouse_endpoint(endpoint: str) -> tuple[str, int, bool]: + """Return (host, port, secure) for clickhouse_connect.get_client. + + Supports: + - http://host:port -> (host, port, False) + - https://host:port -> (host, port, True) + - host:port -> (host, port, False) # Default to insecure for local dev + - host -> (host, 8123, False) # Default HTTP port, insecure + """ + parsed = urlparse(endpoint) + + if parsed.scheme in ("http", "https"): + host = parsed.hostname or "" + port = parsed.port or (8443 if parsed.scheme == "https" else 8123) + secure = parsed.scheme == "https" + return host, port, secure + + # Fallback: accept raw hostname (possibly with :port) + # Default to insecure (HTTP) for local development + if ":" in endpoint: + host, port_str = endpoint.rsplit(":", 1) + return host, int(port_str), False + + return endpoint, 8123, False + + +@singleton +class LLMTraceWriter: + """ + Direct ClickHouse writer for raw LLM traces. + + Uses ClickHouse's async_insert feature for server-side batching. + Each trace is inserted directly and ClickHouse handles batching + for optimal write performance. + + Usage: + writer = LLMTraceWriter() + await writer.write_async(trace) + + Configuration (via settings): + - store_llm_traces: Enable/disable (default: False) + """ + + def __init__(self): + self._client = None + self._shutdown = False + self._write_lock = asyncio.Lock() # Serialize writes - clickhouse_connect isn't thread-safe + + # Check if ClickHouse is configured - if not, writing is disabled + self._enabled = bool(settings.clickhouse_endpoint and settings.clickhouse_password) + + # Register shutdown handler + atexit.register(self._sync_shutdown) + + def _get_client(self): + """Initialize ClickHouse client on first use (lazy loading). + + Configures async_insert with wait_for_async_insert=1 for reliable + server-side batching with acknowledgment. + """ + if self._client is not None: + return self._client + + # Import lazily so OSS users who never enable this don't pay import cost + import clickhouse_connect + + host, port, secure = _parse_clickhouse_endpoint(settings.clickhouse_endpoint) + database = settings.clickhouse_database or "otel" + username = settings.clickhouse_username or "default" + + self._client = clickhouse_connect.get_client( + host=host, + port=port, + username=username, + password=settings.clickhouse_password, + database=database, + secure=secure, + verify=True, + settings={ + # Enable server-side batching + "async_insert": 1, + # Wait for acknowledgment (reliable) + "wait_for_async_insert": 1, + # Flush after 1 second if batch not full + "async_insert_busy_timeout_ms": 1000, + }, + ) + logger.info(f"LLMTraceWriter: Connected to ClickHouse at {host}:{port}/{database} (async_insert enabled)") + return self._client + + async def write_async(self, trace: "LLMTrace") -> None: + """ + Write a trace to ClickHouse (fire-and-forget with retry). + + ClickHouse's async_insert handles batching server-side for optimal + write performance. This method retries on failure with exponential + backoff. + + Args: + trace: The LLMTrace to write + """ + if not self._enabled or self._shutdown: + return + + # Fire-and-forget with create_task to not block the request path + try: + asyncio.create_task(self._write_with_retry(trace)) + except RuntimeError: + # No running event loop (shouldn't happen in normal async context) + pass + + async def _write_with_retry(self, trace: "LLMTrace") -> None: + """Write a single trace with retry on failure.""" + from letta.schemas.llm_trace import LLMTrace + + for attempt in range(MAX_RETRIES): + try: + client = self._get_client() + row = trace.to_clickhouse_row() + columns = LLMTrace.clickhouse_columns() + + # Serialize writes - clickhouse_connect client isn't thread-safe + async with self._write_lock: + # Run synchronous insert in thread pool + await asyncio.to_thread( + client.insert, + "llm_traces", + [row], + column_names=columns, + ) + return # Success + + except Exception as e: + if attempt < MAX_RETRIES - 1: + backoff = INITIAL_BACKOFF_SECONDS * (2**attempt) + logger.warning(f"LLMTraceWriter: Retry {attempt + 1}/{MAX_RETRIES}, backoff {backoff}s: {e}") + await asyncio.sleep(backoff) + else: + logger.error(f"LLMTraceWriter: Dropping trace after {MAX_RETRIES} retries: {e}") + + async def shutdown_async(self) -> None: + """Gracefully shutdown the writer.""" + self._shutdown = True + + # Close client + if self._client: + try: + self._client.close() + except Exception as e: + logger.warning(f"LLMTraceWriter: Error closing client: {e}") + self._client = None + + logger.info("LLMTraceWriter: Shutdown complete") + + def _sync_shutdown(self) -> None: + """Synchronous shutdown handler for atexit.""" + if not self._enabled or self._shutdown: + return + + self._shutdown = True + + if self._client: + try: + self._client.close() + except Exception: + pass + + +# Module-level instance for easy access +_writer_instance: Optional[LLMTraceWriter] = None + + +def get_llm_trace_writer() -> LLMTraceWriter: + """Get the singleton LLMTraceWriter instance.""" + global _writer_instance + if _writer_instance is None: + _writer_instance = LLMTraceWriter() + return _writer_instance diff --git a/letta/services/provider_trace_backends/clickhouse.py b/letta/services/provider_trace_backends/clickhouse.py index 1c5731f7..6b3c286b 100644 --- a/letta/services/provider_trace_backends/clickhouse.py +++ b/letta/services/provider_trace_backends/clickhouse.py @@ -1,17 +1,32 @@ -"""ClickHouse provider trace backend.""" +"""ClickHouse provider trace backend. +Writes traces to the llm_traces table with denormalized columns for cost analytics. +Reads from the OTEL traces table (will eventually cut over to llm_traces). +""" + +import json +import uuid +from typing import TYPE_CHECKING, Optional + +from letta.log import get_logger from letta.schemas.provider_trace import ProviderTrace from letta.schemas.user import User from letta.services.clickhouse_provider_traces import ClickhouseProviderTraceReader from letta.services.provider_trace_backends.base import ProviderTraceBackendClient +from letta.settings import settings + +if TYPE_CHECKING: + from letta.schemas.llm_trace import LLMTrace + +logger = get_logger(__name__) class ClickhouseProviderTraceBackend(ProviderTraceBackendClient): """ - Store provider traces in ClickHouse. + ClickHouse backend for provider traces. - Writes flow through OTEL instrumentation, so create_async is a no-op. - Only reads are performed directly against ClickHouse. + - Writes go to llm_traces table (denormalized for cost analytics) + - Reads come from OTEL traces table (will cut over to llm_traces later) """ def __init__(self): @@ -21,9 +36,29 @@ class ClickhouseProviderTraceBackend(ProviderTraceBackendClient): self, actor: User, provider_trace: ProviderTrace, - ) -> ProviderTrace: - # ClickHouse writes flow through OTEL instrumentation, not direct writes. - # Return a ProviderTrace with the same ID for consistency across backends. + ) -> ProviderTrace | None: + """Write provider trace to ClickHouse llm_traces table.""" + if not settings.store_llm_traces: + # Return minimal trace for consistency if writes disabled + return ProviderTrace( + id=provider_trace.id, + step_id=provider_trace.step_id, + request_json=provider_trace.request_json or {}, + response_json=provider_trace.response_json or {}, + ) + + try: + from letta.schemas.llm_trace import LLMTrace + from letta.services.llm_trace_writer import get_llm_trace_writer + + trace = self._convert_to_trace(actor, provider_trace) + if trace: + writer = get_llm_trace_writer() + await writer.write_async(trace) + + except Exception as e: + logger.debug(f"Failed to write trace to ClickHouse: {e}") + return ProviderTrace( id=provider_trace.id, step_id=provider_trace.step_id, @@ -36,7 +71,125 @@ class ClickhouseProviderTraceBackend(ProviderTraceBackendClient): step_id: str, actor: User, ) -> ProviderTrace | None: + """Read from OTEL traces table (will cut over to llm_traces later).""" return await self._reader.get_provider_trace_by_step_id_async( step_id=step_id, organization_id=actor.organization_id, ) + + def _convert_to_trace( + self, + actor: User, + provider_trace: ProviderTrace, + ) -> Optional["LLMTrace"]: + """Convert ProviderTrace to LLMTrace for analytics storage.""" + from letta.schemas.llm_trace import LLMTrace + + # Serialize JSON fields + request_json_str = json.dumps(provider_trace.request_json, default=str) + response_json_str = json.dumps(provider_trace.response_json, default=str) + llm_config_json_str = json.dumps(provider_trace.llm_config, default=str) if provider_trace.llm_config else "{}" + + # Extract provider and model from llm_config + llm_config = provider_trace.llm_config or {} + provider = llm_config.get("model_endpoint_type", "unknown") + model = llm_config.get("model", "unknown") + is_byok = llm_config.get("provider_category") == "byok" + + # Extract usage from response (generic parsing for common formats) + usage = self._extract_usage(provider_trace.response_json, provider) + + # Check for error in response + is_error = "error" in provider_trace.response_json + error_type = None + error_message = None + if is_error: + # error_type may be at top level or inside error dict + error_type = provider_trace.response_json.get("error_type") + error_data = provider_trace.response_json.get("error", {}) + if isinstance(error_data, dict): + error_type = error_type or error_data.get("type") + error_message = error_data.get("message", str(error_data))[:1000] + else: + error_message = str(error_data)[:1000] + + # Extract UUID from provider_trace.id (strip "provider_trace-" prefix) + trace_id = provider_trace.id + if not trace_id: + logger.warning("ProviderTrace missing id - trace correlation across backends will fail") + trace_id = str(uuid.uuid4()) + elif trace_id.startswith("provider_trace-"): + trace_id = trace_id[len("provider_trace-") :] + + return LLMTrace( + id=trace_id, + organization_id=provider_trace.org_id or actor.organization_id, + project_id=None, + agent_id=provider_trace.agent_id, + agent_tags=provider_trace.agent_tags or [], + run_id=provider_trace.run_id, + step_id=provider_trace.step_id, + trace_id=None, + call_type=provider_trace.call_type or "unknown", + provider=provider, + model=model, + is_byok=is_byok, + request_size_bytes=len(request_json_str.encode("utf-8")), + response_size_bytes=len(response_json_str.encode("utf-8")), + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + cached_input_tokens=usage.get("cached_input_tokens"), + cache_write_tokens=usage.get("cache_write_tokens"), + reasoning_tokens=usage.get("reasoning_tokens"), + latency_ms=0, # Not available in ProviderTrace + is_error=is_error, + error_type=error_type, + error_message=error_message, + request_json=request_json_str, + response_json=response_json_str, + llm_config_json=llm_config_json_str, + ) + + def _extract_usage(self, response_json: dict, provider: str) -> dict: + """Extract usage statistics from response JSON. + + Handles common formats from OpenAI, Anthropic, and other providers. + """ + usage = {} + + # OpenAI format: response.usage + if "usage" in response_json: + u = response_json["usage"] + usage["prompt_tokens"] = u.get("prompt_tokens", 0) + usage["completion_tokens"] = u.get("completion_tokens", 0) + usage["total_tokens"] = u.get("total_tokens", 0) + + # OpenAI reasoning tokens + if "completion_tokens_details" in u: + details = u["completion_tokens_details"] + usage["reasoning_tokens"] = details.get("reasoning_tokens") + + # OpenAI cached tokens + if "prompt_tokens_details" in u: + details = u["prompt_tokens_details"] + usage["cached_input_tokens"] = details.get("cached_tokens") + + # Anthropic format: response.usage with cache fields + if provider == "anthropic" and "usage" in response_json: + u = response_json["usage"] + # input_tokens can be 0 when all tokens come from cache + input_tokens = u.get("input_tokens", 0) + cache_read = u.get("cache_read_input_tokens", 0) + cache_write = u.get("cache_creation_input_tokens", 0) + # Total prompt = input + cached (for cost analytics) + usage["prompt_tokens"] = input_tokens + cache_read + cache_write + usage["completion_tokens"] = u.get("output_tokens", usage.get("completion_tokens", 0)) + usage["cached_input_tokens"] = cache_read if cache_read else None + usage["cache_write_tokens"] = cache_write if cache_write else None + + # Recalculate total if not present + if "total_tokens" not in usage or usage["total_tokens"] == 0: + usage["total_tokens"] = usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0) + + return usage diff --git a/letta/services/summarizer/summarizer.py b/letta/services/summarizer/summarizer.py index 64e9f8ba..9ff685bb 100644 --- a/letta/services/summarizer/summarizer.py +++ b/letta/services/summarizer/summarizer.py @@ -561,6 +561,7 @@ async def simple_summary( "output_tokens": getattr(interface, "output_tokens", None), }, }, + llm_config=summarizer_llm_config, ) if not text: diff --git a/letta/services/telemetry_manager.py b/letta/services/telemetry_manager.py index deddb20b..94d6ef40 100644 --- a/letta/services/telemetry_manager.py +++ b/letta/services/telemetry_manager.py @@ -20,10 +20,10 @@ class TelemetryManager: Supports multiple backends for dual-write scenarios (e.g., migration). Configure via LETTA_TELEMETRY_PROVIDER_TRACE_BACKEND (comma-separated): - postgres: Store in PostgreSQL (default) - - clickhouse: Store in ClickHouse via OTEL instrumentation - - socket: Store via Unix socket to Crouton sidecar (which writes to GCS) + - clickhouse: Store in ClickHouse (writes to llm_traces table, reads from OTEL traces) + - socket: Store via Unix socket to external sidecar - Example: LETTA_TELEMETRY_PROVIDER_TRACE_BACKEND=postgres,socket + Example: LETTA_TELEMETRY_PROVIDER_TRACE_BACKEND=postgres,clickhouse Multi-backend behavior: - Writes: Sent to ALL configured backends concurrently via asyncio.gather. diff --git a/letta/settings.py b/letta/settings.py index f52fca3c..5f8d455f 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -331,6 +331,13 @@ class Settings(BaseSettings): track_agent_run: bool = Field(default=True, description="Enable tracking agent run with cancellation support") track_provider_trace: bool = Field(default=True, description="Enable tracking raw llm request and response at each step") + # LLM trace storage for analytics (direct ClickHouse, bypasses OTEL for large payloads) + # TTL is configured in the ClickHouse DDL (default 90 days) + store_llm_traces: bool = Field( + default=False, + description="Enable storing LLM traces in ClickHouse for cost analytics", + ) + # FastAPI Application Settings uvicorn_workers: int = 1 uvicorn_reload: bool = False diff --git a/tests/integration_test_clickhouse_llm_traces.py b/tests/integration_test_clickhouse_llm_traces.py new file mode 100644 index 00000000..f928c098 --- /dev/null +++ b/tests/integration_test_clickhouse_llm_traces.py @@ -0,0 +1,350 @@ +""" +Integration tests for ClickHouse-backed LLM raw traces. + +Validates that: +1) Agent message requests are stored in ClickHouse (request_json contains the message) +2) Summarization traces are stored and retrievable by step_id +3) Error traces are stored with is_error, error_type, and error_message +4) llm_config_json is properly stored +5) Cache and usage statistics are stored (cached_input_tokens, cache_write_tokens, reasoning_tokens) +""" + +import asyncio +import json +import os +import time +import uuid + +import pytest + +from letta.agents.letta_agent_v3 import LettaAgentV3 +from letta.config import LettaConfig +from letta.schemas.agent import CreateAgent +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import MessageRole +from letta.schemas.letta_message_content import TextContent +from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message, MessageCreate +from letta.schemas.run import Run +from letta.server.server import SyncServer +from letta.services.llm_trace_reader import get_llm_trace_reader +from letta.services.provider_trace_backends import get_provider_trace_backends +from letta.services.summarizer.summarizer import simple_summary +from letta.settings import settings, telemetry_settings + + +def _require_clickhouse_env() -> dict[str, str]: + endpoint = os.getenv("CLICKHOUSE_ENDPOINT") + password = os.getenv("CLICKHOUSE_PASSWORD") + if not endpoint or not password: + pytest.skip("ClickHouse env vars not set (CLICKHOUSE_ENDPOINT, CLICKHOUSE_PASSWORD)") + return { + "endpoint": endpoint, + "password": password, + "username": os.getenv("CLICKHOUSE_USERNAME", "default"), + "database": os.getenv("CLICKHOUSE_DATABASE", "otel"), + } + + +def _anthropic_llm_config() -> LLMConfig: + return LLMConfig( + model="claude-3-5-haiku-20241022", + model_endpoint_type="anthropic", + model_endpoint="https://api.anthropic.com/v1", + context_window=200000, + max_tokens=2048, + put_inner_thoughts_in_kwargs=False, + enable_reasoner=False, + ) + + +@pytest.fixture +async def server(): + config = LettaConfig.load() + config.save() + server = SyncServer(init_with_default_org_and_user=True) + await server.init_async() + await server.tool_manager.upsert_base_tools_async(actor=server.default_user) + yield server + + +@pytest.fixture +async def actor(server: SyncServer): + return server.default_user + + +@pytest.fixture +def clickhouse_settings(): + env = _require_clickhouse_env() + + original_values = { + "endpoint": settings.clickhouse_endpoint, + "username": settings.clickhouse_username, + "password": settings.clickhouse_password, + "database": settings.clickhouse_database, + "store_llm_traces": settings.store_llm_traces, + "provider_trace_backend": telemetry_settings.provider_trace_backend, + } + + settings.clickhouse_endpoint = env["endpoint"] + settings.clickhouse_username = env["username"] + settings.clickhouse_password = env["password"] + settings.clickhouse_database = env["database"] + settings.store_llm_traces = True + + # Configure telemetry to use clickhouse backend (set the underlying field, not the property) + telemetry_settings.provider_trace_backend = "clickhouse" + # Clear the cached backends so they get recreated with new settings + get_provider_trace_backends.cache_clear() + + yield + + settings.clickhouse_endpoint = original_values["endpoint"] + settings.clickhouse_username = original_values["username"] + settings.clickhouse_password = original_values["password"] + settings.clickhouse_database = original_values["database"] + settings.store_llm_traces = original_values["store_llm_traces"] + telemetry_settings.provider_trace_backend = original_values["provider_trace_backend"] + # Clear cache again to restore original backends + get_provider_trace_backends.cache_clear() + + +async def _wait_for_raw_trace(step_id: str, organization_id: str, timeout_seconds: int = 30): + """Wait for a trace to appear in ClickHouse. + + With async_insert + wait_for_async_insert=1, traces should appear quickly, + but we poll to handle any propagation delay. + """ + reader = get_llm_trace_reader() + deadline = time.time() + timeout_seconds + + while time.time() < deadline: + trace = await reader.get_by_step_id_async(step_id=step_id, organization_id=organization_id) + if trace is not None: + return trace + await asyncio.sleep(0.5) + + raise AssertionError(f"Timed out waiting for raw trace with step_id={step_id}") + + +@pytest.mark.asyncio +async def test_agent_message_stored_in_clickhouse(server: SyncServer, actor, clickhouse_settings): + """Test that agent step traces are stored with all fields including llm_config_json.""" + message_text = f"ClickHouse trace test {uuid.uuid4()}" + llm_config = _anthropic_llm_config() + + agent_state = await server.agent_manager.create_agent_async( + CreateAgent( + name=f"clickhouse_agent_{uuid.uuid4().hex[:8]}", + llm_config=llm_config, + embedding_config=EmbeddingConfig.default_config(model_name="letta"), + ), + actor=actor, + ) + + agent = LettaAgentV3(agent_state=agent_state, actor=actor) + run = await server.run_manager.create_run( + Run(agent_id=agent_state.id), + actor=actor, + ) + run_id = run.id + response = await agent.step( + [MessageCreate(role=MessageRole.user, content=[TextContent(text=message_text)])], + run_id=run_id, + ) + + step_id = next(msg.step_id for msg in reversed(response.messages) if msg.step_id) + trace = await _wait_for_raw_trace(step_id=step_id, organization_id=actor.organization_id) + + # Basic trace fields + assert trace.step_id == step_id + assert message_text in trace.request_json + assert trace.is_error is False + assert trace.error_type is None + assert trace.error_message is None + + # Verify llm_config_json is stored and contains expected fields + assert trace.llm_config_json, "llm_config_json should not be empty" + config_data = json.loads(trace.llm_config_json) + assert config_data.get("model") == llm_config.model + assert "context_window" in config_data + assert "max_tokens" in config_data + + # Token usage should be populated + assert trace.prompt_tokens > 0 + assert trace.completion_tokens >= 0 + assert trace.total_tokens > 0 + + +@pytest.mark.asyncio +async def test_summary_stored_with_content_and_usage(server: SyncServer, actor, clickhouse_settings): + """Test that summarization traces are stored with content, usage, and cache info.""" + step_id = f"step-{uuid.uuid4()}" + llm_config = _anthropic_llm_config() + summary_source_messages = [ + Message(role=MessageRole.system, content=[TextContent(text="System prompt")]), + Message(role=MessageRole.user, content=[TextContent(text="User message 1")]), + Message(role=MessageRole.assistant, content=[TextContent(text="Assistant response 1")]), + Message(role=MessageRole.user, content=[TextContent(text="User message 2")]), + ] + + summary_text = await simple_summary( + messages=summary_source_messages, + llm_config=llm_config, + actor=actor, + agent_id=f"agent-{uuid.uuid4()}", + agent_tags=["test", "clickhouse"], + run_id=f"run-{uuid.uuid4()}", + step_id=step_id, + compaction_settings={"mode": "partial_evict", "message_buffer_limit": 60}, + ) + + trace = await _wait_for_raw_trace(step_id=step_id, organization_id=actor.organization_id) + + # Basic assertions + assert trace.step_id == step_id + assert trace.call_type == "summarization" + assert trace.is_error is False + + # Verify llm_config_json is stored + assert trace.llm_config_json, "llm_config_json should not be empty" + config_data = json.loads(trace.llm_config_json) + assert config_data.get("model") == llm_config.model + + # Verify summary content in response + summary_in_response = False + try: + response_payload = json.loads(trace.response_json) + if isinstance(response_payload, dict): + if "choices" in response_payload: + content = response_payload.get("choices", [{}])[0].get("message", {}).get("content", "") + summary_in_response = summary_text.strip() in (content or "") + elif "content" in response_payload: + summary_in_response = summary_text.strip() in (response_payload.get("content") or "") + except Exception: + summary_in_response = False + + assert summary_in_response or summary_text in trace.response_json + + # Token usage should be populated + assert trace.prompt_tokens > 0 + assert trace.total_tokens > 0 + + # Cache fields may or may not be populated depending on provider response + # Just verify they're accessible (not erroring) + _ = trace.cached_input_tokens + _ = trace.cache_write_tokens + _ = trace.reasoning_tokens + + +@pytest.mark.asyncio +async def test_error_trace_stored_in_clickhouse(server: SyncServer, actor, clickhouse_settings): + """Test that error traces are stored with is_error=True and error details.""" + from letta.llm_api.anthropic_client import AnthropicClient + + step_id = f"step-error-{uuid.uuid4()}" + + # Create a client with invalid config to trigger an error + invalid_llm_config = LLMConfig( + model="invalid-model-that-does-not-exist", + model_endpoint_type="anthropic", + model_endpoint="https://api.anthropic.com/v1", + context_window=200000, + max_tokens=2048, + ) + + from letta.services.telemetry_manager import TelemetryManager + + client = AnthropicClient() + client.set_telemetry_context( + telemetry_manager=TelemetryManager(), + agent_id=f"agent-{uuid.uuid4()}", + run_id=f"run-{uuid.uuid4()}", + step_id=step_id, + call_type="agent_step", + org_id=actor.organization_id, + ) + client.actor = actor + + # Make a request that will fail + request_data = { + "model": invalid_llm_config.model, + "messages": [{"role": "user", "content": "test"}], + "max_tokens": 100, + } + + try: + await client.request_async_with_telemetry(request_data, invalid_llm_config) + except Exception: + pass # Expected to fail + + # Wait for the error trace to be written + trace = await _wait_for_raw_trace(step_id=step_id, organization_id=actor.organization_id) + + # Verify error fields + assert trace.step_id == step_id + assert trace.is_error is True + assert trace.error_type is not None, "error_type should be set for error traces" + assert trace.error_message is not None, "error_message should be set for error traces" + + # Verify llm_config_json is still stored even for errors + assert trace.llm_config_json, "llm_config_json should be stored even for error traces" + config_data = json.loads(trace.llm_config_json) + assert config_data.get("model") == invalid_llm_config.model + + +@pytest.mark.asyncio +async def test_cache_tokens_stored_for_anthropic(server: SyncServer, actor, clickhouse_settings): + """Test that Anthropic cache tokens (cached_input_tokens, cache_write_tokens) are stored. + + Note: This test verifies the fields are properly stored when present in the response. + Actual cache token values depend on Anthropic's prompt caching behavior. + """ + message_text = f"Cache test {uuid.uuid4()}" + llm_config = _anthropic_llm_config() + + agent_state = await server.agent_manager.create_agent_async( + CreateAgent( + name=f"cache_test_agent_{uuid.uuid4().hex[:8]}", + llm_config=llm_config, + embedding_config=EmbeddingConfig.default_config(model_name="letta"), + ), + actor=actor, + ) + + agent = LettaAgentV3(agent_state=agent_state, actor=actor) + run = await server.run_manager.create_run( + Run(agent_id=agent_state.id), + actor=actor, + ) + + # Make two requests - second may benefit from caching + response1 = await agent.step( + [MessageCreate(role=MessageRole.user, content=[TextContent(text=message_text)])], + run_id=run.id, + ) + step_id_1 = next(msg.step_id for msg in reversed(response1.messages) if msg.step_id) + + response2 = await agent.step( + [MessageCreate(role=MessageRole.user, content=[TextContent(text="Follow up question")])], + run_id=run.id, + ) + step_id_2 = next(msg.step_id for msg in reversed(response2.messages) if msg.step_id) + + # Check traces for both requests + trace1 = await _wait_for_raw_trace(step_id=step_id_1, organization_id=actor.organization_id) + trace2 = await _wait_for_raw_trace(step_id=step_id_2, organization_id=actor.organization_id) + + # Verify cache fields are accessible (may be None if no caching occurred) + # The important thing is they're stored correctly when present + for trace in [trace1, trace2]: + assert trace.prompt_tokens > 0 + # Cache fields should be stored (may be None or int) + assert trace.cached_input_tokens is None or isinstance(trace.cached_input_tokens, int) + assert trace.cache_write_tokens is None or isinstance(trace.cache_write_tokens, int) + assert trace.reasoning_tokens is None or isinstance(trace.reasoning_tokens, int) + + # Verify llm_config_json + assert trace.llm_config_json + config_data = json.loads(trace.llm_config_json) + assert config_data.get("model") == llm_config.model