Files
letta-server/tests/test_provider_trace_backends.py
Kian Jones 00b36bc591 fix: resolve crouton telemetry failures (#9269)
Two issues were causing telemetry failures:
1. Startup race - memgpt-server sending telemetry before crouton created socket
2. Oversized payloads - large context windows (1M+ tokens) exceeding buffer

Changes:
- Increase crouton buffer to 128MB max with lazy allocation (64KB initial)
- Bump crouton resources (512Mi limit, 128Mi request)
- Add retry with exponential backoff in socket backend
- Move crouton to initContainers with restartPolicy: Always for deterministic startup

🐙 Generated with [Letta Code](https://letta.com)

Co-authored-by: Letta <noreply@letta.com>
2026-02-24 10:52:06 -08:00

404 lines
15 KiB
Python

"""Unit tests for provider trace backends."""
import asyncio
import json
import os
import socket
import tempfile
import threading
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from letta.schemas.provider_trace import ProviderTrace
from letta.schemas.user import User
from letta.services.provider_trace_backends.base import ProviderTraceBackend, ProviderTraceBackendClient
from letta.services.provider_trace_backends.socket import SocketProviderTraceBackend
@pytest.fixture
def mock_actor():
"""Create a mock user/actor."""
return User(
id="user-00000000-0000-4000-8000-000000000000",
organization_id="org-00000000-0000-4000-8000-000000000000",
name="test_user",
)
@pytest.fixture
def sample_provider_trace():
"""Create a sample ProviderTrace."""
return ProviderTrace(
request_json={
"model": "gpt-4o-mini",
"messages": [{"role": "user", "content": "Hello"}],
},
response_json={
"id": "chatcmpl-xyz",
"model": "gpt-4o-mini",
"choices": [{"message": {"content": "Hi!"}}],
"usage": {"prompt_tokens": 10, "completion_tokens": 5},
},
step_id="step-test-789",
run_id="run-test-abc",
)
class TestProviderTraceBackendEnum:
"""Tests for ProviderTraceBackend enum."""
def test_enum_values(self):
assert ProviderTraceBackend.POSTGRES.value == "postgres"
assert ProviderTraceBackend.CLICKHOUSE.value == "clickhouse"
assert ProviderTraceBackend.SOCKET.value == "socket"
def test_enum_string_comparison(self):
assert ProviderTraceBackend.POSTGRES == "postgres"
assert ProviderTraceBackend.SOCKET == "socket"
class TestProviderTrace:
"""Tests for ProviderTrace schema."""
def test_id_generation(self):
"""Test that ID is auto-generated with correct prefix."""
trace = ProviderTrace(
request_json={"model": "test"},
response_json={"id": "test"},
step_id="step-123",
)
assert trace.id.startswith("provider_trace-")
def test_id_uniqueness(self):
"""Test that each instance gets a unique ID."""
trace1 = ProviderTrace(request_json={}, response_json={}, step_id="step-1")
trace2 = ProviderTrace(request_json={}, response_json={}, step_id="step-2")
assert trace1.id != trace2.id
def test_optional_fields(self):
"""Test optional telemetry fields."""
trace = ProviderTrace(
request_json={},
response_json={},
step_id="step-123",
agent_id="agent-456",
agent_tags=["env:dev", "team:ml"],
call_type="summarization",
run_id="run-789",
)
assert trace.agent_id == "agent-456"
assert trace.agent_tags == ["env:dev", "team:ml"]
assert trace.call_type == "summarization"
assert trace.run_id == "run-789"
def test_v2_protocol_fields(self):
"""Test v2 protocol fields (org_id, user_id, compaction_settings, llm_config)."""
trace = ProviderTrace(
request_json={},
response_json={},
step_id="step-123",
org_id="org-123",
user_id="user-123",
compaction_settings={"mode": "sliding_window", "target_message_count": 50},
llm_config={"model": "gpt-4", "temperature": 0.7},
)
assert trace.org_id == "org-123"
assert trace.user_id == "user-123"
assert trace.compaction_settings == {"mode": "sliding_window", "target_message_count": 50}
assert trace.llm_config == {"model": "gpt-4", "temperature": 0.7}
def test_v2_fields_mutually_exclusive_by_convention(self):
"""Test that compaction_settings is set for summarization, llm_config for non-summarization."""
summarization_trace = ProviderTrace(
request_json={},
response_json={},
step_id="step-123",
call_type="summarization",
compaction_settings={"mode": "partial_evict"},
llm_config=None,
)
assert summarization_trace.call_type == "summarization"
assert summarization_trace.compaction_settings is not None
assert summarization_trace.llm_config is None
agent_step_trace = ProviderTrace(
request_json={},
response_json={},
step_id="step-456",
call_type="agent_step",
compaction_settings=None,
llm_config={"model": "claude-3"},
)
assert agent_step_trace.call_type == "agent_step"
assert agent_step_trace.compaction_settings is None
assert agent_step_trace.llm_config is not None
class TestSocketProviderTraceBackend:
"""Tests for SocketProviderTraceBackend."""
def test_init_default_path(self):
"""Test default socket path."""
backend = SocketProviderTraceBackend()
assert backend.socket_path == "/var/run/telemetry/telemetry.sock"
def test_init_custom_path(self):
"""Test custom socket path."""
backend = SocketProviderTraceBackend(socket_path="/tmp/custom.sock")
assert backend.socket_path == "/tmp/custom.sock"
@pytest.mark.asyncio
async def test_create_async_returns_provider_trace(self, mock_actor, sample_provider_trace):
"""Test that create_async returns a ProviderTrace."""
backend = SocketProviderTraceBackend(socket_path="/nonexistent/path.sock")
result = await backend.create_async(
actor=mock_actor,
provider_trace=sample_provider_trace,
)
assert isinstance(result, ProviderTrace)
assert result.id == sample_provider_trace.id
assert result.step_id == sample_provider_trace.step_id
@pytest.mark.asyncio
async def test_get_by_step_id_returns_none(self, mock_actor):
"""Test that read operations return None (write-only backend)."""
backend = SocketProviderTraceBackend()
result = await backend.get_by_step_id_async(
step_id="step-123",
actor=mock_actor,
)
assert result is None
def test_send_to_socket_with_real_socket(self, sample_provider_trace):
"""Test sending data to a real Unix socket."""
received_data = []
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
# Create a simple socket server
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server_sock.bind(socket_path)
server_sock.listen(1)
server_sock.settimeout(5.0)
def accept_connection():
try:
conn, _ = server_sock.accept()
data = conn.recv(65536)
received_data.append(data.decode())
conn.close()
except socket.timeout:
pass # Expected - test socket has short timeout, data may not arrive
finally:
server_sock.close()
# Start server in background
server_thread = threading.Thread(target=accept_connection)
server_thread.start()
# Send data via backend
backend = SocketProviderTraceBackend(socket_path=socket_path)
backend._send_to_crouton(sample_provider_trace)
# Wait for send to complete
server_thread.join(timeout=5.0)
# Verify data was received
assert len(received_data) == 1
record = json.loads(received_data[0].strip())
assert record["provider_trace_id"] == sample_provider_trace.id
assert record["step_id"] == "step-test-789"
assert record["run_id"] == "run-test-abc"
assert record["request"]["model"] == "gpt-4o-mini"
assert record["response"]["usage"]["prompt_tokens"] == 10
assert record["response"]["usage"]["completion_tokens"] == 5
def test_send_to_nonexistent_socket_does_not_raise(self, sample_provider_trace):
"""Test that sending to nonexistent socket fails silently."""
backend = SocketProviderTraceBackend(socket_path="/nonexistent/path.sock")
# Should not raise
backend._send_to_crouton(sample_provider_trace)
def test_record_extracts_usage_from_openai_response(self):
"""Test usage extraction from OpenAI-style response."""
trace = ProviderTrace(
request_json={"model": "gpt-4"},
response_json={
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50,
}
},
step_id="step-123",
)
backend = SocketProviderTraceBackend(socket_path="/fake/path")
# Access internal method to build record
with patch.object(backend, "_send_async"):
backend._send_to_crouton(trace)
def test_record_extracts_usage_from_anthropic_response(self):
"""Test usage extraction from Anthropic-style response."""
trace = ProviderTrace(
request_json={"model": "claude-3"},
response_json={
"usage": {
"input_tokens": 100,
"output_tokens": 50,
}
},
step_id="step-123",
)
backend = SocketProviderTraceBackend(socket_path="/fake/path")
with patch.object(backend, "_send_async"):
backend._send_to_crouton(trace)
def test_record_extracts_error_from_response(self):
"""Test error extraction from response."""
trace = ProviderTrace(
request_json={"model": "gpt-4"},
response_json={
"error": {"message": "Rate limit exceeded"},
},
step_id="step-123",
)
backend = SocketProviderTraceBackend(socket_path="/fake/path")
# Capture the record sent to _send_async
captured_records = []
def capture_record(record):
captured_records.append(record)
with patch.object(backend, "_send_async", side_effect=capture_record):
backend._send_to_crouton(trace)
assert len(captured_records) == 1
assert captured_records[0]["error"] == "Rate limit exceeded"
assert captured_records[0]["response"] is None
def test_record_includes_v3_protocol_fields(self):
"""Test that v3 protocol fields are included in the socket record."""
trace = ProviderTrace(
request_json={"model": "gpt-4"},
response_json={"id": "test"},
step_id="step-123",
org_id="org-456",
user_id="user-456",
compaction_settings={"mode": "sliding_window"},
llm_config={"model": "gpt-4", "temperature": 0.5},
)
backend = SocketProviderTraceBackend(socket_path="/fake/path")
captured_records = []
def capture_record(record):
captured_records.append(record)
with patch.object(backend, "_send_async", side_effect=capture_record):
backend._send_to_crouton(trace)
assert len(captured_records) == 1
record = captured_records[0]
assert record["protocol_version"] == 3
assert record["org_id"] == "org-456"
assert record["user_id"] == "user-456"
assert record["compaction_settings"] == {"mode": "sliding_window"}
assert record["llm_config"] == {"model": "gpt-4", "temperature": 0.5}
class TestBackendFactory:
"""Tests for backend factory."""
def test_get_postgres_backend(self):
"""Test getting postgres backend."""
from letta.services.provider_trace_backends.factory import _create_backend
backend = _create_backend("postgres")
assert backend.__class__.__name__ == "PostgresProviderTraceBackend"
def test_get_socket_backend(self):
"""Test getting socket backend."""
with patch("letta.settings.telemetry_settings") as mock_settings:
mock_settings.socket_path = "/tmp/test.sock"
from letta.services.provider_trace_backends.factory import _create_backend
backend = _create_backend("socket")
assert backend.__class__.__name__ == "SocketProviderTraceBackend"
def test_get_multiple_backends(self):
"""Test getting multiple backends via environment."""
import os
from letta.services.provider_trace_backends.factory import (
get_provider_trace_backends,
)
# Clear cache first
get_provider_trace_backends.cache_clear()
# This test just verifies the factory works - actual backend list
# depends on env var LETTA_TELEMETRY_PROVIDER_TRACE_BACKEND
backends = get_provider_trace_backends()
assert len(backends) >= 1
assert all(hasattr(b, "create_async") and hasattr(b, "get_by_step_id_async") for b in backends)
def test_unknown_backend_defaults_to_postgres(self):
"""Test that unknown backend type defaults to postgres."""
from letta.services.provider_trace_backends.factory import _create_backend
backend = _create_backend("unknown_backend")
assert backend.__class__.__name__ == "PostgresProviderTraceBackend"
class TestTelemetrySettings:
"""Tests for telemetry settings."""
def test_provider_trace_backends_parsing(self):
"""Test parsing comma-separated backend list."""
from letta.settings import TelemetrySettings
# Create a fresh settings object and set the value directly
settings = TelemetrySettings(provider_trace_backend="postgres,socket")
backends = settings.provider_trace_backends
assert backends == ["postgres", "socket"]
def test_provider_trace_backends_single(self):
"""Test single backend."""
from letta.settings import TelemetrySettings
settings = TelemetrySettings(provider_trace_backend="postgres")
backends = settings.provider_trace_backends
assert backends == ["postgres"]
def test_provider_trace_backends_with_whitespace(self):
"""Test backend list with whitespace."""
from letta.settings import TelemetrySettings
settings = TelemetrySettings(provider_trace_backend="postgres , socket , clickhouse")
backends = settings.provider_trace_backends
assert backends == ["postgres", "socket", "clickhouse"]
def test_socket_backend_enabled(self):
"""Test socket_backend_enabled property."""
from letta.settings import TelemetrySettings
settings1 = TelemetrySettings(provider_trace_backend="postgres")
assert settings1.socket_backend_enabled is False
settings2 = TelemetrySettings(provider_trace_backend="postgres,socket")
assert settings2.socket_backend_enabled is True