From baa72e43c35e852c793dcb6cb49cdff90870862e Mon Sep 17 00:00:00 2001 From: cthomas Date: Fri, 12 Sep 2025 19:12:31 -0700 Subject: [PATCH] test: update provider trace tests to use sdk client (#2866) --- tests/test_provider_trace.py | 109 +++++------------------------------ 1 file changed, 13 insertions(+), 96 deletions(-) diff --git a/tests/test_provider_trace.py b/tests/test_provider_trace.py index 871c9ba1..256d95ad 100644 --- a/tests/test_provider_trace.py +++ b/tests/test_provider_trace.py @@ -1,5 +1,4 @@ import asyncio -import json import os import threading import time @@ -9,19 +8,10 @@ import pytest from dotenv import load_dotenv from letta_client import Letta -from letta.agents.letta_agent import LettaAgent from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.letta_message_content import TextContent from letta.schemas.llm_config import LLMConfig from letta.schemas.message import MessageCreate -from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode -from letta.services.agent_manager import AgentManager -from letta.services.block_manager import BlockManager -from letta.services.job_manager import JobManager -from letta.services.message_manager import MessageManager -from letta.services.passage_manager import PassageManager -from letta.services.step_manager import StepManager -from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager def _run_server(): @@ -105,105 +95,32 @@ def agent_state(client, roll_dice_tool, weather_tool): @pytest.mark.asyncio @pytest.mark.parametrize("message", ["Get the weather in San Francisco."]) -async def test_provider_trace_experimental_step(message, agent_state, default_user): - experimental_agent = LettaAgent( - agent_id=agent_state.id, - message_manager=MessageManager(), - agent_manager=AgentManager(), - block_manager=BlockManager(), - job_manager=JobManager(), - passage_manager=PassageManager(), - step_manager=StepManager(), - telemetry_manager=TelemetryManager(), - actor=default_user, +async def test_provider_trace_experimental_step(client, message, agent_state): + response = client.agents.messages.create( + agent_id=agent_state.id, messages=[MessageCreate(role="user", content=[TextContent(text=message)])] ) - - response = await experimental_agent.step([MessageCreate(role="user", content=[TextContent(text=message)])]) tool_step = response.messages[0].step_id reply_step = response.messages[-1].step_id - tool_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user) - reply_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user) + tool_telemetry = client.telemetry.retrieve_provider_trace(step_id=tool_step) + reply_telemetry = client.telemetry.retrieve_provider_trace(step_id=reply_step) assert tool_telemetry.request_json assert reply_telemetry.request_json @pytest.mark.asyncio @pytest.mark.parametrize("message", ["Get the weather in San Francisco."]) -async def test_provider_trace_experimental_step_stream(message, agent_state, default_user): - experimental_agent = LettaAgent( - agent_id=agent_state.id, - message_manager=MessageManager(), - agent_manager=AgentManager(), - block_manager=BlockManager(), - job_manager=JobManager(), - passage_manager=PassageManager(), - step_manager=StepManager(), - telemetry_manager=TelemetryManager(), - actor=default_user, - ) - stream = experimental_agent.step_stream([MessageCreate(role="user", content=[TextContent(text=message)])]) - - result = StreamingResponseWithStatusCode( - stream, - media_type="text/event-stream", +async def test_provider_trace_experimental_step_stream(client, message, agent_state): + last_message_id = client.agents.messages.list(agent_id=agent_state.id, limit=1)[0] + stream = client.agents.messages.create_stream( + agent_id=agent_state.id, messages=[MessageCreate(role="user", content=[TextContent(text=message)])] ) - message_id = None + list(stream) - async def test_send(message) -> None: - nonlocal message_id - if "body" in message and not message_id: - body = message["body"].decode("utf-8").split("data:") - message_id = json.loads(body[1])["id"] - - await result.stream_response(send=test_send) - - messages = await experimental_agent.message_manager.get_messages_by_ids_async([message_id], actor=default_user) - step_ids = set((message.step_id for message in messages)) + messages = client.agents.messages.list(agent_id=agent_state.id, after=last_message_id) + step_ids = [id for id in set((message.step_id for message in messages)) if id is not None] for step_id in step_ids: - telemetry_data = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=step_id, actor=default_user) + telemetry_data = client.telemetry.retrieve_provider_trace(step_id=step_id) assert telemetry_data.request_json assert telemetry_data.response_json - - -@pytest.mark.asyncio -@pytest.mark.parametrize("message", ["Get the weather in San Francisco."]) -async def test_provider_trace_step(client, agent_state, default_user, message): - client.agents.messages.create(agent_id=agent_state.id, messages=[]) - response = client.agents.messages.create( - agent_id=agent_state.id, - messages=[MessageCreate(role="user", content=[TextContent(text=message)])], - ) - tool_step = response.messages[0].step_id - reply_step = response.messages[-1].step_id - - tool_telemetry = await TelemetryManager().get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user) - reply_telemetry = await TelemetryManager().get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user) - assert tool_telemetry.request_json - assert reply_telemetry.request_json - - -@pytest.mark.asyncio -@pytest.mark.parametrize("message", ["Get the weather in San Francisco."]) -async def test_noop_provider_trace(message, agent_state, default_user): - experimental_agent = LettaAgent( - agent_id=agent_state.id, - message_manager=MessageManager(), - agent_manager=AgentManager(), - block_manager=BlockManager(), - job_manager=JobManager(), - passage_manager=PassageManager(), - step_manager=StepManager(), - telemetry_manager=NoopTelemetryManager(), - actor=default_user, - ) - - response = await experimental_agent.step([MessageCreate(role="user", content=[TextContent(text=message)])]) - tool_step = response.messages[0].step_id - reply_step = response.messages[-1].step_id - - tool_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user) - reply_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user) - assert tool_telemetry is None - assert reply_telemetry is None