feat: track llm provider traces and tracking steps in async agent loop (#2219)
This commit is contained in:
@@ -0,0 +1,50 @@
|
||||
"""add support for request and response jsons from llm providers
|
||||
|
||||
Revision ID: cc8dc340836d
|
||||
Revises: 220856bbf43b
|
||||
Create Date: 2025-05-19 14:25:41.999676
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "cc8dc340836d"
|
||||
down_revision: Union[str, None] = "220856bbf43b"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"provider_traces",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("request_json", sa.JSON(), nullable=False),
|
||||
sa.Column("response_json", sa.JSON(), nullable=False),
|
||||
sa.Column("step_id", sa.String(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("ix_step_id", "provider_traces", ["step_id"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_step_id", table_name="provider_traces")
|
||||
op.drop_table("provider_traces")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,12 +1,14 @@
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from openai.types.beta.function_tool import FunctionTool as OpenAITool
|
||||
|
||||
from letta.agents.helpers import generate_step_id
|
||||
from letta.constants import (
|
||||
CLI_WARNING_PREFIX,
|
||||
COMPOSIO_ENTITY_ENV_VAR_KEY,
|
||||
@@ -61,9 +63,10 @@ from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
from letta.services.step_manager import StepManager
|
||||
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
|
||||
from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import summarizer_settings
|
||||
from letta.settings import settings, summarizer_settings
|
||||
from letta.streaming_interface import StreamingRefreshCLIInterface
|
||||
from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message
|
||||
from letta.tracing import log_event, trace_method
|
||||
@@ -141,6 +144,7 @@ class Agent(BaseAgent):
|
||||
self.agent_manager = AgentManager()
|
||||
self.job_manager = JobManager()
|
||||
self.step_manager = StepManager()
|
||||
self.telemetry_manager = TelemetryManager() if settings.llm_api_logging else NoopTelemetryManager()
|
||||
|
||||
# State needed for heartbeat pausing
|
||||
|
||||
@@ -298,6 +302,7 @@ class Agent(BaseAgent):
|
||||
step_count: Optional[int] = None,
|
||||
last_function_failed: bool = False,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
step_id: Optional[str] = None,
|
||||
) -> ChatCompletionResponse | None:
|
||||
"""Get response from LLM API with robust retry mechanism."""
|
||||
log_telemetry(self.logger, "_get_ai_reply start")
|
||||
@@ -347,8 +352,9 @@ class Agent(BaseAgent):
|
||||
messages=message_sequence,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
tools=allowed_functions,
|
||||
stream=stream,
|
||||
force_tool_call=force_tool_call,
|
||||
telemetry_manager=self.telemetry_manager,
|
||||
step_id=step_id,
|
||||
)
|
||||
else:
|
||||
# Fallback to existing flow
|
||||
@@ -365,6 +371,9 @@ class Agent(BaseAgent):
|
||||
stream_interface=self.interface,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
name=self.agent_state.name,
|
||||
telemetry_manager=self.telemetry_manager,
|
||||
step_id=step_id,
|
||||
actor=self.user,
|
||||
)
|
||||
log_telemetry(self.logger, "_get_ai_reply create finish")
|
||||
|
||||
@@ -840,6 +849,9 @@ class Agent(BaseAgent):
|
||||
# Extract job_id from metadata if present
|
||||
job_id = metadata.get("job_id") if metadata else None
|
||||
|
||||
# Declare step_id for the given step to be used as the step is processing.
|
||||
step_id = generate_step_id()
|
||||
|
||||
# Step 0: update core memory
|
||||
# only pulling latest block data if shared memory is being used
|
||||
current_persisted_memory = Memory(
|
||||
@@ -870,6 +882,7 @@ class Agent(BaseAgent):
|
||||
step_count=step_count,
|
||||
last_function_failed=last_function_failed,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
step_id=step_id,
|
||||
)
|
||||
if not response:
|
||||
# EDGE CASE: Function call failed AND there's no tools left for agent to call -> return early
|
||||
@@ -953,6 +966,7 @@ class Agent(BaseAgent):
|
||||
actor=self.user,
|
||||
),
|
||||
job_id=job_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
for message in all_new_messages:
|
||||
message.step_id = step.id
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import List, Tuple
|
||||
|
||||
@@ -150,3 +151,7 @@ def deserialize_message_history(xml_str: str) -> Tuple[List[str], str]:
|
||||
context = sum_el.text or ""
|
||||
|
||||
return messages, context
|
||||
|
||||
|
||||
def generate_step_id():
|
||||
return f"step-{uuid.uuid4()}"
|
||||
|
||||
@@ -8,7 +8,7 @@ from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_async
|
||||
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_async, generate_step_id
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.tool_execution_helper import enable_strict_mode
|
||||
from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface
|
||||
@@ -24,7 +24,8 @@ from letta.schemas.letta_message import AssistantMessage
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import ToolCall
|
||||
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
||||
from letta.schemas.provider_trace import ProviderTraceCreate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import create_letta_messages_from_llm_response
|
||||
@@ -32,6 +33,8 @@ from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.step_manager import NoopStepManager, StepManager
|
||||
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
|
||||
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
||||
from letta.settings import settings
|
||||
from letta.system import package_function_response
|
||||
@@ -50,6 +53,8 @@ class LettaAgent(BaseAgent):
|
||||
block_manager: BlockManager,
|
||||
passage_manager: PassageManager,
|
||||
actor: User,
|
||||
step_manager: StepManager = NoopStepManager(),
|
||||
telemetry_manager: TelemetryManager = NoopTelemetryManager(),
|
||||
):
|
||||
super().__init__(agent_id=agent_id, openai_client=None, message_manager=message_manager, agent_manager=agent_manager, actor=actor)
|
||||
|
||||
@@ -57,6 +62,8 @@ class LettaAgent(BaseAgent):
|
||||
# Summarizer settings
|
||||
self.block_manager = block_manager
|
||||
self.passage_manager = passage_manager
|
||||
self.step_manager = step_manager
|
||||
self.telemetry_manager = telemetry_manager
|
||||
self.response_messages: List[Message] = []
|
||||
|
||||
self.last_function_response = None
|
||||
@@ -68,9 +75,7 @@ class LettaAgent(BaseAgent):
|
||||
@trace_method
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True) -> LettaResponse:
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor)
|
||||
current_in_context_messages, new_in_context_messages, usage = await self._step(
|
||||
agent_state=agent_state, input_messages=input_messages, max_steps=max_steps
|
||||
)
|
||||
_, new_in_context_messages, usage = await self._step(agent_state=agent_state, input_messages=input_messages, max_steps=max_steps)
|
||||
return _create_letta_response(
|
||||
new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage
|
||||
)
|
||||
@@ -89,14 +94,43 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
usage = LettaUsageStatistics()
|
||||
for _ in range(max_steps):
|
||||
response = await self._get_ai_reply(
|
||||
step_id = generate_step_id()
|
||||
|
||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||
if settings.experimental_enable_async_db_engine:
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
||||
)
|
||||
else:
|
||||
if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex":
|
||||
logger.info("Skipping memory rebuild")
|
||||
else:
|
||||
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
|
||||
log_event("agent.stream_no_tokens.messages.refreshed") # [1^]
|
||||
|
||||
request_data = await self._create_llm_request_data_async(
|
||||
llm_client=llm_client,
|
||||
in_context_messages=current_in_context_messages + new_in_context_messages,
|
||||
in_context_messages=in_context_messages,
|
||||
agent_state=agent_state,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
stream=False,
|
||||
# TODO: also pass in reasoning content
|
||||
# TODO: pass in reasoning content
|
||||
)
|
||||
log_event("agent.stream_no_tokens.llm_request.created") # [2^]
|
||||
|
||||
try:
|
||||
response_data = await llm_client.request_async(request_data, agent_state.llm_config)
|
||||
except Exception as e:
|
||||
raise llm_client.handle_llm_error(e)
|
||||
log_event("agent.stream_no_tokens.llm_response.received") # [3^]
|
||||
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
|
||||
|
||||
# update usage
|
||||
# TODO: add run_id
|
||||
usage.step_count += 1
|
||||
usage.completion_tokens += response.usage.completion_tokens
|
||||
usage.prompt_tokens += response.usage.prompt_tokens
|
||||
usage.total_tokens += response.usage.total_tokens
|
||||
|
||||
if not response.choices[0].message.tool_calls:
|
||||
# TODO: make into a real error
|
||||
@@ -109,6 +143,18 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
log_event("agent.stream_no_tokens.llm_response.processed") # [4^]
|
||||
|
||||
# Log LLM Trace
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
|
||||
# stream step
|
||||
# TODO: improve TTFT
|
||||
@@ -119,13 +165,6 @@ class LettaAgent(BaseAgent):
|
||||
for message in letta_messages:
|
||||
yield f"data: {message.model_dump_json()}\n\n"
|
||||
|
||||
# update usage
|
||||
# TODO: add run_id
|
||||
usage.step_count += 1
|
||||
usage.completion_tokens += response.usage.completion_tokens
|
||||
usage.prompt_tokens += response.usage.prompt_tokens
|
||||
usage.total_tokens += response.usage.total_tokens
|
||||
|
||||
if not should_continue:
|
||||
break
|
||||
|
||||
@@ -140,6 +179,13 @@ class LettaAgent(BaseAgent):
|
||||
async def _step(
|
||||
self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10
|
||||
) -> Tuple[List[Message], List[Message], CompletionUsage]:
|
||||
"""
|
||||
Carries out an invocation of the agent loop. In each step, the agent
|
||||
1. Rebuilds its memory
|
||||
2. Generates a request for the LLM
|
||||
3. Fetches a response from the LLM
|
||||
4. Processes the response
|
||||
"""
|
||||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
)
|
||||
@@ -151,14 +197,42 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
usage = LettaUsageStatistics()
|
||||
for _ in range(max_steps):
|
||||
response = await self._get_ai_reply(
|
||||
step_id = generate_step_id()
|
||||
|
||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||
if settings.experimental_enable_async_db_engine:
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
||||
)
|
||||
else:
|
||||
if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex":
|
||||
logger.info("Skipping memory rebuild")
|
||||
else:
|
||||
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
|
||||
log_event("agent.step.messages.refreshed") # [1^]
|
||||
|
||||
request_data = await self._create_llm_request_data_async(
|
||||
llm_client=llm_client,
|
||||
in_context_messages=current_in_context_messages + new_in_context_messages,
|
||||
in_context_messages=in_context_messages,
|
||||
agent_state=agent_state,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
stream=False,
|
||||
# TODO: also pass in reasoning content
|
||||
# TODO: pass in reasoning content
|
||||
)
|
||||
log_event("agent.step.llm_request.created") # [2^]
|
||||
|
||||
try:
|
||||
response_data = await llm_client.request_async(request_data, agent_state.llm_config)
|
||||
except Exception as e:
|
||||
raise llm_client.handle_llm_error(e)
|
||||
log_event("agent.step.llm_response.received") # [3^]
|
||||
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
|
||||
|
||||
# TODO: add run_id
|
||||
usage.step_count += 1
|
||||
usage.completion_tokens += response.usage.completion_tokens
|
||||
usage.prompt_tokens += response.usage.prompt_tokens
|
||||
usage.total_tokens += response.usage.total_tokens
|
||||
|
||||
if not response.choices[0].message.tool_calls:
|
||||
# TODO: make into a real error
|
||||
@@ -167,17 +241,22 @@ class LettaAgent(BaseAgent):
|
||||
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
|
||||
|
||||
persisted_messages, should_continue = await self._handle_ai_response(
|
||||
tool_call, agent_state, tool_rules_solver, reasoning_content=reasoning
|
||||
tool_call, agent_state, tool_rules_solver, reasoning_content=reasoning, step_id=step_id, usage=usage
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
log_event("agent.step.llm_response.processed") # [4^]
|
||||
|
||||
# update usage
|
||||
# TODO: add run_id
|
||||
usage.step_count += 1
|
||||
usage.completion_tokens += response.usage.completion_tokens
|
||||
usage.prompt_tokens += response.usage.prompt_tokens
|
||||
usage.total_tokens += response.usage.total_tokens
|
||||
# Log LLM Trace
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
|
||||
if not should_continue:
|
||||
break
|
||||
@@ -194,8 +273,12 @@ class LettaAgent(BaseAgent):
|
||||
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Main streaming loop that yields partial tokens.
|
||||
Whenever we detect a tool call, we yield from _handle_ai_response as well.
|
||||
Carries out an invocation of the agent loop in a streaming fashion that yields partial tokens.
|
||||
Whenever we detect a tool call, we yield from _handle_ai_response as well. At each step, the agent
|
||||
1. Rebuilds its memory
|
||||
2. Generates a request for the LLM
|
||||
3. Fetches a response from the LLM
|
||||
4. Processes the response
|
||||
"""
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor)
|
||||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
|
||||
@@ -210,13 +293,34 @@ class LettaAgent(BaseAgent):
|
||||
usage = LettaUsageStatistics()
|
||||
|
||||
for _ in range(max_steps):
|
||||
stream = await self._get_ai_reply(
|
||||
step_id = generate_step_id()
|
||||
|
||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||
if settings.experimental_enable_async_db_engine:
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
||||
)
|
||||
else:
|
||||
if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex":
|
||||
logger.info("Skipping memory rebuild")
|
||||
else:
|
||||
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
|
||||
log_event("agent.step.messages.refreshed") # [1^]
|
||||
|
||||
request_data = await self._create_llm_request_data_async(
|
||||
llm_client=llm_client,
|
||||
in_context_messages=current_in_context_messages + new_in_context_messages,
|
||||
in_context_messages=in_context_messages,
|
||||
agent_state=agent_state,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
stream=True,
|
||||
)
|
||||
log_event("agent.stream.llm_request.created") # [2^]
|
||||
|
||||
try:
|
||||
stream = await llm_client.stream_async(request_data, agent_state.llm_config)
|
||||
except Exception as e:
|
||||
raise llm_client.handle_llm_error(e)
|
||||
log_event("agent.stream.llm_response.received") # [3^]
|
||||
|
||||
# TODO: THIS IS INCREDIBLY UGLY
|
||||
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
|
||||
if agent_state.llm_config.model_endpoint_type == "anthropic":
|
||||
@@ -251,10 +355,39 @@ class LettaAgent(BaseAgent):
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=interface.letta_assistant_message_id,
|
||||
pre_computed_tool_message_id=interface.letta_tool_message_id,
|
||||
step_id=step_id,
|
||||
usage=usage,
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
|
||||
# TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream
|
||||
# log_event("agent.stream.llm_response.processed") # [4^]
|
||||
|
||||
# Log LLM Trace
|
||||
# TODO (cliandy): we are piecing together the streamed response here. Content here does not match the actual response schema.
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json={
|
||||
"content": {
|
||||
"tool_call": tool_call.model_dump_json(),
|
||||
"reasoning": [content.model_dump_json() for content in reasoning_content],
|
||||
},
|
||||
"id": interface.message_id,
|
||||
"model": interface.model,
|
||||
"role": "assistant",
|
||||
# "stop_reason": "",
|
||||
# "stop_sequence": None,
|
||||
"type": "message",
|
||||
"usage": {"input_tokens": interface.input_tokens, "output_tokens": interface.output_tokens},
|
||||
},
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
|
||||
if not use_assistant_message or should_continue:
|
||||
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
|
||||
yield f"data: {tool_return.model_dump_json()}\n\n"
|
||||
@@ -277,14 +410,12 @@ class LettaAgent(BaseAgent):
|
||||
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
|
||||
|
||||
@trace_method
|
||||
# When raising an error this doesn't show up
|
||||
async def _get_ai_reply(
|
||||
async def _create_llm_request_data_async(
|
||||
self,
|
||||
llm_client: LLMClientBase,
|
||||
in_context_messages: List[Message],
|
||||
agent_state: AgentState,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
stream: bool,
|
||||
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
||||
if settings.experimental_enable_async_db_engine:
|
||||
self.num_messages = self.num_messages or (await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id))
|
||||
@@ -332,15 +463,7 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
|
||||
|
||||
response = await llm_client.send_llm_request_async(
|
||||
messages=in_context_messages,
|
||||
llm_config=agent_state.llm_config,
|
||||
tools=allowed_tools,
|
||||
force_tool_call=force_tool_call,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return response
|
||||
return llm_client.build_request_data(in_context_messages, agent_state.llm_config, allowed_tools, force_tool_call)
|
||||
|
||||
@trace_method
|
||||
async def _handle_ai_response(
|
||||
@@ -351,6 +474,8 @@ class LettaAgent(BaseAgent):
|
||||
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
|
||||
pre_computed_assistant_message_id: Optional[str] = None,
|
||||
pre_computed_tool_message_id: Optional[str] = None,
|
||||
step_id: str | None = None,
|
||||
usage: LettaUsageStatistics = None,
|
||||
) -> Tuple[List[Message], bool]:
|
||||
"""
|
||||
Now that streaming is done, handle the final AI response.
|
||||
@@ -397,7 +522,28 @@ class LettaAgent(BaseAgent):
|
||||
elif tool_rules_solver.is_continue_tool(tool_name=tool_call_name):
|
||||
continue_stepping = True
|
||||
|
||||
# 5. Persist to DB
|
||||
# 5a. Persist Steps to DB
|
||||
# Following agent loop to persist this before messages
|
||||
# TODO (cliandy): determine what should match old loop w/provider_id, job_id
|
||||
# TODO (cliandy): UsageStatistics and LettaUsageStatistics are used in many places, but are not the same.
|
||||
logged_step = await self.step_manager.log_step_async(
|
||||
actor=self.actor,
|
||||
agent_id=agent_state.id,
|
||||
provider_name=agent_state.llm_config.model_endpoint_type,
|
||||
model=agent_state.llm_config.model,
|
||||
model_endpoint=agent_state.llm_config.model_endpoint,
|
||||
context_window_limit=agent_state.llm_config.context_window,
|
||||
usage=UsageStatistics(
|
||||
total_tokens=usage.total_tokens,
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
),
|
||||
provider_id=None,
|
||||
job_id=None,
|
||||
step_id=step_id,
|
||||
)
|
||||
|
||||
# 5b. Persist Messages to DB
|
||||
tool_call_messages = create_letta_messages_from_llm_response(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
@@ -411,7 +557,9 @@ class LettaAgent(BaseAgent):
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
|
||||
pre_computed_tool_message_id=pre_computed_tool_message_id,
|
||||
step_id=logged_step.id if logged_step else None, # TODO (cliandy): eventually move over other agent loops
|
||||
)
|
||||
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(tool_call_messages, actor=self.actor)
|
||||
self.last_function_response = function_response
|
||||
|
||||
|
||||
@@ -74,6 +74,7 @@ class AnthropicStreamingInterface:
|
||||
# usage trackers
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
self.model = None
|
||||
|
||||
# reasoning object trackers
|
||||
self.reasoning_messages = []
|
||||
@@ -311,6 +312,7 @@ class AnthropicStreamingInterface:
|
||||
self.message_id = event.message.id
|
||||
self.input_tokens += event.message.usage.input_tokens
|
||||
self.output_tokens += event.message.usage.output_tokens
|
||||
self.model = event.message.model
|
||||
elif isinstance(event, BetaRawMessageDeltaEvent):
|
||||
self.output_tokens += event.usage.output_tokens
|
||||
elif isinstance(event, BetaRawMessageStopEvent):
|
||||
|
||||
@@ -40,6 +40,9 @@ class OpenAIStreamingInterface:
|
||||
self.letta_assistant_message_id = Message.generate_id()
|
||||
self.letta_tool_message_id = Message.generate_id()
|
||||
|
||||
self.message_id = None
|
||||
self.model = None
|
||||
|
||||
# token counters
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
@@ -69,6 +72,10 @@ class OpenAIStreamingInterface:
|
||||
prev_message_type = None
|
||||
message_index = 0
|
||||
async for chunk in stream:
|
||||
if not self.model or not self.message_id:
|
||||
self.model = chunk.model
|
||||
self.message_id = chunk.id
|
||||
|
||||
# track usage
|
||||
if chunk.usage:
|
||||
self.input_tokens += len(chunk.usage.prompt_tokens)
|
||||
|
||||
@@ -20,15 +20,19 @@ from letta.llm_api.openai import (
|
||||
build_openai_chat_completions_request,
|
||||
openai_chat_completions_process_stream,
|
||||
openai_chat_completions_request,
|
||||
prepare_openai_payload,
|
||||
)
|
||||
from letta.local_llm.chat_completion_proxy import get_chat_completion
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.orm.user import User
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.schemas.provider_trace import ProviderTraceCreate
|
||||
from letta.services.telemetry_manager import TelemetryManager
|
||||
from letta.settings import ModelSettings
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
|
||||
from letta.tracing import log_event, trace_method
|
||||
@@ -142,6 +146,9 @@ def create(
|
||||
model_settings: Optional[dict] = None, # TODO: eventually pass from server
|
||||
put_inner_thoughts_first: bool = True,
|
||||
name: Optional[str] = None,
|
||||
telemetry_manager: Optional[TelemetryManager] = None,
|
||||
step_id: Optional[str] = None,
|
||||
actor: Optional[User] = None,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Return response to chat completion with backoff"""
|
||||
from letta.utils import printd
|
||||
@@ -233,6 +240,16 @@ def create(
|
||||
if isinstance(stream_interface, AgentChunkStreamingInterface):
|
||||
stream_interface.stream_end()
|
||||
|
||||
telemetry_manager.create_provider_trace(
|
||||
actor=actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=prepare_openai_payload(data),
|
||||
response_json=response.model_json_schema(),
|
||||
step_id=step_id,
|
||||
organization_id=actor.organization_id,
|
||||
),
|
||||
)
|
||||
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG)
|
||||
|
||||
@@ -407,6 +424,16 @@ def create(
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG)
|
||||
|
||||
telemetry_manager.create_provider_trace(
|
||||
actor=actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=chat_completion_request.model_json_schema(),
|
||||
response_json=response.model_json_schema(),
|
||||
step_id=step_id,
|
||||
organization_id=actor.organization_id,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
# elif llm_config.model_endpoint_type == "cohere":
|
||||
|
||||
@@ -9,7 +9,9 @@ from letta.errors import LLMError
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.tracing import log_event
|
||||
from letta.schemas.provider_trace import ProviderTraceCreate
|
||||
from letta.services.telemetry_manager import TelemetryManager
|
||||
from letta.tracing import log_event, trace_method
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm import User
|
||||
@@ -31,13 +33,15 @@ class LLMClientBase:
|
||||
self.put_inner_thoughts_first = put_inner_thoughts_first
|
||||
self.use_tool_naming = use_tool_naming
|
||||
|
||||
@trace_method
|
||||
def send_llm_request(
|
||||
self,
|
||||
messages: List[Message],
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None, # TODO: change to Tool object
|
||||
stream: bool = False,
|
||||
force_tool_call: Optional[str] = None,
|
||||
telemetry_manager: Optional["TelemetryManager"] = None,
|
||||
step_id: Optional[str] = None,
|
||||
) -> Union[ChatCompletionResponse, Stream[ChatCompletionChunk]]:
|
||||
"""
|
||||
Issues a request to the downstream model endpoint and parses response.
|
||||
@@ -48,37 +52,51 @@ class LLMClientBase:
|
||||
|
||||
try:
|
||||
log_event(name="llm_request_sent", attributes=request_data)
|
||||
if stream:
|
||||
return self.stream(request_data, llm_config)
|
||||
else:
|
||||
response_data = self.request(request_data, llm_config)
|
||||
response_data = self.request(request_data, llm_config)
|
||||
if step_id and telemetry_manager:
|
||||
telemetry_manager.create_provider_trace(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
log_event(name="llm_response_received", attributes=response_data)
|
||||
except Exception as e:
|
||||
raise self.handle_llm_error(e)
|
||||
|
||||
return self.convert_response_to_chat_completion(response_data, messages, llm_config)
|
||||
|
||||
@trace_method
|
||||
async def send_llm_request_async(
|
||||
self,
|
||||
request_data: dict,
|
||||
messages: List[Message],
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None, # TODO: change to Tool object
|
||||
stream: bool = False,
|
||||
force_tool_call: Optional[str] = None,
|
||||
telemetry_manager: "TelemetryManager | None" = None,
|
||||
step_id: str | None = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncStream[ChatCompletionChunk]]:
|
||||
"""
|
||||
Issues a request to the downstream model endpoint.
|
||||
If stream=True, returns an AsyncStream[ChatCompletionChunk] that can be async iterated over.
|
||||
Otherwise returns a ChatCompletionResponse.
|
||||
"""
|
||||
request_data = self.build_request_data(messages, llm_config, tools, force_tool_call)
|
||||
|
||||
try:
|
||||
log_event(name="llm_request_sent", attributes=request_data)
|
||||
if stream:
|
||||
return await self.stream_async(request_data, llm_config)
|
||||
else:
|
||||
response_data = await self.request_async(request_data, llm_config)
|
||||
response_data = await self.request_async(request_data, llm_config)
|
||||
await telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
|
||||
log_event(name="llm_response_received", attributes=response_data)
|
||||
except Exception as e:
|
||||
raise self.handle_llm_error(e)
|
||||
@@ -133,13 +151,6 @@ class LLMClientBase:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def stream(self, request_data: dict, llm_config: LLMConfig) -> Stream[ChatCompletionChunk]:
|
||||
"""
|
||||
Performs underlying streaming request to llm and returns raw response.
|
||||
"""
|
||||
raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}")
|
||||
|
||||
@abstractmethod
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
|
||||
"""
|
||||
|
||||
@@ -256,14 +256,6 @@ class OpenAIClient(LLMClientBase):
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
def stream(self, request_data: dict, llm_config: LLMConfig) -> Stream[ChatCompletionChunk]:
|
||||
"""
|
||||
Performs underlying streaming request to OpenAI and returns the stream iterator.
|
||||
"""
|
||||
client = OpenAI(**self._prepare_client_kwargs(llm_config))
|
||||
response_stream: Stream[ChatCompletionChunk] = client.chat.completions.create(**request_data, stream=True)
|
||||
return response_stream
|
||||
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
|
||||
"""
|
||||
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
|
||||
|
||||
@@ -93,7 +93,6 @@ def summarize_messages(
|
||||
response = llm_client.send_llm_request(
|
||||
messages=message_sequence,
|
||||
llm_config=llm_config_no_inner_thoughts,
|
||||
stream=False,
|
||||
)
|
||||
else:
|
||||
response = create(
|
||||
|
||||
@@ -19,6 +19,7 @@ from letta.orm.message import Message
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
|
||||
from letta.orm.provider import Provider
|
||||
from letta.orm.provider_trace import ProviderTrace
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.orm.source import Source
|
||||
from letta.orm.sources_agents import SourcesAgents
|
||||
|
||||
26
letta/orm/provider_trace.py
Normal file
26
letta/orm/provider_trace.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import JSON, ForeignKeyConstraint, Index, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.provider_trace import ProviderTrace as PydanticProviderTrace
|
||||
|
||||
|
||||
class ProviderTrace(SqlalchemyBase, OrganizationMixin):
|
||||
"""Defines data model for storing provider trace information"""
|
||||
|
||||
__tablename__ = "provider_traces"
|
||||
__pydantic_model__ = PydanticProviderTrace
|
||||
__table_args__ = (Index("ix_step_id", "step_id"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
primary_key=True, doc="Unique provider trace identifier", default=lambda: f"provider_trace-{uuid.uuid4()}"
|
||||
)
|
||||
request_json: Mapped[dict] = mapped_column(JSON, doc="JSON content of the provider request")
|
||||
response_json: Mapped[dict] = mapped_column(JSON, doc="JSON content of the provider response")
|
||||
step_id: Mapped[str] = mapped_column(String, nullable=True, doc="ID of the step that this trace is associated with")
|
||||
|
||||
# Relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin")
|
||||
43
letta/schemas/provider_trace.py
Normal file
43
letta/schemas/provider_trace.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
|
||||
|
||||
class BaseProviderTrace(OrmMetadataBase):
|
||||
__id_prefix__ = "provider_trace"
|
||||
|
||||
|
||||
class ProviderTraceCreate(BaseModel):
|
||||
"""Request to create a provider trace"""
|
||||
|
||||
request_json: dict[str, Any] = Field(..., description="JSON content of the provider request")
|
||||
response_json: dict[str, Any] = Field(..., description="JSON content of the provider response")
|
||||
step_id: str = Field(None, description="ID of the step that this trace is associated with")
|
||||
organization_id: str = Field(..., description="The unique identifier of the organization.")
|
||||
|
||||
|
||||
class ProviderTrace(BaseProviderTrace):
|
||||
"""
|
||||
Letta's internal representation of a provider trace.
|
||||
|
||||
Attributes:
|
||||
id (str): The unique identifier of the provider trace.
|
||||
request_json (Dict[str, Any]): JSON content of the provider request.
|
||||
response_json (Dict[str, Any]): JSON content of the provider response.
|
||||
step_id (str): ID of the step that this trace is associated with.
|
||||
organization_id (str): The unique identifier of the organization.
|
||||
created_at (datetime): The timestamp when the object was created.
|
||||
"""
|
||||
|
||||
id: str = BaseProviderTrace.generate_id_field()
|
||||
request_json: Dict[str, Any] = Field(..., description="JSON content of the provider request")
|
||||
response_json: Dict[str, Any] = Field(..., description="JSON content of the provider response")
|
||||
step_id: Optional[str] = Field(None, description="ID of the step that this trace is associated with")
|
||||
organization_id: str = Field(..., description="The unique identifier of the organization.")
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||
@@ -13,6 +13,7 @@ from letta.server.rest_api.routers.v1.sandbox_configs import router as sandbox_c
|
||||
from letta.server.rest_api.routers.v1.sources import router as sources_router
|
||||
from letta.server.rest_api.routers.v1.steps import router as steps_router
|
||||
from letta.server.rest_api.routers.v1.tags import router as tags_router
|
||||
from letta.server.rest_api.routers.v1.telemetry import router as telemetry_router
|
||||
from letta.server.rest_api.routers.v1.tools import router as tools_router
|
||||
from letta.server.rest_api.routers.v1.voice import router as voice_router
|
||||
|
||||
@@ -31,6 +32,7 @@ ROUTERS = [
|
||||
runs_router,
|
||||
steps_router,
|
||||
tags_router,
|
||||
telemetry_router,
|
||||
messages_router,
|
||||
voice_router,
|
||||
embeddings_router,
|
||||
|
||||
@@ -9,6 +9,7 @@ from marshmallow import ValidationError
|
||||
from orjson import orjson
|
||||
from pydantic import Field
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
@@ -33,6 +34,7 @@ from letta.schemas.user import User
|
||||
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.telemetry_manager import NoopTelemetryManager
|
||||
from letta.settings import settings
|
||||
|
||||
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
|
||||
@@ -646,6 +648,8 @@ async def send_message(
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
)
|
||||
|
||||
result = await experimental_agent.step(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message)
|
||||
@@ -707,14 +711,18 @@ async def send_message_streaming(
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
)
|
||||
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
|
||||
|
||||
if request.stream_tokens and model_compatible_token_streaming:
|
||||
result = StreamingResponse(
|
||||
result = StreamingResponseWithStatusCode(
|
||||
experimental_agent.step_stream(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
result = StreamingResponse(
|
||||
result = StreamingResponseWithStatusCode(
|
||||
experimental_agent.step_stream_no_tokens(
|
||||
request.messages, max_steps=10, use_assistant_message=request.use_assistant_message
|
||||
),
|
||||
|
||||
18
letta/server/rest_api/routers/v1/telemetry.py
Normal file
18
letta/server/rest_api/routers/v1/telemetry.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
|
||||
from letta.schemas.provider_trace import ProviderTrace
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
router = APIRouter(prefix="/telemetry", tags=["telemetry"])
|
||||
|
||||
|
||||
@router.get("/{step_id}", response_model=ProviderTrace, operation_id="retrieve_provider_trace")
|
||||
async def retrieve_provider_trace_by_step_id(
|
||||
step_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
return await server.telemetry_manager.get_provider_trace_by_step_id_async(
|
||||
step_id=step_id, actor=await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
)
|
||||
105
letta/server/rest_api/streaming_response.py
Normal file
105
letta/server/rest_api/streaming_response.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Alternative implementation of StreamingResponse that allows for effectively
|
||||
# stremaing HTTP trailers, as we cannot set codes after the initial response.
|
||||
# Taken from: https://github.com/fastapi/fastapi/discussions/10138#discussioncomment-10377361
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
from starlette.types import Send
|
||||
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class StreamingResponseWithStatusCode(StreamingResponse):
|
||||
"""
|
||||
Variation of StreamingResponse that can dynamically decide the HTTP status code,
|
||||
based on the return value of the content iterator (parameter `content`).
|
||||
Expects the content to yield either just str content as per the original `StreamingResponse`
|
||||
or else tuples of (`content`: `str`, `status_code`: `int`).
|
||||
"""
|
||||
|
||||
body_iterator: AsyncIterator[str | bytes]
|
||||
response_started: bool = False
|
||||
|
||||
async def stream_response(self, send: Send) -> None:
|
||||
more_body = True
|
||||
try:
|
||||
first_chunk = await self.body_iterator.__anext__()
|
||||
if isinstance(first_chunk, tuple):
|
||||
first_chunk_content, self.status_code = first_chunk
|
||||
else:
|
||||
first_chunk_content = first_chunk
|
||||
if isinstance(first_chunk_content, str):
|
||||
first_chunk_content = first_chunk_content.encode(self.charset)
|
||||
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
self.response_started = True
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": first_chunk_content,
|
||||
"more_body": more_body,
|
||||
}
|
||||
)
|
||||
|
||||
async for chunk in self.body_iterator:
|
||||
if isinstance(chunk, tuple):
|
||||
content, status_code = chunk
|
||||
if status_code // 100 != 2:
|
||||
# An error occurred mid-stream
|
||||
if not isinstance(content, bytes):
|
||||
content = content.encode(self.charset)
|
||||
more_body = False
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": content,
|
||||
"more_body": more_body,
|
||||
}
|
||||
)
|
||||
return
|
||||
else:
|
||||
content = chunk
|
||||
|
||||
if isinstance(content, str):
|
||||
content = content.encode(self.charset)
|
||||
more_body = True
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": content,
|
||||
"more_body": more_body,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("unhandled_streaming_error")
|
||||
more_body = False
|
||||
error_resp = {"error": {"message": "Internal Server Error"}}
|
||||
error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset)
|
||||
if not self.response_started:
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 500,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": error_event,
|
||||
"more_body": more_body,
|
||||
}
|
||||
)
|
||||
if more_body:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
@@ -190,6 +190,7 @@ def create_letta_messages_from_llm_response(
|
||||
pre_computed_assistant_message_id: Optional[str] = None,
|
||||
pre_computed_tool_message_id: Optional[str] = None,
|
||||
llm_batch_item_id: Optional[str] = None,
|
||||
step_id: str | None = None,
|
||||
) -> List[Message]:
|
||||
messages = []
|
||||
|
||||
@@ -244,6 +245,9 @@ def create_letta_messages_from_llm_response(
|
||||
)
|
||||
messages.append(heartbeat_system_message)
|
||||
|
||||
for message in messages:
|
||||
message.step_id = step_id
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
|
||||
@@ -94,6 +94,7 @@ from letta.services.provider_manager import ProviderManager
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.step_manager import StepManager
|
||||
from letta.services.telemetry_manager import TelemetryManager
|
||||
from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
@@ -213,6 +214,7 @@ class SyncServer(Server):
|
||||
self.identity_manager = IdentityManager()
|
||||
self.group_manager = GroupManager()
|
||||
self.batch_manager = LLMBatchManager()
|
||||
self.telemetry_manager = TelemetryManager()
|
||||
|
||||
# A resusable httpx client
|
||||
timeout = httpx.Timeout(connect=10.0, read=20.0, write=10.0, pool=10.0)
|
||||
|
||||
10
letta/services/helpers/noop_helper.py
Normal file
10
letta/services/helpers/noop_helper.py
Normal file
@@ -0,0 +1,10 @@
|
||||
def singleton(cls):
|
||||
"""Decorator to make a class a Singleton class."""
|
||||
instances = {}
|
||||
|
||||
def get_instance(*args, **kwargs):
|
||||
if cls not in instances:
|
||||
instances[cls] = cls(*args, **kwargs)
|
||||
return instances[cls]
|
||||
|
||||
return get_instance
|
||||
@@ -12,6 +12,7 @@ from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.step import Step as PydanticStep
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.helpers.noop_helper import singleton
|
||||
from letta.tracing import get_trace_id
|
||||
from letta.utils import enforce_types
|
||||
|
||||
@@ -63,6 +64,7 @@ class StepManager:
|
||||
usage: UsageStatistics,
|
||||
provider_id: Optional[str] = None,
|
||||
job_id: Optional[str] = None,
|
||||
step_id: Optional[str] = None,
|
||||
) -> PydanticStep:
|
||||
step_data = {
|
||||
"origin": None,
|
||||
@@ -81,6 +83,8 @@ class StepManager:
|
||||
"tid": None,
|
||||
"trace_id": get_trace_id(), # Get the current trace ID
|
||||
}
|
||||
if step_id:
|
||||
step_data["id"] = step_id
|
||||
with db_registry.session() as session:
|
||||
if job_id:
|
||||
self._verify_job_access(session, job_id, actor, access=["write"])
|
||||
@@ -88,6 +92,46 @@ class StepManager:
|
||||
new_step.create(session)
|
||||
return new_step.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def log_step_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: str,
|
||||
provider_name: str,
|
||||
model: str,
|
||||
model_endpoint: Optional[str],
|
||||
context_window_limit: int,
|
||||
usage: UsageStatistics,
|
||||
provider_id: Optional[str] = None,
|
||||
job_id: Optional[str] = None,
|
||||
step_id: Optional[str] = None,
|
||||
) -> PydanticStep:
|
||||
step_data = {
|
||||
"origin": None,
|
||||
"organization_id": actor.organization_id,
|
||||
"agent_id": agent_id,
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider_name,
|
||||
"model": model,
|
||||
"model_endpoint": model_endpoint,
|
||||
"context_window_limit": context_window_limit,
|
||||
"completion_tokens": usage.completion_tokens,
|
||||
"prompt_tokens": usage.prompt_tokens,
|
||||
"total_tokens": usage.total_tokens,
|
||||
"job_id": job_id,
|
||||
"tags": [],
|
||||
"tid": None,
|
||||
"trace_id": get_trace_id(), # Get the current trace ID
|
||||
}
|
||||
if step_id:
|
||||
step_data["id"] = step_id
|
||||
async with db_registry.async_session() as session:
|
||||
if job_id:
|
||||
self._verify_job_access(session, job_id, actor, access=["write"])
|
||||
new_step = StepModel(**step_data)
|
||||
await new_step.create_async(session)
|
||||
return new_step.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_step(self, step_id: str, actor: PydanticUser) -> PydanticStep:
|
||||
with db_registry.session() as session:
|
||||
@@ -147,3 +191,44 @@ class StepManager:
|
||||
if not job:
|
||||
raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access")
|
||||
return job
|
||||
|
||||
|
||||
@singleton
|
||||
class NoopStepManager(StepManager):
|
||||
"""
|
||||
Noop implementation of StepManager.
|
||||
Temporarily used for migrations, but allows for different implementations in the future.
|
||||
Will not allow for writes, but will still allow for reads.
|
||||
"""
|
||||
|
||||
@enforce_types
|
||||
def log_step(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: str,
|
||||
provider_name: str,
|
||||
model: str,
|
||||
model_endpoint: Optional[str],
|
||||
context_window_limit: int,
|
||||
usage: UsageStatistics,
|
||||
provider_id: Optional[str] = None,
|
||||
job_id: Optional[str] = None,
|
||||
step_id: Optional[str] = None,
|
||||
) -> PydanticStep:
|
||||
return
|
||||
|
||||
@enforce_types
|
||||
async def log_step_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: str,
|
||||
provider_name: str,
|
||||
model: str,
|
||||
model_endpoint: Optional[str],
|
||||
context_window_limit: int,
|
||||
usage: UsageStatistics,
|
||||
provider_id: Optional[str] = None,
|
||||
job_id: Optional[str] = None,
|
||||
step_id: Optional[str] = None,
|
||||
) -> PydanticStep:
|
||||
return
|
||||
|
||||
52
letta/services/telemetry_manager.py
Normal file
52
letta/services/telemetry_manager.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from sqlalchemy import select
|
||||
|
||||
from letta.orm.provider_trace import ProviderTrace as ProviderTraceModel
|
||||
from letta.schemas.provider_trace import ProviderTrace as PydanticProviderTrace
|
||||
from letta.schemas.provider_trace import ProviderTraceCreate
|
||||
from letta.schemas.step import Step as PydanticStep
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.helpers.noop_helper import singleton
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class TelemetryManager:
|
||||
@enforce_types
|
||||
async def get_provider_trace_by_step_id_async(
|
||||
self,
|
||||
step_id: str,
|
||||
actor: PydanticUser,
|
||||
) -> PydanticProviderTrace:
|
||||
async with db_registry.async_session() as session:
|
||||
provider_trace = await ProviderTraceModel.read_async(db_session=session, step_id=step_id, actor=actor)
|
||||
return provider_trace.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def create_provider_trace_async(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
|
||||
async with db_registry.async_session() as session:
|
||||
provider_trace = ProviderTraceModel(**provider_trace_create.model_dump())
|
||||
await provider_trace.create_async(session, actor=actor)
|
||||
return provider_trace.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def create_provider_trace(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
|
||||
with db_registry.session() as session:
|
||||
provider_trace = ProviderTraceModel(**provider_trace_create.model_dump())
|
||||
provider_trace.create(session, actor=actor)
|
||||
return provider_trace.to_pydantic()
|
||||
|
||||
|
||||
@singleton
|
||||
class NoopTelemetryManager(TelemetryManager):
|
||||
"""
|
||||
Noop implementation of TelemetryManager.
|
||||
"""
|
||||
|
||||
async def create_provider_trace_async(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
|
||||
return
|
||||
|
||||
async def get_provider_trace_by_step_id_async(self, step_id: str, actor: PydanticUser) -> PydanticStep:
|
||||
return
|
||||
|
||||
def create_provider_trace(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
|
||||
return
|
||||
@@ -190,6 +190,7 @@ class Settings(BaseSettings):
|
||||
verbose_telemetry_logging: bool = False
|
||||
otel_exporter_otlp_endpoint: Optional[str] = None # otel default: "http://localhost:4317"
|
||||
disable_tracing: bool = False
|
||||
llm_api_logging: bool = True
|
||||
|
||||
# uvicorn settings
|
||||
uvicorn_workers: int = 1
|
||||
|
||||
@@ -63,6 +63,7 @@ def check_composio_key_set():
|
||||
yield
|
||||
|
||||
|
||||
# --- Tool Fixtures ---
|
||||
@pytest.fixture
|
||||
def weather_tool_func():
|
||||
def get_weather(location: str) -> str:
|
||||
@@ -110,6 +111,23 @@ def print_tool_func():
|
||||
yield print_tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def roll_dice_tool_func():
|
||||
def roll_dice():
|
||||
"""
|
||||
Rolls a 6 sided die.
|
||||
|
||||
Returns:
|
||||
str: The roll result.
|
||||
"""
|
||||
import time
|
||||
|
||||
time.sleep(1)
|
||||
return "Rolled a 10!"
|
||||
|
||||
yield roll_dice
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_beta_message_batch() -> BetaMessageBatch:
|
||||
return BetaMessageBatch(
|
||||
|
||||
@@ -6,7 +6,7 @@ from sqlalchemy import delete
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import DEFAULT_HUMAN
|
||||
from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
|
||||
from letta.orm import Provider, Step
|
||||
from letta.orm import Provider, ProviderTrace, Step
|
||||
from letta.orm.enums import JobType
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agent import CreateAgent
|
||||
@@ -39,6 +39,7 @@ def org_id(server):
|
||||
|
||||
# cleanup
|
||||
with db_registry.session() as session:
|
||||
session.execute(delete(ProviderTrace))
|
||||
session.execute(delete(Step))
|
||||
session.execute(delete(Provider))
|
||||
session.commit()
|
||||
|
||||
@@ -2,7 +2,7 @@ import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.orm import Provider, Step
|
||||
from letta.orm import Provider, ProviderTrace, Step
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.group import (
|
||||
@@ -38,6 +38,7 @@ def org_id(server):
|
||||
|
||||
# cleanup
|
||||
with db_registry.session() as session:
|
||||
session.execute(delete(ProviderTrace))
|
||||
session.execute(delete(Step))
|
||||
session.execute(delete(Provider))
|
||||
session.commit()
|
||||
|
||||
205
tests/test_provider_trace.py
Normal file
205
tests/test_provider_trace.py
Normal file
@@ -0,0 +1,205 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta
|
||||
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.step_manager import StepManager
|
||||
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
|
||||
|
||||
|
||||
def _run_server():
|
||||
"""Starts the Letta server in a background thread."""
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server_url():
|
||||
"""Ensures a server is running and returns its base URL."""
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(5) # Allow server startup time
|
||||
|
||||
return url
|
||||
|
||||
|
||||
# # --- Client Setup --- #
|
||||
@pytest.fixture(scope="session")
|
||||
def client(server_url):
|
||||
"""Creates a REST client for testing."""
|
||||
client = Letta(base_url=server_url)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop(request):
|
||||
"""Create an instance of the default event loop for each test case."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def roll_dice_tool(client, roll_dice_tool_func):
|
||||
print_tool = client.tools.upsert_from_function(func=roll_dice_tool_func)
|
||||
yield print_tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def weather_tool(client, weather_tool_func):
|
||||
weather_tool = client.tools.upsert_from_function(func=weather_tool_func)
|
||||
yield weather_tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def print_tool(client, print_tool_func):
|
||||
print_tool = client.tools.upsert_from_function(func=print_tool_func)
|
||||
yield print_tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def agent_state(client, roll_dice_tool, weather_tool):
|
||||
"""Creates an agent and ensures cleanup after tests."""
|
||||
agent_state = client.agents.create(
|
||||
name=f"test_compl_{str(uuid.uuid4())[5:]}",
|
||||
tool_ids=[roll_dice_tool.id, weather_tool.id],
|
||||
include_base_tools=True,
|
||||
memory_blocks=[
|
||||
{
|
||||
"label": "human",
|
||||
"value": "Name: Matt",
|
||||
},
|
||||
{
|
||||
"label": "persona",
|
||||
"value": "Friendly agent",
|
||||
},
|
||||
],
|
||||
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
)
|
||||
yield agent_state
|
||||
client.agents.delete(agent_state.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
|
||||
async def test_provider_trace_experimental_step(message, agent_state, default_user):
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_state.id,
|
||||
message_manager=MessageManager(),
|
||||
agent_manager=AgentManager(),
|
||||
block_manager=BlockManager(),
|
||||
passage_manager=PassageManager(),
|
||||
step_manager=StepManager(),
|
||||
telemetry_manager=TelemetryManager(),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
response = await experimental_agent.step([MessageCreate(role="user", content=[TextContent(text=message)])])
|
||||
tool_step = response.messages[0].step_id
|
||||
reply_step = response.messages[-1].step_id
|
||||
|
||||
tool_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user)
|
||||
reply_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user)
|
||||
assert tool_telemetry.request_json
|
||||
assert reply_telemetry.request_json
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
|
||||
async def test_provider_trace_experimental_step_stream(message, agent_state, default_user, event_loop):
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_state.id,
|
||||
message_manager=MessageManager(),
|
||||
agent_manager=AgentManager(),
|
||||
block_manager=BlockManager(),
|
||||
passage_manager=PassageManager(),
|
||||
step_manager=StepManager(),
|
||||
telemetry_manager=TelemetryManager(),
|
||||
actor=default_user,
|
||||
)
|
||||
stream = experimental_agent.step_stream([MessageCreate(role="user", content=[TextContent(text=message)])])
|
||||
|
||||
result = StreamingResponseWithStatusCode(
|
||||
stream,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
message_id = None
|
||||
|
||||
async def test_send(message) -> None:
|
||||
nonlocal message_id
|
||||
if "body" in message and not message_id:
|
||||
body = message["body"].decode("utf-8").split("data:")
|
||||
message_id = json.loads(body[1])["id"]
|
||||
|
||||
await result.stream_response(send=test_send)
|
||||
|
||||
messages = await experimental_agent.message_manager.get_messages_by_ids_async([message_id], actor=default_user)
|
||||
step_ids = set((message.step_id for message in messages))
|
||||
for step_id in step_ids:
|
||||
telemetry_data = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=step_id, actor=default_user)
|
||||
assert telemetry_data.request_json
|
||||
assert telemetry_data.response_json
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
|
||||
async def test_provider_trace_step(client, agent_state, default_user, message, event_loop):
|
||||
client.agents.messages.create(agent_id=agent_state.id, messages=[])
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=[TextContent(text=message)])],
|
||||
)
|
||||
tool_step = response.messages[0].step_id
|
||||
reply_step = response.messages[-1].step_id
|
||||
|
||||
tool_telemetry = await TelemetryManager().get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user)
|
||||
reply_telemetry = await TelemetryManager().get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user)
|
||||
assert tool_telemetry.request_json
|
||||
assert reply_telemetry.request_json
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
|
||||
async def test_noop_provider_trace(message, agent_state, default_user, event_loop):
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_state.id,
|
||||
message_manager=MessageManager(),
|
||||
agent_manager=AgentManager(),
|
||||
block_manager=BlockManager(),
|
||||
passage_manager=PassageManager(),
|
||||
step_manager=StepManager(),
|
||||
telemetry_manager=NoopTelemetryManager(),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
response = await experimental_agent.step([MessageCreate(role="user", content=[TextContent(text=message)])])
|
||||
tool_step = response.messages[0].step_id
|
||||
reply_step = response.messages[-1].step_id
|
||||
|
||||
tool_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user)
|
||||
reply_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user)
|
||||
assert tool_telemetry is None
|
||||
assert reply_telemetry is None
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy import delete
|
||||
|
||||
import letta.utils as utils
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR
|
||||
from letta.orm import Provider, Step
|
||||
from letta.orm import Provider, ProviderTrace, Step
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.enums import MessageRole, ProviderCategory, ProviderType
|
||||
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
|
||||
@@ -286,6 +286,7 @@ def org_id(server):
|
||||
|
||||
# cleanup
|
||||
with db_registry.session() as session:
|
||||
session.execute(delete(ProviderTrace))
|
||||
session.execute(delete(Step))
|
||||
session.execute(delete(Provider))
|
||||
session.commit()
|
||||
|
||||
Reference in New Issue
Block a user