* auto fixes * auto fix pt2 and transitive deps and undefined var checking locals() * manual fixes (ignored or letta-code fixed) * fix circular import
403 lines
15 KiB
Python
403 lines
15 KiB
Python
"""
|
|
Comprehensive tests for provider trace telemetry.
|
|
|
|
Tests verify that provider traces are correctly created with all telemetry context
|
|
(agent_id, agent_tags, run_id, step_id, call_type) across:
|
|
- Agent steps (non-streaming and streaming)
|
|
- Tool calls
|
|
- Summarization calls
|
|
- Different agent architectures (V2, V3)
|
|
"""
|
|
|
|
import asyncio
|
|
import os
|
|
import threading
|
|
import time
|
|
import uuid
|
|
|
|
import pytest
|
|
from dotenv import load_dotenv
|
|
from letta_client import Letta
|
|
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.letta_message_content import TextContent
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.message import MessageCreate
|
|
|
|
|
|
def _run_server():
|
|
"""Starts the Letta server in a background thread."""
|
|
load_dotenv()
|
|
from letta.server.rest_api.app import start_server
|
|
|
|
start_server(debug=True)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def server_url():
|
|
"""Ensures a server is running and returns its base URL."""
|
|
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
|
|
|
if not os.getenv("LETTA_SERVER_URL"):
|
|
thread = threading.Thread(target=_run_server, daemon=True)
|
|
thread.start()
|
|
time.sleep(5)
|
|
|
|
return url
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def client(server_url):
|
|
"""Creates a REST client for testing."""
|
|
client = Letta(base_url=server_url)
|
|
yield client
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def event_loop(request):
|
|
"""Create an instance of the default event loop for each test case."""
|
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
|
yield loop
|
|
loop.close()
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def roll_dice_tool(client, roll_dice_tool_func):
|
|
tool = client.tools.upsert_from_function(func=roll_dice_tool_func)
|
|
yield tool
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def weather_tool(client, weather_tool_func):
|
|
tool = client.tools.upsert_from_function(func=weather_tool_func)
|
|
yield tool
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def print_tool(client, print_tool_func):
|
|
tool = client.tools.upsert_from_function(func=print_tool_func)
|
|
yield tool
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def agent_state(client, roll_dice_tool, weather_tool):
|
|
"""Creates an agent with tools and ensures cleanup after tests."""
|
|
agent_state = client.agents.create(
|
|
name=f"test_provider_trace_{str(uuid.uuid4())[:8]}",
|
|
tool_ids=[roll_dice_tool.id, weather_tool.id],
|
|
include_base_tools=True,
|
|
tags=["test", "provider-trace"],
|
|
memory_blocks=[
|
|
{"label": "human", "value": "Name: TestUser"},
|
|
{"label": "persona", "value": "Helpful test agent"},
|
|
],
|
|
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
|
|
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
|
)
|
|
yield agent_state
|
|
client.agents.delete(agent_state.id)
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def agent_state_with_tags(client, weather_tool):
|
|
"""Creates an agent with specific tags for tag verification tests."""
|
|
agent_state = client.agents.create(
|
|
name=f"test_tagged_agent_{str(uuid.uuid4())[:8]}",
|
|
tool_ids=[weather_tool.id],
|
|
include_base_tools=True,
|
|
tags=["env:test", "team:telemetry", "version:v1"],
|
|
memory_blocks=[
|
|
{"label": "human", "value": "Name: TagTestUser"},
|
|
{"label": "persona", "value": "Agent with tags"},
|
|
],
|
|
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
|
|
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
|
)
|
|
yield agent_state
|
|
client.agents.delete(agent_state.id)
|
|
|
|
|
|
class TestProviderTraceBasicStep:
|
|
"""Tests for basic agent step provider traces."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_non_streaming_step_creates_provider_trace(self, client, agent_state):
|
|
"""Verify provider trace is created for non-streaming agent step."""
|
|
response = client.agents.messages.create(
|
|
agent_id=agent_state.id,
|
|
messages=[MessageCreate(role="user", content=[TextContent(text="Hello, how are you?")])],
|
|
)
|
|
|
|
assert len(response.messages) > 0
|
|
step_id = response.messages[-1].step_id
|
|
assert step_id is not None
|
|
|
|
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
|
assert trace is not None
|
|
assert trace.request_json is not None
|
|
assert trace.response_json is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_step_creates_provider_trace(self, client, agent_state):
|
|
"""Verify provider trace is created for streaming agent step."""
|
|
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)[0]
|
|
|
|
stream = client.agents.messages.create_stream(
|
|
agent_id=agent_state.id,
|
|
messages=[MessageCreate(role="user", content=[TextContent(text="Tell me a joke.")])],
|
|
)
|
|
list(stream)
|
|
|
|
messages = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id)
|
|
step_ids = list({msg.step_id for msg in messages if msg.step_id is not None})
|
|
|
|
assert len(step_ids) > 0
|
|
for step_id in step_ids:
|
|
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
|
assert trace is not None
|
|
assert trace.request_json is not None
|
|
assert trace.response_json is not None
|
|
|
|
|
|
class TestProviderTraceWithToolCalls:
|
|
"""Tests for provider traces when tools are called."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_call_step_has_provider_trace(self, client, agent_state):
|
|
"""Verify provider trace exists for steps that invoke tools."""
|
|
response = client.agents.messages.create(
|
|
agent_id=agent_state.id,
|
|
messages=[MessageCreate(role="user", content=[TextContent(text="Get the weather in San Francisco.")])],
|
|
)
|
|
|
|
tool_call_step_id = response.messages[0].step_id
|
|
final_step_id = response.messages[-1].step_id
|
|
|
|
tool_trace = client.telemetry.retrieve_provider_trace(step_id=tool_call_step_id)
|
|
assert tool_trace is not None
|
|
assert tool_trace.request_json is not None
|
|
|
|
if tool_call_step_id != final_step_id:
|
|
final_trace = client.telemetry.retrieve_provider_trace(step_id=final_step_id)
|
|
assert final_trace is not None
|
|
assert final_trace.request_json is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_tool_call_has_provider_trace(self, client, agent_state):
|
|
"""Verify provider trace exists for streaming steps with tool calls."""
|
|
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)[0]
|
|
|
|
stream = client.agents.messages.create_stream(
|
|
agent_id=agent_state.id,
|
|
messages=[MessageCreate(role="user", content=[TextContent(text="Roll the dice for me.")])],
|
|
)
|
|
list(stream)
|
|
|
|
messages = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id)
|
|
step_ids = list({msg.step_id for msg in messages if msg.step_id is not None})
|
|
|
|
assert len(step_ids) > 0
|
|
for step_id in step_ids:
|
|
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
|
assert trace is not None
|
|
assert trace.request_json is not None
|
|
|
|
|
|
class TestProviderTraceTelemetryContext:
|
|
"""Tests verifying telemetry context fields are correctly populated."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_provider_trace_contains_agent_id(self, client, agent_state):
|
|
"""Verify provider trace contains the correct agent_id."""
|
|
response = client.agents.messages.create(
|
|
agent_id=agent_state.id,
|
|
messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])],
|
|
)
|
|
|
|
step_id = response.messages[-1].step_id
|
|
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
|
|
|
assert trace is not None
|
|
assert trace.agent_id == agent_state.id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_provider_trace_contains_agent_tags(self, client, agent_state_with_tags):
|
|
"""Verify provider trace contains the agent's tags."""
|
|
response = client.agents.messages.create(
|
|
agent_id=agent_state_with_tags.id,
|
|
messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])],
|
|
)
|
|
|
|
step_id = response.messages[-1].step_id
|
|
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
|
|
|
assert trace is not None
|
|
assert trace.agent_tags is not None
|
|
assert set(trace.agent_tags) == {"env:test", "team:telemetry", "version:v1"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_provider_trace_contains_step_id(self, client, agent_state):
|
|
"""Verify provider trace step_id matches the message step_id."""
|
|
response = client.agents.messages.create(
|
|
agent_id=agent_state.id,
|
|
messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])],
|
|
)
|
|
|
|
step_id = response.messages[-1].step_id
|
|
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
|
|
|
assert trace is not None
|
|
assert trace.step_id == step_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_provider_trace_contains_run_id_for_async_job(self, client, agent_state):
|
|
"""Verify provider trace contains run_id when created via async job."""
|
|
job = client.agents.messages.create_async(
|
|
agent_id=agent_state.id,
|
|
messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])],
|
|
)
|
|
|
|
while job.status not in ["completed", "failed"]:
|
|
time.sleep(0.5)
|
|
job = client.jobs.retrieve(job.id)
|
|
|
|
assert job.status == "completed"
|
|
|
|
messages = client.agents.messages.list(agent_id=agent_state.id, limit=5)
|
|
step_ids = list({msg.step_id for msg in messages if msg.step_id is not None})
|
|
|
|
assert len(step_ids) > 0
|
|
trace = client.telemetry.retrieve_provider_trace(step_id=step_ids[0])
|
|
assert trace is not None
|
|
assert trace.run_id == job.id
|
|
|
|
|
|
class TestProviderTraceMultiStep:
|
|
"""Tests for provider traces across multiple agent steps."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multi_step_conversation_has_traces_for_each_step(self, client, agent_state):
|
|
"""Verify each step in a multi-step conversation has its own provider trace."""
|
|
response = client.agents.messages.create(
|
|
agent_id=agent_state.id,
|
|
messages=[
|
|
MessageCreate(
|
|
role="user",
|
|
content=[TextContent(text="First, get the weather in NYC. Then roll the dice.")],
|
|
)
|
|
],
|
|
)
|
|
|
|
step_ids = list({msg.step_id for msg in response.messages if msg.step_id is not None})
|
|
|
|
assert len(step_ids) >= 1
|
|
|
|
for step_id in step_ids:
|
|
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
|
assert trace is not None, f"No trace found for step_id={step_id}"
|
|
assert trace.request_json is not None
|
|
assert trace.agent_id == agent_state.id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_consecutive_messages_have_separate_traces(self, client, agent_state):
|
|
"""Verify consecutive messages create separate traces."""
|
|
response1 = client.agents.messages.create(
|
|
agent_id=agent_state.id,
|
|
messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])],
|
|
)
|
|
step_id_1 = response1.messages[-1].step_id
|
|
|
|
response2 = client.agents.messages.create(
|
|
agent_id=agent_state.id,
|
|
messages=[MessageCreate(role="user", content=[TextContent(text="How are you?")])],
|
|
)
|
|
step_id_2 = response2.messages[-1].step_id
|
|
|
|
assert step_id_1 != step_id_2
|
|
|
|
trace1 = client.telemetry.retrieve_provider_trace(step_id=step_id_1)
|
|
trace2 = client.telemetry.retrieve_provider_trace(step_id=step_id_2)
|
|
|
|
assert trace1 is not None
|
|
assert trace2 is not None
|
|
assert trace1.id != trace2.id
|
|
|
|
|
|
class TestProviderTraceRequestResponseContent:
|
|
"""Tests verifying request and response JSON content."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_request_json_contains_model(self, client, agent_state):
|
|
"""Verify request_json contains model information."""
|
|
response = client.agents.messages.create(
|
|
agent_id=agent_state.id,
|
|
messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])],
|
|
)
|
|
|
|
step_id = response.messages[-1].step_id
|
|
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
|
|
|
assert trace is not None
|
|
assert trace.request_json is not None
|
|
assert "model" in trace.request_json
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_request_json_contains_messages(self, client, agent_state):
|
|
"""Verify request_json contains messages array."""
|
|
response = client.agents.messages.create(
|
|
agent_id=agent_state.id,
|
|
messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])],
|
|
)
|
|
|
|
step_id = response.messages[-1].step_id
|
|
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
|
|
|
assert trace is not None
|
|
assert trace.request_json is not None
|
|
assert "messages" in trace.request_json
|
|
assert isinstance(trace.request_json["messages"], list)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_response_json_contains_usage(self, client, agent_state):
|
|
"""Verify response_json contains usage statistics."""
|
|
response = client.agents.messages.create(
|
|
agent_id=agent_state.id,
|
|
messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])],
|
|
)
|
|
|
|
step_id = response.messages[-1].step_id
|
|
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
|
|
|
assert trace is not None
|
|
assert trace.response_json is not None
|
|
assert "usage" in trace.response_json or "usage" in str(trace.response_json)
|
|
|
|
|
|
class TestProviderTraceEdgeCases:
|
|
"""Tests for edge cases and error scenarios."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_nonexistent_step_id_returns_none_or_empty(self, client):
|
|
"""Verify querying nonexistent step_id handles gracefully."""
|
|
fake_step_id = f"step-{uuid.uuid4()}"
|
|
|
|
try:
|
|
trace = client.telemetry.retrieve_provider_trace(step_id=fake_step_id)
|
|
assert trace is None or trace.request_json is None
|
|
except Exception:
|
|
pass
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_message_still_creates_trace(self, client, agent_state):
|
|
"""Verify trace is created even for minimal messages."""
|
|
response = client.agents.messages.create(
|
|
agent_id=agent_state.id,
|
|
messages=[MessageCreate(role="user", content=[TextContent(text="Hi")])],
|
|
)
|
|
|
|
step_id = response.messages[-1].step_id
|
|
assert step_id is not None
|
|
|
|
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
|
assert trace is not None
|