feat: track llm provider traces and tracking steps in async agent loop (#2219)

This commit is contained in:
Andy Li
2025-05-19 15:50:56 -07:00
committed by GitHub
parent 969f0d65c8
commit a78abc610e
28 changed files with 920 additions and 82 deletions

View File

@@ -0,0 +1,50 @@
"""add support for request and response jsons from llm providers
Revision ID: cc8dc340836d
Revises: 220856bbf43b
Create Date: 2025-05-19 14:25:41.999676
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "cc8dc340836d"
down_revision: Union[str, None] = "220856bbf43b"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"provider_traces",
sa.Column("id", sa.String(), nullable=False),
sa.Column("request_json", sa.JSON(), nullable=False),
sa.Column("response_json", sa.JSON(), nullable=False),
sa.Column("step_id", sa.String(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_step_id", "provider_traces", ["step_id"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_step_id", table_name="provider_traces")
op.drop_table("provider_traces")
# ### end Alembic commands ###

View File

@@ -1,12 +1,14 @@
import json
import time
import traceback
import uuid
import warnings
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Union
from openai.types.beta.function_tool import FunctionTool as OpenAITool
from letta.agents.helpers import generate_step_id
from letta.constants import (
CLI_WARNING_PREFIX,
COMPOSIO_ENTITY_ENV_VAR_KEY,
@@ -61,9 +63,10 @@ from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.provider_manager import ProviderManager
from letta.services.step_manager import StepManager
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.tool_manager import ToolManager
from letta.settings import summarizer_settings
from letta.settings import settings, summarizer_settings
from letta.streaming_interface import StreamingRefreshCLIInterface
from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message
from letta.tracing import log_event, trace_method
@@ -141,6 +144,7 @@ class Agent(BaseAgent):
self.agent_manager = AgentManager()
self.job_manager = JobManager()
self.step_manager = StepManager()
self.telemetry_manager = TelemetryManager() if settings.llm_api_logging else NoopTelemetryManager()
# State needed for heartbeat pausing
@@ -298,6 +302,7 @@ class Agent(BaseAgent):
step_count: Optional[int] = None,
last_function_failed: bool = False,
put_inner_thoughts_first: bool = True,
step_id: Optional[str] = None,
) -> ChatCompletionResponse | None:
"""Get response from LLM API with robust retry mechanism."""
log_telemetry(self.logger, "_get_ai_reply start")
@@ -347,8 +352,9 @@ class Agent(BaseAgent):
messages=message_sequence,
llm_config=self.agent_state.llm_config,
tools=allowed_functions,
stream=stream,
force_tool_call=force_tool_call,
telemetry_manager=self.telemetry_manager,
step_id=step_id,
)
else:
# Fallback to existing flow
@@ -365,6 +371,9 @@ class Agent(BaseAgent):
stream_interface=self.interface,
put_inner_thoughts_first=put_inner_thoughts_first,
name=self.agent_state.name,
telemetry_manager=self.telemetry_manager,
step_id=step_id,
actor=self.user,
)
log_telemetry(self.logger, "_get_ai_reply create finish")
@@ -840,6 +849,9 @@ class Agent(BaseAgent):
# Extract job_id from metadata if present
job_id = metadata.get("job_id") if metadata else None
# Declare step_id for the given step to be used as the step is processing.
step_id = generate_step_id()
# Step 0: update core memory
# only pulling latest block data if shared memory is being used
current_persisted_memory = Memory(
@@ -870,6 +882,7 @@ class Agent(BaseAgent):
step_count=step_count,
last_function_failed=last_function_failed,
put_inner_thoughts_first=put_inner_thoughts_first,
step_id=step_id,
)
if not response:
# EDGE CASE: Function call failed AND there's no tools left for agent to call -> return early
@@ -953,6 +966,7 @@ class Agent(BaseAgent):
actor=self.user,
),
job_id=job_id,
step_id=step_id,
)
for message in all_new_messages:
message.step_id = step.id

View File

@@ -1,3 +1,4 @@
import uuid
import xml.etree.ElementTree as ET
from typing import List, Tuple
@@ -150,3 +151,7 @@ def deserialize_message_history(xml_str: str) -> Tuple[List[str], str]:
context = sum_el.text or ""
return messages, context
def generate_step_id():
return f"step-{uuid.uuid4()}"

View File

@@ -8,7 +8,7 @@ from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from letta.agents.base_agent import BaseAgent
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_async
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_async, generate_step_id
from letta.helpers import ToolRulesSolver
from letta.helpers.tool_execution_helper import enable_strict_mode
from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface
@@ -24,7 +24,8 @@ from letta.schemas.letta_message import AssistantMessage
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_response import ToolCall
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
from letta.schemas.provider_trace import ProviderTraceCreate
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.utils import create_letta_messages_from_llm_response
@@ -32,6 +33,8 @@ from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.step_manager import NoopStepManager, StepManager
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
from letta.settings import settings
from letta.system import package_function_response
@@ -50,6 +53,8 @@ class LettaAgent(BaseAgent):
block_manager: BlockManager,
passage_manager: PassageManager,
actor: User,
step_manager: StepManager = NoopStepManager(),
telemetry_manager: TelemetryManager = NoopTelemetryManager(),
):
super().__init__(agent_id=agent_id, openai_client=None, message_manager=message_manager, agent_manager=agent_manager, actor=actor)
@@ -57,6 +62,8 @@ class LettaAgent(BaseAgent):
# Summarizer settings
self.block_manager = block_manager
self.passage_manager = passage_manager
self.step_manager = step_manager
self.telemetry_manager = telemetry_manager
self.response_messages: List[Message] = []
self.last_function_response = None
@@ -68,9 +75,7 @@ class LettaAgent(BaseAgent):
@trace_method
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True) -> LettaResponse:
agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor)
current_in_context_messages, new_in_context_messages, usage = await self._step(
agent_state=agent_state, input_messages=input_messages, max_steps=max_steps
)
_, new_in_context_messages, usage = await self._step(agent_state=agent_state, input_messages=input_messages, max_steps=max_steps)
return _create_letta_response(
new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage
)
@@ -89,14 +94,43 @@ class LettaAgent(BaseAgent):
)
usage = LettaUsageStatistics()
for _ in range(max_steps):
response = await self._get_ai_reply(
step_id = generate_step_id()
in_context_messages = current_in_context_messages + new_in_context_messages
if settings.experimental_enable_async_db_engine:
in_context_messages = await self._rebuild_memory_async(
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
)
else:
if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex":
logger.info("Skipping memory rebuild")
else:
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
log_event("agent.stream_no_tokens.messages.refreshed") # [1^]
request_data = await self._create_llm_request_data_async(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
in_context_messages=in_context_messages,
agent_state=agent_state,
tool_rules_solver=tool_rules_solver,
stream=False,
# TODO: also pass in reasoning content
# TODO: pass in reasoning content
)
log_event("agent.stream_no_tokens.llm_request.created") # [2^]
try:
response_data = await llm_client.request_async(request_data, agent_state.llm_config)
except Exception as e:
raise llm_client.handle_llm_error(e)
log_event("agent.stream_no_tokens.llm_response.received") # [3^]
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
# update usage
# TODO: add run_id
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
if not response.choices[0].message.tool_calls:
# TODO: make into a real error
@@ -109,6 +143,18 @@ class LettaAgent(BaseAgent):
)
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
log_event("agent.stream_no_tokens.llm_response.processed") # [4^]
# Log LLM Trace
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
# stream step
# TODO: improve TTFT
@@ -119,13 +165,6 @@ class LettaAgent(BaseAgent):
for message in letta_messages:
yield f"data: {message.model_dump_json()}\n\n"
# update usage
# TODO: add run_id
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
if not should_continue:
break
@@ -140,6 +179,13 @@ class LettaAgent(BaseAgent):
async def _step(
self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10
) -> Tuple[List[Message], List[Message], CompletionUsage]:
"""
Carries out an invocation of the agent loop. In each step, the agent
1. Rebuilds its memory
2. Generates a request for the LLM
3. Fetches a response from the LLM
4. Processes the response
"""
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
input_messages, agent_state, self.message_manager, self.actor
)
@@ -151,14 +197,42 @@ class LettaAgent(BaseAgent):
)
usage = LettaUsageStatistics()
for _ in range(max_steps):
response = await self._get_ai_reply(
step_id = generate_step_id()
in_context_messages = current_in_context_messages + new_in_context_messages
if settings.experimental_enable_async_db_engine:
in_context_messages = await self._rebuild_memory_async(
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
)
else:
if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex":
logger.info("Skipping memory rebuild")
else:
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
log_event("agent.step.messages.refreshed") # [1^]
request_data = await self._create_llm_request_data_async(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
in_context_messages=in_context_messages,
agent_state=agent_state,
tool_rules_solver=tool_rules_solver,
stream=False,
# TODO: also pass in reasoning content
# TODO: pass in reasoning content
)
log_event("agent.step.llm_request.created") # [2^]
try:
response_data = await llm_client.request_async(request_data, agent_state.llm_config)
except Exception as e:
raise llm_client.handle_llm_error(e)
log_event("agent.step.llm_response.received") # [3^]
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
# TODO: add run_id
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
if not response.choices[0].message.tool_calls:
# TODO: make into a real error
@@ -167,17 +241,22 @@ class LettaAgent(BaseAgent):
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
persisted_messages, should_continue = await self._handle_ai_response(
tool_call, agent_state, tool_rules_solver, reasoning_content=reasoning
tool_call, agent_state, tool_rules_solver, reasoning_content=reasoning, step_id=step_id, usage=usage
)
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
log_event("agent.step.llm_response.processed") # [4^]
# update usage
# TODO: add run_id
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
# Log LLM Trace
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
if not should_continue:
break
@@ -194,8 +273,12 @@ class LettaAgent(BaseAgent):
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True
) -> AsyncGenerator[str, None]:
"""
Main streaming loop that yields partial tokens.
Whenever we detect a tool call, we yield from _handle_ai_response as well.
Carries out an invocation of the agent loop in a streaming fashion that yields partial tokens.
Whenever we detect a tool call, we yield from _handle_ai_response as well. At each step, the agent
1. Rebuilds its memory
2. Generates a request for the LLM
3. Fetches a response from the LLM
4. Processes the response
"""
agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor)
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
@@ -210,13 +293,34 @@ class LettaAgent(BaseAgent):
usage = LettaUsageStatistics()
for _ in range(max_steps):
stream = await self._get_ai_reply(
step_id = generate_step_id()
in_context_messages = current_in_context_messages + new_in_context_messages
if settings.experimental_enable_async_db_engine:
in_context_messages = await self._rebuild_memory_async(
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
)
else:
if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex":
logger.info("Skipping memory rebuild")
else:
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
log_event("agent.step.messages.refreshed") # [1^]
request_data = await self._create_llm_request_data_async(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
in_context_messages=in_context_messages,
agent_state=agent_state,
tool_rules_solver=tool_rules_solver,
stream=True,
)
log_event("agent.stream.llm_request.created") # [2^]
try:
stream = await llm_client.stream_async(request_data, agent_state.llm_config)
except Exception as e:
raise llm_client.handle_llm_error(e)
log_event("agent.stream.llm_response.received") # [3^]
# TODO: THIS IS INCREDIBLY UGLY
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
if agent_state.llm_config.model_endpoint_type == "anthropic":
@@ -251,10 +355,39 @@ class LettaAgent(BaseAgent):
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=interface.letta_assistant_message_id,
pre_computed_tool_message_id=interface.letta_tool_message_id,
step_id=step_id,
usage=usage,
)
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
# TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream
# log_event("agent.stream.llm_response.processed") # [4^]
# Log LLM Trace
# TODO (cliandy): we are piecing together the streamed response here. Content here does not match the actual response schema.
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json={
"content": {
"tool_call": tool_call.model_dump_json(),
"reasoning": [content.model_dump_json() for content in reasoning_content],
},
"id": interface.message_id,
"model": interface.model,
"role": "assistant",
# "stop_reason": "",
# "stop_sequence": None,
"type": "message",
"usage": {"input_tokens": interface.input_tokens, "output_tokens": interface.output_tokens},
},
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
if not use_assistant_message or should_continue:
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
yield f"data: {tool_return.model_dump_json()}\n\n"
@@ -277,14 +410,12 @@ class LettaAgent(BaseAgent):
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
@trace_method
# When raising an error this doesn't show up
async def _get_ai_reply(
async def _create_llm_request_data_async(
self,
llm_client: LLMClientBase,
in_context_messages: List[Message],
agent_state: AgentState,
tool_rules_solver: ToolRulesSolver,
stream: bool,
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
if settings.experimental_enable_async_db_engine:
self.num_messages = self.num_messages or (await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id))
@@ -332,15 +463,7 @@ class LettaAgent(BaseAgent):
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
response = await llm_client.send_llm_request_async(
messages=in_context_messages,
llm_config=agent_state.llm_config,
tools=allowed_tools,
force_tool_call=force_tool_call,
stream=stream,
)
return response
return llm_client.build_request_data(in_context_messages, agent_state.llm_config, allowed_tools, force_tool_call)
@trace_method
async def _handle_ai_response(
@@ -351,6 +474,8 @@ class LettaAgent(BaseAgent):
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
pre_computed_assistant_message_id: Optional[str] = None,
pre_computed_tool_message_id: Optional[str] = None,
step_id: str | None = None,
usage: LettaUsageStatistics = None,
) -> Tuple[List[Message], bool]:
"""
Now that streaming is done, handle the final AI response.
@@ -397,7 +522,28 @@ class LettaAgent(BaseAgent):
elif tool_rules_solver.is_continue_tool(tool_name=tool_call_name):
continue_stepping = True
# 5. Persist to DB
# 5a. Persist Steps to DB
# Following agent loop to persist this before messages
# TODO (cliandy): determine what should match old loop w/provider_id, job_id
# TODO (cliandy): UsageStatistics and LettaUsageStatistics are used in many places, but are not the same.
logged_step = await self.step_manager.log_step_async(
actor=self.actor,
agent_id=agent_state.id,
provider_name=agent_state.llm_config.model_endpoint_type,
model=agent_state.llm_config.model,
model_endpoint=agent_state.llm_config.model_endpoint,
context_window_limit=agent_state.llm_config.context_window,
usage=UsageStatistics(
total_tokens=usage.total_tokens,
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
),
provider_id=None,
job_id=None,
step_id=step_id,
)
# 5b. Persist Messages to DB
tool_call_messages = create_letta_messages_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
@@ -411,7 +557,9 @@ class LettaAgent(BaseAgent):
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
pre_computed_tool_message_id=pre_computed_tool_message_id,
step_id=logged_step.id if logged_step else None, # TODO (cliandy): eventually move over other agent loops
)
persisted_messages = await self.message_manager.create_many_messages_async(tool_call_messages, actor=self.actor)
self.last_function_response = function_response

