test: update provider trace tests to use sdk client (#2866)
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
@@ -9,19 +8,10 @@ import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta
|
||||
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.step_manager import StepManager
|
||||
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
|
||||
|
||||
|
||||
def _run_server():
|
||||
@@ -105,105 +95,32 @@ def agent_state(client, roll_dice_tool, weather_tool):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
|
||||
async def test_provider_trace_experimental_step(message, agent_state, default_user):
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_state.id,
|
||||
message_manager=MessageManager(),
|
||||
agent_manager=AgentManager(),
|
||||
block_manager=BlockManager(),
|
||||
job_manager=JobManager(),
|
||||
passage_manager=PassageManager(),
|
||||
step_manager=StepManager(),
|
||||
telemetry_manager=TelemetryManager(),
|
||||
actor=default_user,
|
||||
async def test_provider_trace_experimental_step(client, message, agent_state):
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id, messages=[MessageCreate(role="user", content=[TextContent(text=message)])]
|
||||
)
|
||||
|
||||
response = await experimental_agent.step([MessageCreate(role="user", content=[TextContent(text=message)])])
|
||||
tool_step = response.messages[0].step_id
|
||||
reply_step = response.messages[-1].step_id
|
||||
|
||||
tool_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user)
|
||||
reply_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user)
|
||||
tool_telemetry = client.telemetry.retrieve_provider_trace(step_id=tool_step)
|
||||
reply_telemetry = client.telemetry.retrieve_provider_trace(step_id=reply_step)
|
||||
assert tool_telemetry.request_json
|
||||
assert reply_telemetry.request_json
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
|
||||
async def test_provider_trace_experimental_step_stream(message, agent_state, default_user):
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_state.id,
|
||||
message_manager=MessageManager(),
|
||||
agent_manager=AgentManager(),
|
||||
block_manager=BlockManager(),
|
||||
job_manager=JobManager(),
|
||||
passage_manager=PassageManager(),
|
||||
step_manager=StepManager(),
|
||||
telemetry_manager=TelemetryManager(),
|
||||
actor=default_user,
|
||||
)
|
||||
stream = experimental_agent.step_stream([MessageCreate(role="user", content=[TextContent(text=message)])])
|
||||
|
||||
result = StreamingResponseWithStatusCode(
|
||||
stream,
|
||||
media_type="text/event-stream",
|
||||
async def test_provider_trace_experimental_step_stream(client, message, agent_state):
|
||||
last_message_id = client.agents.messages.list(agent_id=agent_state.id, limit=1)[0]
|
||||
stream = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id, messages=[MessageCreate(role="user", content=[TextContent(text=message)])]
|
||||
)
|
||||
|
||||
message_id = None
|
||||
list(stream)
|
||||
|
||||
async def test_send(message) -> None:
|
||||
nonlocal message_id
|
||||
if "body" in message and not message_id:
|
||||
body = message["body"].decode("utf-8").split("data:")
|
||||
message_id = json.loads(body[1])["id"]
|
||||
|
||||
await result.stream_response(send=test_send)
|
||||
|
||||
messages = await experimental_agent.message_manager.get_messages_by_ids_async([message_id], actor=default_user)
|
||||
step_ids = set((message.step_id for message in messages))
|
||||
messages = client.agents.messages.list(agent_id=agent_state.id, after=last_message_id)
|
||||
step_ids = [id for id in set((message.step_id for message in messages)) if id is not None]
|
||||
for step_id in step_ids:
|
||||
telemetry_data = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=step_id, actor=default_user)
|
||||
telemetry_data = client.telemetry.retrieve_provider_trace(step_id=step_id)
|
||||
assert telemetry_data.request_json
|
||||
assert telemetry_data.response_json
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
|
||||
async def test_provider_trace_step(client, agent_state, default_user, message):
|
||||
client.agents.messages.create(agent_id=agent_state.id, messages=[])
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=[TextContent(text=message)])],
|
||||
)
|
||||
tool_step = response.messages[0].step_id
|
||||
reply_step = response.messages[-1].step_id
|
||||
|
||||
tool_telemetry = await TelemetryManager().get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user)
|
||||
reply_telemetry = await TelemetryManager().get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user)
|
||||
assert tool_telemetry.request_json
|
||||
assert reply_telemetry.request_json
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
|
||||
async def test_noop_provider_trace(message, agent_state, default_user):
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_state.id,
|
||||
message_manager=MessageManager(),
|
||||
agent_manager=AgentManager(),
|
||||
block_manager=BlockManager(),
|
||||
job_manager=JobManager(),
|
||||
passage_manager=PassageManager(),
|
||||
step_manager=StepManager(),
|
||||
telemetry_manager=NoopTelemetryManager(),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
response = await experimental_agent.step([MessageCreate(role="user", content=[TextContent(text=message)])])
|
||||
tool_step = response.messages[0].step_id
|
||||
reply_step = response.messages[-1].step_id
|
||||
|
||||
tool_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user)
|
||||
reply_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user)
|
||||
assert tool_telemetry is None
|
||||
assert reply_telemetry is None
|
||||
|
||||
Reference in New Issue
Block a user