From a78abc610e9248f2375810fe4c92f3e3deff643b Mon Sep 17 00:00:00 2001 From: Andy Li <55300002+cliandy@users.noreply.github.com> Date: Mon, 19 May 2025 15:50:56 -0700 Subject: [PATCH] feat: track llm provider traces and tracking steps in async agent loop (#2219) --- ...d_add_support_for_request_and_response_.py | 50 ++++ letta/agent.py | 18 +- letta/agents/helpers.py | 5 + letta/agents/letta_agent.py | 238 ++++++++++++++---- .../anthropic_streaming_interface.py | 2 + .../interfaces/openai_streaming_interface.py | 7 + letta/llm_api/llm_api_tools.py | 27 ++ letta/llm_api/llm_client_base.py | 53 ++-- letta/llm_api/openai_client.py | 8 - letta/memory.py | 1 - letta/orm/__init__.py | 1 + letta/orm/provider_trace.py | 26 ++ letta/schemas/provider_trace.py | 43 ++++ letta/server/rest_api/routers/v1/__init__.py | 2 + letta/server/rest_api/routers/v1/agents.py | 12 +- letta/server/rest_api/routers/v1/telemetry.py | 18 ++ letta/server/rest_api/streaming_response.py | 105 ++++++++ letta/server/rest_api/utils.py | 4 + letta/server/server.py | 2 + letta/services/helpers/noop_helper.py | 10 + letta/services/step_manager.py | 85 +++++++ letta/services/telemetry_manager.py | 52 ++++ letta/settings.py | 1 + tests/conftest.py | 18 ++ tests/integration_test_sleeptime_agent.py | 3 +- tests/test_multi_agent.py | 3 +- tests/test_provider_trace.py | 205 +++++++++++++++ tests/test_server.py | 3 +- 28 files changed, 920 insertions(+), 82 deletions(-) create mode 100644 alembic/versions/cc8dc340836d_add_support_for_request_and_response_.py create mode 100644 letta/orm/provider_trace.py create mode 100644 letta/schemas/provider_trace.py create mode 100644 letta/server/rest_api/routers/v1/telemetry.py create mode 100644 letta/server/rest_api/streaming_response.py create mode 100644 letta/services/helpers/noop_helper.py create mode 100644 letta/services/telemetry_manager.py create mode 100644 tests/test_provider_trace.py diff --git a/alembic/versions/cc8dc340836d_add_support_for_request_and_response_.py b/alembic/versions/cc8dc340836d_add_support_for_request_and_response_.py new file mode 100644 index 00000000..7ce2c0dc --- /dev/null +++ b/alembic/versions/cc8dc340836d_add_support_for_request_and_response_.py @@ -0,0 +1,50 @@ +"""add support for request and response jsons from llm providers + +Revision ID: cc8dc340836d +Revises: 220856bbf43b +Create Date: 2025-05-19 14:25:41.999676 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "cc8dc340836d" +down_revision: Union[str, None] = "220856bbf43b" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "provider_traces", + sa.Column("id", sa.String(), nullable=False), + sa.Column("request_json", sa.JSON(), nullable=False), + sa.Column("response_json", sa.JSON(), nullable=False), + sa.Column("step_id", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ix_step_id", "provider_traces", ["step_id"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_step_id", table_name="provider_traces") + op.drop_table("provider_traces") + # ### end Alembic commands ### diff --git a/letta/agent.py b/letta/agent.py index df755c23..e2ed59d5 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -1,12 +1,14 @@ import json import time import traceback +import uuid import warnings from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple, Union from openai.types.beta.function_tool import FunctionTool as OpenAITool +from letta.agents.helpers import generate_step_id from letta.constants import ( CLI_WARNING_PREFIX, COMPOSIO_ENTITY_ENV_VAR_KEY, @@ -61,9 +63,10 @@ from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager from letta.services.provider_manager import ProviderManager from letta.services.step_manager import StepManager +from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox from letta.services.tool_manager import ToolManager -from letta.settings import summarizer_settings +from letta.settings import settings, summarizer_settings from letta.streaming_interface import StreamingRefreshCLIInterface from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message from letta.tracing import log_event, trace_method @@ -141,6 +144,7 @@ class Agent(BaseAgent): self.agent_manager = AgentManager() self.job_manager = JobManager() self.step_manager = StepManager() + self.telemetry_manager = TelemetryManager() if settings.llm_api_logging else NoopTelemetryManager() # State needed for heartbeat pausing @@ -298,6 +302,7 @@ class Agent(BaseAgent): step_count: Optional[int] = None, last_function_failed: bool = False, put_inner_thoughts_first: bool = True, + step_id: Optional[str] = None, ) -> ChatCompletionResponse | None: """Get response from LLM API with robust retry mechanism.""" log_telemetry(self.logger, "_get_ai_reply start") @@ -347,8 +352,9 @@ class Agent(BaseAgent): messages=message_sequence, llm_config=self.agent_state.llm_config, tools=allowed_functions, - stream=stream, force_tool_call=force_tool_call, + telemetry_manager=self.telemetry_manager, + step_id=step_id, ) else: # Fallback to existing flow @@ -365,6 +371,9 @@ class Agent(BaseAgent): stream_interface=self.interface, put_inner_thoughts_first=put_inner_thoughts_first, name=self.agent_state.name, + telemetry_manager=self.telemetry_manager, + step_id=step_id, + actor=self.user, ) log_telemetry(self.logger, "_get_ai_reply create finish") @@ -840,6 +849,9 @@ class Agent(BaseAgent): # Extract job_id from metadata if present job_id = metadata.get("job_id") if metadata else None + # Declare step_id for the given step to be used as the step is processing. + step_id = generate_step_id() + # Step 0: update core memory # only pulling latest block data if shared memory is being used current_persisted_memory = Memory( @@ -870,6 +882,7 @@ class Agent(BaseAgent): step_count=step_count, last_function_failed=last_function_failed, put_inner_thoughts_first=put_inner_thoughts_first, + step_id=step_id, ) if not response: # EDGE CASE: Function call failed AND there's no tools left for agent to call -> return early @@ -953,6 +966,7 @@ class Agent(BaseAgent): actor=self.user, ), job_id=job_id, + step_id=step_id, ) for message in all_new_messages: message.step_id = step.id diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 5578d1fb..3a525e7a 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -1,3 +1,4 @@ +import uuid import xml.etree.ElementTree as ET from typing import List, Tuple @@ -150,3 +151,7 @@ def deserialize_message_history(xml_str: str) -> Tuple[List[str], str]: context = sum_el.text or "" return messages, context + + +def generate_step_id(): + return f"step-{uuid.uuid4()}" diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 0bbdc015..94c71544 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -8,7 +8,7 @@ from openai.types import CompletionUsage from openai.types.chat import ChatCompletion, ChatCompletionChunk from letta.agents.base_agent import BaseAgent -from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_async +from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_async, generate_step_id from letta.helpers import ToolRulesSolver from letta.helpers.tool_execution_helper import enable_strict_mode from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface @@ -24,7 +24,8 @@ from letta.schemas.letta_message import AssistantMessage from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.letta_response import LettaResponse from letta.schemas.message import Message, MessageCreate -from letta.schemas.openai.chat_completion_response import ToolCall +from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics +from letta.schemas.provider_trace import ProviderTraceCreate from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User from letta.server.rest_api.utils import create_letta_messages_from_llm_response @@ -32,6 +33,8 @@ from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager +from letta.services.step_manager import NoopStepManager, StepManager +from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager from letta.settings import settings from letta.system import package_function_response @@ -50,6 +53,8 @@ class LettaAgent(BaseAgent): block_manager: BlockManager, passage_manager: PassageManager, actor: User, + step_manager: StepManager = NoopStepManager(), + telemetry_manager: TelemetryManager = NoopTelemetryManager(), ): super().__init__(agent_id=agent_id, openai_client=None, message_manager=message_manager, agent_manager=agent_manager, actor=actor) @@ -57,6 +62,8 @@ class LettaAgent(BaseAgent): # Summarizer settings self.block_manager = block_manager self.passage_manager = passage_manager + self.step_manager = step_manager + self.telemetry_manager = telemetry_manager self.response_messages: List[Message] = [] self.last_function_response = None @@ -68,9 +75,7 @@ class LettaAgent(BaseAgent): @trace_method async def step(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True) -> LettaResponse: agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor) - current_in_context_messages, new_in_context_messages, usage = await self._step( - agent_state=agent_state, input_messages=input_messages, max_steps=max_steps - ) + _, new_in_context_messages, usage = await self._step(agent_state=agent_state, input_messages=input_messages, max_steps=max_steps) return _create_letta_response( new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage ) @@ -89,14 +94,43 @@ class LettaAgent(BaseAgent): ) usage = LettaUsageStatistics() for _ in range(max_steps): - response = await self._get_ai_reply( + step_id = generate_step_id() + + in_context_messages = current_in_context_messages + new_in_context_messages + if settings.experimental_enable_async_db_engine: + in_context_messages = await self._rebuild_memory_async( + in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories + ) + else: + if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex": + logger.info("Skipping memory rebuild") + else: + in_context_messages = self._rebuild_memory(in_context_messages, agent_state) + log_event("agent.stream_no_tokens.messages.refreshed") # [1^] + + request_data = await self._create_llm_request_data_async( llm_client=llm_client, - in_context_messages=current_in_context_messages + new_in_context_messages, + in_context_messages=in_context_messages, agent_state=agent_state, tool_rules_solver=tool_rules_solver, - stream=False, - # TODO: also pass in reasoning content + # TODO: pass in reasoning content ) + log_event("agent.stream_no_tokens.llm_request.created") # [2^] + + try: + response_data = await llm_client.request_async(request_data, agent_state.llm_config) + except Exception as e: + raise llm_client.handle_llm_error(e) + log_event("agent.stream_no_tokens.llm_response.received") # [3^] + + response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config) + + # update usage + # TODO: add run_id + usage.step_count += 1 + usage.completion_tokens += response.usage.completion_tokens + usage.prompt_tokens += response.usage.prompt_tokens + usage.total_tokens += response.usage.total_tokens if not response.choices[0].message.tool_calls: # TODO: make into a real error @@ -109,6 +143,18 @@ class LettaAgent(BaseAgent): ) self.response_messages.extend(persisted_messages) new_in_context_messages.extend(persisted_messages) + log_event("agent.stream_no_tokens.llm_response.processed") # [4^] + + # Log LLM Trace + await self.telemetry_manager.create_provider_trace_async( + actor=self.actor, + provider_trace_create=ProviderTraceCreate( + request_json=request_data, + response_json=response_data, + step_id=step_id, + organization_id=self.actor.organization_id, + ), + ) # stream step # TODO: improve TTFT @@ -119,13 +165,6 @@ class LettaAgent(BaseAgent): for message in letta_messages: yield f"data: {message.model_dump_json()}\n\n" - # update usage - # TODO: add run_id - usage.step_count += 1 - usage.completion_tokens += response.usage.completion_tokens - usage.prompt_tokens += response.usage.prompt_tokens - usage.total_tokens += response.usage.total_tokens - if not should_continue: break @@ -140,6 +179,13 @@ class LettaAgent(BaseAgent): async def _step( self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10 ) -> Tuple[List[Message], List[Message], CompletionUsage]: + """ + Carries out an invocation of the agent loop. In each step, the agent + 1. Rebuilds its memory + 2. Generates a request for the LLM + 3. Fetches a response from the LLM + 4. Processes the response + """ current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async( input_messages, agent_state, self.message_manager, self.actor ) @@ -151,14 +197,42 @@ class LettaAgent(BaseAgent): ) usage = LettaUsageStatistics() for _ in range(max_steps): - response = await self._get_ai_reply( + step_id = generate_step_id() + + in_context_messages = current_in_context_messages + new_in_context_messages + if settings.experimental_enable_async_db_engine: + in_context_messages = await self._rebuild_memory_async( + in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories + ) + else: + if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex": + logger.info("Skipping memory rebuild") + else: + in_context_messages = self._rebuild_memory(in_context_messages, agent_state) + log_event("agent.step.messages.refreshed") # [1^] + + request_data = await self._create_llm_request_data_async( llm_client=llm_client, - in_context_messages=current_in_context_messages + new_in_context_messages, + in_context_messages=in_context_messages, agent_state=agent_state, tool_rules_solver=tool_rules_solver, - stream=False, - # TODO: also pass in reasoning content + # TODO: pass in reasoning content ) + log_event("agent.step.llm_request.created") # [2^] + + try: + response_data = await llm_client.request_async(request_data, agent_state.llm_config) + except Exception as e: + raise llm_client.handle_llm_error(e) + log_event("agent.step.llm_response.received") # [3^] + + response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config) + + # TODO: add run_id + usage.step_count += 1 + usage.completion_tokens += response.usage.completion_tokens + usage.prompt_tokens += response.usage.prompt_tokens + usage.total_tokens += response.usage.total_tokens if not response.choices[0].message.tool_calls: # TODO: make into a real error @@ -167,17 +241,22 @@ class LettaAgent(BaseAgent): reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons persisted_messages, should_continue = await self._handle_ai_response( - tool_call, agent_state, tool_rules_solver, reasoning_content=reasoning + tool_call, agent_state, tool_rules_solver, reasoning_content=reasoning, step_id=step_id, usage=usage ) self.response_messages.extend(persisted_messages) new_in_context_messages.extend(persisted_messages) + log_event("agent.step.llm_response.processed") # [4^] - # update usage - # TODO: add run_id - usage.step_count += 1 - usage.completion_tokens += response.usage.completion_tokens - usage.prompt_tokens += response.usage.prompt_tokens - usage.total_tokens += response.usage.total_tokens + # Log LLM Trace + await self.telemetry_manager.create_provider_trace_async( + actor=self.actor, + provider_trace_create=ProviderTraceCreate( + request_json=request_data, + response_json=response_data, + step_id=step_id, + organization_id=self.actor.organization_id, + ), + ) if not should_continue: break @@ -194,8 +273,12 @@ class LettaAgent(BaseAgent): self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True ) -> AsyncGenerator[str, None]: """ - Main streaming loop that yields partial tokens. - Whenever we detect a tool call, we yield from _handle_ai_response as well. + Carries out an invocation of the agent loop in a streaming fashion that yields partial tokens. + Whenever we detect a tool call, we yield from _handle_ai_response as well. At each step, the agent + 1. Rebuilds its memory + 2. Generates a request for the LLM + 3. Fetches a response from the LLM + 4. Processes the response """ agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor) current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async( @@ -210,13 +293,34 @@ class LettaAgent(BaseAgent): usage = LettaUsageStatistics() for _ in range(max_steps): - stream = await self._get_ai_reply( + step_id = generate_step_id() + + in_context_messages = current_in_context_messages + new_in_context_messages + if settings.experimental_enable_async_db_engine: + in_context_messages = await self._rebuild_memory_async( + in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories + ) + else: + if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex": + logger.info("Skipping memory rebuild") + else: + in_context_messages = self._rebuild_memory(in_context_messages, agent_state) + log_event("agent.step.messages.refreshed") # [1^] + + request_data = await self._create_llm_request_data_async( llm_client=llm_client, - in_context_messages=current_in_context_messages + new_in_context_messages, + in_context_messages=in_context_messages, agent_state=agent_state, tool_rules_solver=tool_rules_solver, - stream=True, ) + log_event("agent.stream.llm_request.created") # [2^] + + try: + stream = await llm_client.stream_async(request_data, agent_state.llm_config) + except Exception as e: + raise llm_client.handle_llm_error(e) + log_event("agent.stream.llm_response.received") # [3^] + # TODO: THIS IS INCREDIBLY UGLY # TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED if agent_state.llm_config.model_endpoint_type == "anthropic": @@ -251,10 +355,39 @@ class LettaAgent(BaseAgent): reasoning_content=reasoning_content, pre_computed_assistant_message_id=interface.letta_assistant_message_id, pre_computed_tool_message_id=interface.letta_tool_message_id, + step_id=step_id, + usage=usage, ) self.response_messages.extend(persisted_messages) new_in_context_messages.extend(persisted_messages) + # TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream + # log_event("agent.stream.llm_response.processed") # [4^] + + # Log LLM Trace + # TODO (cliandy): we are piecing together the streamed response here. Content here does not match the actual response schema. + await self.telemetry_manager.create_provider_trace_async( + actor=self.actor, + provider_trace_create=ProviderTraceCreate( + request_json=request_data, + response_json={ + "content": { + "tool_call": tool_call.model_dump_json(), + "reasoning": [content.model_dump_json() for content in reasoning_content], + }, + "id": interface.message_id, + "model": interface.model, + "role": "assistant", + # "stop_reason": "", + # "stop_sequence": None, + "type": "message", + "usage": {"input_tokens": interface.input_tokens, "output_tokens": interface.output_tokens}, + }, + step_id=step_id, + organization_id=self.actor.organization_id, + ), + ) + if not use_assistant_message or should_continue: tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0] yield f"data: {tool_return.model_dump_json()}\n\n" @@ -277,14 +410,12 @@ class LettaAgent(BaseAgent): yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n" @trace_method - # When raising an error this doesn't show up - async def _get_ai_reply( + async def _create_llm_request_data_async( self, llm_client: LLMClientBase, in_context_messages: List[Message], agent_state: AgentState, tool_rules_solver: ToolRulesSolver, - stream: bool, ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: if settings.experimental_enable_async_db_engine: self.num_messages = self.num_messages or (await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)) @@ -332,15 +463,7 @@ class LettaAgent(BaseAgent): allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)] - response = await llm_client.send_llm_request_async( - messages=in_context_messages, - llm_config=agent_state.llm_config, - tools=allowed_tools, - force_tool_call=force_tool_call, - stream=stream, - ) - - return response + return llm_client.build_request_data(in_context_messages, agent_state.llm_config, allowed_tools, force_tool_call) @trace_method async def _handle_ai_response( @@ -351,6 +474,8 @@ class LettaAgent(BaseAgent): reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None, pre_computed_assistant_message_id: Optional[str] = None, pre_computed_tool_message_id: Optional[str] = None, + step_id: str | None = None, + usage: LettaUsageStatistics = None, ) -> Tuple[List[Message], bool]: """ Now that streaming is done, handle the final AI response. @@ -397,7 +522,28 @@ class LettaAgent(BaseAgent): elif tool_rules_solver.is_continue_tool(tool_name=tool_call_name): continue_stepping = True - # 5. Persist to DB + # 5a. Persist Steps to DB + # Following agent loop to persist this before messages + # TODO (cliandy): determine what should match old loop w/provider_id, job_id + # TODO (cliandy): UsageStatistics and LettaUsageStatistics are used in many places, but are not the same. + logged_step = await self.step_manager.log_step_async( + actor=self.actor, + agent_id=agent_state.id, + provider_name=agent_state.llm_config.model_endpoint_type, + model=agent_state.llm_config.model, + model_endpoint=agent_state.llm_config.model_endpoint, + context_window_limit=agent_state.llm_config.context_window, + usage=UsageStatistics( + total_tokens=usage.total_tokens, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + ), + provider_id=None, + job_id=None, + step_id=step_id, + ) + + # 5b. Persist Messages to DB tool_call_messages = create_letta_messages_from_llm_response( agent_id=agent_state.id, model=agent_state.llm_config.model, @@ -411,7 +557,9 @@ class LettaAgent(BaseAgent): reasoning_content=reasoning_content, pre_computed_assistant_message_id=pre_computed_assistant_message_id, pre_computed_tool_message_id=pre_computed_tool_message_id, + step_id=logged_step.id if logged_step else None, # TODO (cliandy): eventually move over other agent loops ) + persisted_messages = await self.message_manager.create_many_messages_async(tool_call_messages, actor=self.actor) self.last_function_response = function_response diff --git a/letta/interfaces/anthropic_streaming_interface.py b/letta/interfaces/anthropic_streaming_interface.py index d8643538..08c7ce0a 100644 --- a/letta/interfaces/anthropic_streaming_interface.py +++ b/letta/interfaces/anthropic_streaming_interface.py @@ -74,6 +74,7 @@ class AnthropicStreamingInterface: # usage trackers self.input_tokens = 0 self.output_tokens = 0 + self.model = None # reasoning object trackers self.reasoning_messages = [] @@ -311,6 +312,7 @@ class AnthropicStreamingInterface: self.message_id = event.message.id self.input_tokens += event.message.usage.input_tokens self.output_tokens += event.message.usage.output_tokens + self.model = event.message.model elif isinstance(event, BetaRawMessageDeltaEvent): self.output_tokens += event.usage.output_tokens elif isinstance(event, BetaRawMessageStopEvent): diff --git a/letta/interfaces/openai_streaming_interface.py b/letta/interfaces/openai_streaming_interface.py index 168d0521..3d1fabe5 100644 --- a/letta/interfaces/openai_streaming_interface.py +++ b/letta/interfaces/openai_streaming_interface.py @@ -40,6 +40,9 @@ class OpenAIStreamingInterface: self.letta_assistant_message_id = Message.generate_id() self.letta_tool_message_id = Message.generate_id() + self.message_id = None + self.model = None + # token counters self.input_tokens = 0 self.output_tokens = 0 @@ -69,6 +72,10 @@ class OpenAIStreamingInterface: prev_message_type = None message_index = 0 async for chunk in stream: + if not self.model or not self.message_id: + self.model = chunk.model + self.message_id = chunk.id + # track usage if chunk.usage: self.input_tokens += len(chunk.usage.prompt_tokens) diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index d86abc9b..a1af262f 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -20,15 +20,19 @@ from letta.llm_api.openai import ( build_openai_chat_completions_request, openai_chat_completions_process_stream, openai_chat_completions_request, + prepare_openai_payload, ) from letta.local_llm.chat_completion_proxy import get_chat_completion from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages +from letta.orm.user import User from letta.schemas.enums import ProviderCategory from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype from letta.schemas.openai.chat_completion_response import ChatCompletionResponse +from letta.schemas.provider_trace import ProviderTraceCreate +from letta.services.telemetry_manager import TelemetryManager from letta.settings import ModelSettings from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface from letta.tracing import log_event, trace_method @@ -142,6 +146,9 @@ def create( model_settings: Optional[dict] = None, # TODO: eventually pass from server put_inner_thoughts_first: bool = True, name: Optional[str] = None, + telemetry_manager: Optional[TelemetryManager] = None, + step_id: Optional[str] = None, + actor: Optional[User] = None, ) -> ChatCompletionResponse: """Return response to chat completion with backoff""" from letta.utils import printd @@ -233,6 +240,16 @@ def create( if isinstance(stream_interface, AgentChunkStreamingInterface): stream_interface.stream_end() + telemetry_manager.create_provider_trace( + actor=actor, + provider_trace_create=ProviderTraceCreate( + request_json=prepare_openai_payload(data), + response_json=response.model_json_schema(), + step_id=step_id, + organization_id=actor.organization_id, + ), + ) + if llm_config.put_inner_thoughts_in_kwargs: response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG) @@ -407,6 +424,16 @@ def create( if llm_config.put_inner_thoughts_in_kwargs: response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG) + telemetry_manager.create_provider_trace( + actor=actor, + provider_trace_create=ProviderTraceCreate( + request_json=chat_completion_request.model_json_schema(), + response_json=response.model_json_schema(), + step_id=step_id, + organization_id=actor.organization_id, + ), + ) + return response # elif llm_config.model_endpoint_type == "cohere": diff --git a/letta/llm_api/llm_client_base.py b/letta/llm_api/llm_client_base.py index f56601ee..6374a85c 100644 --- a/letta/llm_api/llm_client_base.py +++ b/letta/llm_api/llm_client_base.py @@ -9,7 +9,9 @@ from letta.errors import LLMError from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import ChatCompletionResponse -from letta.tracing import log_event +from letta.schemas.provider_trace import ProviderTraceCreate +from letta.services.telemetry_manager import TelemetryManager +from letta.tracing import log_event, trace_method if TYPE_CHECKING: from letta.orm import User @@ -31,13 +33,15 @@ class LLMClientBase: self.put_inner_thoughts_first = put_inner_thoughts_first self.use_tool_naming = use_tool_naming + @trace_method def send_llm_request( self, messages: List[Message], llm_config: LLMConfig, tools: Optional[List[dict]] = None, # TODO: change to Tool object - stream: bool = False, force_tool_call: Optional[str] = None, + telemetry_manager: Optional["TelemetryManager"] = None, + step_id: Optional[str] = None, ) -> Union[ChatCompletionResponse, Stream[ChatCompletionChunk]]: """ Issues a request to the downstream model endpoint and parses response. @@ -48,37 +52,51 @@ class LLMClientBase: try: log_event(name="llm_request_sent", attributes=request_data) - if stream: - return self.stream(request_data, llm_config) - else: - response_data = self.request(request_data, llm_config) + response_data = self.request(request_data, llm_config) + if step_id and telemetry_manager: + telemetry_manager.create_provider_trace( + actor=self.actor, + provider_trace_create=ProviderTraceCreate( + request_json=request_data, + response_json=response_data, + step_id=step_id, + organization_id=self.actor.organization_id, + ), + ) log_event(name="llm_response_received", attributes=response_data) except Exception as e: raise self.handle_llm_error(e) return self.convert_response_to_chat_completion(response_data, messages, llm_config) + @trace_method async def send_llm_request_async( self, + request_data: dict, messages: List[Message], llm_config: LLMConfig, - tools: Optional[List[dict]] = None, # TODO: change to Tool object - stream: bool = False, - force_tool_call: Optional[str] = None, + telemetry_manager: "TelemetryManager | None" = None, + step_id: str | None = None, ) -> Union[ChatCompletionResponse, AsyncStream[ChatCompletionChunk]]: """ Issues a request to the downstream model endpoint. If stream=True, returns an AsyncStream[ChatCompletionChunk] that can be async iterated over. Otherwise returns a ChatCompletionResponse. """ - request_data = self.build_request_data(messages, llm_config, tools, force_tool_call) try: log_event(name="llm_request_sent", attributes=request_data) - if stream: - return await self.stream_async(request_data, llm_config) - else: - response_data = await self.request_async(request_data, llm_config) + response_data = await self.request_async(request_data, llm_config) + await telemetry_manager.create_provider_trace_async( + actor=self.actor, + provider_trace_create=ProviderTraceCreate( + request_json=request_data, + response_json=response_data, + step_id=step_id, + organization_id=self.actor.organization_id, + ), + ) + log_event(name="llm_response_received", attributes=response_data) except Exception as e: raise self.handle_llm_error(e) @@ -133,13 +151,6 @@ class LLMClientBase: """ raise NotImplementedError - @abstractmethod - def stream(self, request_data: dict, llm_config: LLMConfig) -> Stream[ChatCompletionChunk]: - """ - Performs underlying streaming request to llm and returns raw response. - """ - raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}") - @abstractmethod async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]: """ diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index 150def39..f3353bed 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -256,14 +256,6 @@ class OpenAIClient(LLMClientBase): return chat_completion_response - def stream(self, request_data: dict, llm_config: LLMConfig) -> Stream[ChatCompletionChunk]: - """ - Performs underlying streaming request to OpenAI and returns the stream iterator. - """ - client = OpenAI(**self._prepare_client_kwargs(llm_config)) - response_stream: Stream[ChatCompletionChunk] = client.chat.completions.create(**request_data, stream=True) - return response_stream - async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]: """ Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator. diff --git a/letta/memory.py b/letta/memory.py index 939e0874..818f45ca 100644 --- a/letta/memory.py +++ b/letta/memory.py @@ -93,7 +93,6 @@ def summarize_messages( response = llm_client.send_llm_request( messages=message_sequence, llm_config=llm_config_no_inner_thoughts, - stream=False, ) else: response = create( diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index 348cd19e..de395e28 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -19,6 +19,7 @@ from letta.orm.message import Message from letta.orm.organization import Organization from letta.orm.passage import AgentPassage, BasePassage, SourcePassage from letta.orm.provider import Provider +from letta.orm.provider_trace import ProviderTrace from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable from letta.orm.source import Source from letta.orm.sources_agents import SourcesAgents diff --git a/letta/orm/provider_trace.py b/letta/orm/provider_trace.py new file mode 100644 index 00000000..c957636e --- /dev/null +++ b/letta/orm/provider_trace.py @@ -0,0 +1,26 @@ +import uuid + +from sqlalchemy import JSON, ForeignKeyConstraint, Index, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.mixins import OrganizationMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.provider_trace import ProviderTrace as PydanticProviderTrace + + +class ProviderTrace(SqlalchemyBase, OrganizationMixin): + """Defines data model for storing provider trace information""" + + __tablename__ = "provider_traces" + __pydantic_model__ = PydanticProviderTrace + __table_args__ = (Index("ix_step_id", "step_id"),) + + id: Mapped[str] = mapped_column( + primary_key=True, doc="Unique provider trace identifier", default=lambda: f"provider_trace-{uuid.uuid4()}" + ) + request_json: Mapped[dict] = mapped_column(JSON, doc="JSON content of the provider request") + response_json: Mapped[dict] = mapped_column(JSON, doc="JSON content of the provider response") + step_id: Mapped[str] = mapped_column(String, nullable=True, doc="ID of the step that this trace is associated with") + + # Relationships + organization: Mapped["Organization"] = relationship("Organization", lazy="selectin") diff --git a/letta/schemas/provider_trace.py b/letta/schemas/provider_trace.py new file mode 100644 index 00000000..bcc151de --- /dev/null +++ b/letta/schemas/provider_trace.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field + +from letta.helpers.datetime_helpers import get_utc_time +from letta.schemas.letta_base import OrmMetadataBase + + +class BaseProviderTrace(OrmMetadataBase): + __id_prefix__ = "provider_trace" + + +class ProviderTraceCreate(BaseModel): + """Request to create a provider trace""" + + request_json: dict[str, Any] = Field(..., description="JSON content of the provider request") + response_json: dict[str, Any] = Field(..., description="JSON content of the provider response") + step_id: str = Field(None, description="ID of the step that this trace is associated with") + organization_id: str = Field(..., description="The unique identifier of the organization.") + + +class ProviderTrace(BaseProviderTrace): + """ + Letta's internal representation of a provider trace. + + Attributes: + id (str): The unique identifier of the provider trace. + request_json (Dict[str, Any]): JSON content of the provider request. + response_json (Dict[str, Any]): JSON content of the provider response. + step_id (str): ID of the step that this trace is associated with. + organization_id (str): The unique identifier of the organization. + created_at (datetime): The timestamp when the object was created. + """ + + id: str = BaseProviderTrace.generate_id_field() + request_json: Dict[str, Any] = Field(..., description="JSON content of the provider request") + response_json: Dict[str, Any] = Field(..., description="JSON content of the provider response") + step_id: Optional[str] = Field(None, description="ID of the step that this trace is associated with") + organization_id: str = Field(..., description="The unique identifier of the organization.") + created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.") diff --git a/letta/server/rest_api/routers/v1/__init__.py b/letta/server/rest_api/routers/v1/__init__.py index 666aeedc..4607f8f9 100644 --- a/letta/server/rest_api/routers/v1/__init__.py +++ b/letta/server/rest_api/routers/v1/__init__.py @@ -13,6 +13,7 @@ from letta.server.rest_api.routers.v1.sandbox_configs import router as sandbox_c from letta.server.rest_api.routers.v1.sources import router as sources_router from letta.server.rest_api.routers.v1.steps import router as steps_router from letta.server.rest_api.routers.v1.tags import router as tags_router +from letta.server.rest_api.routers.v1.telemetry import router as telemetry_router from letta.server.rest_api.routers.v1.tools import router as tools_router from letta.server.rest_api.routers.v1.voice import router as voice_router @@ -31,6 +32,7 @@ ROUTERS = [ runs_router, steps_router, tags_router, + telemetry_router, messages_router, voice_router, embeddings_router, diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 8098c9b5..e26dbc62 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -9,6 +9,7 @@ from marshmallow import ValidationError from orjson import orjson from pydantic import Field from sqlalchemy.exc import IntegrityError, OperationalError +from starlette.background import BackgroundTask from starlette.responses import Response, StreamingResponse from letta.agents.letta_agent import LettaAgent @@ -33,6 +34,7 @@ from letta.schemas.user import User from letta.serialize_schemas.pydantic_agent_schema import AgentSchema from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer +from letta.services.telemetry_manager import NoopTelemetryManager from letta.settings import settings # These can be forward refs, but because Fastapi needs them at runtime the must be imported normally @@ -646,6 +648,8 @@ async def send_message( block_manager=server.block_manager, passage_manager=server.passage_manager, actor=actor, + step_manager=server.step_manager, + telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), ) result = await experimental_agent.step(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message) @@ -707,14 +711,18 @@ async def send_message_streaming( block_manager=server.block_manager, passage_manager=server.passage_manager, actor=actor, + step_manager=server.step_manager, + telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), ) + from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode + if request.stream_tokens and model_compatible_token_streaming: - result = StreamingResponse( + result = StreamingResponseWithStatusCode( experimental_agent.step_stream(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message), media_type="text/event-stream", ) else: - result = StreamingResponse( + result = StreamingResponseWithStatusCode( experimental_agent.step_stream_no_tokens( request.messages, max_steps=10, use_assistant_message=request.use_assistant_message ), diff --git a/letta/server/rest_api/routers/v1/telemetry.py b/letta/server/rest_api/routers/v1/telemetry.py new file mode 100644 index 00000000..75e8de95 --- /dev/null +++ b/letta/server/rest_api/routers/v1/telemetry.py @@ -0,0 +1,18 @@ +from fastapi import APIRouter, Depends, Header + +from letta.schemas.provider_trace import ProviderTrace +from letta.server.rest_api.utils import get_letta_server +from letta.server.server import SyncServer + +router = APIRouter(prefix="/telemetry", tags=["telemetry"]) + + +@router.get("/{step_id}", response_model=ProviderTrace, operation_id="retrieve_provider_trace") +async def retrieve_provider_trace_by_step_id( + step_id: str, + server: SyncServer = Depends(get_letta_server), + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + return await server.telemetry_manager.get_provider_trace_by_step_id_async( + step_id=step_id, actor=await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + ) diff --git a/letta/server/rest_api/streaming_response.py b/letta/server/rest_api/streaming_response.py new file mode 100644 index 00000000..06e019b3 --- /dev/null +++ b/letta/server/rest_api/streaming_response.py @@ -0,0 +1,105 @@ +# Alternative implementation of StreamingResponse that allows for effectively +# stremaing HTTP trailers, as we cannot set codes after the initial response. +# Taken from: https://github.com/fastapi/fastapi/discussions/10138#discussioncomment-10377361 + +import json +from collections.abc import AsyncIterator + +from fastapi.responses import StreamingResponse +from starlette.types import Send + +from letta.log import get_logger + +logger = get_logger(__name__) + + +class StreamingResponseWithStatusCode(StreamingResponse): + """ + Variation of StreamingResponse that can dynamically decide the HTTP status code, + based on the return value of the content iterator (parameter `content`). + Expects the content to yield either just str content as per the original `StreamingResponse` + or else tuples of (`content`: `str`, `status_code`: `int`). + """ + + body_iterator: AsyncIterator[str | bytes] + response_started: bool = False + + async def stream_response(self, send: Send) -> None: + more_body = True + try: + first_chunk = await self.body_iterator.__anext__() + if isinstance(first_chunk, tuple): + first_chunk_content, self.status_code = first_chunk + else: + first_chunk_content = first_chunk + if isinstance(first_chunk_content, str): + first_chunk_content = first_chunk_content.encode(self.charset) + + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + self.response_started = True + await send( + { + "type": "http.response.body", + "body": first_chunk_content, + "more_body": more_body, + } + ) + + async for chunk in self.body_iterator: + if isinstance(chunk, tuple): + content, status_code = chunk + if status_code // 100 != 2: + # An error occurred mid-stream + if not isinstance(content, bytes): + content = content.encode(self.charset) + more_body = False + await send( + { + "type": "http.response.body", + "body": content, + "more_body": more_body, + } + ) + return + else: + content = chunk + + if isinstance(content, str): + content = content.encode(self.charset) + more_body = True + await send( + { + "type": "http.response.body", + "body": content, + "more_body": more_body, + } + ) + + except Exception as exc: + logger.exception("unhandled_streaming_error") + more_body = False + error_resp = {"error": {"message": "Internal Server Error"}} + error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset) + if not self.response_started: + await send( + { + "type": "http.response.start", + "status": 500, + "headers": self.raw_headers, + } + ) + await send( + { + "type": "http.response.body", + "body": error_event, + "more_body": more_body, + } + ) + if more_body: + await send({"type": "http.response.body", "body": b"", "more_body": False}) diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index e025a2dd..d04806e3 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -190,6 +190,7 @@ def create_letta_messages_from_llm_response( pre_computed_assistant_message_id: Optional[str] = None, pre_computed_tool_message_id: Optional[str] = None, llm_batch_item_id: Optional[str] = None, + step_id: str | None = None, ) -> List[Message]: messages = [] @@ -244,6 +245,9 @@ def create_letta_messages_from_llm_response( ) messages.append(heartbeat_system_message) + for message in messages: + message.step_id = step_id + return messages diff --git a/letta/server/server.py b/letta/server/server.py index 80eed05d..7b5bc4f6 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -94,6 +94,7 @@ from letta.services.provider_manager import ProviderManager from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.source_manager import SourceManager from letta.services.step_manager import StepManager +from letta.services.telemetry_manager import TelemetryManager from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager @@ -213,6 +214,7 @@ class SyncServer(Server): self.identity_manager = IdentityManager() self.group_manager = GroupManager() self.batch_manager = LLMBatchManager() + self.telemetry_manager = TelemetryManager() # A resusable httpx client timeout = httpx.Timeout(connect=10.0, read=20.0, write=10.0, pool=10.0) diff --git a/letta/services/helpers/noop_helper.py b/letta/services/helpers/noop_helper.py new file mode 100644 index 00000000..7f32e628 --- /dev/null +++ b/letta/services/helpers/noop_helper.py @@ -0,0 +1,10 @@ +def singleton(cls): + """Decorator to make a class a Singleton class.""" + instances = {} + + def get_instance(*args, **kwargs): + if cls not in instances: + instances[cls] = cls(*args, **kwargs) + return instances[cls] + + return get_instance diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index cf34915d..cdac474b 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -12,6 +12,7 @@ from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.step import Step as PydanticStep from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry +from letta.services.helpers.noop_helper import singleton from letta.tracing import get_trace_id from letta.utils import enforce_types @@ -63,6 +64,7 @@ class StepManager: usage: UsageStatistics, provider_id: Optional[str] = None, job_id: Optional[str] = None, + step_id: Optional[str] = None, ) -> PydanticStep: step_data = { "origin": None, @@ -81,6 +83,8 @@ class StepManager: "tid": None, "trace_id": get_trace_id(), # Get the current trace ID } + if step_id: + step_data["id"] = step_id with db_registry.session() as session: if job_id: self._verify_job_access(session, job_id, actor, access=["write"]) @@ -88,6 +92,46 @@ class StepManager: new_step.create(session) return new_step.to_pydantic() + @enforce_types + async def log_step_async( + self, + actor: PydanticUser, + agent_id: str, + provider_name: str, + model: str, + model_endpoint: Optional[str], + context_window_limit: int, + usage: UsageStatistics, + provider_id: Optional[str] = None, + job_id: Optional[str] = None, + step_id: Optional[str] = None, + ) -> PydanticStep: + step_data = { + "origin": None, + "organization_id": actor.organization_id, + "agent_id": agent_id, + "provider_id": provider_id, + "provider_name": provider_name, + "model": model, + "model_endpoint": model_endpoint, + "context_window_limit": context_window_limit, + "completion_tokens": usage.completion_tokens, + "prompt_tokens": usage.prompt_tokens, + "total_tokens": usage.total_tokens, + "job_id": job_id, + "tags": [], + "tid": None, + "trace_id": get_trace_id(), # Get the current trace ID + } + if step_id: + step_data["id"] = step_id + async with db_registry.async_session() as session: + if job_id: + self._verify_job_access(session, job_id, actor, access=["write"]) + new_step = StepModel(**step_data) + await new_step.create_async(session) + return new_step.to_pydantic() + @enforce_types def get_step(self, step_id: str, actor: PydanticUser) -> PydanticStep: with db_registry.session() as session: @@ -147,3 +191,44 @@ class StepManager: if not job: raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access") return job + + +@singleton +class NoopStepManager(StepManager): + """ + Noop implementation of StepManager. + Temporarily used for migrations, but allows for different implementations in the future. + Will not allow for writes, but will still allow for reads. + """ + + @enforce_types + def log_step( + self, + actor: PydanticUser, + agent_id: str, + provider_name: str, + model: str, + model_endpoint: Optional[str], + context_window_limit: int, + usage: UsageStatistics, + provider_id: Optional[str] = None, + job_id: Optional[str] = None, + step_id: Optional[str] = None, + ) -> PydanticStep: + return + + @enforce_types + async def log_step_async( + self, + actor: PydanticUser, + agent_id: str, + provider_name: str, + model: str, + model_endpoint: Optional[str], + context_window_limit: int, + usage: UsageStatistics, + provider_id: Optional[str] = None, + job_id: Optional[str] = None, + step_id: Optional[str] = None, + ) -> PydanticStep: + return diff --git a/letta/services/telemetry_manager.py b/letta/services/telemetry_manager.py new file mode 100644 index 00000000..e6ab218c --- /dev/null +++ b/letta/services/telemetry_manager.py @@ -0,0 +1,52 @@ +from sqlalchemy import select + +from letta.orm.provider_trace import ProviderTrace as ProviderTraceModel +from letta.schemas.provider_trace import ProviderTrace as PydanticProviderTrace +from letta.schemas.provider_trace import ProviderTraceCreate +from letta.schemas.step import Step as PydanticStep +from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry +from letta.services.helpers.noop_helper import singleton +from letta.utils import enforce_types + + +class TelemetryManager: + @enforce_types + async def get_provider_trace_by_step_id_async( + self, + step_id: str, + actor: PydanticUser, + ) -> PydanticProviderTrace: + async with db_registry.async_session() as session: + provider_trace = await ProviderTraceModel.read_async(db_session=session, step_id=step_id, actor=actor) + return provider_trace.to_pydantic() + + @enforce_types + async def create_provider_trace_async(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace: + async with db_registry.async_session() as session: + provider_trace = ProviderTraceModel(**provider_trace_create.model_dump()) + await provider_trace.create_async(session, actor=actor) + return provider_trace.to_pydantic() + + @enforce_types + def create_provider_trace(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace: + with db_registry.session() as session: + provider_trace = ProviderTraceModel(**provider_trace_create.model_dump()) + provider_trace.create(session, actor=actor) + return provider_trace.to_pydantic() + + +@singleton +class NoopTelemetryManager(TelemetryManager): + """ + Noop implementation of TelemetryManager. + """ + + async def create_provider_trace_async(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace: + return + + async def get_provider_trace_by_step_id_async(self, step_id: str, actor: PydanticUser) -> PydanticStep: + return + + def create_provider_trace(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace: + return diff --git a/letta/settings.py b/letta/settings.py index 562c5d70..a638a594 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -190,6 +190,7 @@ class Settings(BaseSettings): verbose_telemetry_logging: bool = False otel_exporter_otlp_endpoint: Optional[str] = None # otel default: "http://localhost:4317" disable_tracing: bool = False + llm_api_logging: bool = True # uvicorn settings uvicorn_workers: int = 1 diff --git a/tests/conftest.py b/tests/conftest.py index e44d2fec..cb25bb85 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -63,6 +63,7 @@ def check_composio_key_set(): yield +# --- Tool Fixtures --- @pytest.fixture def weather_tool_func(): def get_weather(location: str) -> str: @@ -110,6 +111,23 @@ def print_tool_func(): yield print_tool +@pytest.fixture +def roll_dice_tool_func(): + def roll_dice(): + """ + Rolls a 6 sided die. + + Returns: + str: The roll result. + """ + import time + + time.sleep(1) + return "Rolled a 10!" + + yield roll_dice + + @pytest.fixture def dummy_beta_message_batch() -> BetaMessageBatch: return BetaMessageBatch( diff --git a/tests/integration_test_sleeptime_agent.py b/tests/integration_test_sleeptime_agent.py index 17dc8430..18d72a79 100644 --- a/tests/integration_test_sleeptime_agent.py +++ b/tests/integration_test_sleeptime_agent.py @@ -6,7 +6,7 @@ from sqlalchemy import delete from letta.config import LettaConfig from letta.constants import DEFAULT_HUMAN from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2 -from letta.orm import Provider, Step +from letta.orm import Provider, ProviderTrace, Step from letta.orm.enums import JobType from letta.orm.errors import NoResultFound from letta.schemas.agent import CreateAgent @@ -39,6 +39,7 @@ def org_id(server): # cleanup with db_registry.session() as session: + session.execute(delete(ProviderTrace)) session.execute(delete(Step)) session.execute(delete(Provider)) session.commit() diff --git a/tests/test_multi_agent.py b/tests/test_multi_agent.py index 150922c4..f989b434 100644 --- a/tests/test_multi_agent.py +++ b/tests/test_multi_agent.py @@ -2,7 +2,7 @@ import pytest from sqlalchemy import delete from letta.config import LettaConfig -from letta.orm import Provider, Step +from letta.orm import Provider, ProviderTrace, Step from letta.schemas.agent import CreateAgent from letta.schemas.block import CreateBlock from letta.schemas.group import ( @@ -38,6 +38,7 @@ def org_id(server): # cleanup with db_registry.session() as session: + session.execute(delete(ProviderTrace)) session.execute(delete(Step)) session.execute(delete(Provider)) session.commit() diff --git a/tests/test_provider_trace.py b/tests/test_provider_trace.py new file mode 100644 index 00000000..43e13a34 --- /dev/null +++ b/tests/test_provider_trace.py @@ -0,0 +1,205 @@ +import asyncio +import json +import os +import threading +import time +import uuid + +import pytest +from dotenv import load_dotenv +from letta_client import Letta + +from letta.agents.letta_agent import LettaAgent +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.letta_message_content import TextContent +from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import MessageCreate +from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode +from letta.services.agent_manager import AgentManager +from letta.services.block_manager import BlockManager +from letta.services.message_manager import MessageManager +from letta.services.passage_manager import PassageManager +from letta.services.step_manager import StepManager +from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager + + +def _run_server(): + """Starts the Letta server in a background thread.""" + load_dotenv() + from letta.server.rest_api.app import start_server + + start_server(debug=True) + + +@pytest.fixture(scope="session") +def server_url(): + """Ensures a server is running and returns its base URL.""" + url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + + if not os.getenv("LETTA_SERVER_URL"): + thread = threading.Thread(target=_run_server, daemon=True) + thread.start() + time.sleep(5) # Allow server startup time + + return url + + +# # --- Client Setup --- # +@pytest.fixture(scope="session") +def client(server_url): + """Creates a REST client for testing.""" + client = Letta(base_url=server_url) + yield client + + +@pytest.fixture(scope="session") +def event_loop(request): + """Create an instance of the default event loop for each test case.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="function") +def roll_dice_tool(client, roll_dice_tool_func): + print_tool = client.tools.upsert_from_function(func=roll_dice_tool_func) + yield print_tool + + +@pytest.fixture(scope="function") +def weather_tool(client, weather_tool_func): + weather_tool = client.tools.upsert_from_function(func=weather_tool_func) + yield weather_tool + + +@pytest.fixture(scope="function") +def print_tool(client, print_tool_func): + print_tool = client.tools.upsert_from_function(func=print_tool_func) + yield print_tool + + +@pytest.fixture(scope="function") +def agent_state(client, roll_dice_tool, weather_tool): + """Creates an agent and ensures cleanup after tests.""" + agent_state = client.agents.create( + name=f"test_compl_{str(uuid.uuid4())[5:]}", + tool_ids=[roll_dice_tool.id, weather_tool.id], + include_base_tools=True, + memory_blocks=[ + { + "label": "human", + "value": "Name: Matt", + }, + { + "label": "persona", + "value": "Friendly agent", + }, + ], + llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ) + yield agent_state + client.agents.delete(agent_state.id) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("message", ["Get the weather in San Francisco."]) +async def test_provider_trace_experimental_step(message, agent_state, default_user): + experimental_agent = LettaAgent( + agent_id=agent_state.id, + message_manager=MessageManager(), + agent_manager=AgentManager(), + block_manager=BlockManager(), + passage_manager=PassageManager(), + step_manager=StepManager(), + telemetry_manager=TelemetryManager(), + actor=default_user, + ) + + response = await experimental_agent.step([MessageCreate(role="user", content=[TextContent(text=message)])]) + tool_step = response.messages[0].step_id + reply_step = response.messages[-1].step_id + + tool_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user) + reply_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user) + assert tool_telemetry.request_json + assert reply_telemetry.request_json + + +@pytest.mark.asyncio +@pytest.mark.parametrize("message", ["Get the weather in San Francisco."]) +async def test_provider_trace_experimental_step_stream(message, agent_state, default_user, event_loop): + experimental_agent = LettaAgent( + agent_id=agent_state.id, + message_manager=MessageManager(), + agent_manager=AgentManager(), + block_manager=BlockManager(), + passage_manager=PassageManager(), + step_manager=StepManager(), + telemetry_manager=TelemetryManager(), + actor=default_user, + ) + stream = experimental_agent.step_stream([MessageCreate(role="user", content=[TextContent(text=message)])]) + + result = StreamingResponseWithStatusCode( + stream, + media_type="text/event-stream", + ) + + message_id = None + + async def test_send(message) -> None: + nonlocal message_id + if "body" in message and not message_id: + body = message["body"].decode("utf-8").split("data:") + message_id = json.loads(body[1])["id"] + + await result.stream_response(send=test_send) + + messages = await experimental_agent.message_manager.get_messages_by_ids_async([message_id], actor=default_user) + step_ids = set((message.step_id for message in messages)) + for step_id in step_ids: + telemetry_data = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=step_id, actor=default_user) + assert telemetry_data.request_json + assert telemetry_data.response_json + + +@pytest.mark.asyncio +@pytest.mark.parametrize("message", ["Get the weather in San Francisco."]) +async def test_provider_trace_step(client, agent_state, default_user, message, event_loop): + client.agents.messages.create(agent_id=agent_state.id, messages=[]) + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[MessageCreate(role="user", content=[TextContent(text=message)])], + ) + tool_step = response.messages[0].step_id + reply_step = response.messages[-1].step_id + + tool_telemetry = await TelemetryManager().get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user) + reply_telemetry = await TelemetryManager().get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user) + assert tool_telemetry.request_json + assert reply_telemetry.request_json + + +@pytest.mark.asyncio +@pytest.mark.parametrize("message", ["Get the weather in San Francisco."]) +async def test_noop_provider_trace(message, agent_state, default_user, event_loop): + experimental_agent = LettaAgent( + agent_id=agent_state.id, + message_manager=MessageManager(), + agent_manager=AgentManager(), + block_manager=BlockManager(), + passage_manager=PassageManager(), + step_manager=StepManager(), + telemetry_manager=NoopTelemetryManager(), + actor=default_user, + ) + + response = await experimental_agent.step([MessageCreate(role="user", content=[TextContent(text=message)])]) + tool_step = response.messages[0].step_id + reply_step = response.messages[-1].step_id + + tool_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user) + reply_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user) + assert tool_telemetry is None + assert reply_telemetry is None diff --git a/tests/test_server.py b/tests/test_server.py index a3932d81..0b43a12f 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -11,7 +11,7 @@ from sqlalchemy import delete import letta.utils as utils from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR -from letta.orm import Provider, Step +from letta.orm import Provider, ProviderTrace, Step from letta.schemas.block import CreateBlock from letta.schemas.enums import MessageRole, ProviderCategory, ProviderType from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage @@ -286,6 +286,7 @@ def org_id(server): # cleanup with db_registry.session() as session: + session.execute(delete(ProviderTrace)) session.execute(delete(Step)) session.execute(delete(Provider)) session.commit()