View File

@@ -74,6 +74,7 @@ class AnthropicStreamingInterface:
# usage trackers
self.input_tokens = 0
self.output_tokens = 0
self.model = None
# reasoning object trackers
self.reasoning_messages = []
@@ -311,6 +312,7 @@ class AnthropicStreamingInterface:
self.message_id = event.message.id
self.input_tokens += event.message.usage.input_tokens
self.output_tokens += event.message.usage.output_tokens
self.model = event.message.model
elif isinstance(event, BetaRawMessageDeltaEvent):
self.output_tokens += event.usage.output_tokens
elif isinstance(event, BetaRawMessageStopEvent):

View File

@@ -40,6 +40,9 @@ class OpenAIStreamingInterface:
self.letta_assistant_message_id = Message.generate_id()
self.letta_tool_message_id = Message.generate_id()
self.message_id = None
self.model = None
# token counters
self.input_tokens = 0
self.output_tokens = 0
@@ -69,6 +72,10 @@ class OpenAIStreamingInterface:
prev_message_type = None
message_index = 0
async for chunk in stream:
if not self.model or not self.message_id:
self.model = chunk.model
self.message_id = chunk.id
# track usage
if chunk.usage:
self.input_tokens += len(chunk.usage.prompt_tokens)

View File

@@ -20,15 +20,19 @@ from letta.llm_api.openai import (
build_openai_chat_completions_request,
openai_chat_completions_process_stream,
openai_chat_completions_request,
prepare_openai_payload,
)
from letta.local_llm.chat_completion_proxy import get_chat_completion
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.orm.user import User
from letta.schemas.enums import ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.provider_trace import ProviderTraceCreate
from letta.services.telemetry_manager import TelemetryManager
from letta.settings import ModelSettings
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
from letta.tracing import log_event, trace_method
@@ -142,6 +146,9 @@ def create(
model_settings: Optional[dict] = None, # TODO: eventually pass from server
put_inner_thoughts_first: bool = True,
name: Optional[str] = None,
telemetry_manager: Optional[TelemetryManager] = None,
step_id: Optional[str] = None,
actor: Optional[User] = None,
) -> ChatCompletionResponse:
"""Return response to chat completion with backoff"""
from letta.utils import printd
@@ -233,6 +240,16 @@ def create(
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.stream_end()
telemetry_manager.create_provider_trace(
actor=actor,
provider_trace_create=ProviderTraceCreate(
request_json=prepare_openai_payload(data),
response_json=response.model_json_schema(),
step_id=step_id,
organization_id=actor.organization_id,
),
)
if llm_config.put_inner_thoughts_in_kwargs:
response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG)
@@ -407,6 +424,16 @@ def create(
if llm_config.put_inner_thoughts_in_kwargs:
response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG)
telemetry_manager.create_provider_trace(
actor=actor,
provider_trace_create=ProviderTraceCreate(
request_json=chat_completion_request.model_json_schema(),
response_json=response.model_json_schema(),
step_id=step_id,
organization_id=actor.organization_id,
),
)
return response
# elif llm_config.model_endpoint_type == "cohere":

