Files
letta-server/tests/test_provider_trace_backends.py
Kian Jones 9418ab9815 feat: add provider trace backend abstraction for multi-backend telemetry (#8814)
* feat: add provider trace backend abstraction for multi-backend telemetry

Introduces a pluggable backend system for provider traces:
- Base class with async/sync create and read interfaces
- PostgreSQL backend (existing behavior)
- ClickHouse backend (via OTEL instrumentation)
- Socket backend (writes to Unix socket for crouton sidecar)
- Factory for instantiating backends from config

Refactors TelemetryManager to use backends with support for:
- Multi-backend writes (concurrent via asyncio.gather)
- Primary backend for reads (first in config list)
- Graceful error handling per backend

Config: LETTA_TELEMETRY_PROVIDER_TRACE_BACKEND (comma-separated)
Example: "postgres,socket" for dual-write to Postgres and crouton

🐙 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* feat: add protocol version to socket backend records

Adds PROTOCOL_VERSION constant to socket backend:
- Included in every telemetry record sent to crouton
- Must match ProtocolVersion in apps/crouton/main.go
- Enables crouton to detect and reject incompatible messages

🐙 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* fix: remove organization_id from ProviderTraceCreate calls

The organization_id is now handled via the actor parameter in the
telemetry manager, not through ProviderTraceCreate schema. This fixes
validation errors after changing ProviderTraceCreate to inherit from
BaseProviderTrace which forbids extra fields.

🐙 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* consolidate provider trace

* add clickhouse-connect to fix bug on main lmao

* auto generated sdk changes, and deployment details, and clikchouse prefix bug and added fields to runs trace return api

* auto generated sdk changes, and deployment details, and clikchouse prefix bug and added fields to runs trace return api

* consolidate provider trace

* consolidate provider trace bug fix

---------

Co-authored-by: Letta <noreply@letta.com>
2026-01-19 15:54:43 -08:00

333 lines
12 KiB
Python

"""Unit tests for provider trace backends."""
import asyncio
import json
import os
import socket
import tempfile
import threading
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from letta.schemas.provider_trace import ProviderTrace
from letta.schemas.user import User
from letta.services.provider_trace_backends.base import ProviderTraceBackend, ProviderTraceBackendClient
from letta.services.provider_trace_backends.socket import SocketProviderTraceBackend
@pytest.fixture
def mock_actor():
"""Create a mock user/actor."""
return User(
id="user-00000000-0000-4000-8000-000000000000",
organization_id="org-00000000-0000-4000-8000-000000000000",
name="test_user",
)
@pytest.fixture
def sample_provider_trace():
"""Create a sample ProviderTrace."""
return ProviderTrace(
request_json={
"model": "gpt-4o-mini",
"messages": [{"role": "user", "content": "Hello"}],
},
response_json={
"id": "chatcmpl-xyz",
"model": "gpt-4o-mini",
"choices": [{"message": {"content": "Hi!"}}],
"usage": {"prompt_tokens": 10, "completion_tokens": 5},
},
step_id="step-test-789",
run_id="run-test-abc",
)
class TestProviderTraceBackendEnum:
"""Tests for ProviderTraceBackend enum."""
def test_enum_values(self):
assert ProviderTraceBackend.POSTGRES.value == "postgres"
assert ProviderTraceBackend.CLICKHOUSE.value == "clickhouse"
assert ProviderTraceBackend.SOCKET.value == "socket"
def test_enum_string_comparison(self):
assert ProviderTraceBackend.POSTGRES == "postgres"
assert ProviderTraceBackend.SOCKET == "socket"
class TestProviderTrace:
"""Tests for ProviderTrace schema."""
def test_id_generation(self):
"""Test that ID is auto-generated with correct prefix."""
trace = ProviderTrace(
request_json={"model": "test"},
response_json={"id": "test"},
step_id="step-123",
)
assert trace.id.startswith("provider_trace-")
def test_id_uniqueness(self):
"""Test that each instance gets a unique ID."""
trace1 = ProviderTrace(request_json={}, response_json={}, step_id="step-1")
trace2 = ProviderTrace(request_json={}, response_json={}, step_id="step-2")
assert trace1.id != trace2.id
def test_optional_fields(self):
"""Test optional telemetry fields."""
trace = ProviderTrace(
request_json={},
response_json={},
step_id="step-123",
agent_id="agent-456",
agent_tags=["env:dev", "team:ml"],
call_type="summarization",
run_id="run-789",
)
assert trace.agent_id == "agent-456"
assert trace.agent_tags == ["env:dev", "team:ml"]
assert trace.call_type == "summarization"
assert trace.run_id == "run-789"
class TestSocketProviderTraceBackend:
"""Tests for SocketProviderTraceBackend."""
def test_init_default_path(self):
"""Test default socket path."""
backend = SocketProviderTraceBackend()
assert backend.socket_path == "/var/run/telemetry/telemetry.sock"
def test_init_custom_path(self):
"""Test custom socket path."""
backend = SocketProviderTraceBackend(socket_path="/tmp/custom.sock")
assert backend.socket_path == "/tmp/custom.sock"
@pytest.mark.asyncio
async def test_create_async_returns_provider_trace(self, mock_actor, sample_provider_trace):
"""Test that create_async returns a ProviderTrace."""
backend = SocketProviderTraceBackend(socket_path="/nonexistent/path.sock")
result = await backend.create_async(
actor=mock_actor,
provider_trace=sample_provider_trace,
)
assert isinstance(result, ProviderTrace)
assert result.id == sample_provider_trace.id
assert result.step_id == sample_provider_trace.step_id
@pytest.mark.asyncio
async def test_get_by_step_id_returns_none(self, mock_actor):
"""Test that read operations return None (write-only backend)."""
backend = SocketProviderTraceBackend()
result = await backend.get_by_step_id_async(
step_id="step-123",
actor=mock_actor,
)
assert result is None
def test_send_to_socket_with_real_socket(self, sample_provider_trace):
"""Test sending data to a real Unix socket."""
received_data = []
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
# Create a simple socket server
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server_sock.bind(socket_path)
server_sock.listen(1)
server_sock.settimeout(5.0)
def accept_connection():
try:
conn, _ = server_sock.accept()
data = conn.recv(65536)
received_data.append(data.decode())
conn.close()
except socket.timeout:
pass # Expected - test socket has short timeout, data may not arrive
finally:
server_sock.close()
# Start server in background
server_thread = threading.Thread(target=accept_connection)
server_thread.start()
# Send data via backend
backend = SocketProviderTraceBackend(socket_path=socket_path)
backend._send_to_crouton(sample_provider_trace)
# Wait for send to complete
server_thread.join(timeout=5.0)
# Verify data was received
assert len(received_data) == 1
record = json.loads(received_data[0].strip())
assert record["provider_trace_id"] == sample_provider_trace.id
assert record["model"] == "gpt-4o-mini"
assert record["provider"] == "openai"
assert record["input_tokens"] == 10
assert record["output_tokens"] == 5
assert record["context"]["step_id"] == "step-test-789"
assert record["context"]["run_id"] == "run-test-abc"
def test_send_to_nonexistent_socket_does_not_raise(self, sample_provider_trace):
"""Test that sending to nonexistent socket fails silently."""
backend = SocketProviderTraceBackend(socket_path="/nonexistent/path.sock")
# Should not raise
backend._send_to_crouton(sample_provider_trace)
def test_record_extracts_usage_from_openai_response(self):
"""Test usage extraction from OpenAI-style response."""
trace = ProviderTrace(
request_json={"model": "gpt-4"},
response_json={
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50,
}
},
step_id="step-123",
)
backend = SocketProviderTraceBackend(socket_path="/fake/path")
# Access internal method to build record
with patch.object(backend, "_send_async"):
backend._send_to_crouton(trace)
def test_record_extracts_usage_from_anthropic_response(self):
"""Test usage extraction from Anthropic-style response."""
trace = ProviderTrace(
request_json={"model": "claude-3"},
response_json={
"usage": {
"input_tokens": 100,
"output_tokens": 50,
}
},
step_id="step-123",
)
backend = SocketProviderTraceBackend(socket_path="/fake/path")
with patch.object(backend, "_send_async"):
backend._send_to_crouton(trace)
def test_record_extracts_error_from_response(self):
"""Test error extraction from response."""
trace = ProviderTrace(
request_json={"model": "gpt-4"},
response_json={
"error": {"message": "Rate limit exceeded"},
},
step_id="step-123",
)
backend = SocketProviderTraceBackend(socket_path="/fake/path")
# Capture the record sent to _send_async
captured_records = []
def capture_record(record):
captured_records.append(record)
with patch.object(backend, "_send_async", side_effect=capture_record):
backend._send_to_crouton(trace)
assert len(captured_records) == 1
assert captured_records[0]["error"] == "Rate limit exceeded"
assert captured_records[0]["response"] is None
class TestBackendFactory:
"""Tests for backend factory."""
def test_get_postgres_backend(self):
"""Test getting postgres backend."""
from letta.services.provider_trace_backends.factory import _create_backend
backend = _create_backend("postgres")
assert backend.__class__.__name__ == "PostgresProviderTraceBackend"
def test_get_socket_backend(self):
"""Test getting socket backend."""
with patch("letta.settings.telemetry_settings") as mock_settings:
mock_settings.socket_path = "/tmp/test.sock"
from letta.services.provider_trace_backends.factory import _create_backend
backend = _create_backend("socket")
assert backend.__class__.__name__ == "SocketProviderTraceBackend"
def test_get_multiple_backends(self):
"""Test getting multiple backends via environment."""
import os
from letta.services.provider_trace_backends.factory import (
get_provider_trace_backends,
)
# Clear cache first
get_provider_trace_backends.cache_clear()
# This test just verifies the factory works - actual backend list
# depends on env var LETTA_TELEMETRY_PROVIDER_TRACE_BACKEND
backends = get_provider_trace_backends()
assert len(backends) >= 1
assert all(hasattr(b, "create_async") and hasattr(b, "get_by_step_id_async") for b in backends)
def test_unknown_backend_defaults_to_postgres(self):
"""Test that unknown backend type defaults to postgres."""
from letta.services.provider_trace_backends.factory import _create_backend
backend = _create_backend("unknown_backend")
assert backend.__class__.__name__ == "PostgresProviderTraceBackend"
class TestTelemetrySettings:
"""Tests for telemetry settings."""
def test_provider_trace_backends_parsing(self):
"""Test parsing comma-separated backend list."""
from letta.settings import TelemetrySettings
# Create a fresh settings object and set the value directly
settings = TelemetrySettings(provider_trace_backend="postgres,socket")
backends = settings.provider_trace_backends
assert backends == ["postgres", "socket"]
def test_provider_trace_backends_single(self):
"""Test single backend."""
from letta.settings import TelemetrySettings
settings = TelemetrySettings(provider_trace_backend="postgres")
backends = settings.provider_trace_backends
assert backends == ["postgres"]
def test_provider_trace_backends_with_whitespace(self):
"""Test backend list with whitespace."""
from letta.settings import TelemetrySettings
settings = TelemetrySettings(provider_trace_backend="postgres , socket , clickhouse")
backends = settings.provider_trace_backends
assert backends == ["postgres", "socket", "clickhouse"]
def test_socket_backend_enabled(self):
"""Test socket_backend_enabled property."""
from letta.settings import TelemetrySettings
settings1 = TelemetrySettings(provider_trace_backend="postgres")
assert settings1.socket_backend_enabled is False
settings2 = TelemetrySettings(provider_trace_backend="postgres,socket")
assert settings2.socket_backend_enabled is True