Files
letta-server/tests/test_provider_trace_backends.py
Kian Jones 25d54dd896 chore: enable F821, F401, W293 (#9503)
* auto fixes

* auto fix pt2 and transitive deps and undefined var checking locals()

* manual fixes (ignored or letta-code fixed)

* fix circular import
2026-02-24 10:55:08 -08:00

402 lines
15 KiB
Python

"""Unit tests for provider trace backends."""
import json
import os
import socket
import tempfile
import threading
from unittest.mock import patch
import pytest
from letta.schemas.provider_trace import ProviderTrace
from letta.schemas.user import User
from letta.services.provider_trace_backends.base import ProviderTraceBackend
from letta.services.provider_trace_backends.socket import SocketProviderTraceBackend
@pytest.fixture
def mock_actor():
"""Create a mock user/actor."""
return User(
id="user-00000000-0000-4000-8000-000000000000",
organization_id="org-00000000-0000-4000-8000-000000000000",
name="test_user",
)
@pytest.fixture
def sample_provider_trace():
"""Create a sample ProviderTrace."""
return ProviderTrace(
request_json={
"model": "gpt-4o-mini",
"messages": [{"role": "user", "content": "Hello"}],
},
response_json={
"id": "chatcmpl-xyz",
"model": "gpt-4o-mini",
"choices": [{"message": {"content": "Hi!"}}],
"usage": {"prompt_tokens": 10, "completion_tokens": 5},
},
step_id="step-test-789",
run_id="run-test-abc",
)
class TestProviderTraceBackendEnum:
"""Tests for ProviderTraceBackend enum."""
def test_enum_values(self):
assert ProviderTraceBackend.POSTGRES.value == "postgres"
assert ProviderTraceBackend.CLICKHOUSE.value == "clickhouse"
assert ProviderTraceBackend.SOCKET.value == "socket"
def test_enum_string_comparison(self):
assert ProviderTraceBackend.POSTGRES == "postgres"
assert ProviderTraceBackend.SOCKET == "socket"
class TestProviderTrace:
"""Tests for ProviderTrace schema."""
def test_id_generation(self):
"""Test that ID is auto-generated with correct prefix."""
trace = ProviderTrace(
request_json={"model": "test"},
response_json={"id": "test"},
step_id="step-123",
)
assert trace.id.startswith("provider_trace-")
def test_id_uniqueness(self):
"""Test that each instance gets a unique ID."""
trace1 = ProviderTrace(request_json={}, response_json={}, step_id="step-1")
trace2 = ProviderTrace(request_json={}, response_json={}, step_id="step-2")
assert trace1.id != trace2.id
def test_optional_fields(self):
"""Test optional telemetry fields."""
trace = ProviderTrace(
request_json={},
response_json={},
step_id="step-123",
agent_id="agent-456",
agent_tags=["env:dev", "team:ml"],
call_type="summarization",
run_id="run-789",
)
assert trace.agent_id == "agent-456"
assert trace.agent_tags == ["env:dev", "team:ml"]
assert trace.call_type == "summarization"
assert trace.run_id == "run-789"
def test_v2_protocol_fields(self):
"""Test v2 protocol fields (org_id, user_id, compaction_settings, llm_config)."""
trace = ProviderTrace(
request_json={},
response_json={},
step_id="step-123",
org_id="org-123",
user_id="user-123",
compaction_settings={"mode": "sliding_window", "target_message_count": 50},
llm_config={"model": "gpt-4", "temperature": 0.7},
)
assert trace.org_id == "org-123"
assert trace.user_id == "user-123"
assert trace.compaction_settings == {"mode": "sliding_window", "target_message_count": 50}
assert trace.llm_config == {"model": "gpt-4", "temperature": 0.7}
def test_v2_fields_mutually_exclusive_by_convention(self):
"""Test that compaction_settings is set for summarization, llm_config for non-summarization."""
summarization_trace = ProviderTrace(
request_json={},
response_json={},
step_id="step-123",
call_type="summarization",
compaction_settings={"mode": "partial_evict"},
llm_config=None,
)
assert summarization_trace.call_type == "summarization"
assert summarization_trace.compaction_settings is not None
assert summarization_trace.llm_config is None
agent_step_trace = ProviderTrace(
request_json={},
response_json={},
step_id="step-456",
call_type="agent_step",
compaction_settings=None,
llm_config={"model": "claude-3"},
)
assert agent_step_trace.call_type == "agent_step"
assert agent_step_trace.compaction_settings is None
assert agent_step_trace.llm_config is not None
class TestSocketProviderTraceBackend:
"""Tests for SocketProviderTraceBackend."""
def test_init_default_path(self):
"""Test default socket path."""
backend = SocketProviderTraceBackend()
assert backend.socket_path == "/var/run/telemetry/telemetry.sock"
def test_init_custom_path(self):
"""Test custom socket path."""
backend = SocketProviderTraceBackend(socket_path="/tmp/custom.sock")
assert backend.socket_path == "/tmp/custom.sock"
@pytest.mark.asyncio
async def test_create_async_returns_provider_trace(self, mock_actor, sample_provider_trace):
"""Test that create_async returns a ProviderTrace."""
backend = SocketProviderTraceBackend(socket_path="/nonexistent/path.sock")
result = await backend.create_async(
actor=mock_actor,
provider_trace=sample_provider_trace,
)
assert isinstance(result, ProviderTrace)
assert result.id == sample_provider_trace.id
assert result.step_id == sample_provider_trace.step_id
@pytest.mark.asyncio
async def test_get_by_step_id_returns_none(self, mock_actor):
"""Test that read operations return None (write-only backend)."""
backend = SocketProviderTraceBackend()
result = await backend.get_by_step_id_async(
step_id="step-123",
actor=mock_actor,
)
assert result is None
def test_send_to_socket_with_real_socket(self, sample_provider_trace):
"""Test sending data to a real Unix socket."""
received_data = []
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
# Create a simple socket server
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server_sock.bind(socket_path)
server_sock.listen(1)
server_sock.settimeout(5.0)
def accept_connection():
try:
conn, _ = server_sock.accept()
data = conn.recv(65536)
received_data.append(data.decode())
conn.close()
except socket.timeout:
pass # Expected - test socket has short timeout, data may not arrive
finally:
server_sock.close()
# Start server in background
server_thread = threading.Thread(target=accept_connection)
server_thread.start()
# Send data via backend
backend = SocketProviderTraceBackend(socket_path=socket_path)
backend._send_to_crouton(sample_provider_trace)
# Wait for send to complete
server_thread.join(timeout=5.0)
# Verify data was received
assert len(received_data) == 1
record = json.loads(received_data[0].strip())
assert record["provider_trace_id"] == sample_provider_trace.id
assert record["step_id"] == "step-test-789"
assert record["run_id"] == "run-test-abc"
assert record["request"]["model"] == "gpt-4o-mini"
assert record["response"]["usage"]["prompt_tokens"] == 10
assert record["response"]["usage"]["completion_tokens"] == 5
def test_send_to_nonexistent_socket_does_not_raise(self, sample_provider_trace):
"""Test that sending to nonexistent socket fails silently."""
backend = SocketProviderTraceBackend(socket_path="/nonexistent/path.sock")
# Should not raise
backend._send_to_crouton(sample_provider_trace)
def test_record_extracts_usage_from_openai_response(self):
"""Test usage extraction from OpenAI-style response."""
trace = ProviderTrace(
request_json={"model": "gpt-4"},
response_json={
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50,
}
},
step_id="step-123",
)
backend = SocketProviderTraceBackend(socket_path="/fake/path")
# Access internal method to build record
with patch.object(backend, "_send_async"):
backend._send_to_crouton(trace)
def test_record_extracts_usage_from_anthropic_response(self):
"""Test usage extraction from Anthropic-style response."""
trace = ProviderTrace(
request_json={"model": "claude-3"},
response_json={
"usage": {
"input_tokens": 100,
"output_tokens": 50,
}
},
step_id="step-123",
)
backend = SocketProviderTraceBackend(socket_path="/fake/path")
with patch.object(backend, "_send_async"):
backend._send_to_crouton(trace)
def test_record_extracts_error_from_response(self):
"""Test error extraction from response."""
trace = ProviderTrace(
request_json={"model": "gpt-4"},
response_json={
"error": {"message": "Rate limit exceeded"},
},
step_id="step-123",
)
backend = SocketProviderTraceBackend(socket_path="/fake/path")
# Capture the record sent to _send_async
captured_records = []
def capture_record(record):
captured_records.append(record)
with patch.object(backend, "_send_async", side_effect=capture_record):
backend._send_to_crouton(trace)
assert len(captured_records) == 1
assert captured_records[0]["error"] == "Rate limit exceeded"
assert captured_records[0]["response"] is None
def test_record_includes_v3_protocol_fields(self):
"""Test that v3 protocol fields are included in the socket record."""
trace = ProviderTrace(
request_json={"model": "gpt-4"},
response_json={"id": "test"},
step_id="step-123",
org_id="org-456",
user_id="user-456",
compaction_settings={"mode": "sliding_window"},
llm_config={"model": "gpt-4", "temperature": 0.5},
)
backend = SocketProviderTraceBackend(socket_path="/fake/path")
captured_records = []
def capture_record(record):
captured_records.append(record)
with patch.object(backend, "_send_async", side_effect=capture_record):
backend._send_to_crouton(trace)
assert len(captured_records) == 1
record = captured_records[0]
assert record["protocol_version"] == 3
assert record["org_id"] == "org-456"
assert record["user_id"] == "user-456"
assert record["compaction_settings"] == {"mode": "sliding_window"}
assert record["llm_config"] == {"model": "gpt-4", "temperature": 0.5}
class TestBackendFactory:
"""Tests for backend factory."""
def test_get_postgres_backend(self):
"""Test getting postgres backend."""
from letta.services.provider_trace_backends.factory import _create_backend
backend = _create_backend("postgres")
assert backend.__class__.__name__ == "PostgresProviderTraceBackend"
def test_get_socket_backend(self):
"""Test getting socket backend."""
with patch("letta.settings.telemetry_settings") as mock_settings:
mock_settings.socket_path = "/tmp/test.sock"
from letta.services.provider_trace_backends.factory import _create_backend
backend = _create_backend("socket")
assert backend.__class__.__name__ == "SocketProviderTraceBackend"
def test_get_multiple_backends(self):
"""Test getting multiple backends via environment."""
from letta.services.provider_trace_backends.factory import (
get_provider_trace_backends,
)
# Clear cache first
get_provider_trace_backends.cache_clear()
# This test just verifies the factory works - actual backend list
# depends on env var LETTA_TELEMETRY_PROVIDER_TRACE_BACKEND
backends = get_provider_trace_backends()
assert len(backends) >= 1
assert all(hasattr(b, "create_async") and hasattr(b, "get_by_step_id_async") for b in backends)
def test_unknown_backend_defaults_to_postgres(self):
"""Test that unknown backend type defaults to postgres."""
from letta.services.provider_trace_backends.factory import _create_backend
backend = _create_backend("unknown_backend")
assert backend.__class__.__name__ == "PostgresProviderTraceBackend"
class TestTelemetrySettings:
"""Tests for telemetry settings."""
def test_provider_trace_backends_parsing(self):
"""Test parsing comma-separated backend list."""
from letta.settings import TelemetrySettings
# Create a fresh settings object and set the value directly
settings = TelemetrySettings(provider_trace_backend="postgres,socket")
backends = settings.provider_trace_backends
assert backends == ["postgres", "socket"]
def test_provider_trace_backends_single(self):
"""Test single backend."""
from letta.settings import TelemetrySettings
settings = TelemetrySettings(provider_trace_backend="postgres")
backends = settings.provider_trace_backends
assert backends == ["postgres"]
def test_provider_trace_backends_with_whitespace(self):
"""Test backend list with whitespace."""
from letta.settings import TelemetrySettings
settings = TelemetrySettings(provider_trace_backend="postgres , socket , clickhouse")
backends = settings.provider_trace_backends
assert backends == ["postgres", "socket", "clickhouse"]
def test_socket_backend_enabled(self):
"""Test socket_backend_enabled property."""
from letta.settings import TelemetrySettings
settings1 = TelemetrySettings(provider_trace_backend="postgres")
assert settings1.socket_backend_enabled is False
settings2 = TelemetrySettings(provider_trace_backend="postgres,socket")
assert settings2.socket_backend_enabled is True