View File

@@ -9,7 +9,9 @@ from letta.errors import LLMError
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.tracing import log_event
from letta.schemas.provider_trace import ProviderTraceCreate
from letta.services.telemetry_manager import TelemetryManager
from letta.tracing import log_event, trace_method
if TYPE_CHECKING:
from letta.orm import User
@@ -31,13 +33,15 @@ class LLMClientBase:
self.put_inner_thoughts_first = put_inner_thoughts_first
self.use_tool_naming = use_tool_naming
@trace_method
def send_llm_request(
self,
messages: List[Message],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None, # TODO: change to Tool object
stream: bool = False,
force_tool_call: Optional[str] = None,
telemetry_manager: Optional["TelemetryManager"] = None,
step_id: Optional[str] = None,
) -> Union[ChatCompletionResponse, Stream[ChatCompletionChunk]]:
"""
Issues a request to the downstream model endpoint and parses response.
@@ -48,37 +52,51 @@ class LLMClientBase:
try:
log_event(name="llm_request_sent", attributes=request_data)
if stream:
return self.stream(request_data, llm_config)
else:
response_data = self.request(request_data, llm_config)
response_data = self.request(request_data, llm_config)
if step_id and telemetry_manager:
telemetry_manager.create_provider_trace(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
log_event(name="llm_response_received", attributes=response_data)
except Exception as e:
raise self.handle_llm_error(e)
return self.convert_response_to_chat_completion(response_data, messages, llm_config)
@trace_method
async def send_llm_request_async(
self,
request_data: dict,
messages: List[Message],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None, # TODO: change to Tool object
stream: bool = False,
force_tool_call: Optional[str] = None,
telemetry_manager: "TelemetryManager | None" = None,
step_id: str | None = None,
) -> Union[ChatCompletionResponse, AsyncStream[ChatCompletionChunk]]:
"""
Issues a request to the downstream model endpoint.
If stream=True, returns an AsyncStream[ChatCompletionChunk] that can be async iterated over.
Otherwise returns a ChatCompletionResponse.
"""
request_data = self.build_request_data(messages, llm_config, tools, force_tool_call)
try:
log_event(name="llm_request_sent", attributes=request_data)
if stream:
return await self.stream_async(request_data, llm_config)
else:
response_data = await self.request_async(request_data, llm_config)
response_data = await self.request_async(request_data, llm_config)
await telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
log_event(name="llm_response_received", attributes=response_data)
except Exception as e:
raise self.handle_llm_error(e)
@@ -133,13 +151,6 @@ class LLMClientBase:
"""
raise NotImplementedError
@abstractmethod
def stream(self, request_data: dict, llm_config: LLMConfig) -> Stream[ChatCompletionChunk]:
"""
Performs underlying streaming request to llm and returns raw response.
"""
raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}")
@abstractmethod
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
"""

View File

@@ -256,14 +256,6 @@ class OpenAIClient(LLMClientBase):
return chat_completion_response
def stream(self, request_data: dict, llm_config: LLMConfig) -> Stream[ChatCompletionChunk]:
"""
Performs underlying streaming request to OpenAI and returns the stream iterator.
"""
client = OpenAI(**self._prepare_client_kwargs(llm_config))
response_stream: Stream[ChatCompletionChunk] = client.chat.completions.create(**request_data, stream=True)
return response_stream
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
"""
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.

