diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index e21ac9ac..f317bd81 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -1034,6 +1034,7 @@ class LettaAgent(BaseAgent): llm_client, tool_rules_solver, run_id=run_id, + step_id=step_id, ) step_progression = StepProgression.STREAM_RECEIVED @@ -1470,6 +1471,7 @@ class LettaAgent(BaseAgent): agent_id=self.agent_id, agent_tags=agent_state.tags, run_id=self.current_run_id, + step_id=step_metrics.id, call_type="agent_step", ) response = await llm_client.request_async_with_telemetry(request_data, agent_state.llm_config) @@ -1514,6 +1516,7 @@ class LettaAgent(BaseAgent): llm_client: LLMClientBase, tool_rules_solver: ToolRulesSolver, run_id: str | None = None, + step_id: str | None = None, ) -> tuple[dict, AsyncStream[ChatCompletionChunk], list[Message], list[Message], list[str], int] | None: for attempt in range(self.max_summarization_retries + 1): try: @@ -1541,6 +1544,7 @@ class LettaAgent(BaseAgent): agent_id=self.agent_id, agent_tags=agent_state.tags, run_id=self.current_run_id, + step_id=step_id, call_type="agent_step", ) diff --git a/tests/test_provider_trace.py b/tests/test_provider_trace.py index 256d95ad..d2fc4f47 100644 --- a/tests/test_provider_trace.py +++ b/tests/test_provider_trace.py @@ -1,8 +1,20 @@ +""" +Comprehensive tests for provider trace telemetry. + +Tests verify that provider traces are correctly created with all telemetry context +(agent_id, agent_tags, run_id, step_id, call_type) across: +- Agent steps (non-streaming and streaming) +- Tool calls +- Summarization calls +- Different agent architectures (V2, V3) +""" + import asyncio import os import threading import time import uuid +from unittest.mock import patch import pytest from dotenv import load_dotenv @@ -30,12 +42,11 @@ def server_url(): if not os.getenv("LETTA_SERVER_URL"): thread = threading.Thread(target=_run_server, daemon=True) thread.start() - time.sleep(5) # Allow server startup time + time.sleep(5) return url -# # --- Client Setup --- # @pytest.fixture(scope="session") def client(server_url): """Creates a REST client for testing.""" @@ -53,38 +64,33 @@ def event_loop(request): @pytest.fixture(scope="function") def roll_dice_tool(client, roll_dice_tool_func): - print_tool = client.tools.upsert_from_function(func=roll_dice_tool_func) - yield print_tool + tool = client.tools.upsert_from_function(func=roll_dice_tool_func) + yield tool @pytest.fixture(scope="function") def weather_tool(client, weather_tool_func): - weather_tool = client.tools.upsert_from_function(func=weather_tool_func) - yield weather_tool + tool = client.tools.upsert_from_function(func=weather_tool_func) + yield tool @pytest.fixture(scope="function") def print_tool(client, print_tool_func): - print_tool = client.tools.upsert_from_function(func=print_tool_func) - yield print_tool + tool = client.tools.upsert_from_function(func=print_tool_func) + yield tool @pytest.fixture(scope="function") def agent_state(client, roll_dice_tool, weather_tool): - """Creates an agent and ensures cleanup after tests.""" + """Creates an agent with tools and ensures cleanup after tests.""" agent_state = client.agents.create( - name=f"test_compl_{str(uuid.uuid4())[5:]}", + name=f"test_provider_trace_{str(uuid.uuid4())[:8]}", tool_ids=[roll_dice_tool.id, weather_tool.id], include_base_tools=True, + tags=["test", "provider-trace"], memory_blocks=[ - { - "label": "human", - "value": "Name: Matt", - }, - { - "label": "persona", - "value": "Friendly agent", - }, + {"label": "human", "value": "Name: TestUser"}, + {"label": "persona", "value": "Helpful test agent"}, ], llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), embedding_config=EmbeddingConfig.default_config(provider="openai"), @@ -93,34 +99,305 @@ def agent_state(client, roll_dice_tool, weather_tool): client.agents.delete(agent_state.id) -@pytest.mark.asyncio -@pytest.mark.parametrize("message", ["Get the weather in San Francisco."]) -async def test_provider_trace_experimental_step(client, message, agent_state): - response = client.agents.messages.create( - agent_id=agent_state.id, messages=[MessageCreate(role="user", content=[TextContent(text=message)])] +@pytest.fixture(scope="function") +def agent_state_with_tags(client, weather_tool): + """Creates an agent with specific tags for tag verification tests.""" + agent_state = client.agents.create( + name=f"test_tagged_agent_{str(uuid.uuid4())[:8]}", + tool_ids=[weather_tool.id], + include_base_tools=True, + tags=["env:test", "team:telemetry", "version:v1"], + memory_blocks=[ + {"label": "human", "value": "Name: TagTestUser"}, + {"label": "persona", "value": "Agent with tags"}, + ], + llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), ) - tool_step = response.messages[0].step_id - reply_step = response.messages[-1].step_id - - tool_telemetry = client.telemetry.retrieve_provider_trace(step_id=tool_step) - reply_telemetry = client.telemetry.retrieve_provider_trace(step_id=reply_step) - assert tool_telemetry.request_json - assert reply_telemetry.request_json + yield agent_state + client.agents.delete(agent_state.id) -@pytest.mark.asyncio -@pytest.mark.parametrize("message", ["Get the weather in San Francisco."]) -async def test_provider_trace_experimental_step_stream(client, message, agent_state): - last_message_id = client.agents.messages.list(agent_id=agent_state.id, limit=1)[0] - stream = client.agents.messages.create_stream( - agent_id=agent_state.id, messages=[MessageCreate(role="user", content=[TextContent(text=message)])] - ) +class TestProviderTraceBasicStep: + """Tests for basic agent step provider traces.""" - list(stream) + @pytest.mark.asyncio + async def test_non_streaming_step_creates_provider_trace(self, client, agent_state): + """Verify provider trace is created for non-streaming agent step.""" + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[MessageCreate(role="user", content=[TextContent(text="Hello, how are you?")])], + ) - messages = client.agents.messages.list(agent_id=agent_state.id, after=last_message_id) - step_ids = [id for id in set((message.step_id for message in messages)) if id is not None] - for step_id in step_ids: - telemetry_data = client.telemetry.retrieve_provider_trace(step_id=step_id) - assert telemetry_data.request_json - assert telemetry_data.response_json + assert len(response.messages) > 0 + step_id = response.messages[-1].step_id + assert step_id is not None + + trace = client.telemetry.retrieve_provider_trace(step_id=step_id) + assert trace is not None + assert trace.request_json is not None + assert trace.response_json is not None + + @pytest.mark.asyncio + async def test_streaming_step_creates_provider_trace(self, client, agent_state): + """Verify provider trace is created for streaming agent step.""" + last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)[0] + + stream = client.agents.messages.create_stream( + agent_id=agent_state.id, + messages=[MessageCreate(role="user", content=[TextContent(text="Tell me a joke.")])], + ) + list(stream) + + messages = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id) + step_ids = list({msg.step_id for msg in messages if msg.step_id is not None}) + + assert len(step_ids) > 0 + for step_id in step_ids: + trace = client.telemetry.retrieve_provider_trace(step_id=step_id) + assert trace is not None + assert trace.request_json is not None + assert trace.response_json is not None + + +class TestProviderTraceWithToolCalls: + """Tests for provider traces when tools are called.""" + + @pytest.mark.asyncio + async def test_tool_call_step_has_provider_trace(self, client, agent_state): + """Verify provider trace exists for steps that invoke tools.""" + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[MessageCreate(role="user", content=[TextContent(text="Get the weather in San Francisco.")])], + ) + + tool_call_step_id = response.messages[0].step_id + final_step_id = response.messages[-1].step_id + + tool_trace = client.telemetry.retrieve_provider_trace(step_id=tool_call_step_id) + assert tool_trace is not None + assert tool_trace.request_json is not None + + if tool_call_step_id != final_step_id: + final_trace = client.telemetry.retrieve_provider_trace(step_id=final_step_id) + assert final_trace is not None + assert final_trace.request_json is not None + + @pytest.mark.asyncio + async def test_streaming_tool_call_has_provider_trace(self, client, agent_state): + """Verify provider trace exists for streaming steps with tool calls.""" + last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)[0] + + stream = client.agents.messages.create_stream( + agent_id=agent_state.id, + messages=[MessageCreate(role="user", content=[TextContent(text="Roll the dice for me.")])], + ) + list(stream) + + messages = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id) + step_ids = list({msg.step_id for msg in messages if msg.step_id is not None}) + + assert len(step_ids) > 0 + for step_id in step_ids: + trace = client.telemetry.retrieve_provider_trace(step_id=step_id) + assert trace is not None + assert trace.request_json is not None + + +class TestProviderTraceTelemetryContext: + """Tests verifying telemetry context fields are correctly populated.""" + + @pytest.mark.asyncio + async def test_provider_trace_contains_agent_id(self, client, agent_state): + """Verify provider trace contains the correct agent_id.""" + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])], + ) + + step_id = response.messages[-1].step_id + trace = client.telemetry.retrieve_provider_trace(step_id=step_id) + + assert trace is not None + assert trace.agent_id == agent_state.id + + @pytest.mark.asyncio + async def test_provider_trace_contains_agent_tags(self, client, agent_state_with_tags): + """Verify provider trace contains the agent's tags.""" + response = client.agents.messages.create( + agent_id=agent_state_with_tags.id, + messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])], + ) + + step_id = response.messages[-1].step_id + trace = client.telemetry.retrieve_provider_trace(step_id=step_id) + + assert trace is not None + assert trace.agent_tags is not None + assert set(trace.agent_tags) == {"env:test", "team:telemetry", "version:v1"} + + @pytest.mark.asyncio + async def test_provider_trace_contains_step_id(self, client, agent_state): + """Verify provider trace step_id matches the message step_id.""" + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])], + ) + + step_id = response.messages[-1].step_id + trace = client.telemetry.retrieve_provider_trace(step_id=step_id) + + assert trace is not None + assert trace.step_id == step_id + + @pytest.mark.asyncio + async def test_provider_trace_contains_run_id_for_async_job(self, client, agent_state): + """Verify provider trace contains run_id when created via async job.""" + job = client.agents.messages.create_async( + agent_id=agent_state.id, + messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])], + ) + + while job.status not in ["completed", "failed"]: + time.sleep(0.5) + job = client.jobs.retrieve(job.id) + + assert job.status == "completed" + + messages = client.agents.messages.list(agent_id=agent_state.id, limit=5) + step_ids = list({msg.step_id for msg in messages if msg.step_id is not None}) + + assert len(step_ids) > 0 + trace = client.telemetry.retrieve_provider_trace(step_id=step_ids[0]) + assert trace is not None + assert trace.run_id == job.id + + +class TestProviderTraceMultiStep: + """Tests for provider traces across multiple agent steps.""" + + @pytest.mark.asyncio + async def test_multi_step_conversation_has_traces_for_each_step(self, client, agent_state): + """Verify each step in a multi-step conversation has its own provider trace.""" + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[ + MessageCreate( + role="user", + content=[TextContent(text="First, get the weather in NYC. Then roll the dice.")], + ) + ], + ) + + step_ids = list({msg.step_id for msg in response.messages if msg.step_id is not None}) + + assert len(step_ids) >= 1 + + for step_id in step_ids: + trace = client.telemetry.retrieve_provider_trace(step_id=step_id) + assert trace is not None, f"No trace found for step_id={step_id}" + assert trace.request_json is not None + assert trace.agent_id == agent_state.id + + @pytest.mark.asyncio + async def test_consecutive_messages_have_separate_traces(self, client, agent_state): + """Verify consecutive messages create separate traces.""" + response1 = client.agents.messages.create( + agent_id=agent_state.id, + messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])], + ) + step_id_1 = response1.messages[-1].step_id + + response2 = client.agents.messages.create( + agent_id=agent_state.id, + messages=[MessageCreate(role="user", content=[TextContent(text="How are you?")])], + ) + step_id_2 = response2.messages[-1].step_id + + assert step_id_1 != step_id_2 + + trace1 = client.telemetry.retrieve_provider_trace(step_id=step_id_1) + trace2 = client.telemetry.retrieve_provider_trace(step_id=step_id_2) + + assert trace1 is not None + assert trace2 is not None + assert trace1.id != trace2.id + + +class TestProviderTraceRequestResponseContent: + """Tests verifying request and response JSON content.""" + + @pytest.mark.asyncio + async def test_request_json_contains_model(self, client, agent_state): + """Verify request_json contains model information.""" + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])], + ) + + step_id = response.messages[-1].step_id + trace = client.telemetry.retrieve_provider_trace(step_id=step_id) + + assert trace is not None + assert trace.request_json is not None + assert "model" in trace.request_json + + @pytest.mark.asyncio + async def test_request_json_contains_messages(self, client, agent_state): + """Verify request_json contains messages array.""" + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])], + ) + + step_id = response.messages[-1].step_id + trace = client.telemetry.retrieve_provider_trace(step_id=step_id) + + assert trace is not None + assert trace.request_json is not None + assert "messages" in trace.request_json + assert isinstance(trace.request_json["messages"], list) + + @pytest.mark.asyncio + async def test_response_json_contains_usage(self, client, agent_state): + """Verify response_json contains usage statistics.""" + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[MessageCreate(role="user", content=[TextContent(text="Hello")])], + ) + + step_id = response.messages[-1].step_id + trace = client.telemetry.retrieve_provider_trace(step_id=step_id) + + assert trace is not None + assert trace.response_json is not None + assert "usage" in trace.response_json or "usage" in str(trace.response_json) + + +class TestProviderTraceEdgeCases: + """Tests for edge cases and error scenarios.""" + + @pytest.mark.asyncio + async def test_nonexistent_step_id_returns_none_or_empty(self, client): + """Verify querying nonexistent step_id handles gracefully.""" + fake_step_id = f"step-{uuid.uuid4()}" + + try: + trace = client.telemetry.retrieve_provider_trace(step_id=fake_step_id) + assert trace is None or trace.request_json is None + except Exception: + pass + + @pytest.mark.asyncio + async def test_empty_message_still_creates_trace(self, client, agent_state): + """Verify trace is created even for minimal messages.""" + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[MessageCreate(role="user", content=[TextContent(text="Hi")])], + ) + + step_id = response.messages[-1].step_id + assert step_id is not None + + trace = client.telemetry.retrieve_provider_trace(step_id=step_id) + assert trace is not None diff --git a/tests/test_provider_trace_agents.py b/tests/test_provider_trace_agents.py new file mode 100644 index 00000000..830d776c --- /dev/null +++ b/tests/test_provider_trace_agents.py @@ -0,0 +1,408 @@ +""" +Unit tests for provider trace telemetry across agent versions and adapters. + +Tests verify that telemetry context is correctly passed through: +- Tool generation endpoint +- LettaAgent (v1), LettaAgentV2, LettaAgentV3 +- Streaming and non-streaming paths +- Different stream adapters +""" + +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from letta.schemas.llm_config import LLMConfig + + +@pytest.fixture +def mock_llm_config(): + """Create a mock LLM config.""" + return LLMConfig( + model="gpt-4o-mini", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=8000, + ) + + +class TestToolGenerationTelemetry: + """Tests for tool generation endpoint telemetry.""" + + @pytest.mark.asyncio + async def test_generate_tool_sets_call_type(self, mock_llm_config): + """Verify generate_tool endpoint sets call_type='tool_generation'.""" + from letta.llm_api.llm_client import LLMClient + from letta.schemas.user import User + + mock_actor = User( + id=f"user-{uuid.uuid4()}", + organization_id=f"org-{uuid.uuid4()}", + name="test_user", + ) + + captured_telemetry = {} + + def capture_telemetry(**kwargs): + captured_telemetry.update(kwargs) + + with patch.object(LLMClient, "create") as mock_create: + mock_client = MagicMock() + mock_client.set_telemetry_context = capture_telemetry + mock_client.build_request_data = MagicMock(return_value={}) + mock_client.request_async_with_telemetry = AsyncMock(return_value={}) + mock_client.convert_response_to_chat_completion = AsyncMock( + return_value=MagicMock( + choices=[ + MagicMock( + message=MagicMock( + tool_calls=[ + MagicMock( + function=MagicMock( + arguments='{"raw_source_code": "def test(): pass", "sample_args_json": "{}", "pip_requirements_json": "{}"}' + ) + ) + ], + content=None, + ) + ) + ] + ) + ) + mock_create.return_value = mock_client + + from letta.server.rest_api.routers.v1.tools import GenerateToolInput, generate_tool_from_prompt + + mock_server = MagicMock() + mock_server.user_manager.get_actor_or_default_async = AsyncMock(return_value=mock_actor) + mock_server.get_llm_config_from_handle_async = AsyncMock(return_value=mock_llm_config) + + mock_headers = MagicMock() + mock_headers.actor_id = mock_actor.id + + request = GenerateToolInput( + prompt="Create a function that adds two numbers", + tool_name="add_numbers", + validation_errors=[], + ) + + with patch("letta.server.rest_api.routers.v1.tools.derive_openai_json_schema") as mock_schema: + mock_schema.return_value = {"name": "add_numbers", "parameters": {}} + try: + await generate_tool_from_prompt(request=request, server=mock_server, headers=mock_headers) + except Exception: + pass + + assert captured_telemetry.get("call_type") == "tool_generation" + + @pytest.mark.asyncio + async def test_generate_tool_has_no_agent_context(self, mock_llm_config): + """Verify generate_tool doesn't have agent_id since it's not agent-bound.""" + from letta.llm_api.llm_client import LLMClient + from letta.schemas.user import User + + mock_actor = User( + id=f"user-{uuid.uuid4()}", + organization_id=f"org-{uuid.uuid4()}", + name="test_user", + ) + + captured_telemetry = {} + + def capture_telemetry(**kwargs): + captured_telemetry.update(kwargs) + + with patch.object(LLMClient, "create") as mock_create: + mock_client = MagicMock() + mock_client.set_telemetry_context = capture_telemetry + mock_client.build_request_data = MagicMock(return_value={}) + mock_client.request_async_with_telemetry = AsyncMock(return_value={}) + mock_client.convert_response_to_chat_completion = AsyncMock( + return_value=MagicMock( + choices=[ + MagicMock( + message=MagicMock( + tool_calls=[ + MagicMock( + function=MagicMock( + arguments='{"raw_source_code": "def test(): pass", "sample_args_json": "{}", "pip_requirements_json": "{}"}' + ) + ) + ], + content=None, + ) + ) + ] + ) + ) + mock_create.return_value = mock_client + + from letta.server.rest_api.routers.v1.tools import GenerateToolInput, generate_tool_from_prompt + + mock_server = MagicMock() + mock_server.user_manager.get_actor_or_default_async = AsyncMock(return_value=mock_actor) + mock_server.get_llm_config_from_handle_async = AsyncMock(return_value=mock_llm_config) + + mock_headers = MagicMock() + mock_headers.actor_id = mock_actor.id + + request = GenerateToolInput( + prompt="Create a function", + tool_name="test_func", + validation_errors=[], + ) + + with patch("letta.server.rest_api.routers.v1.tools.derive_openai_json_schema") as mock_schema: + mock_schema.return_value = {"name": "test_func", "parameters": {}} + try: + await generate_tool_from_prompt(request=request, server=mock_server, headers=mock_headers) + except Exception: + pass + + assert captured_telemetry.get("agent_id") is None + assert captured_telemetry.get("step_id") is None + assert captured_telemetry.get("run_id") is None + + +class TestLLMClientTelemetryContext: + """Tests for LLMClient telemetry context methods.""" + + def test_llm_client_has_set_telemetry_context_method(self): + """Verify LLMClient exposes set_telemetry_context.""" + from letta.llm_api.llm_client import LLMClient + + client = LLMClient.create(provider_type="openai", put_inner_thoughts_first=True) + assert hasattr(client, "set_telemetry_context") + assert callable(client.set_telemetry_context) + + def test_llm_client_set_telemetry_context_accepts_all_fields(self): + """Verify set_telemetry_context accepts all telemetry fields.""" + from letta.llm_api.llm_client import LLMClient + + client = LLMClient.create(provider_type="openai", put_inner_thoughts_first=True) + + client.set_telemetry_context( + agent_id=f"agent-{uuid.uuid4()}", + agent_tags=["tag1", "tag2"], + run_id=f"run-{uuid.uuid4()}", + step_id=f"step-{uuid.uuid4()}", + call_type="summarization", + ) + + +class TestAdapterTelemetryAttributes: + """Tests for adapter telemetry attribute support.""" + + def test_base_adapter_has_telemetry_attributes(self, mock_llm_config): + """Verify base LettaLLMAdapter has telemetry attributes.""" + from letta.adapters.letta_llm_adapter import LettaLLMAdapter + from letta.llm_api.llm_client import LLMClient + + mock_client = LLMClient.create(provider_type="openai", put_inner_thoughts_first=True) + + agent_id = f"agent-{uuid.uuid4()}" + agent_tags = ["test-tag"] + run_id = f"run-{uuid.uuid4()}" + + class TestAdapter(LettaLLMAdapter): + async def invoke_llm(self, *args, **kwargs): + pass + + adapter = TestAdapter( + llm_client=mock_client, + llm_config=mock_llm_config, + agent_id=agent_id, + agent_tags=agent_tags, + run_id=run_id, + ) + + assert adapter.agent_id == agent_id + assert adapter.agent_tags == agent_tags + assert adapter.run_id == run_id + + def test_request_adapter_inherits_telemetry_attributes(self, mock_llm_config): + """Verify LettaLLMRequestAdapter inherits telemetry attributes.""" + from letta.adapters.letta_llm_request_adapter import LettaLLMRequestAdapter + from letta.llm_api.llm_client import LLMClient + + mock_client = LLMClient.create(provider_type="openai", put_inner_thoughts_first=True) + + agent_id = f"agent-{uuid.uuid4()}" + agent_tags = ["request-tag"] + run_id = f"run-{uuid.uuid4()}" + + adapter = LettaLLMRequestAdapter( + llm_client=mock_client, + llm_config=mock_llm_config, + agent_id=agent_id, + agent_tags=agent_tags, + run_id=run_id, + ) + + assert adapter.agent_id == agent_id + assert adapter.agent_tags == agent_tags + assert adapter.run_id == run_id + + def test_stream_adapter_inherits_telemetry_attributes(self, mock_llm_config): + """Verify LettaLLMStreamAdapter inherits telemetry attributes.""" + from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter + from letta.llm_api.llm_client import LLMClient + + mock_client = LLMClient.create(provider_type="openai", put_inner_thoughts_first=True) + + agent_id = f"agent-{uuid.uuid4()}" + agent_tags = ["stream-tag"] + run_id = f"run-{uuid.uuid4()}" + + adapter = LettaLLMStreamAdapter( + llm_client=mock_client, + llm_config=mock_llm_config, + agent_id=agent_id, + agent_tags=agent_tags, + run_id=run_id, + ) + + assert adapter.agent_id == agent_id + assert adapter.agent_tags == agent_tags + assert adapter.run_id == run_id + + def test_request_and_stream_adapters_have_consistent_interface(self, mock_llm_config): + """Verify both adapter types have the same telemetry interface.""" + from letta.adapters.letta_llm_request_adapter import LettaLLMRequestAdapter + from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter + from letta.llm_api.llm_client import LLMClient + + mock_client = LLMClient.create(provider_type="openai", put_inner_thoughts_first=True) + + request_adapter = LettaLLMRequestAdapter(llm_client=mock_client, llm_config=mock_llm_config) + stream_adapter = LettaLLMStreamAdapter(llm_client=mock_client, llm_config=mock_llm_config) + + for attr in ["agent_id", "agent_tags", "run_id"]: + assert hasattr(request_adapter, attr), f"LettaLLMRequestAdapter missing {attr}" + assert hasattr(stream_adapter, attr), f"LettaLLMStreamAdapter missing {attr}" + + +class TestSummarizerTelemetry: + """Tests for Summarizer class telemetry context.""" + + def test_summarizer_stores_telemetry_context(self): + """Verify Summarizer stores telemetry context from constructor.""" + from letta.schemas.user import User + from letta.services.summarizer.enums import SummarizationMode + from letta.services.summarizer.summarizer import Summarizer + + mock_actor = User( + id=f"user-{uuid.uuid4()}", + organization_id=f"org-{uuid.uuid4()}", + name="test_user", + ) + + agent_id = f"agent-{uuid.uuid4()}" + run_id = f"run-{uuid.uuid4()}" + step_id = f"step-{uuid.uuid4()}" + + summarizer = Summarizer( + mode=SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER, + summarizer_agent=None, + message_buffer_limit=100, + message_buffer_min=10, + partial_evict_summarizer_percentage=0.5, + agent_manager=MagicMock(), + message_manager=MagicMock(), + actor=mock_actor, + agent_id=agent_id, + run_id=run_id, + step_id=step_id, + ) + + assert summarizer.agent_id == agent_id + assert summarizer.run_id == run_id + assert summarizer.step_id == step_id + + @pytest.mark.asyncio + async def test_summarize_method_accepts_runtime_telemetry(self): + """Verify summarize() method accepts runtime run_id/step_id.""" + from letta.schemas.enums import MessageRole + from letta.schemas.message import Message + from letta.schemas.user import User + from letta.services.summarizer.enums import SummarizationMode + from letta.services.summarizer.summarizer import Summarizer + + mock_actor = User( + id=f"user-{uuid.uuid4()}", + organization_id=f"org-{uuid.uuid4()}", + name="test_user", + ) + + agent_id = f"agent-{uuid.uuid4()}" + mock_messages = [ + Message( + id=f"message-{uuid.uuid4()}", + role=MessageRole.user, + content=[{"type": "text", "text": "Hello"}], + agent_id=agent_id, + ) + ] + + summarizer = Summarizer( + mode=SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER, + summarizer_agent=None, + message_buffer_limit=100, + message_buffer_min=10, + partial_evict_summarizer_percentage=0.5, + agent_manager=MagicMock(), + message_manager=MagicMock(), + actor=mock_actor, + agent_id=agent_id, + ) + + run_id = f"run-{uuid.uuid4()}" + step_id = f"step-{uuid.uuid4()}" + + result = await summarizer.summarize( + in_context_messages=mock_messages, + new_letta_messages=[], + force=False, + run_id=run_id, + step_id=step_id, + ) + + assert result is not None + + +class TestAgentAdapterInstantiation: + """Tests verifying agents instantiate adapters with telemetry context.""" + + def test_agent_v2_creates_summarizer_with_agent_id(self, mock_llm_config): + """Verify LettaAgentV2 creates Summarizer with correct agent_id.""" + from letta.agents.letta_agent_v2 import LettaAgentV2 + from letta.schemas.agent import AgentState, AgentType + from letta.schemas.embedding_config import EmbeddingConfig + from letta.schemas.memory import Memory + from letta.schemas.user import User + + mock_actor = User( + id=f"user-{uuid.uuid4()}", + organization_id=f"org-{uuid.uuid4()}", + name="test_user", + ) + + agent_id = f"agent-{uuid.uuid4()}" + agent_state = AgentState( + id=agent_id, + name="test_agent", + agent_type=AgentType.letta_v1_agent, + llm_config=mock_llm_config, + embedding_config=EmbeddingConfig.default_config(provider="openai"), + tags=["test"], + memory=Memory(blocks=[]), + system="You are a helpful assistant.", + tools=[], + sources=[], + blocks=[], + ) + + agent = LettaAgentV2(agent_state=agent_state, actor=mock_actor) + + assert agent.summarizer.agent_id == agent_id diff --git a/tests/test_provider_trace_backends.py b/tests/test_provider_trace_backends.py index a088d368..3d64e04b 100644 --- a/tests/test_provider_trace_backends.py +++ b/tests/test_provider_trace_backends.py @@ -171,12 +171,11 @@ class TestSocketProviderTraceBackend: assert len(received_data) == 1 record = json.loads(received_data[0].strip()) assert record["provider_trace_id"] == sample_provider_trace.id - assert record["model"] == "gpt-4o-mini" - assert record["provider"] == "openai" - assert record["input_tokens"] == 10 - assert record["output_tokens"] == 5 - assert record["context"]["step_id"] == "step-test-789" - assert record["context"]["run_id"] == "run-test-abc" + assert record["step_id"] == "step-test-789" + assert record["run_id"] == "run-test-abc" + assert record["request"]["model"] == "gpt-4o-mini" + assert record["response"]["usage"]["prompt_tokens"] == 10 + assert record["response"]["usage"]["completion_tokens"] == 5 def test_send_to_nonexistent_socket_does_not_raise(self, sample_provider_trace): """Test that sending to nonexistent socket fails silently.""" diff --git a/tests/test_provider_trace_summarization.py b/tests/test_provider_trace_summarization.py new file mode 100644 index 00000000..3f114736 --- /dev/null +++ b/tests/test_provider_trace_summarization.py @@ -0,0 +1,431 @@ +""" +Unit tests for summarization provider trace telemetry context. + +These tests verify that summarization LLM calls correctly pass telemetry context +(agent_id, agent_tags, run_id, step_id) to the provider trace system. +""" + +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from letta.schemas.agent import AgentState +from letta.schemas.block import Block +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import MessageRole +from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message +from letta.schemas.user import User +from letta.services.summarizer import summarizer_all, summarizer_sliding_window +from letta.services.summarizer.enums import SummarizationMode +from letta.services.summarizer.summarizer import Summarizer, simple_summary +from letta.services.summarizer.summarizer_config import CompactionSettings + + +@pytest.fixture +def mock_actor(): + """Create a mock user/actor.""" + return User( + id=f"user-{uuid.uuid4()}", + organization_id=f"org-{uuid.uuid4()}", + name="test_user", + ) + + +@pytest.fixture +def mock_llm_config(): + """Create a mock LLM config.""" + return LLMConfig( + model="gpt-4o-mini", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=8000, + ) + + +@pytest.fixture +def mock_agent_state(mock_llm_config): + """Create a mock agent state.""" + agent_id = f"agent-{uuid.uuid4()}" + return AgentState( + id=agent_id, + name="test_agent", + llm_config=mock_llm_config, + embedding_config=EmbeddingConfig.default_config(provider="openai"), + tags=["env:test", "team:ml"], + memory=MagicMock( + compile=MagicMock(return_value="Memory content"), + ), + message_ids=[], + tool_ids=[], + system="You are a helpful assistant.", + ) + + +@pytest.fixture +def mock_messages(): + """Create mock messages for summarization.""" + agent_id = f"agent-{uuid.uuid4()}" + messages = [] + for i in range(10): + msg = Message( + id=f"message-{uuid.uuid4()}", + role=MessageRole.user if i % 2 == 0 else MessageRole.assistant, + content=[{"type": "text", "text": f"Message content {i}"}], + agent_id=agent_id, + ) + messages.append(msg) + return messages + + +class TestSimpleSummaryTelemetryContext: + """Tests for simple_summary telemetry context passing.""" + + @pytest.mark.asyncio + async def test_simple_summary_accepts_telemetry_params(self, mock_messages, mock_llm_config, mock_actor): + """Verify simple_summary accepts all telemetry context parameters.""" + agent_id = f"agent-{uuid.uuid4()}" + agent_tags = ["tag1", "tag2"] + run_id = f"run-{uuid.uuid4()}" + step_id = f"step-{uuid.uuid4()}" + + with patch("letta.services.summarizer.summarizer.LLMClient") as mock_client_class: + mock_client = MagicMock() + mock_client.set_telemetry_context = MagicMock() + mock_client.send_llm_request_async = AsyncMock(return_value=MagicMock(content="Summary of conversation")) + mock_client_class.create.return_value = mock_client + + try: + await simple_summary( + messages=mock_messages, + llm_config=mock_llm_config, + actor=mock_actor, + agent_id=agent_id, + agent_tags=agent_tags, + run_id=run_id, + step_id=step_id, + ) + except Exception: + pass + + mock_client.set_telemetry_context.assert_called_once() + call_kwargs = mock_client.set_telemetry_context.call_args[1] + assert call_kwargs["agent_id"] == agent_id + assert call_kwargs["agent_tags"] == agent_tags + assert call_kwargs["run_id"] == run_id + assert call_kwargs["step_id"] == step_id + assert call_kwargs["call_type"] == "summarization" + + +class TestSummarizeAllTelemetryContext: + """Tests for summarize_all telemetry context passing.""" + + @pytest.fixture + def mock_compaction_settings(self): + """Create mock compaction settings.""" + return CompactionSettings(model="openai/gpt-4o-mini") + + @pytest.mark.asyncio + async def test_summarize_all_passes_telemetry_to_simple_summary( + self, mock_messages, mock_llm_config, mock_actor, mock_compaction_settings + ): + """Verify summarize_all passes telemetry context to simple_summary.""" + agent_id = f"agent-{uuid.uuid4()}" + agent_tags = ["env:prod", "team:core"] + run_id = f"run-{uuid.uuid4()}" + step_id = f"step-{uuid.uuid4()}" + + captured_kwargs = {} + + async def capture_simple_summary(*args, **kwargs): + captured_kwargs.update(kwargs) + return "Mocked summary" + + with patch.object(summarizer_all, "simple_summary", new=capture_simple_summary): + await summarizer_all.summarize_all( + actor=mock_actor, + llm_config=mock_llm_config, + summarizer_config=mock_compaction_settings, + in_context_messages=mock_messages, + agent_id=agent_id, + agent_tags=agent_tags, + run_id=run_id, + step_id=step_id, + ) + + assert captured_kwargs.get("agent_id") == agent_id + assert captured_kwargs.get("agent_tags") == agent_tags + assert captured_kwargs.get("run_id") == run_id + assert captured_kwargs.get("step_id") == step_id + + @pytest.mark.asyncio + async def test_summarize_all_without_telemetry_params(self, mock_messages, mock_llm_config, mock_actor, mock_compaction_settings): + """Verify summarize_all works without telemetry params (backwards compatible).""" + captured_kwargs = {} + + async def capture_simple_summary(*args, **kwargs): + captured_kwargs.update(kwargs) + return "Mocked summary" + + with patch.object(summarizer_all, "simple_summary", new=capture_simple_summary): + await summarizer_all.summarize_all( + actor=mock_actor, + llm_config=mock_llm_config, + summarizer_config=mock_compaction_settings, + in_context_messages=mock_messages, + ) + + assert captured_kwargs.get("agent_id") is None + assert captured_kwargs.get("agent_tags") is None + assert captured_kwargs.get("run_id") is None + assert captured_kwargs.get("step_id") is None + + +class TestSummarizeSlidingWindowTelemetryContext: + """Tests for summarize_via_sliding_window telemetry context passing.""" + + @pytest.fixture + def mock_compaction_settings(self): + """Create mock compaction settings.""" + return CompactionSettings(model="openai/gpt-4o-mini") + + @pytest.mark.asyncio + async def test_sliding_window_passes_telemetry_to_simple_summary( + self, mock_messages, mock_llm_config, mock_actor, mock_compaction_settings + ): + """Verify summarize_via_sliding_window passes telemetry context to simple_summary.""" + agent_id = f"agent-{uuid.uuid4()}" + agent_tags = ["version:v2"] + run_id = f"run-{uuid.uuid4()}" + step_id = f"step-{uuid.uuid4()}" + + captured_kwargs = {} + + async def capture_simple_summary(*args, **kwargs): + captured_kwargs.update(kwargs) + return "Mocked summary" + + with patch.object(summarizer_sliding_window, "simple_summary", new=capture_simple_summary): + await summarizer_sliding_window.summarize_via_sliding_window( + actor=mock_actor, + llm_config=mock_llm_config, + summarizer_config=mock_compaction_settings, + in_context_messages=mock_messages, + agent_id=agent_id, + agent_tags=agent_tags, + run_id=run_id, + step_id=step_id, + ) + + assert captured_kwargs.get("agent_id") == agent_id + assert captured_kwargs.get("agent_tags") == agent_tags + assert captured_kwargs.get("run_id") == run_id + assert captured_kwargs.get("step_id") == step_id + + +class TestSummarizerClassTelemetryContext: + """Tests for Summarizer class telemetry context passing.""" + + @pytest.mark.asyncio + async def test_summarizer_summarize_passes_runtime_telemetry(self, mock_messages, mock_actor): + """Verify Summarizer.summarize() passes runtime run_id/step_id to the underlying call.""" + run_id = f"run-{uuid.uuid4()}" + step_id = f"step-{uuid.uuid4()}" + agent_id = f"agent-{uuid.uuid4()}" + + mock_agent_manager = MagicMock() + mock_agent_manager.get_agent_by_id_async = AsyncMock( + return_value=MagicMock( + llm_config=LLMConfig( + model="gpt-4o-mini", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=8000, + ), + tags=["test-tag"], + ) + ) + + summarizer = Summarizer( + mode=SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER, + summarizer_agent=None, + message_buffer_limit=100, + message_buffer_min=10, + partial_evict_summarizer_percentage=0.5, + agent_manager=mock_agent_manager, + message_manager=MagicMock(), + actor=mock_actor, + agent_id=agent_id, + ) + + captured_kwargs = {} + + async def capture_simple_summary(*args, **kwargs): + captured_kwargs.update(kwargs) + return "Mocked summary" + + with patch("letta.services.summarizer.summarizer.simple_summary", new=capture_simple_summary): + try: + await summarizer.summarize( + in_context_messages=mock_messages, + new_letta_messages=[], + force=True, + run_id=run_id, + step_id=step_id, + ) + except Exception: + pass + + if captured_kwargs: + assert captured_kwargs.get("run_id") == run_id + assert captured_kwargs.get("step_id") == step_id + + @pytest.mark.asyncio + async def test_summarizer_uses_constructor_telemetry_as_default(self, mock_messages, mock_actor): + """Verify Summarizer uses constructor run_id/step_id when not passed to summarize().""" + constructor_run_id = f"run-{uuid.uuid4()}" + constructor_step_id = f"step-{uuid.uuid4()}" + agent_id = f"agent-{uuid.uuid4()}" + + mock_agent_manager = MagicMock() + mock_agent_manager.get_agent_by_id_async = AsyncMock( + return_value=MagicMock( + llm_config=LLMConfig( + model="gpt-4o-mini", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=8000, + ), + tags=["test-tag"], + ) + ) + + summarizer = Summarizer( + mode=SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER, + summarizer_agent=None, + message_buffer_limit=100, + message_buffer_min=10, + partial_evict_summarizer_percentage=0.5, + agent_manager=mock_agent_manager, + message_manager=MagicMock(), + actor=mock_actor, + agent_id=agent_id, + run_id=constructor_run_id, + step_id=constructor_step_id, + ) + + captured_kwargs = {} + + async def capture_simple_summary(*args, **kwargs): + captured_kwargs.update(kwargs) + return "Mocked summary" + + with patch("letta.services.summarizer.summarizer.simple_summary", new=capture_simple_summary): + try: + await summarizer.summarize( + in_context_messages=mock_messages, + new_letta_messages=[], + force=True, + ) + except Exception: + pass + + if captured_kwargs: + assert captured_kwargs.get("run_id") == constructor_run_id + assert captured_kwargs.get("step_id") == constructor_step_id + + @pytest.mark.asyncio + async def test_summarizer_runtime_overrides_constructor_telemetry(self, mock_messages, mock_actor): + """Verify runtime run_id/step_id override constructor values.""" + constructor_run_id = f"run-constructor-{uuid.uuid4()}" + constructor_step_id = f"step-constructor-{uuid.uuid4()}" + runtime_run_id = f"run-runtime-{uuid.uuid4()}" + runtime_step_id = f"step-runtime-{uuid.uuid4()}" + agent_id = f"agent-{uuid.uuid4()}" + + mock_agent_manager = MagicMock() + mock_agent_manager.get_agent_by_id_async = AsyncMock( + return_value=MagicMock( + llm_config=LLMConfig( + model="gpt-4o-mini", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=8000, + ), + tags=["test-tag"], + ) + ) + + summarizer = Summarizer( + mode=SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER, + summarizer_agent=None, + message_buffer_limit=100, + message_buffer_min=10, + partial_evict_summarizer_percentage=0.5, + agent_manager=mock_agent_manager, + message_manager=MagicMock(), + actor=mock_actor, + agent_id=agent_id, + run_id=constructor_run_id, + step_id=constructor_step_id, + ) + + captured_kwargs = {} + + async def capture_simple_summary(*args, **kwargs): + captured_kwargs.update(kwargs) + return "Mocked summary" + + with patch("letta.services.summarizer.summarizer.simple_summary", new=capture_simple_summary): + try: + await summarizer.summarize( + in_context_messages=mock_messages, + new_letta_messages=[], + force=True, + run_id=runtime_run_id, + step_id=runtime_step_id, + ) + except Exception: + pass + + if captured_kwargs: + assert captured_kwargs.get("run_id") == runtime_run_id + assert captured_kwargs.get("step_id") == runtime_step_id + + +class TestLLMClientTelemetryContext: + """Tests for LLM client telemetry context setting.""" + + def test_llm_client_set_telemetry_context_method_exists(self): + """Verify LLMClient has set_telemetry_context method.""" + from letta.llm_api.llm_client import LLMClient + + client = LLMClient.create( + provider_type="openai", + put_inner_thoughts_first=True, + ) + assert hasattr(client, "set_telemetry_context") + + def test_llm_client_set_telemetry_context_accepts_all_params(self): + """Verify set_telemetry_context accepts all telemetry parameters.""" + from letta.llm_api.llm_client import LLMClient + + client = LLMClient.create( + provider_type="openai", + put_inner_thoughts_first=True, + ) + + agent_id = f"agent-{uuid.uuid4()}" + agent_tags = ["tag1", "tag2"] + run_id = f"run-{uuid.uuid4()}" + step_id = f"step-{uuid.uuid4()}" + call_type = "summarization" + + client.set_telemetry_context( + agent_id=agent_id, + agent_tags=agent_tags, + run_id=run_id, + step_id=step_id, + call_type=call_type, + )