* auto fixes * auto fix pt2 and transitive deps and undefined var checking locals() * manual fixes (ignored or letta-code fixed) * fix circular import
402 lines
15 KiB
Python
402 lines
15 KiB
Python
"""Unit tests for provider trace backends."""
|
|
|
|
import json
|
|
import os
|
|
import socket
|
|
import tempfile
|
|
import threading
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from letta.schemas.provider_trace import ProviderTrace
|
|
from letta.schemas.user import User
|
|
from letta.services.provider_trace_backends.base import ProviderTraceBackend
|
|
from letta.services.provider_trace_backends.socket import SocketProviderTraceBackend
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_actor():
|
|
"""Create a mock user/actor."""
|
|
return User(
|
|
id="user-00000000-0000-4000-8000-000000000000",
|
|
organization_id="org-00000000-0000-4000-8000-000000000000",
|
|
name="test_user",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_provider_trace():
|
|
"""Create a sample ProviderTrace."""
|
|
return ProviderTrace(
|
|
request_json={
|
|
"model": "gpt-4o-mini",
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
},
|
|
response_json={
|
|
"id": "chatcmpl-xyz",
|
|
"model": "gpt-4o-mini",
|
|
"choices": [{"message": {"content": "Hi!"}}],
|
|
"usage": {"prompt_tokens": 10, "completion_tokens": 5},
|
|
},
|
|
step_id="step-test-789",
|
|
run_id="run-test-abc",
|
|
)
|
|
|
|
|
|
class TestProviderTraceBackendEnum:
|
|
"""Tests for ProviderTraceBackend enum."""
|
|
|
|
def test_enum_values(self):
|
|
assert ProviderTraceBackend.POSTGRES.value == "postgres"
|
|
assert ProviderTraceBackend.CLICKHOUSE.value == "clickhouse"
|
|
assert ProviderTraceBackend.SOCKET.value == "socket"
|
|
|
|
def test_enum_string_comparison(self):
|
|
assert ProviderTraceBackend.POSTGRES == "postgres"
|
|
assert ProviderTraceBackend.SOCKET == "socket"
|
|
|
|
|
|
class TestProviderTrace:
|
|
"""Tests for ProviderTrace schema."""
|
|
|
|
def test_id_generation(self):
|
|
"""Test that ID is auto-generated with correct prefix."""
|
|
trace = ProviderTrace(
|
|
request_json={"model": "test"},
|
|
response_json={"id": "test"},
|
|
step_id="step-123",
|
|
)
|
|
assert trace.id.startswith("provider_trace-")
|
|
|
|
def test_id_uniqueness(self):
|
|
"""Test that each instance gets a unique ID."""
|
|
trace1 = ProviderTrace(request_json={}, response_json={}, step_id="step-1")
|
|
trace2 = ProviderTrace(request_json={}, response_json={}, step_id="step-2")
|
|
assert trace1.id != trace2.id
|
|
|
|
def test_optional_fields(self):
|
|
"""Test optional telemetry fields."""
|
|
trace = ProviderTrace(
|
|
request_json={},
|
|
response_json={},
|
|
step_id="step-123",
|
|
agent_id="agent-456",
|
|
agent_tags=["env:dev", "team:ml"],
|
|
call_type="summarization",
|
|
run_id="run-789",
|
|
)
|
|
assert trace.agent_id == "agent-456"
|
|
assert trace.agent_tags == ["env:dev", "team:ml"]
|
|
assert trace.call_type == "summarization"
|
|
assert trace.run_id == "run-789"
|
|
|
|
def test_v2_protocol_fields(self):
|
|
"""Test v2 protocol fields (org_id, user_id, compaction_settings, llm_config)."""
|
|
trace = ProviderTrace(
|
|
request_json={},
|
|
response_json={},
|
|
step_id="step-123",
|
|
org_id="org-123",
|
|
user_id="user-123",
|
|
compaction_settings={"mode": "sliding_window", "target_message_count": 50},
|
|
llm_config={"model": "gpt-4", "temperature": 0.7},
|
|
)
|
|
assert trace.org_id == "org-123"
|
|
assert trace.user_id == "user-123"
|
|
assert trace.compaction_settings == {"mode": "sliding_window", "target_message_count": 50}
|
|
assert trace.llm_config == {"model": "gpt-4", "temperature": 0.7}
|
|
|
|
def test_v2_fields_mutually_exclusive_by_convention(self):
|
|
"""Test that compaction_settings is set for summarization, llm_config for non-summarization."""
|
|
summarization_trace = ProviderTrace(
|
|
request_json={},
|
|
response_json={},
|
|
step_id="step-123",
|
|
call_type="summarization",
|
|
compaction_settings={"mode": "partial_evict"},
|
|
llm_config=None,
|
|
)
|
|
assert summarization_trace.call_type == "summarization"
|
|
assert summarization_trace.compaction_settings is not None
|
|
assert summarization_trace.llm_config is None
|
|
|
|
agent_step_trace = ProviderTrace(
|
|
request_json={},
|
|
response_json={},
|
|
step_id="step-456",
|
|
call_type="agent_step",
|
|
compaction_settings=None,
|
|
llm_config={"model": "claude-3"},
|
|
)
|
|
assert agent_step_trace.call_type == "agent_step"
|
|
assert agent_step_trace.compaction_settings is None
|
|
assert agent_step_trace.llm_config is not None
|
|
|
|
|
|
class TestSocketProviderTraceBackend:
|
|
"""Tests for SocketProviderTraceBackend."""
|
|
|
|
def test_init_default_path(self):
|
|
"""Test default socket path."""
|
|
backend = SocketProviderTraceBackend()
|
|
assert backend.socket_path == "/var/run/telemetry/telemetry.sock"
|
|
|
|
def test_init_custom_path(self):
|
|
"""Test custom socket path."""
|
|
backend = SocketProviderTraceBackend(socket_path="/tmp/custom.sock")
|
|
assert backend.socket_path == "/tmp/custom.sock"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_async_returns_provider_trace(self, mock_actor, sample_provider_trace):
|
|
"""Test that create_async returns a ProviderTrace."""
|
|
backend = SocketProviderTraceBackend(socket_path="/nonexistent/path.sock")
|
|
|
|
result = await backend.create_async(
|
|
actor=mock_actor,
|
|
provider_trace=sample_provider_trace,
|
|
)
|
|
|
|
assert isinstance(result, ProviderTrace)
|
|
assert result.id == sample_provider_trace.id
|
|
assert result.step_id == sample_provider_trace.step_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_step_id_returns_none(self, mock_actor):
|
|
"""Test that read operations return None (write-only backend)."""
|
|
backend = SocketProviderTraceBackend()
|
|
|
|
result = await backend.get_by_step_id_async(
|
|
step_id="step-123",
|
|
actor=mock_actor,
|
|
)
|
|
|
|
assert result is None
|
|
|
|
def test_send_to_socket_with_real_socket(self, sample_provider_trace):
|
|
"""Test sending data to a real Unix socket."""
|
|
received_data = []
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
socket_path = os.path.join(tmpdir, "test.sock")
|
|
|
|
# Create a simple socket server
|
|
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
server_sock.bind(socket_path)
|
|
server_sock.listen(1)
|
|
server_sock.settimeout(5.0)
|
|
|
|
def accept_connection():
|
|
try:
|
|
conn, _ = server_sock.accept()
|
|
data = conn.recv(65536)
|
|
received_data.append(data.decode())
|
|
conn.close()
|
|
except socket.timeout:
|
|
pass # Expected - test socket has short timeout, data may not arrive
|
|
finally:
|
|
server_sock.close()
|
|
|
|
# Start server in background
|
|
server_thread = threading.Thread(target=accept_connection)
|
|
server_thread.start()
|
|
|
|
# Send data via backend
|
|
backend = SocketProviderTraceBackend(socket_path=socket_path)
|
|
backend._send_to_crouton(sample_provider_trace)
|
|
|
|
# Wait for send to complete
|
|
server_thread.join(timeout=5.0)
|
|
|
|
# Verify data was received
|
|
assert len(received_data) == 1
|
|
record = json.loads(received_data[0].strip())
|
|
assert record["provider_trace_id"] == sample_provider_trace.id
|
|
assert record["step_id"] == "step-test-789"
|
|
assert record["run_id"] == "run-test-abc"
|
|
assert record["request"]["model"] == "gpt-4o-mini"
|
|
assert record["response"]["usage"]["prompt_tokens"] == 10
|
|
assert record["response"]["usage"]["completion_tokens"] == 5
|
|
|
|
def test_send_to_nonexistent_socket_does_not_raise(self, sample_provider_trace):
|
|
"""Test that sending to nonexistent socket fails silently."""
|
|
backend = SocketProviderTraceBackend(socket_path="/nonexistent/path.sock")
|
|
|
|
# Should not raise
|
|
backend._send_to_crouton(sample_provider_trace)
|
|
|
|
def test_record_extracts_usage_from_openai_response(self):
|
|
"""Test usage extraction from OpenAI-style response."""
|
|
trace = ProviderTrace(
|
|
request_json={"model": "gpt-4"},
|
|
response_json={
|
|
"usage": {
|
|
"prompt_tokens": 100,
|
|
"completion_tokens": 50,
|
|
}
|
|
},
|
|
step_id="step-123",
|
|
)
|
|
|
|
backend = SocketProviderTraceBackend(socket_path="/fake/path")
|
|
|
|
# Access internal method to build record
|
|
with patch.object(backend, "_send_async"):
|
|
backend._send_to_crouton(trace)
|
|
|
|
def test_record_extracts_usage_from_anthropic_response(self):
|
|
"""Test usage extraction from Anthropic-style response."""
|
|
trace = ProviderTrace(
|
|
request_json={"model": "claude-3"},
|
|
response_json={
|
|
"usage": {
|
|
"input_tokens": 100,
|
|
"output_tokens": 50,
|
|
}
|
|
},
|
|
step_id="step-123",
|
|
)
|
|
|
|
backend = SocketProviderTraceBackend(socket_path="/fake/path")
|
|
|
|
with patch.object(backend, "_send_async"):
|
|
backend._send_to_crouton(trace)
|
|
|
|
def test_record_extracts_error_from_response(self):
|
|
"""Test error extraction from response."""
|
|
trace = ProviderTrace(
|
|
request_json={"model": "gpt-4"},
|
|
response_json={
|
|
"error": {"message": "Rate limit exceeded"},
|
|
},
|
|
step_id="step-123",
|
|
)
|
|
|
|
backend = SocketProviderTraceBackend(socket_path="/fake/path")
|
|
|
|
# Capture the record sent to _send_async
|
|
captured_records = []
|
|
|
|
def capture_record(record):
|
|
captured_records.append(record)
|
|
|
|
with patch.object(backend, "_send_async", side_effect=capture_record):
|
|
backend._send_to_crouton(trace)
|
|
|
|
assert len(captured_records) == 1
|
|
assert captured_records[0]["error"] == "Rate limit exceeded"
|
|
assert captured_records[0]["response"] is None
|
|
|
|
def test_record_includes_v3_protocol_fields(self):
|
|
"""Test that v3 protocol fields are included in the socket record."""
|
|
trace = ProviderTrace(
|
|
request_json={"model": "gpt-4"},
|
|
response_json={"id": "test"},
|
|
step_id="step-123",
|
|
org_id="org-456",
|
|
user_id="user-456",
|
|
compaction_settings={"mode": "sliding_window"},
|
|
llm_config={"model": "gpt-4", "temperature": 0.5},
|
|
)
|
|
|
|
backend = SocketProviderTraceBackend(socket_path="/fake/path")
|
|
|
|
captured_records = []
|
|
|
|
def capture_record(record):
|
|
captured_records.append(record)
|
|
|
|
with patch.object(backend, "_send_async", side_effect=capture_record):
|
|
backend._send_to_crouton(trace)
|
|
|
|
assert len(captured_records) == 1
|
|
record = captured_records[0]
|
|
assert record["protocol_version"] == 3
|
|
assert record["org_id"] == "org-456"
|
|
assert record["user_id"] == "user-456"
|
|
assert record["compaction_settings"] == {"mode": "sliding_window"}
|
|
assert record["llm_config"] == {"model": "gpt-4", "temperature": 0.5}
|
|
|
|
|
|
class TestBackendFactory:
|
|
"""Tests for backend factory."""
|
|
|
|
def test_get_postgres_backend(self):
|
|
"""Test getting postgres backend."""
|
|
from letta.services.provider_trace_backends.factory import _create_backend
|
|
|
|
backend = _create_backend("postgres")
|
|
assert backend.__class__.__name__ == "PostgresProviderTraceBackend"
|
|
|
|
def test_get_socket_backend(self):
|
|
"""Test getting socket backend."""
|
|
with patch("letta.settings.telemetry_settings") as mock_settings:
|
|
mock_settings.socket_path = "/tmp/test.sock"
|
|
|
|
from letta.services.provider_trace_backends.factory import _create_backend
|
|
|
|
backend = _create_backend("socket")
|
|
assert backend.__class__.__name__ == "SocketProviderTraceBackend"
|
|
|
|
def test_get_multiple_backends(self):
|
|
"""Test getting multiple backends via environment."""
|
|
|
|
from letta.services.provider_trace_backends.factory import (
|
|
get_provider_trace_backends,
|
|
)
|
|
|
|
# Clear cache first
|
|
get_provider_trace_backends.cache_clear()
|
|
|
|
# This test just verifies the factory works - actual backend list
|
|
# depends on env var LETTA_TELEMETRY_PROVIDER_TRACE_BACKEND
|
|
backends = get_provider_trace_backends()
|
|
assert len(backends) >= 1
|
|
assert all(hasattr(b, "create_async") and hasattr(b, "get_by_step_id_async") for b in backends)
|
|
|
|
def test_unknown_backend_defaults_to_postgres(self):
|
|
"""Test that unknown backend type defaults to postgres."""
|
|
from letta.services.provider_trace_backends.factory import _create_backend
|
|
|
|
backend = _create_backend("unknown_backend")
|
|
assert backend.__class__.__name__ == "PostgresProviderTraceBackend"
|
|
|
|
|
|
class TestTelemetrySettings:
|
|
"""Tests for telemetry settings."""
|
|
|
|
def test_provider_trace_backends_parsing(self):
|
|
"""Test parsing comma-separated backend list."""
|
|
from letta.settings import TelemetrySettings
|
|
|
|
# Create a fresh settings object and set the value directly
|
|
settings = TelemetrySettings(provider_trace_backend="postgres,socket")
|
|
backends = settings.provider_trace_backends
|
|
assert backends == ["postgres", "socket"]
|
|
|
|
def test_provider_trace_backends_single(self):
|
|
"""Test single backend."""
|
|
from letta.settings import TelemetrySettings
|
|
|
|
settings = TelemetrySettings(provider_trace_backend="postgres")
|
|
backends = settings.provider_trace_backends
|
|
assert backends == ["postgres"]
|
|
|
|
def test_provider_trace_backends_with_whitespace(self):
|
|
"""Test backend list with whitespace."""
|
|
from letta.settings import TelemetrySettings
|
|
|
|
settings = TelemetrySettings(provider_trace_backend="postgres , socket , clickhouse")
|
|
backends = settings.provider_trace_backends
|
|
assert backends == ["postgres", "socket", "clickhouse"]
|
|
|
|
def test_socket_backend_enabled(self):
|
|
"""Test socket_backend_enabled property."""
|
|
from letta.settings import TelemetrySettings
|
|
|
|
settings1 = TelemetrySettings(provider_trace_backend="postgres")
|
|
assert settings1.socket_backend_enabled is False
|
|
|
|
settings2 = TelemetrySettings(provider_trace_backend="postgres,socket")
|
|
assert settings2.socket_backend_enabled is True
|