feat(tests): add crouton telemetry tests (#9000)
* test: add comprehensive provider trace telemetry tests Add two test files for provider trace telemetry: 1. test_provider_trace.py - Integration tests for: - Basic agent steps (streaming and non-streaming) - Tool calls - Telemetry context fields (agent_id, agent_tags, step_id, run_id) - Multi-step conversations - Request/response JSON content 2. test_provider_trace_summarization.py - Unit tests for: - simple_summary() telemetry context passing - summarize_all() telemetry pass-through - summarize_via_sliding_window() telemetry pass-through - Summarizer class runtime vs constructor telemetry - LLMClient.set_telemetry_context() method 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * test: add telemetry tests for tool generation, adapters, and agent versions Add comprehensive unit tests for provider trace telemetry: - TestToolGenerationTelemetry: Verify /generate-tool endpoint sets call_type="tool_generation" and has no agent context - TestLLMClientTelemetryContext: Verify LLMClient.set_telemetry_context accepts all telemetry fields - TestAdapterTelemetryAttributes: Verify base adapter and subclasses (LettaLLMRequestAdapter, LettaLLMStreamAdapter) support telemetry attrs - TestSummarizerTelemetry: Verify Summarizer stores and passes telemetry - TestAgentAdapterInstantiation: Verify LettaAgentV2 creates Summarizer with correct agent_id 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * ci: add provider trace telemetry tests to unit test workflow Add the new provider trace test files to the CI matrix: - test_provider_trace_backends.py - test_provider_trace_summarization.py - test_provider_trace_agents.py 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: update socket backend test to match new record structure The socket backend record structure changed - step_id/run_id are now at top level, and model/usage are nested in request/response objects. 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: add step_id to V1 agent telemetry context Pass step_id to set_telemetry_context in both streaming and non-streaming paths in LettaAgent (v1). The step_id is available via step_metrics.id in the non-streaming path and passed explicitly in the streaming path. 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> --------- Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
@@ -1,8 +1,20 @@
|
||||
"""
|
||||
Comprehensive tests for provider trace telemetry.
|
||||
|
||||
Tests verify that provider traces are correctly created with all telemetry context
|
||||
(agent_id, agent_tags, run_id, step_id, call_type) across:
|
||||
- Agent steps (non-streaming and streaming)
|
||||
- Tool calls
|
||||
- Summarization calls
|
||||
- Different agent architectures (V2, V3)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
@@ -30,12 +42,11 @@ def server_url():
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(5) # Allow server startup time
|
||||
time.sleep(5)
|
||||
|
||||
return url
|
||||
|
||||
|
||||
# # --- Client Setup --- #
|
||||
@pytest.fixture(scope="session")
|
||||
def client(server_url):
|
||||
"""Creates a REST client for testing."""
|
||||
@@ -53,38 +64,33 @@ def event_loop(request):
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def roll_dice_tool(client, roll_dice_tool_func):
|
||||
print_tool = client.tools.upsert_from_function(func=roll_dice_tool_func)
|
||||
yield print_tool
|
||||
tool = client.tools.upsert_from_function(func=roll_dice_tool_func)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def weather_tool(client, weather_tool_func):
|
||||
weather_tool = client.tools.upsert_from_function(func=weather_tool_func)
|
||||
yield weather_tool
|
||||
tool = client.tools.upsert_from_function(func=weather_tool_func)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def print_tool(client, print_tool_func):
|
||||
print_tool = client.tools.upsert_from_function(func=print_tool_func)
|
||||
yield print_tool
|
||||
tool = client.tools.upsert_from_function(func=print_tool_func)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def agent_state(client, roll_dice_tool, weather_tool):
|
||||
"""Creates an agent and ensures cleanup after tests."""
|
||||
"""Creates an agent with tools and ensures cleanup after tests."""
|
||||
agent_state = client.agents.create(
|
||||
name=f"test_compl_{str(uuid.uuid4())[5:]}",
|
||||
name=f"test_provider_trace_{str(uuid.uuid4())[:8]}",
|
||||
tool_ids=[roll_dice_tool.id, weather_tool.id],
|
||||
include_base_tools=True,
|
||||
tags=["test", "provider-trace"],
|
||||
memory_blocks=[
|
||||
{
|
||||
"label": "human",
|
||||
"value": "Name: Matt",
|
||||
},
|
||||
{
|
||||
"label": "persona",
|
||||
"value": "Friendly agent",
|
||||
},
|
||||
{"label": "human", "value": "Name: TestUser"},
|
||||
{"label": "persona", "value": "Helpful test agent"},
|
||||
],
|
||||
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
@@ -93,34 +99,305 @@ def agent_state(client, roll_dice_tool, weather_tool):
|
||||
client.agents.delete(agent_state.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
|
||||
async def test_provider_trace_experimental_step(client, message, agent_state):
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id, messages=[MessageCreate(role="user", content=[TextContent(text=message)])]
|
||||
@pytest.fixture(scope="function")
|
||||
def agent_state_with_tags(client, weather_tool):
|
||||
"""Creates an agent with specific tags for tag verification tests."""
|
||||
agent_state = client.agents.create(
|
||||
name=f"test_tagged_agent_{str(uuid.uuid4())[:8]}",
|
||||
tool_ids=[weather_tool.id],
|
||||
include_base_tools=True,
|
||||
tags=["env:test", "team:telemetry", "version:v1"],
|
||||
memory_blocks=[
|
||||
{"label": "human", "value": "Name: TagTestUser"},
|
||||
{"label": "persona", "value": "Agent with tags"},
|
||||
],
|
||||
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
)
|
||||
tool_step = response.messages[0].step_id
|
||||
reply_step = response.messages[-1].step_id
|
||||
|
||||
tool_telemetry = client.telemetry.retrieve_provider_trace(step_id=tool_step)
|
||||
reply_telemetry = client.telemetry.retrieve_provider_trace(step_id=reply_step)
|
||||
assert tool_telemetry.request_json
|
||||
assert reply_telemetry.request_json
|
||||
yield agent_state
|
||||
client.agents.delete(agent_state.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
|
||||
async def test_provider_trace_experimental_step_stream(client, message, agent_state):
|
||||
last_message_id = client.agents.messages.list(agent_id=agent_state.id, limit=1)[0]
|
||||
stream = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id, messages=[MessageCreate(role="user", content=[TextContent(text=message)])]
|
||||
)
|
||||
class TestProviderTraceBasicStep:
|
||||
"""Tests for basic agent step provider traces."""
|
||||
|
||||
list(stream)
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_step_creates_provider_trace(self, client, agent_state):
|
||||
"""Verify provider trace is created for non-streaming agent step."""
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=[TextContent(text="Hello, how are you?")])],
|
||||
)
|
||||
|
||||
messages = client.agents.messages.list(agent_id=agent_state.id, after=last_message_id)
|
||||
step_ids = [id for id in set((message.step_id for message in messages)) if id is not None]
|
||||
for step_id in step_ids:
|
||||
telemetry_data = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
||||
assert telemetry_data.request_json
|
||||
assert telemetry_data.response_json
|
||||
assert len(response.messages) > 0
|
||||
step_id = response.messages[-1].step_id
|
||||
assert step_id is not None
|
||||
|
||||
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
||||
assert trace is not None
|
||||
assert trace.request_json is not None
|
||||
assert trace.response_json is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_step_creates_provider_trace(self, client, agent_state):
|
||||
"""Verify provider trace is created for streaming agent step."""
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)[0]
|
||||
|
||||
stream = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=[TextContent(text="Tell me a joke.")])],
|
||||
)
|
||||
list(stream)
|
||||
|
||||
messages = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id)
|
||||
step_ids = list({msg.step_id for msg in messages if msg.step_id is not None})
|
||||
|
||||
assert len(step_ids) > 0
|
||||
for step_id in step_ids:
|
||||
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
||||
assert trace is not None
|
||||
assert trace.request_json is not None
|
||||
assert trace.response_json is not None
|
||||
|
||||
|
||||
class TestProviderTraceWithToolCalls:
|
||||
"""Tests for provider traces when tools are called."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_step_has_provider_trace(self, client, agent_state):
|
||||
"""Verify provider trace exists for steps that invoke tools."""
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=[TextContent(text="Get the weather in San Francisco.")])],
|
||||
)
|
||||
|
||||
tool_call_step_id = response.messages[0].step_id
|
||||
final_step_id = response.messages[-1].step_id
|
||||
|
||||
tool_trace = client.telemetry.retrieve_provider_trace(step_id=tool_call_step_id)
|
||||
assert tool_trace is not None
|
||||
assert tool_trace.request_json is not None
|
||||
|
||||
if tool_call_step_id != final_step_id:
|
||||
final_trace = client.telemetry.retrieve_provider_trace(step_id=final_step_id)
|
||||
assert final_trace is not None
|
||||
assert final_trace.request_json is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_tool_call_has_provider_trace(self, client, agent_state):
|
||||
"""Verify provider trace exists for streaming steps with tool calls."""
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)[0]
|
||||
|
||||
stream = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=[TextContent(text="Roll the dice for me.")])],
|
||||
)
|
||||
list(stream)
|
||||
|
||||
messages = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id)
|
||||
step_ids = list({msg.step_id for msg in messages if msg.step_id is not None})
|
||||
|
||||
assert len(step_ids) > 0
|
||||
for step_id in step_ids:
|
||||
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
||||
assert trace is not None
|
||||
assert trace.request_json is not None
|
||||
|
||||
|
||||
class TestProviderTraceTelemetryContext:
|
||||
"""Tests verifying telemetry context fields are correctly populated."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_trace_contains_agent_id(self, client, agent_state):
|
||||
"""Verify provider trace contains the correct agent_id."""
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])],
|
||||
)
|
||||
|
||||
step_id = response.messages[-1].step_id
|
||||
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
||||
|
||||
assert trace is not None
|
||||
assert trace.agent_id == agent_state.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_trace_contains_agent_tags(self, client, agent_state_with_tags):
|
||||
"""Verify provider trace contains the agent's tags."""
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state_with_tags.id,
|
||||
messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])],
|
||||
)
|
||||
|
||||
step_id = response.messages[-1].step_id
|
||||
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
||||
|
||||
assert trace is not None
|
||||
assert trace.agent_tags is not None
|
||||
assert set(trace.agent_tags) == {"env:test", "team:telemetry", "version:v1"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_trace_contains_step_id(self, client, agent_state):
|
||||
"""Verify provider trace step_id matches the message step_id."""
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])],
|
||||
)
|
||||
|
||||
step_id = response.messages[-1].step_id
|
||||
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
||||
|
||||
assert trace is not None
|
||||
assert trace.step_id == step_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_trace_contains_run_id_for_async_job(self, client, agent_state):
|
||||
"""Verify provider trace contains run_id when created via async job."""
|
||||
job = client.agents.messages.create_async(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])],
|
||||
)
|
||||
|
||||
while job.status not in ["completed", "failed"]:
|
||||
time.sleep(0.5)
|
||||
job = client.jobs.retrieve(job.id)
|
||||
|
||||
assert job.status == "completed"
|
||||
|
||||
messages = client.agents.messages.list(agent_id=agent_state.id, limit=5)
|
||||
step_ids = list({msg.step_id for msg in messages if msg.step_id is not None})
|
||||
|
||||
assert len(step_ids) > 0
|
||||
trace = client.telemetry.retrieve_provider_trace(step_id=step_ids[0])
|
||||
assert trace is not None
|
||||
assert trace.run_id == job.id
|
||||
|
||||
|
||||
class TestProviderTraceMultiStep:
|
||||
"""Tests for provider traces across multiple agent steps."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_step_conversation_has_traces_for_each_step(self, client, agent_state):
|
||||
"""Verify each step in a multi-step conversation has its own provider trace."""
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=[TextContent(text="First, get the weather in NYC. Then roll the dice.")],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
step_ids = list({msg.step_id for msg in response.messages if msg.step_id is not None})
|
||||
|
||||
assert len(step_ids) >= 1
|
||||
|
||||
for step_id in step_ids:
|
||||
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
||||
assert trace is not None, f"No trace found for step_id={step_id}"
|
||||
assert trace.request_json is not None
|
||||
assert trace.agent_id == agent_state.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consecutive_messages_have_separate_traces(self, client, agent_state):
|
||||
"""Verify consecutive messages create separate traces."""
|
||||
response1 = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])],
|
||||
)
|
||||
step_id_1 = response1.messages[-1].step_id
|
||||
|
||||
response2 = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=[TextContent(text="How are you?")])],
|
||||
)
|
||||
step_id_2 = response2.messages[-1].step_id
|
||||
|
||||
assert step_id_1 != step_id_2
|
||||
|
||||
trace1 = client.telemetry.retrieve_provider_trace(step_id=step_id_1)
|
||||
trace2 = client.telemetry.retrieve_provider_trace(step_id=step_id_2)
|
||||
|
||||
assert trace1 is not None
|
||||
assert trace2 is not None
|
||||
assert trace1.id != trace2.id
|
||||
|
||||
|
||||
class TestProviderTraceRequestResponseContent:
|
||||
"""Tests verifying request and response JSON content."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_json_contains_model(self, client, agent_state):
|
||||
"""Verify request_json contains model information."""
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])],
|
||||
)
|
||||
|
||||
step_id = response.messages[-1].step_id
|
||||
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
||||
|
||||
assert trace is not None
|
||||
assert trace.request_json is not None
|
||||
assert "model" in trace.request_json
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_json_contains_messages(self, client, agent_state):
|
||||
"""Verify request_json contains messages array."""
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])],
|
||||
)
|
||||
|
||||
step_id = response.messages[-1].step_id
|
||||
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
||||
|
||||
assert trace is not None
|
||||
assert trace.request_json is not None
|
||||
assert "messages" in trace.request_json
|
||||
assert isinstance(trace.request_json["messages"], list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_json_contains_usage(self, client, agent_state):
|
||||
"""Verify response_json contains usage statistics."""
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])],
|
||||
)
|
||||
|
||||
step_id = response.messages[-1].step_id
|
||||
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
||||
|
||||
assert trace is not None
|
||||
assert trace.response_json is not None
|
||||
assert "usage" in trace.response_json or "usage" in str(trace.response_json)
|
||||
|
||||
|
||||
class TestProviderTraceEdgeCases:
|
||||
"""Tests for edge cases and error scenarios."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonexistent_step_id_returns_none_or_empty(self, client):
|
||||
"""Verify querying nonexistent step_id handles gracefully."""
|
||||
fake_step_id = f"step-{uuid.uuid4()}"
|
||||
|
||||
try:
|
||||
trace = client.telemetry.retrieve_provider_trace(step_id=fake_step_id)
|
||||
assert trace is None or trace.request_json is None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_message_still_creates_trace(self, client, agent_state):
|
||||
"""Verify trace is created even for minimal messages."""
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=[TextContent(text="Hi")])],
|
||||
)
|
||||
|
||||
step_id = response.messages[-1].step_id
|
||||
assert step_id is not None
|
||||
|
||||
trace = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
||||
assert trace is not None
|
||||
|
||||
408
tests/test_provider_trace_agents.py
Normal file
408
tests/test_provider_trace_agents.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""
|
||||
Unit tests for provider trace telemetry across agent versions and adapters.
|
||||
|
||||
Tests verify that telemetry context is correctly passed through:
|
||||
- Tool generation endpoint
|
||||
- LettaAgent (v1), LettaAgentV2, LettaAgentV3
|
||||
- Streaming and non-streaming paths
|
||||
- Different stream adapters
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_config():
|
||||
"""Create a mock LLM config."""
|
||||
return LLMConfig(
|
||||
model="gpt-4o-mini",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=8000,
|
||||
)
|
||||
|
||||
|
||||
class TestToolGenerationTelemetry:
|
||||
"""Tests for tool generation endpoint telemetry."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_tool_sets_call_type(self, mock_llm_config):
|
||||
"""Verify generate_tool endpoint sets call_type='tool_generation'."""
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.schemas.user import User
|
||||
|
||||
mock_actor = User(
|
||||
id=f"user-{uuid.uuid4()}",
|
||||
organization_id=f"org-{uuid.uuid4()}",
|
||||
name="test_user",
|
||||
)
|
||||
|
||||
captured_telemetry = {}
|
||||
|
||||
def capture_telemetry(**kwargs):
|
||||
captured_telemetry.update(kwargs)
|
||||
|
||||
with patch.object(LLMClient, "create") as mock_create:
|
||||
mock_client = MagicMock()
|
||||
mock_client.set_telemetry_context = capture_telemetry
|
||||
mock_client.build_request_data = MagicMock(return_value={})
|
||||
mock_client.request_async_with_telemetry = AsyncMock(return_value={})
|
||||
mock_client.convert_response_to_chat_completion = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
choices=[
|
||||
MagicMock(
|
||||
message=MagicMock(
|
||||
tool_calls=[
|
||||
MagicMock(
|
||||
function=MagicMock(
|
||||
arguments='{"raw_source_code": "def test(): pass", "sample_args_json": "{}", "pip_requirements_json": "{}"}'
|
||||
)
|
||||
)
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
from letta.server.rest_api.routers.v1.tools import GenerateToolInput, generate_tool_from_prompt
|
||||
|
||||
mock_server = MagicMock()
|
||||
mock_server.user_manager.get_actor_or_default_async = AsyncMock(return_value=mock_actor)
|
||||
mock_server.get_llm_config_from_handle_async = AsyncMock(return_value=mock_llm_config)
|
||||
|
||||
mock_headers = MagicMock()
|
||||
mock_headers.actor_id = mock_actor.id
|
||||
|
||||
request = GenerateToolInput(
|
||||
prompt="Create a function that adds two numbers",
|
||||
tool_name="add_numbers",
|
||||
validation_errors=[],
|
||||
)
|
||||
|
||||
with patch("letta.server.rest_api.routers.v1.tools.derive_openai_json_schema") as mock_schema:
|
||||
mock_schema.return_value = {"name": "add_numbers", "parameters": {}}
|
||||
try:
|
||||
await generate_tool_from_prompt(request=request, server=mock_server, headers=mock_headers)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
assert captured_telemetry.get("call_type") == "tool_generation"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_tool_has_no_agent_context(self, mock_llm_config):
|
||||
"""Verify generate_tool doesn't have agent_id since it's not agent-bound."""
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.schemas.user import User
|
||||
|
||||
mock_actor = User(
|
||||
id=f"user-{uuid.uuid4()}",
|
||||
organization_id=f"org-{uuid.uuid4()}",
|
||||
name="test_user",
|
||||
)
|
||||
|
||||
captured_telemetry = {}
|
||||
|
||||
def capture_telemetry(**kwargs):
|
||||
captured_telemetry.update(kwargs)
|
||||
|
||||
with patch.object(LLMClient, "create") as mock_create:
|
||||
mock_client = MagicMock()
|
||||
mock_client.set_telemetry_context = capture_telemetry
|
||||
mock_client.build_request_data = MagicMock(return_value={})
|
||||
mock_client.request_async_with_telemetry = AsyncMock(return_value={})
|
||||
mock_client.convert_response_to_chat_completion = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
choices=[
|
||||
MagicMock(
|
||||
message=MagicMock(
|
||||
tool_calls=[
|
||||
MagicMock(
|
||||
function=MagicMock(
|
||||
arguments='{"raw_source_code": "def test(): pass", "sample_args_json": "{}", "pip_requirements_json": "{}"}'
|
||||
)
|
||||
)
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
from letta.server.rest_api.routers.v1.tools import GenerateToolInput, generate_tool_from_prompt
|
||||
|
||||
mock_server = MagicMock()
|
||||
mock_server.user_manager.get_actor_or_default_async = AsyncMock(return_value=mock_actor)
|
||||
mock_server.get_llm_config_from_handle_async = AsyncMock(return_value=mock_llm_config)
|
||||
|
||||
mock_headers = MagicMock()
|
||||
mock_headers.actor_id = mock_actor.id
|
||||
|
||||
request = GenerateToolInput(
|
||||
prompt="Create a function",
|
||||
tool_name="test_func",
|
||||
validation_errors=[],
|
||||
)
|
||||
|
||||
with patch("letta.server.rest_api.routers.v1.tools.derive_openai_json_schema") as mock_schema:
|
||||
mock_schema.return_value = {"name": "test_func", "parameters": {}}
|
||||
try:
|
||||
await generate_tool_from_prompt(request=request, server=mock_server, headers=mock_headers)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
assert captured_telemetry.get("agent_id") is None
|
||||
assert captured_telemetry.get("step_id") is None
|
||||
assert captured_telemetry.get("run_id") is None
|
||||
|
||||
|
||||
class TestLLMClientTelemetryContext:
|
||||
"""Tests for LLMClient telemetry context methods."""
|
||||
|
||||
def test_llm_client_has_set_telemetry_context_method(self):
|
||||
"""Verify LLMClient exposes set_telemetry_context."""
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.create(provider_type="openai", put_inner_thoughts_first=True)
|
||||
assert hasattr(client, "set_telemetry_context")
|
||||
assert callable(client.set_telemetry_context)
|
||||
|
||||
def test_llm_client_set_telemetry_context_accepts_all_fields(self):
|
||||
"""Verify set_telemetry_context accepts all telemetry fields."""
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.create(provider_type="openai", put_inner_thoughts_first=True)
|
||||
|
||||
client.set_telemetry_context(
|
||||
agent_id=f"agent-{uuid.uuid4()}",
|
||||
agent_tags=["tag1", "tag2"],
|
||||
run_id=f"run-{uuid.uuid4()}",
|
||||
step_id=f"step-{uuid.uuid4()}",
|
||||
call_type="summarization",
|
||||
)
|
||||
|
||||
|
||||
class TestAdapterTelemetryAttributes:
|
||||
"""Tests for adapter telemetry attribute support."""
|
||||
|
||||
def test_base_adapter_has_telemetry_attributes(self, mock_llm_config):
|
||||
"""Verify base LettaLLMAdapter has telemetry attributes."""
|
||||
from letta.adapters.letta_llm_adapter import LettaLLMAdapter
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
mock_client = LLMClient.create(provider_type="openai", put_inner_thoughts_first=True)
|
||||
|
||||
agent_id = f"agent-{uuid.uuid4()}"
|
||||
agent_tags = ["test-tag"]
|
||||
run_id = f"run-{uuid.uuid4()}"
|
||||
|
||||
class TestAdapter(LettaLLMAdapter):
|
||||
async def invoke_llm(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
adapter = TestAdapter(
|
||||
llm_client=mock_client,
|
||||
llm_config=mock_llm_config,
|
||||
agent_id=agent_id,
|
||||
agent_tags=agent_tags,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
assert adapter.agent_id == agent_id
|
||||
assert adapter.agent_tags == agent_tags
|
||||
assert adapter.run_id == run_id
|
||||
|
||||
def test_request_adapter_inherits_telemetry_attributes(self, mock_llm_config):
|
||||
"""Verify LettaLLMRequestAdapter inherits telemetry attributes."""
|
||||
from letta.adapters.letta_llm_request_adapter import LettaLLMRequestAdapter
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
mock_client = LLMClient.create(provider_type="openai", put_inner_thoughts_first=True)
|
||||
|
||||
agent_id = f"agent-{uuid.uuid4()}"
|
||||
agent_tags = ["request-tag"]
|
||||
run_id = f"run-{uuid.uuid4()}"
|
||||
|
||||
adapter = LettaLLMRequestAdapter(
|
||||
llm_client=mock_client,
|
||||
llm_config=mock_llm_config,
|
||||
agent_id=agent_id,
|
||||
agent_tags=agent_tags,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
assert adapter.agent_id == agent_id
|
||||
assert adapter.agent_tags == agent_tags
|
||||
assert adapter.run_id == run_id
|
||||
|
||||
def test_stream_adapter_inherits_telemetry_attributes(self, mock_llm_config):
|
||||
"""Verify LettaLLMStreamAdapter inherits telemetry attributes."""
|
||||
from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
mock_client = LLMClient.create(provider_type="openai", put_inner_thoughts_first=True)
|
||||
|
||||
agent_id = f"agent-{uuid.uuid4()}"
|
||||
agent_tags = ["stream-tag"]
|
||||
run_id = f"run-{uuid.uuid4()}"
|
||||
|
||||
adapter = LettaLLMStreamAdapter(
|
||||
llm_client=mock_client,
|
||||
llm_config=mock_llm_config,
|
||||
agent_id=agent_id,
|
||||
agent_tags=agent_tags,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
assert adapter.agent_id == agent_id
|
||||
assert adapter.agent_tags == agent_tags
|
||||
assert adapter.run_id == run_id
|
||||
|
||||
def test_request_and_stream_adapters_have_consistent_interface(self, mock_llm_config):
|
||||
"""Verify both adapter types have the same telemetry interface."""
|
||||
from letta.adapters.letta_llm_request_adapter import LettaLLMRequestAdapter
|
||||
from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
mock_client = LLMClient.create(provider_type="openai", put_inner_thoughts_first=True)
|
||||
|
||||
request_adapter = LettaLLMRequestAdapter(llm_client=mock_client, llm_config=mock_llm_config)
|
||||
stream_adapter = LettaLLMStreamAdapter(llm_client=mock_client, llm_config=mock_llm_config)
|
||||
|
||||
for attr in ["agent_id", "agent_tags", "run_id"]:
|
||||
assert hasattr(request_adapter, attr), f"LettaLLMRequestAdapter missing {attr}"
|
||||
assert hasattr(stream_adapter, attr), f"LettaLLMStreamAdapter missing {attr}"
|
||||
|
||||
|
||||
class TestSummarizerTelemetry:
|
||||
"""Tests for Summarizer class telemetry context."""
|
||||
|
||||
def test_summarizer_stores_telemetry_context(self):
|
||||
"""Verify Summarizer stores telemetry context from constructor."""
|
||||
from letta.schemas.user import User
|
||||
from letta.services.summarizer.enums import SummarizationMode
|
||||
from letta.services.summarizer.summarizer import Summarizer
|
||||
|
||||
mock_actor = User(
|
||||
id=f"user-{uuid.uuid4()}",
|
||||
organization_id=f"org-{uuid.uuid4()}",
|
||||
name="test_user",
|
||||
)
|
||||
|
||||
agent_id = f"agent-{uuid.uuid4()}"
|
||||
run_id = f"run-{uuid.uuid4()}"
|
||||
step_id = f"step-{uuid.uuid4()}"
|
||||
|
||||
summarizer = Summarizer(
|
||||
mode=SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER,
|
||||
summarizer_agent=None,
|
||||
message_buffer_limit=100,
|
||||
message_buffer_min=10,
|
||||
partial_evict_summarizer_percentage=0.5,
|
||||
agent_manager=MagicMock(),
|
||||
message_manager=MagicMock(),
|
||||
actor=mock_actor,
|
||||
agent_id=agent_id,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
|
||||
assert summarizer.agent_id == agent_id
|
||||
assert summarizer.run_id == run_id
|
||||
assert summarizer.step_id == step_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_method_accepts_runtime_telemetry(self):
|
||||
"""Verify summarize() method accepts runtime run_id/step_id."""
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.user import User
|
||||
from letta.services.summarizer.enums import SummarizationMode
|
||||
from letta.services.summarizer.summarizer import Summarizer
|
||||
|
||||
mock_actor = User(
|
||||
id=f"user-{uuid.uuid4()}",
|
||||
organization_id=f"org-{uuid.uuid4()}",
|
||||
name="test_user",
|
||||
)
|
||||
|
||||
agent_id = f"agent-{uuid.uuid4()}"
|
||||
mock_messages = [
|
||||
Message(
|
||||
id=f"message-{uuid.uuid4()}",
|
||||
role=MessageRole.user,
|
||||
content=[{"type": "text", "text": "Hello"}],
|
||||
agent_id=agent_id,
|
||||
)
|
||||
]
|
||||
|
||||
summarizer = Summarizer(
|
||||
mode=SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER,
|
||||
summarizer_agent=None,
|
||||
message_buffer_limit=100,
|
||||
message_buffer_min=10,
|
||||
partial_evict_summarizer_percentage=0.5,
|
||||
agent_manager=MagicMock(),
|
||||
message_manager=MagicMock(),
|
||||
actor=mock_actor,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
run_id = f"run-{uuid.uuid4()}"
|
||||
step_id = f"step-{uuid.uuid4()}"
|
||||
|
||||
result = await summarizer.summarize(
|
||||
in_context_messages=mock_messages,
|
||||
new_letta_messages=[],
|
||||
force=False,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
|
||||
class TestAgentAdapterInstantiation:
|
||||
"""Tests verifying agents instantiate adapters with telemetry context."""
|
||||
|
||||
def test_agent_v2_creates_summarizer_with_agent_id(self, mock_llm_config):
|
||||
"""Verify LettaAgentV2 creates Summarizer with correct agent_id."""
|
||||
from letta.agents.letta_agent_v2 import LettaAgentV2
|
||||
from letta.schemas.agent import AgentState, AgentType
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.user import User
|
||||
|
||||
mock_actor = User(
|
||||
id=f"user-{uuid.uuid4()}",
|
||||
organization_id=f"org-{uuid.uuid4()}",
|
||||
name="test_user",
|
||||
)
|
||||
|
||||
agent_id = f"agent-{uuid.uuid4()}"
|
||||
agent_state = AgentState(
|
||||
id=agent_id,
|
||||
name="test_agent",
|
||||
agent_type=AgentType.letta_v1_agent,
|
||||
llm_config=mock_llm_config,
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
tags=["test"],
|
||||
memory=Memory(blocks=[]),
|
||||
system="You are a helpful assistant.",
|
||||
tools=[],
|
||||
sources=[],
|
||||
blocks=[],
|
||||
)
|
||||
|
||||
agent = LettaAgentV2(agent_state=agent_state, actor=mock_actor)
|
||||
|
||||
assert agent.summarizer.agent_id == agent_id
|
||||
@@ -171,12 +171,11 @@ class TestSocketProviderTraceBackend:
|
||||
assert len(received_data) == 1
|
||||
record = json.loads(received_data[0].strip())
|
||||
assert record["provider_trace_id"] == sample_provider_trace.id
|
||||
assert record["model"] == "gpt-4o-mini"
|
||||
assert record["provider"] == "openai"
|
||||
assert record["input_tokens"] == 10
|
||||
assert record["output_tokens"] == 5
|
||||
assert record["context"]["step_id"] == "step-test-789"
|
||||
assert record["context"]["run_id"] == "run-test-abc"
|
||||
assert record["step_id"] == "step-test-789"
|
||||
assert record["run_id"] == "run-test-abc"
|
||||
assert record["request"]["model"] == "gpt-4o-mini"
|
||||
assert record["response"]["usage"]["prompt_tokens"] == 10
|
||||
assert record["response"]["usage"]["completion_tokens"] == 5
|
||||
|
||||
def test_send_to_nonexistent_socket_does_not_raise(self, sample_provider_trace):
|
||||
"""Test that sending to nonexistent socket fails silently."""
|
||||
|
||||
431
tests/test_provider_trace_summarization.py
Normal file
431
tests/test_provider_trace_summarization.py
Normal file
@@ -0,0 +1,431 @@
|
||||
"""
|
||||
Unit tests for summarization provider trace telemetry context.
|
||||
|
||||
These tests verify that summarization LLM calls correctly pass telemetry context
|
||||
(agent_id, agent_tags, run_id, step_id) to the provider trace system.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.user import User
|
||||
from letta.services.summarizer import summarizer_all, summarizer_sliding_window
|
||||
from letta.services.summarizer.enums import SummarizationMode
|
||||
from letta.services.summarizer.summarizer import Summarizer, simple_summary
|
||||
from letta.services.summarizer.summarizer_config import CompactionSettings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_actor():
|
||||
"""Create a mock user/actor."""
|
||||
return User(
|
||||
id=f"user-{uuid.uuid4()}",
|
||||
organization_id=f"org-{uuid.uuid4()}",
|
||||
name="test_user",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_config():
|
||||
"""Create a mock LLM config."""
|
||||
return LLMConfig(
|
||||
model="gpt-4o-mini",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=8000,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent_state(mock_llm_config):
|
||||
"""Create a mock agent state."""
|
||||
agent_id = f"agent-{uuid.uuid4()}"
|
||||
return AgentState(
|
||||
id=agent_id,
|
||||
name="test_agent",
|
||||
llm_config=mock_llm_config,
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
tags=["env:test", "team:ml"],
|
||||
memory=MagicMock(
|
||||
compile=MagicMock(return_value="Memory content"),
|
||||
),
|
||||
message_ids=[],
|
||||
tool_ids=[],
|
||||
system="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_messages():
|
||||
"""Create mock messages for summarization."""
|
||||
agent_id = f"agent-{uuid.uuid4()}"
|
||||
messages = []
|
||||
for i in range(10):
|
||||
msg = Message(
|
||||
id=f"message-{uuid.uuid4()}",
|
||||
role=MessageRole.user if i % 2 == 0 else MessageRole.assistant,
|
||||
content=[{"type": "text", "text": f"Message content {i}"}],
|
||||
agent_id=agent_id,
|
||||
)
|
||||
messages.append(msg)
|
||||
return messages
|
||||
|
||||
|
||||
class TestSimpleSummaryTelemetryContext:
|
||||
"""Tests for simple_summary telemetry context passing."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_summary_accepts_telemetry_params(self, mock_messages, mock_llm_config, mock_actor):
|
||||
"""Verify simple_summary accepts all telemetry context parameters."""
|
||||
agent_id = f"agent-{uuid.uuid4()}"
|
||||
agent_tags = ["tag1", "tag2"]
|
||||
run_id = f"run-{uuid.uuid4()}"
|
||||
step_id = f"step-{uuid.uuid4()}"
|
||||
|
||||
with patch("letta.services.summarizer.summarizer.LLMClient") as mock_client_class:
|
||||
mock_client = MagicMock()
|
||||
mock_client.set_telemetry_context = MagicMock()
|
||||
mock_client.send_llm_request_async = AsyncMock(return_value=MagicMock(content="Summary of conversation"))
|
||||
mock_client_class.create.return_value = mock_client
|
||||
|
||||
try:
|
||||
await simple_summary(
|
||||
messages=mock_messages,
|
||||
llm_config=mock_llm_config,
|
||||
actor=mock_actor,
|
||||
agent_id=agent_id,
|
||||
agent_tags=agent_tags,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
mock_client.set_telemetry_context.assert_called_once()
|
||||
call_kwargs = mock_client.set_telemetry_context.call_args[1]
|
||||
assert call_kwargs["agent_id"] == agent_id
|
||||
assert call_kwargs["agent_tags"] == agent_tags
|
||||
assert call_kwargs["run_id"] == run_id
|
||||
assert call_kwargs["step_id"] == step_id
|
||||
assert call_kwargs["call_type"] == "summarization"
|
||||
|
||||
|
||||
class TestSummarizeAllTelemetryContext:
|
||||
"""Tests for summarize_all telemetry context passing."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_compaction_settings(self):
|
||||
"""Create mock compaction settings."""
|
||||
return CompactionSettings(model="openai/gpt-4o-mini")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_all_passes_telemetry_to_simple_summary(
|
||||
self, mock_messages, mock_llm_config, mock_actor, mock_compaction_settings
|
||||
):
|
||||
"""Verify summarize_all passes telemetry context to simple_summary."""
|
||||
agent_id = f"agent-{uuid.uuid4()}"
|
||||
agent_tags = ["env:prod", "team:core"]
|
||||
run_id = f"run-{uuid.uuid4()}"
|
||||
step_id = f"step-{uuid.uuid4()}"
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
async def capture_simple_summary(*args, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return "Mocked summary"
|
||||
|
||||
with patch.object(summarizer_all, "simple_summary", new=capture_simple_summary):
|
||||
await summarizer_all.summarize_all(
|
||||
actor=mock_actor,
|
||||
llm_config=mock_llm_config,
|
||||
summarizer_config=mock_compaction_settings,
|
||||
in_context_messages=mock_messages,
|
||||
agent_id=agent_id,
|
||||
agent_tags=agent_tags,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
|
||||
assert captured_kwargs.get("agent_id") == agent_id
|
||||
assert captured_kwargs.get("agent_tags") == agent_tags
|
||||
assert captured_kwargs.get("run_id") == run_id
|
||||
assert captured_kwargs.get("step_id") == step_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_all_without_telemetry_params(self, mock_messages, mock_llm_config, mock_actor, mock_compaction_settings):
|
||||
"""Verify summarize_all works without telemetry params (backwards compatible)."""
|
||||
captured_kwargs = {}
|
||||
|
||||
async def capture_simple_summary(*args, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return "Mocked summary"
|
||||
|
||||
with patch.object(summarizer_all, "simple_summary", new=capture_simple_summary):
|
||||
await summarizer_all.summarize_all(
|
||||
actor=mock_actor,
|
||||
llm_config=mock_llm_config,
|
||||
summarizer_config=mock_compaction_settings,
|
||||
in_context_messages=mock_messages,
|
||||
)
|
||||
|
||||
assert captured_kwargs.get("agent_id") is None
|
||||
assert captured_kwargs.get("agent_tags") is None
|
||||
assert captured_kwargs.get("run_id") is None
|
||||
assert captured_kwargs.get("step_id") is None
|
||||
|
||||
|
||||
class TestSummarizeSlidingWindowTelemetryContext:
|
||||
"""Tests for summarize_via_sliding_window telemetry context passing."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_compaction_settings(self):
|
||||
"""Create mock compaction settings."""
|
||||
return CompactionSettings(model="openai/gpt-4o-mini")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sliding_window_passes_telemetry_to_simple_summary(
|
||||
self, mock_messages, mock_llm_config, mock_actor, mock_compaction_settings
|
||||
):
|
||||
"""Verify summarize_via_sliding_window passes telemetry context to simple_summary."""
|
||||
agent_id = f"agent-{uuid.uuid4()}"
|
||||
agent_tags = ["version:v2"]
|
||||
run_id = f"run-{uuid.uuid4()}"
|
||||
step_id = f"step-{uuid.uuid4()}"
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
async def capture_simple_summary(*args, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return "Mocked summary"
|
||||
|
||||
with patch.object(summarizer_sliding_window, "simple_summary", new=capture_simple_summary):
|
||||
await summarizer_sliding_window.summarize_via_sliding_window(
|
||||
actor=mock_actor,
|
||||
llm_config=mock_llm_config,
|
||||
summarizer_config=mock_compaction_settings,
|
||||
in_context_messages=mock_messages,
|
||||
agent_id=agent_id,
|
||||
agent_tags=agent_tags,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
|
||||
assert captured_kwargs.get("agent_id") == agent_id
|
||||
assert captured_kwargs.get("agent_tags") == agent_tags
|
||||
assert captured_kwargs.get("run_id") == run_id
|
||||
assert captured_kwargs.get("step_id") == step_id
|
||||
|
||||
|
||||
class TestSummarizerClassTelemetryContext:
|
||||
"""Tests for Summarizer class telemetry context passing."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarizer_summarize_passes_runtime_telemetry(self, mock_messages, mock_actor):
|
||||
"""Verify Summarizer.summarize() passes runtime run_id/step_id to the underlying call."""
|
||||
run_id = f"run-{uuid.uuid4()}"
|
||||
step_id = f"step-{uuid.uuid4()}"
|
||||
agent_id = f"agent-{uuid.uuid4()}"
|
||||
|
||||
mock_agent_manager = MagicMock()
|
||||
mock_agent_manager.get_agent_by_id_async = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
llm_config=LLMConfig(
|
||||
model="gpt-4o-mini",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=8000,
|
||||
),
|
||||
tags=["test-tag"],
|
||||
)
|
||||
)
|
||||
|
||||
summarizer = Summarizer(
|
||||
mode=SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER,
|
||||
summarizer_agent=None,
|
||||
message_buffer_limit=100,
|
||||
message_buffer_min=10,
|
||||
partial_evict_summarizer_percentage=0.5,
|
||||
agent_manager=mock_agent_manager,
|
||||
message_manager=MagicMock(),
|
||||
actor=mock_actor,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
async def capture_simple_summary(*args, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return "Mocked summary"
|
||||
|
||||
with patch("letta.services.summarizer.summarizer.simple_summary", new=capture_simple_summary):
|
||||
try:
|
||||
await summarizer.summarize(
|
||||
in_context_messages=mock_messages,
|
||||
new_letta_messages=[],
|
||||
force=True,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if captured_kwargs:
|
||||
assert captured_kwargs.get("run_id") == run_id
|
||||
assert captured_kwargs.get("step_id") == step_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarizer_uses_constructor_telemetry_as_default(self, mock_messages, mock_actor):
|
||||
"""Verify Summarizer uses constructor run_id/step_id when not passed to summarize()."""
|
||||
constructor_run_id = f"run-{uuid.uuid4()}"
|
||||
constructor_step_id = f"step-{uuid.uuid4()}"
|
||||
agent_id = f"agent-{uuid.uuid4()}"
|
||||
|
||||
mock_agent_manager = MagicMock()
|
||||
mock_agent_manager.get_agent_by_id_async = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
llm_config=LLMConfig(
|
||||
model="gpt-4o-mini",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=8000,
|
||||
),
|
||||
tags=["test-tag"],
|
||||
)
|
||||
)
|
||||
|
||||
summarizer = Summarizer(
|
||||
mode=SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER,
|
||||
summarizer_agent=None,
|
||||
message_buffer_limit=100,
|
||||
message_buffer_min=10,
|
||||
partial_evict_summarizer_percentage=0.5,
|
||||
agent_manager=mock_agent_manager,
|
||||
message_manager=MagicMock(),
|
||||
actor=mock_actor,
|
||||
agent_id=agent_id,
|
||||
run_id=constructor_run_id,
|
||||
step_id=constructor_step_id,
|
||||
)
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
async def capture_simple_summary(*args, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return "Mocked summary"
|
||||
|
||||
with patch("letta.services.summarizer.summarizer.simple_summary", new=capture_simple_summary):
|
||||
try:
|
||||
await summarizer.summarize(
|
||||
in_context_messages=mock_messages,
|
||||
new_letta_messages=[],
|
||||
force=True,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if captured_kwargs:
|
||||
assert captured_kwargs.get("run_id") == constructor_run_id
|
||||
assert captured_kwargs.get("step_id") == constructor_step_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarizer_runtime_overrides_constructor_telemetry(self, mock_messages, mock_actor):
|
||||
"""Verify runtime run_id/step_id override constructor values."""
|
||||
constructor_run_id = f"run-constructor-{uuid.uuid4()}"
|
||||
constructor_step_id = f"step-constructor-{uuid.uuid4()}"
|
||||
runtime_run_id = f"run-runtime-{uuid.uuid4()}"
|
||||
runtime_step_id = f"step-runtime-{uuid.uuid4()}"
|
||||
agent_id = f"agent-{uuid.uuid4()}"
|
||||
|
||||
mock_agent_manager = MagicMock()
|
||||
mock_agent_manager.get_agent_by_id_async = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
llm_config=LLMConfig(
|
||||
model="gpt-4o-mini",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=8000,
|
||||
),
|
||||
tags=["test-tag"],
|
||||
)
|
||||
)
|
||||
|
||||
summarizer = Summarizer(
|
||||
mode=SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER,
|
||||
summarizer_agent=None,
|
||||
message_buffer_limit=100,
|
||||
message_buffer_min=10,
|
||||
partial_evict_summarizer_percentage=0.5,
|
||||
agent_manager=mock_agent_manager,
|
||||
message_manager=MagicMock(),
|
||||
actor=mock_actor,
|
||||
agent_id=agent_id,
|
||||
run_id=constructor_run_id,
|
||||
step_id=constructor_step_id,
|
||||
)
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
async def capture_simple_summary(*args, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return "Mocked summary"
|
||||
|
||||
with patch("letta.services.summarizer.summarizer.simple_summary", new=capture_simple_summary):
|
||||
try:
|
||||
await summarizer.summarize(
|
||||
in_context_messages=mock_messages,
|
||||
new_letta_messages=[],
|
||||
force=True,
|
||||
run_id=runtime_run_id,
|
||||
step_id=runtime_step_id,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if captured_kwargs:
|
||||
assert captured_kwargs.get("run_id") == runtime_run_id
|
||||
assert captured_kwargs.get("step_id") == runtime_step_id
|
||||
|
||||
|
||||
class TestLLMClientTelemetryContext:
|
||||
"""Tests for LLM client telemetry context setting."""
|
||||
|
||||
def test_llm_client_set_telemetry_context_method_exists(self):
|
||||
"""Verify LLMClient has set_telemetry_context method."""
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.create(
|
||||
provider_type="openai",
|
||||
put_inner_thoughts_first=True,
|
||||
)
|
||||
assert hasattr(client, "set_telemetry_context")
|
||||
|
||||
def test_llm_client_set_telemetry_context_accepts_all_params(self):
|
||||
"""Verify set_telemetry_context accepts all telemetry parameters."""
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.create(
|
||||
provider_type="openai",
|
||||
put_inner_thoughts_first=True,
|
||||
)
|
||||
|
||||
agent_id = f"agent-{uuid.uuid4()}"
|
||||
agent_tags = ["tag1", "tag2"]
|
||||
run_id = f"run-{uuid.uuid4()}"
|
||||
step_id = f"step-{uuid.uuid4()}"
|
||||
call_type = "summarization"
|
||||
|
||||
client.set_telemetry_context(
|
||||
agent_id=agent_id,
|
||||
agent_tags=agent_tags,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
call_type=call_type,
|
||||
)
|
||||
Reference in New Issue
Block a user