View File

@@ -93,7 +93,6 @@ def summarize_messages(
response = llm_client.send_llm_request(
messages=message_sequence,
llm_config=llm_config_no_inner_thoughts,
stream=False,
)
else:
response = create(

View File

@@ -19,6 +19,7 @@ from letta.orm.message import Message
from letta.orm.organization import Organization
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
from letta.orm.provider import Provider
from letta.orm.provider_trace import ProviderTrace
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
from letta.orm.source import Source
from letta.orm.sources_agents import SourcesAgents

View File

@@ -0,0 +1,26 @@
import uuid
from sqlalchemy import JSON, ForeignKeyConstraint, Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.provider_trace import ProviderTrace as PydanticProviderTrace
class ProviderTrace(SqlalchemyBase, OrganizationMixin):
"""Defines data model for storing provider trace information"""
__tablename__ = "provider_traces"
__pydantic_model__ = PydanticProviderTrace
__table_args__ = (Index("ix_step_id", "step_id"),)
id: Mapped[str] = mapped_column(
primary_key=True, doc="Unique provider trace identifier", default=lambda: f"provider_trace-{uuid.uuid4()}"
)
request_json: Mapped[dict] = mapped_column(JSON, doc="JSON content of the provider request")
response_json: Mapped[dict] = mapped_column(JSON, doc="JSON content of the provider response")
step_id: Mapped[str] = mapped_column(String, nullable=True, doc="ID of the step that this trace is associated with")
# Relationships
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin")

View File

@@ -0,0 +1,43 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from letta.helpers.datetime_helpers import get_utc_time
from letta.schemas.letta_base import OrmMetadataBase
class BaseProviderTrace(OrmMetadataBase):
__id_prefix__ = "provider_trace"
class ProviderTraceCreate(BaseModel):
"""Request to create a provider trace"""
request_json: dict[str, Any] = Field(..., description="JSON content of the provider request")
response_json: dict[str, Any] = Field(..., description="JSON content of the provider response")
step_id: str = Field(None, description="ID of the step that this trace is associated with")
organization_id: str = Field(..., description="The unique identifier of the organization.")
class ProviderTrace(BaseProviderTrace):
"""
Letta's internal representation of a provider trace.
Attributes:
id (str): The unique identifier of the provider trace.
request_json (Dict[str, Any]): JSON content of the provider request.
response_json (Dict[str, Any]): JSON content of the provider response.
step_id (str): ID of the step that this trace is associated with.
organization_id (str): The unique identifier of the organization.
created_at (datetime): The timestamp when the object was created.
"""
id: str = BaseProviderTrace.generate_id_field()
request_json: Dict[str, Any] = Field(..., description="JSON content of the provider request")
response_json: Dict[str, Any] = Field(..., description="JSON content of the provider response")
step_id: Optional[str] = Field(None, description="ID of the step that this trace is associated with")
organization_id: str = Field(..., description="The unique identifier of the organization.")
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")

View File

@@ -13,6 +13,7 @@ from letta.server.rest_api.routers.v1.sandbox_configs import router as sandbox_c
from letta.server.rest_api.routers.v1.sources import router as sources_router
from letta.server.rest_api.routers.v1.steps import router as steps_router
from letta.server.rest_api.routers.v1.tags import router as tags_router
from letta.server.rest_api.routers.v1.telemetry import router as telemetry_router
from letta.server.rest_api.routers.v1.tools import router as tools_router
from letta.server.rest_api.routers.v1.voice import router as voice_router
@@ -31,6 +32,7 @@ ROUTERS = [
runs_router,
steps_router,
tags_router,
telemetry_router,
messages_router,
voice_router,
embeddings_router,

View File

@@ -9,6 +9,7 @@ from marshmallow import ValidationError
from orjson import orjson
from pydantic import Field
from sqlalchemy.exc import IntegrityError, OperationalError
from starlette.background import BackgroundTask
from starlette.responses import Response, StreamingResponse
from letta.agents.letta_agent import LettaAgent
@@ -33,6 +34,7 @@ from letta.schemas.user import User
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
from letta.services.telemetry_manager import NoopTelemetryManager
from letta.settings import settings
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
@@ -646,6 +648,8 @@ async def send_message(
block_manager=server.block_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
)
result = await experimental_agent.step(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message)
@@ -707,14 +711,18 @@ async def send_message_streaming(
block_manager=server.block_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
)
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
if request.stream_tokens and model_compatible_token_streaming:
result = StreamingResponse(
result = StreamingResponseWithStatusCode(
experimental_agent.step_stream(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message),
media_type="text/event-stream",
)
else:
result = StreamingResponse(
result = StreamingResponseWithStatusCode(
experimental_agent.step_stream_no_tokens(
request.messages, max_steps=10, use_assistant_message=request.use_assistant_message
),

View File

@@ -0,0 +1,18 @@
from fastapi import APIRouter, Depends, Header
from letta.schemas.provider_trace import ProviderTrace
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
router = APIRouter(prefix="/telemetry", tags=["telemetry"])
@router.get("/{step_id}", response_model=ProviderTrace, operation_id="retrieve_provider_trace")
async def retrieve_provider_trace_by_step_id(
step_id: str,
server: SyncServer = Depends(get_letta_server),
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
return await server.telemetry_manager.get_provider_trace_by_step_id_async(
step_id=step_id, actor=await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
)

View File

@@ -0,0 +1,105 @@
# Alternative implementation of StreamingResponse that allows for effectively
# stremaing HTTP trailers, as we cannot set codes after the initial response.
# Taken from: https://github.com/fastapi/fastapi/discussions/10138#discussioncomment-10377361
import json
from collections.abc import AsyncIterator
from fastapi.responses import StreamingResponse
from starlette.types import Send
from letta.log import get_logger
logger = get_logger(__name__)
class StreamingResponseWithStatusCode(StreamingResponse):
"""
Variation of StreamingResponse that can dynamically decide the HTTP status code,
based on the return value of the content iterator (parameter `content`).
Expects the content to yield either just str content as per the original `StreamingResponse`
or else tuples of (`content`: `str`, `status_code`: `int`).
"""
body_iterator: AsyncIterator[str | bytes]
response_started: bool = False
async def stream_response(self, send: Send) -> None:
more_body = True
try:
first_chunk = await self.body_iterator.__anext__()
if isinstance(first_chunk, tuple):
first_chunk_content, self.status_code = first_chunk
else:
first_chunk_content = first_chunk
if isinstance(first_chunk_content, str):
first_chunk_content = first_chunk_content.encode(self.charset)
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
self.response_started = True
await send(
{
"type": "http.response.body",
"body": first_chunk_content,
"more_body": more_body,
}
)
async for chunk in self.body_iterator:
if isinstance(chunk, tuple):
content, status_code = chunk
if status_code // 100 != 2:
# An error occurred mid-stream
if not isinstance(content, bytes):
content = content.encode(self.charset)
more_body = False
await send(
{
"type": "http.response.body",
"body": content,
"more_body": more_body,
}
)
return
else:
content = chunk
if isinstance(content, str):
content = content.encode(self.charset)
more_body = True
await send(
{
"type": "http.response.body",
"body": content,
"more_body": more_body,
}
)
except Exception as exc:
logger.exception("unhandled_streaming_error")
more_body = False
error_resp = {"error": {"message": "Internal Server Error"}}
error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset)
if not self.response_started:
await send(
{
"type": "http.response.start",
"status": 500,
"headers": self.raw_headers,
}
)
await send(
{
"type": "http.response.body",
"body": error_event,
"more_body": more_body,
}
)
if more_body:
await send({"type": "http.response.body", "body": b"", "more_body": False})

View File

@@ -190,6 +190,7 @@ def create_letta_messages_from_llm_response(
pre_computed_assistant_message_id: Optional[str] = None,
pre_computed_tool_message_id: Optional[str] = None,
llm_batch_item_id: Optional[str] = None,
step_id: str | None = None,
) -> List[Message]:
messages = []
@@ -244,6 +245,9 @@ def create_letta_messages_from_llm_response(
)
messages.append(heartbeat_system_message)
for message in messages:
message.step_id = step_id
return messages

View File

@@ -94,6 +94,7 @@ from letta.services.provider_manager import ProviderManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.source_manager import SourceManager
from letta.services.step_manager import StepManager
from letta.services.telemetry_manager import TelemetryManager
from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
@@ -213,6 +214,7 @@ class SyncServer(Server):
self.identity_manager = IdentityManager()
self.group_manager = GroupManager()
self.batch_manager = LLMBatchManager()
self.telemetry_manager = TelemetryManager()
# A resusable httpx client
timeout = httpx.Timeout(connect=10.0, read=20.0, write=10.0, pool=10.0)

View File

@@ -0,0 +1,10 @@
def singleton(cls):
"""Decorator to make a class a Singleton class."""
instances = {}
def get_instance(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance

View File

@@ -12,6 +12,7 @@ from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.step import Step as PydanticStep
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.services.helpers.noop_helper import singleton
from letta.tracing import get_trace_id
from letta.utils import enforce_types
@@ -63,6 +64,7 @@ class StepManager:
usage: UsageStatistics,
provider_id: Optional[str] = None,
job_id: Optional[str] = None,
step_id: Optional[str] = None,
) -> PydanticStep:
step_data = {
"origin": None,
@@ -81,6 +83,8 @@ class StepManager:
"tid": None,
"trace_id": get_trace_id(), # Get the current trace ID
}
if step_id:
step_data["id"] = step_id
with db_registry.session() as session:
if job_id:
self._verify_job_access(session, job_id, actor, access=["write"])
@@ -88,6 +92,46 @@ class StepManager:
new_step.create(session)
return new_step.to_pydantic()
@enforce_types
async def log_step_async(
self,
actor: PydanticUser,
agent_id: str,
provider_name: str,
model: str,
model_endpoint: Optional[str],
context_window_limit: int,
usage: UsageStatistics,
provider_id: Optional[str] = None,
job_id: Optional[str] = None,
step_id: Optional[str] = None,
) -> PydanticStep:
step_data = {
"origin": None,
"organization_id": actor.organization_id,
"agent_id": agent_id,
"provider_id": provider_id,
"provider_name": provider_name,
"model": model,
"model_endpoint": model_endpoint,
"context_window_limit": context_window_limit,
"completion_tokens": usage.completion_tokens,
"prompt_tokens": usage.prompt_tokens,
"total_tokens": usage.total_tokens,
"job_id": job_id,
"tags": [],
"tid": None,
"trace_id": get_trace_id(), # Get the current trace ID
}
if step_id:
step_data["id"] = step_id
async with db_registry.async_session() as session:
if job_id:
self._verify_job_access(session, job_id, actor, access=["write"])
new_step = StepModel(**step_data)
await new_step.create_async(session)
return new_step.to_pydantic()
@enforce_types
def get_step(self, step_id: str, actor: PydanticUser) -> PydanticStep:
with db_registry.session() as session:
@@ -147,3 +191,44 @@ class StepManager:
if not job:
raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access")
return job
@singleton
class NoopStepManager(StepManager):
"""
Noop implementation of StepManager.
Temporarily used for migrations, but allows for different implementations in the future.
Will not allow for writes, but will still allow for reads.
"""
@enforce_types
def log_step(
self,
actor: PydanticUser,
agent_id: str,
provider_name: str,
model: str,
model_endpoint: Optional[str],
context_window_limit: int,
usage: UsageStatistics,
provider_id: Optional[str] = None,
job_id: Optional[str] = None,
step_id: Optional[str] = None,
) -> PydanticStep:
return
@enforce_types
async def log_step_async(
self,
actor: PydanticUser,
agent_id: str,
provider_name: str,
model: str,
model_endpoint: Optional[str],
context_window_limit: int,
usage: UsageStatistics,
provider_id: Optional[str] = None,
job_id: Optional[str] = None,
step_id: Optional[str] = None,
) -> PydanticStep:
return

View File

@@ -0,0 +1,52 @@
from sqlalchemy import select
from letta.orm.provider_trace import ProviderTrace as ProviderTraceModel
from letta.schemas.provider_trace import ProviderTrace as PydanticProviderTrace
from letta.schemas.provider_trace import ProviderTraceCreate
from letta.schemas.step import Step as PydanticStep
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.services.helpers.noop_helper import singleton
from letta.utils import enforce_types
class TelemetryManager:
@enforce_types
async def get_provider_trace_by_step_id_async(
self,
step_id: str,
actor: PydanticUser,
) -> PydanticProviderTrace:
async with db_registry.async_session() as session:
provider_trace = await ProviderTraceModel.read_async(db_session=session, step_id=step_id, actor=actor)
return provider_trace.to_pydantic()
@enforce_types
async def create_provider_trace_async(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
async with db_registry.async_session() as session:
provider_trace = ProviderTraceModel(**provider_trace_create.model_dump())
await provider_trace.create_async(session, actor=actor)
return provider_trace.to_pydantic()
@enforce_types
def create_provider_trace(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
with db_registry.session() as session:
provider_trace = ProviderTraceModel(**provider_trace_create.model_dump())
provider_trace.create(session, actor=actor)
return provider_trace.to_pydantic()
@singleton
class NoopTelemetryManager(TelemetryManager):
"""
Noop implementation of TelemetryManager.
"""
async def create_provider_trace_async(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
return
async def get_provider_trace_by_step_id_async(self, step_id: str, actor: PydanticUser) -> PydanticStep:
return
def create_provider_trace(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
return

View File

@@ -190,6 +190,7 @@ class Settings(BaseSettings):
verbose_telemetry_logging: bool = False
otel_exporter_otlp_endpoint: Optional[str] = None # otel default: "http://localhost:4317"
disable_tracing: bool = False
llm_api_logging: bool = True
# uvicorn settings
uvicorn_workers: int = 1

View File

@@ -63,6 +63,7 @@ def check_composio_key_set():
yield
# --- Tool Fixtures ---
@pytest.fixture
def weather_tool_func():
def get_weather(location: str) -> str:
@@ -110,6 +111,23 @@ def print_tool_func():
yield print_tool
@pytest.fixture
def roll_dice_tool_func():
def roll_dice():
"""
Rolls a 6 sided die.
Returns:
str: The roll result.
"""
import time
time.sleep(1)
return "Rolled a 10!"
yield roll_dice
@pytest.fixture
def dummy_beta_message_batch() -> BetaMessageBatch:
return BetaMessageBatch(

View File

@@ -6,7 +6,7 @@ from sqlalchemy import delete
from letta.config import LettaConfig
from letta.constants import DEFAULT_HUMAN
from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
from letta.orm import Provider, Step
from letta.orm import Provider, ProviderTrace, Step
from letta.orm.enums import JobType
from letta.orm.errors import NoResultFound
from letta.schemas.agent import CreateAgent
@@ -39,6 +39,7 @@ def org_id(server):
# cleanup
with db_registry.session() as session:
session.execute(delete(ProviderTrace))
session.execute(delete(Step))
session.execute(delete(Provider))
session.commit()

View File

@@ -2,7 +2,7 @@ import pytest
from sqlalchemy import delete
from letta.config import LettaConfig
from letta.orm import Provider, Step
from letta.orm import Provider, ProviderTrace, Step
from letta.schemas.agent import CreateAgent
from letta.schemas.block import CreateBlock
from letta.schemas.group import (
@@ -38,6 +38,7 @@ def org_id(server):
# cleanup
with db_registry.session() as session:
session.execute(delete(ProviderTrace))
session.execute(delete(Step))
session.execute(delete(Provider))
session.commit()

View File

@@ -0,0 +1,205 @@
import asyncio
import json
import os
import threading
import time
import uuid
import pytest
from dotenv import load_dotenv
from letta_client import Letta
from letta.agents.letta_agent import LettaAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.step_manager import StepManager
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
def _run_server():
"""Starts the Letta server in a background thread."""
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
@pytest.fixture(scope="session")
def server_url():
"""Ensures a server is running and returns its base URL."""
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
time.sleep(5) # Allow server startup time
return url
# # --- Client Setup --- #
@pytest.fixture(scope="session")
def client(server_url):
"""Creates a REST client for testing."""
client = Letta(base_url=server_url)
yield client
@pytest.fixture(scope="session")
def event_loop(request):
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="function")
def roll_dice_tool(client, roll_dice_tool_func):
print_tool = client.tools.upsert_from_function(func=roll_dice_tool_func)
yield print_tool
@pytest.fixture(scope="function")
def weather_tool(client, weather_tool_func):
weather_tool = client.tools.upsert_from_function(func=weather_tool_func)
yield weather_tool
@pytest.fixture(scope="function")
def print_tool(client, print_tool_func):
print_tool = client.tools.upsert_from_function(func=print_tool_func)
yield print_tool
@pytest.fixture(scope="function")
def agent_state(client, roll_dice_tool, weather_tool):
"""Creates an agent and ensures cleanup after tests."""
agent_state = client.agents.create(
name=f"test_compl_{str(uuid.uuid4())[5:]}",
tool_ids=[roll_dice_tool.id, weather_tool.id],
include_base_tools=True,
memory_blocks=[
{
"label": "human",
"value": "Name: Matt",
},
{
"label": "persona",
"value": "Friendly agent",
},
],
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
)
yield agent_state
client.agents.delete(agent_state.id)
@pytest.mark.asyncio
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
async def test_provider_trace_experimental_step(message, agent_state, default_user):
experimental_agent = LettaAgent(
agent_id=agent_state.id,
message_manager=MessageManager(),
agent_manager=AgentManager(),
block_manager=BlockManager(),
passage_manager=PassageManager(),
step_manager=StepManager(),
telemetry_manager=TelemetryManager(),
actor=default_user,
)
response = await experimental_agent.step([MessageCreate(role="user", content=[TextContent(text=message)])])
tool_step = response.messages[0].step_id
reply_step = response.messages[-1].step_id
tool_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user)
reply_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user)
assert tool_telemetry.request_json
assert reply_telemetry.request_json
@pytest.mark.asyncio
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
async def test_provider_trace_experimental_step_stream(message, agent_state, default_user, event_loop):
experimental_agent = LettaAgent(
agent_id=agent_state.id,
message_manager=MessageManager(),
agent_manager=AgentManager(),
block_manager=BlockManager(),
passage_manager=PassageManager(),
step_manager=StepManager(),
telemetry_manager=TelemetryManager(),
actor=default_user,
)
stream = experimental_agent.step_stream([MessageCreate(role="user", content=[TextContent(text=message)])])
result = StreamingResponseWithStatusCode(
stream,
media_type="text/event-stream",
)
message_id = None
async def test_send(message) -> None:
nonlocal message_id
if "body" in message and not message_id:
body = message["body"].decode("utf-8").split("data:")
message_id = json.loads(body[1])["id"]
await result.stream_response(send=test_send)
messages = await experimental_agent.message_manager.get_messages_by_ids_async([message_id], actor=default_user)
step_ids = set((message.step_id for message in messages))
for step_id in step_ids:
telemetry_data = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=step_id, actor=default_user)
assert telemetry_data.request_json
assert telemetry_data.response_json
@pytest.mark.asyncio
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
async def test_provider_trace_step(client, agent_state, default_user, message, event_loop):
client.agents.messages.create(agent_id=agent_state.id, messages=[])
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=[MessageCreate(role="user", content=[TextContent(text=message)])],
)
tool_step = response.messages[0].step_id
reply_step = response.messages[-1].step_id
tool_telemetry = await TelemetryManager().get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user)
reply_telemetry = await TelemetryManager().get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user)
assert tool_telemetry.request_json
assert reply_telemetry.request_json
@pytest.mark.asyncio
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
async def test_noop_provider_trace(message, agent_state, default_user, event_loop):
experimental_agent = LettaAgent(
agent_id=agent_state.id,
message_manager=MessageManager(),
agent_manager=AgentManager(),
block_manager=BlockManager(),
passage_manager=PassageManager(),
step_manager=StepManager(),
telemetry_manager=NoopTelemetryManager(),
actor=default_user,
)
response = await experimental_agent.step([MessageCreate(role="user", content=[TextContent(text=message)])])
tool_step = response.messages[0].step_id
reply_step = response.messages[-1].step_id
tool_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user)
reply_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user)
assert tool_telemetry is None
assert reply_telemetry is None

View File

@@ -11,7 +11,7 @@ from sqlalchemy import delete
import letta.utils as utils
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR
from letta.orm import Provider, Step
from letta.orm import Provider, ProviderTrace, Step
from letta.schemas.block import CreateBlock
from letta.schemas.enums import MessageRole, ProviderCategory, ProviderType
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
@@ -286,6 +286,7 @@ def org_id(server):
# cleanup
with db_registry.session() as session:
session.execute(delete(ProviderTrace))
session.execute(delete(Step))
session.execute(delete(Provider))
session.commit()