chore: sync 0.12.0 version (#3023)
Co-authored-by: Matthew Zhou <mattzh1314@gmail.com> Co-authored-by: Kian Jones <11655409+kianjones9@users.noreply.github.com> Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
1
.github/workflows/core-unit-sqlite-test.yaml
vendored
1
.github/workflows/core-unit-sqlite-test.yaml
vendored
@@ -28,7 +28,6 @@ jobs:
|
||||
install-args: '--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox --extra sqlite'
|
||||
timeout-minutes: 15
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
|
||||
matrix-strategy: |
|
||||
{
|
||||
"fail-fast": false,
|
||||
|
||||
38
.github/workflows/docker-integration-tests.yaml
vendored
38
.github/workflows/docker-integration-tests.yaml
vendored
@@ -10,6 +10,29 @@ jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
env:
|
||||
# Database configuration - these will be used by dev-compose.yaml
|
||||
LETTA_PG_DB: letta
|
||||
LETTA_PG_USER: letta
|
||||
LETTA_PG_PASSWORD: letta
|
||||
LETTA_PG_HOST: pgvector_db # Internal Docker service name
|
||||
LETTA_PG_PORT: 5432
|
||||
# Server configuration for tests
|
||||
LETTA_SERVER_PASS: test_server_token
|
||||
LETTA_SERVER_URL: http://localhost:8283
|
||||
# API keys
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
# Additional API keys that dev-compose.yaml expects (optional)
|
||||
GROQ_API_KEY: ""
|
||||
ANTHROPIC_API_KEY: ""
|
||||
OLLAMA_BASE_URL: ""
|
||||
AZURE_API_KEY: ""
|
||||
AZURE_BASE_URL: ""
|
||||
AZURE_API_VERSION: ""
|
||||
GEMINI_API_KEY: ""
|
||||
VLLM_API_BASE: ""
|
||||
OPENLLM_AUTH_TYPE: ""
|
||||
OPENLLM_API_KEY: ""
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -32,14 +55,8 @@ jobs:
|
||||
chmod -R 755 /home/runner/.letta/logs
|
||||
|
||||
- name: Build and run docker dev server
|
||||
env:
|
||||
LETTA_PG_DB: letta
|
||||
LETTA_PG_USER: letta
|
||||
LETTA_PG_PASSWORD: letta
|
||||
LETTA_PG_PORT: 8888
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
run: |
|
||||
# dev-compose.yaml will use the environment variables we set above
|
||||
docker compose -f dev-compose.yaml up --build -d
|
||||
|
||||
- name: Wait for service
|
||||
@@ -47,13 +64,6 @@ jobs:
|
||||
|
||||
- name: Run tests with pytest
|
||||
env:
|
||||
LETTA_PG_DB: letta
|
||||
LETTA_PG_USER: letta
|
||||
LETTA_PG_PASSWORD: letta
|
||||
LETTA_PG_PORT: 8888
|
||||
LETTA_SERVER_PASS: test_server_token
|
||||
LETTA_SERVER_URL: http://localhost:8283
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
|
||||
run: |
|
||||
uv sync --extra dev --extra postgres --extra sqlite
|
||||
|
||||
@@ -15,8 +15,13 @@ letta_config = LettaConfig.load()
|
||||
config = context.config
|
||||
|
||||
if settings.database_engine is DatabaseChoice.POSTGRES:
|
||||
config.set_main_option("sqlalchemy.url", settings.letta_pg_uri)
|
||||
print("Using database: ", settings.letta_pg_uri)
|
||||
# Convert PostgreSQL URI to sync format for alembic using common utility
|
||||
from letta.database_utils import get_database_uri_for_context
|
||||
|
||||
sync_pg_uri = get_database_uri_for_context(settings.letta_pg_uri, "alembic")
|
||||
|
||||
config.set_main_option("sqlalchemy.url", sync_pg_uri)
|
||||
print("Using database: ", sync_pg_uri)
|
||||
else:
|
||||
config.set_main_option("sqlalchemy.url", "sqlite:///" + os.path.join(letta_config.recall_storage_path, "sqlite.db"))
|
||||
|
||||
|
||||
55
alembic/versions/c734cfc0d595_add_runs_metrics_table.py
Normal file
55
alembic/versions/c734cfc0d595_add_runs_metrics_table.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""add runs_metrics table
|
||||
|
||||
Revision ID: c734cfc0d595
|
||||
Revises: 038e68cdf0df
|
||||
Create Date: 2025-10-08 14:35:23.302204
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c734cfc0d595"
|
||||
down_revision: Union[str, None] = "038e68cdf0df"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"run_metrics",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("run_start_ns", sa.BigInteger(), nullable=True),
|
||||
sa.Column("run_ns", sa.BigInteger(), nullable=True),
|
||||
sa.Column("num_steps", sa.Integer(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("project_id", sa.String(), nullable=True),
|
||||
sa.Column("agent_id", sa.String(), nullable=False),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.Column("base_template_id", sa.String(), nullable=True),
|
||||
sa.Column("template_id", sa.String(), nullable=True),
|
||||
sa.Column("deployment_id", sa.String(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["id"], ["runs.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("run_metrics")
|
||||
# ### end Alembic commands ###
|
||||
@@ -31,8 +31,9 @@ services:
|
||||
- LETTA_PG_DB=${LETTA_PG_DB:-letta}
|
||||
- LETTA_PG_USER=${LETTA_PG_USER:-letta}
|
||||
- LETTA_PG_PASSWORD=${LETTA_PG_PASSWORD:-letta}
|
||||
- LETTA_PG_HOST=pgvector_db
|
||||
- LETTA_PG_PORT=5432
|
||||
- LETTA_PG_HOST=${LETTA_PG_HOST:-pgvector_db}
|
||||
- LETTA_PG_PORT=${LETTA_PG_PORT:-5432}
|
||||
- LETTA_PG_URI=${LETTA_PG_URI:-postgresql://${LETTA_PG_USER:-letta}:${LETTA_PG_PASSWORD:-letta}@${LETTA_PG_HOST:-pgvector_db}:${LETTA_PG_PORT:-5432}/${LETTA_PG_DB:-letta}}
|
||||
- LETTA_DEBUG=True
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
- GROQ_API_KEY=${GROQ_API_KEY}
|
||||
|
||||
@@ -10698,6 +10698,47 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/runs/{run_id}/metrics": {
|
||||
"get": {
|
||||
"tags": ["runs"],
|
||||
"summary": "Retrieve Metrics For Run",
|
||||
"description": "Get run metrics by run ID.",
|
||||
"operationId": "retrieve_metrics_for_run",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "run_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"title": "Run Id"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/RunMetrics"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/runs/{run_id}/steps": {
|
||||
"get": {
|
||||
"tags": ["runs"],
|
||||
@@ -31576,6 +31617,103 @@
|
||||
"title": "Run",
|
||||
"description": "Representation of a run - a conversation or processing session for an agent.\nRuns track when agents process messages and maintain the relationship between agents, steps, and messages.\n\nParameters:\n id (str): The unique identifier of the run (prefixed with 'run-').\n status (JobStatus): The current status of the run.\n created_at (datetime): The timestamp when the run was created.\n completed_at (datetime): The timestamp when the run was completed.\n agent_id (str): The unique identifier of the agent associated with the run.\n stop_reason (StopReasonType): The reason why the run was stopped.\n background (bool): Whether the run was created in background mode.\n metadata (dict): Additional metadata for the run.\n request_config (LettaRequestConfig): The request configuration for the run."
|
||||
},
|
||||
"RunMetrics": {
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"title": "Id",
|
||||
"description": "The id of the run this metric belongs to (matches runs.id)."
|
||||
},
|
||||
"agent_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Agent Id",
|
||||
"description": "The unique identifier of the agent."
|
||||
},
|
||||
"project_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Project Id",
|
||||
"description": "The project that the run belongs to (cloud only)."
|
||||
},
|
||||
"run_start_ns": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "integer"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Run Start Ns",
|
||||
"description": "The timestamp of the start of the run in nanoseconds."
|
||||
},
|
||||
"run_ns": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "integer"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Run Ns",
|
||||
"description": "Total time for the run in nanoseconds."
|
||||
},
|
||||
"num_steps": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "integer"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Num Steps",
|
||||
"description": "The number of steps in the run."
|
||||
},
|
||||
"template_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Template Id",
|
||||
"description": "The template ID that the run belongs to (cloud only)."
|
||||
},
|
||||
"base_template_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Base Template Id",
|
||||
"description": "The base template ID that the run belongs to (cloud only)."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"type": "object",
|
||||
"required": ["id"],
|
||||
"title": "RunMetrics"
|
||||
},
|
||||
"RunStatus": {
|
||||
"type": "string",
|
||||
"enum": ["created", "running", "completed", "failed", "cancelled"],
|
||||
|
||||
@@ -595,6 +595,27 @@ class LettaAgentV3(LettaAgentV2):
|
||||
# -1. no tool call, no content
|
||||
if tool_call is None and (content is None or len(content) == 0):
|
||||
# Edge case is when there's also no content - basically, the LLM "no-op'd"
|
||||
# If RequiredBeforeExitToolRule exists and not all required tools have been called,
|
||||
# inject a rule-violation heartbeat to keep looping and inform the model.
|
||||
uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in agent_state.tools]))
|
||||
if uncalled:
|
||||
# TODO: we may need to change this to not have a "heartbeat" prefix for v3?
|
||||
heartbeat_reason = (
|
||||
f"{NON_USER_MSG_PREFIX}ToolRuleViolated: You must call {', '.join(uncalled)} at least once to exit the loop."
|
||||
)
|
||||
from letta.server.rest_api.utils import create_heartbeat_system_message
|
||||
|
||||
heartbeat_msg = create_heartbeat_system_message(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
function_call_success=True,
|
||||
timezone=agent_state.timezone,
|
||||
heartbeat_reason=heartbeat_reason,
|
||||
run_id=run_id,
|
||||
)
|
||||
messages_to_persist = (initial_messages or []) + [heartbeat_msg]
|
||||
continue_stepping, stop_reason = True, None
|
||||
else:
|
||||
# In this case, we actually do not want to persist the no-op message
|
||||
continue_stepping, heartbeat_reason, stop_reason = False, None, LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
||||
messages_to_persist = initial_messages or []
|
||||
@@ -627,7 +648,8 @@ class LettaAgentV3(LettaAgentV2):
|
||||
run_id=run_id,
|
||||
is_approval_response=is_approval or is_denial,
|
||||
force_set_request_heartbeat=False,
|
||||
add_heartbeat_on_continue=False,
|
||||
# If we're continuing due to a required-before-exit rule, include a heartbeat to guide the model
|
||||
add_heartbeat_on_continue=bool(heartbeat_reason),
|
||||
)
|
||||
messages_to_persist = (initial_messages or []) + assistant_message
|
||||
|
||||
@@ -843,7 +865,13 @@ class LettaAgentV3(LettaAgentV2):
|
||||
stop_reason: LettaStopReason | None = None
|
||||
|
||||
if tool_call_name is None:
|
||||
# No tool call? End loop
|
||||
# No tool call – if there are required-before-exit tools uncalled, keep stepping
|
||||
# and provide explicit feedback to the model; otherwise end the loop.
|
||||
uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in agent_state.tools]))
|
||||
if uncalled and not is_final_step:
|
||||
reason = f"{NON_USER_MSG_PREFIX}ToolRuleViolated: You must call {', '.join(uncalled)} at least once to exit the loop."
|
||||
return True, reason, None
|
||||
# No required tools remaining → end turn
|
||||
return False, None, LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
||||
else:
|
||||
if tool_rule_violated:
|
||||
|
||||
161
letta/database_utils.py
Normal file
161
letta/database_utils.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Database URI utilities for consistent database connection handling across the application.
|
||||
|
||||
This module provides utilities for parsing and converting database URIs to ensure
|
||||
consistent behavior between the main application, alembic migrations, and other
|
||||
database-related components.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
|
||||
def parse_database_uri(uri: str) -> dict[str, Optional[str]]:
|
||||
"""
|
||||
Parse a database URI into its components.
|
||||
|
||||
Args:
|
||||
uri: Database URI (e.g., postgresql://user:pass@host:port/db)
|
||||
|
||||
Returns:
|
||||
Dictionary with parsed components: scheme, driver, user, password, host, port, database
|
||||
"""
|
||||
parsed = urlparse(uri)
|
||||
|
||||
# Extract driver from scheme (e.g., postgresql+asyncpg -> asyncpg)
|
||||
scheme_parts = parsed.scheme.split("+")
|
||||
base_scheme = scheme_parts[0] if scheme_parts else ""
|
||||
driver = scheme_parts[1] if len(scheme_parts) > 1 else None
|
||||
|
||||
return {
|
||||
"scheme": base_scheme,
|
||||
"driver": driver,
|
||||
"user": parsed.username,
|
||||
"password": parsed.password,
|
||||
"host": parsed.hostname,
|
||||
"port": str(parsed.port) if parsed.port else None,
|
||||
"database": parsed.path.lstrip("/") if parsed.path else None,
|
||||
"query": parsed.query,
|
||||
"fragment": parsed.fragment,
|
||||
}
|
||||
|
||||
|
||||
def build_database_uri(
|
||||
scheme: str = "postgresql",
|
||||
driver: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
database: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
fragment: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build a database URI from components.
|
||||
|
||||
Args:
|
||||
scheme: Base scheme (e.g., "postgresql")
|
||||
driver: Driver name (e.g., "asyncpg", "pg8000")
|
||||
user: Username
|
||||
password: Password
|
||||
host: Hostname
|
||||
port: Port number
|
||||
database: Database name
|
||||
query: Query string
|
||||
fragment: Fragment
|
||||
|
||||
Returns:
|
||||
Complete database URI
|
||||
"""
|
||||
# Combine scheme and driver
|
||||
full_scheme = f"{scheme}+{driver}" if driver else scheme
|
||||
|
||||
# Build netloc (user:password@host:port)
|
||||
netloc_parts = []
|
||||
if user:
|
||||
if password:
|
||||
netloc_parts.append(f"{user}:{password}")
|
||||
else:
|
||||
netloc_parts.append(user)
|
||||
|
||||
if host:
|
||||
if port:
|
||||
netloc_parts.append(f"{host}:{port}")
|
||||
else:
|
||||
netloc_parts.append(host)
|
||||
|
||||
netloc = "@".join(netloc_parts) if netloc_parts else ""
|
||||
|
||||
# Build path
|
||||
path = f"/{database}" if database else ""
|
||||
|
||||
# Build the URI
|
||||
return urlunparse((full_scheme, netloc, path, "", query or "", fragment or ""))
|
||||
|
||||
|
||||
def convert_to_async_uri(uri: str) -> str:
|
||||
"""
|
||||
Convert a database URI to use the asyncpg driver for async operations.
|
||||
|
||||
Args:
|
||||
uri: Original database URI
|
||||
|
||||
Returns:
|
||||
URI with asyncpg driver and ssl parameter adjustments
|
||||
"""
|
||||
components = parse_database_uri(uri)
|
||||
|
||||
# Convert to asyncpg driver
|
||||
components["driver"] = "asyncpg"
|
||||
|
||||
# Build the new URI
|
||||
new_uri = build_database_uri(**components)
|
||||
|
||||
# Replace sslmode= with ssl= for asyncpg compatibility
|
||||
new_uri = new_uri.replace("sslmode=", "ssl=")
|
||||
|
||||
return new_uri
|
||||
|
||||
|
||||
def convert_to_sync_uri(uri: str) -> str:
|
||||
"""
|
||||
Convert a database URI to use the pg8000 driver for sync operations (alembic).
|
||||
|
||||
Args:
|
||||
uri: Original database URI
|
||||
|
||||
Returns:
|
||||
URI with pg8000 driver and sslmode parameter adjustments
|
||||
"""
|
||||
components = parse_database_uri(uri)
|
||||
|
||||
# Convert to pg8000 driver
|
||||
components["driver"] = "pg8000"
|
||||
|
||||
# Build the new URI
|
||||
new_uri = build_database_uri(**components)
|
||||
|
||||
# Replace ssl= with sslmode= for pg8000 compatibility
|
||||
new_uri = new_uri.replace("ssl=", "sslmode=")
|
||||
|
||||
return new_uri
|
||||
|
||||
|
||||
def get_database_uri_for_context(uri: str, context: str = "async") -> str:
|
||||
"""
|
||||
Get the appropriate database URI for a specific context.
|
||||
|
||||
Args:
|
||||
uri: Original database URI
|
||||
context: Context type ("async" for asyncpg, "sync" for pg8000, "alembic" for pg8000)
|
||||
|
||||
Returns:
|
||||
URI formatted for the specified context
|
||||
"""
|
||||
if context in ["async"]:
|
||||
return convert_to_async_uri(uri)
|
||||
elif context in ["sync", "alembic"]:
|
||||
return convert_to_sync_uri(uri)
|
||||
else:
|
||||
raise ValueError(f"Unknown context: {context}. Must be 'async', 'sync', or 'alembic'")
|
||||
@@ -325,6 +325,7 @@ class AnthropicClient(LLMClientBase):
|
||||
data["system"] = self._add_cache_control_to_system_message(system_content)
|
||||
data["messages"] = PydanticMessage.to_anthropic_dicts_from_list(
|
||||
messages=messages[1:],
|
||||
current_model=llm_config.model,
|
||||
inner_thoughts_xml_tag=inner_thoughts_xml_tag,
|
||||
put_inner_thoughts_in_kwargs=put_kwargs,
|
||||
# if react, use native content + strip heartbeats
|
||||
|
||||
@@ -311,6 +311,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
contents = self.add_dummy_model_messages(
|
||||
PydanticMessage.to_google_dicts_from_list(
|
||||
messages,
|
||||
current_model=llm_config.model,
|
||||
put_inner_thoughts_in_kwargs=False if agent_type == AgentType.letta_v1_agent else True,
|
||||
native_content=True if agent_type == AgentType.letta_v1_agent else False,
|
||||
),
|
||||
|
||||
@@ -27,6 +27,7 @@ from letta.orm.prompt import Prompt
|
||||
from letta.orm.provider import Provider
|
||||
from letta.orm.provider_trace import ProviderTrace
|
||||
from letta.orm.run import Run
|
||||
from letta.orm.run_metrics import RunMetrics
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.orm.source import Source
|
||||
from letta.orm.sources_agents import SourcesAgents
|
||||
|
||||
82
letta/orm/run_metrics.py
Normal file
82
letta/orm/run_metrics.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from sqlalchemy import BigInteger, ForeignKey, Integer, String
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import AgentMixin, OrganizationMixin, ProjectMixin, TemplateMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.run_metrics import RunMetrics as PydanticRunMetrics
|
||||
from letta.schemas.user import User
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.run import Run
|
||||
from letta.orm.step import Step
|
||||
|
||||
|
||||
class RunMetrics(SqlalchemyBase, ProjectMixin, AgentMixin, OrganizationMixin, TemplateMixin):
|
||||
"""Tracks performance metrics for agent steps."""
|
||||
|
||||
__tablename__ = "run_metrics"
|
||||
__pydantic_model__ = PydanticRunMetrics
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
ForeignKey("runs.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
doc="The unique identifier of the run this metric belongs to (also serves as PK)",
|
||||
)
|
||||
run_start_ns: Mapped[Optional[int]] = mapped_column(
|
||||
BigInteger,
|
||||
nullable=True,
|
||||
doc="The timestamp of the start of the run in nanoseconds",
|
||||
)
|
||||
run_ns: Mapped[Optional[int]] = mapped_column(
|
||||
BigInteger,
|
||||
nullable=True,
|
||||
doc="Total time for the run in nanoseconds",
|
||||
)
|
||||
num_steps: Mapped[Optional[int]] = mapped_column(
|
||||
Integer,
|
||||
nullable=True,
|
||||
doc="The number of steps in the run",
|
||||
)
|
||||
run: Mapped[Optional["Run"]] = relationship("Run", foreign_keys=[id])
|
||||
agent: Mapped[Optional["Agent"]] = relationship("Agent")
|
||||
|
||||
def create(
|
||||
self,
|
||||
db_session: Session,
|
||||
actor: Optional[User] = None,
|
||||
no_commit: bool = False,
|
||||
) -> "RunMetrics":
|
||||
"""Override create to handle SQLite timestamp issues"""
|
||||
# For SQLite, explicitly set timestamps as server_default may not work
|
||||
if settings.database_engine == DatabaseChoice.SQLITE:
|
||||
now = datetime.now(timezone.utc)
|
||||
if not self.created_at:
|
||||
self.created_at = now
|
||||
if not self.updated_at:
|
||||
self.updated_at = now
|
||||
|
||||
return super().create(db_session, actor=actor, no_commit=no_commit)
|
||||
|
||||
async def create_async(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
actor: Optional[User] = None,
|
||||
no_commit: bool = False,
|
||||
no_refresh: bool = False,
|
||||
) -> "RunMetrics":
|
||||
"""Override create_async to handle SQLite timestamp issues"""
|
||||
# For SQLite, explicitly set timestamps as server_default may not work
|
||||
if settings.database_engine == DatabaseChoice.SQLITE:
|
||||
now = datetime.now(timezone.utc)
|
||||
if not self.created_at:
|
||||
self.created_at = now
|
||||
if not self.updated_at:
|
||||
self.updated_at = now
|
||||
|
||||
return await super().create_async(db_session, actor=actor, no_commit=no_commit, no_refresh=no_refresh)
|
||||
@@ -965,7 +965,13 @@ class Message(BaseMessage):
|
||||
}
|
||||
|
||||
elif self.role == "assistant" or self.role == "approval":
|
||||
try:
|
||||
assert self.tool_calls is not None or text_content is not None, vars(self)
|
||||
except AssertionError as e:
|
||||
# relax check if this message only contains reasoning content
|
||||
if self.content is not None and len(self.content) > 0 and isinstance(self.content[0], ReasoningContent):
|
||||
return None
|
||||
raise e
|
||||
|
||||
# if native content, then put it directly inside the content
|
||||
if native_content:
|
||||
@@ -1040,6 +1046,7 @@ class Message(BaseMessage):
|
||||
put_inner_thoughts_in_kwargs: bool = False,
|
||||
use_developer_message: bool = False,
|
||||
) -> List[dict]:
|
||||
messages = Message.filter_messages_for_llm_api(messages)
|
||||
result = [
|
||||
m.to_openai_dict(
|
||||
max_tool_id_length=max_tool_id_length,
|
||||
@@ -1149,6 +1156,7 @@ class Message(BaseMessage):
|
||||
messages: List[Message],
|
||||
max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN,
|
||||
) -> List[dict]:
|
||||
messages = Message.filter_messages_for_llm_api(messages)
|
||||
result = []
|
||||
for message in messages:
|
||||
result.extend(message.to_openai_responses_dicts(max_tool_id_length=max_tool_id_length))
|
||||
@@ -1156,6 +1164,7 @@ class Message(BaseMessage):
|
||||
|
||||
def to_anthropic_dict(
|
||||
self,
|
||||
current_model: str,
|
||||
inner_thoughts_xml_tag="thinking",
|
||||
put_inner_thoughts_in_kwargs: bool = False,
|
||||
# if true, then treat the content field as AssistantMessage
|
||||
@@ -1242,6 +1251,7 @@ class Message(BaseMessage):
|
||||
for content_part in self.content:
|
||||
# TextContent, ImageContent, ToolCallContent, ToolReturnContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent
|
||||
if isinstance(content_part, ReasoningContent):
|
||||
if current_model == self.model:
|
||||
content.append(
|
||||
{
|
||||
"type": "thinking",
|
||||
@@ -1250,6 +1260,7 @@ class Message(BaseMessage):
|
||||
}
|
||||
)
|
||||
elif isinstance(content_part, RedactedReasoningContent):
|
||||
if current_model == self.model:
|
||||
content.append(
|
||||
{
|
||||
"type": "redacted_thinking",
|
||||
@@ -1272,6 +1283,7 @@ class Message(BaseMessage):
|
||||
if self.content is not None and len(self.content) >= 1:
|
||||
for content_part in self.content:
|
||||
if isinstance(content_part, ReasoningContent):
|
||||
if current_model == self.model:
|
||||
content.append(
|
||||
{
|
||||
"type": "thinking",
|
||||
@@ -1280,6 +1292,7 @@ class Message(BaseMessage):
|
||||
}
|
||||
)
|
||||
if isinstance(content_part, RedactedReasoningContent):
|
||||
if current_model == self.model:
|
||||
content.append(
|
||||
{
|
||||
"type": "redacted_thinking",
|
||||
@@ -1349,14 +1362,17 @@ class Message(BaseMessage):
|
||||
@staticmethod
|
||||
def to_anthropic_dicts_from_list(
|
||||
messages: List[Message],
|
||||
current_model: str,
|
||||
inner_thoughts_xml_tag: str = "thinking",
|
||||
put_inner_thoughts_in_kwargs: bool = False,
|
||||
# if true, then treat the content field as AssistantMessage
|
||||
native_content: bool = False,
|
||||
strip_request_heartbeat: bool = False,
|
||||
) -> List[dict]:
|
||||
messages = Message.filter_messages_for_llm_api(messages)
|
||||
result = [
|
||||
m.to_anthropic_dict(
|
||||
current_model=current_model,
|
||||
inner_thoughts_xml_tag=inner_thoughts_xml_tag,
|
||||
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
||||
native_content=native_content,
|
||||
@@ -1369,6 +1385,7 @@ class Message(BaseMessage):
|
||||
|
||||
def to_google_dict(
|
||||
self,
|
||||
current_model: str,
|
||||
put_inner_thoughts_in_kwargs: bool = True,
|
||||
# if true, then treat the content field as AssistantMessage
|
||||
native_content: bool = False,
|
||||
@@ -1484,10 +1501,11 @@ class Message(BaseMessage):
|
||||
for content in self.content:
|
||||
if isinstance(content, TextContent):
|
||||
native_part = {"text": content.text}
|
||||
if content.signature:
|
||||
if content.signature and current_model == self.model:
|
||||
native_part["thought_signature"] = content.signature
|
||||
native_google_content_parts.append(native_part)
|
||||
elif isinstance(content, ReasoningContent):
|
||||
if current_model == self.model:
|
||||
native_google_content_parts.append({"text": content.reasoning, "thought": True})
|
||||
elif isinstance(content, ToolCallContent):
|
||||
native_part = {
|
||||
@@ -1496,7 +1514,7 @@ class Message(BaseMessage):
|
||||
"args": content.input,
|
||||
},
|
||||
}
|
||||
if content.signature:
|
||||
if content.signature and current_model == self.model:
|
||||
native_part["thought_signature"] = content.signature
|
||||
native_google_content_parts.append(native_part)
|
||||
else:
|
||||
@@ -1554,11 +1572,14 @@ class Message(BaseMessage):
|
||||
@staticmethod
|
||||
def to_google_dicts_from_list(
|
||||
messages: List[Message],
|
||||
current_model: str,
|
||||
put_inner_thoughts_in_kwargs: bool = True,
|
||||
native_content: bool = False,
|
||||
):
|
||||
messages = Message.filter_messages_for_llm_api(messages)
|
||||
result = [
|
||||
m.to_google_dict(
|
||||
current_model=current_model,
|
||||
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
||||
native_content=native_content,
|
||||
)
|
||||
@@ -1567,6 +1588,45 @@ class Message(BaseMessage):
|
||||
result = [m for m in result if m is not None]
|
||||
return result
|
||||
|
||||
def is_approval_request(self) -> bool:
|
||||
return self.role == "approval" and self.tool_calls is not None and len(self.tool_calls) > 0
|
||||
|
||||
def is_approval_response(self) -> bool:
|
||||
return self.role == "approval" and self.tool_calls is None and self.approve is not None
|
||||
|
||||
def is_summarization_message(self) -> bool:
|
||||
return (
|
||||
self.role == "user"
|
||||
and self.content is not None
|
||||
and len(self.content) == 1
|
||||
and isinstance(self.content[0], TextContent)
|
||||
and "system_alert" in self.content[0].text
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def filter_messages_for_llm_api(
|
||||
messages: List[Message],
|
||||
) -> List[Message]:
|
||||
messages = [m for m in messages if m is not None]
|
||||
if len(messages) == 0:
|
||||
return []
|
||||
# Add special handling for legacy bug where summarization triggers in the middle of hitl
|
||||
messages_to_filter = []
|
||||
for i in range(len(messages) - 1):
|
||||
first_message_is_approval = messages[i].is_approval_request()
|
||||
second_message_is_summary = messages[i + 1].is_summarization_message()
|
||||
third_message_is_optional_approval = i + 2 >= len(messages) or messages[i + 2].is_approval_response()
|
||||
if first_message_is_approval and second_message_is_summary and third_message_is_optional_approval:
|
||||
messages_to_filter.append(messages[i])
|
||||
for idx in reversed(messages_to_filter): # reverse to avoid index shift
|
||||
messages.remove(idx)
|
||||
|
||||
# Filter last message if it is a lone approval request without a response - this only occurs for token counting
|
||||
if messages[-1].role == "approval" and messages[-1].tool_calls is not None and len(messages[-1].tool_calls) > 0:
|
||||
messages.remove(messages[-1])
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def generate_otid_from_id(message_id: str, index: int) -> str:
|
||||
"""
|
||||
|
||||
21
letta/schemas/run_metrics.py
Normal file
21
letta/schemas/run_metrics.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
class RunMetricsBase(LettaBase):
|
||||
__id_prefix__ = "run"
|
||||
|
||||
|
||||
class RunMetrics(RunMetricsBase):
|
||||
id: str = Field(..., description="The id of the run this metric belongs to (matches runs.id).")
|
||||
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization.")
|
||||
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
|
||||
project_id: Optional[str] = Field(None, description="The project that the run belongs to (cloud only).")
|
||||
run_start_ns: Optional[int] = Field(None, description="The timestamp of the start of the run in nanoseconds.")
|
||||
run_ns: Optional[int] = Field(None, description="Total time for the run in nanoseconds.")
|
||||
num_steps: Optional[int] = Field(None, description="The number of steps in the run.")
|
||||
template_id: Optional[str] = Field(None, description="The template ID that the run belongs to (cloud only).")
|
||||
base_template_id: Optional[str] = Field(None, description="The base template ID that the run belongs to (cloud only).")
|
||||
@@ -10,18 +10,11 @@ from sqlalchemy.ext.asyncio import (
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from letta.database_utils import get_database_uri_for_context
|
||||
from letta.settings import settings
|
||||
|
||||
# Convert PostgreSQL URI to async format
|
||||
pg_uri = settings.letta_pg_uri
|
||||
if pg_uri.startswith("postgresql://"):
|
||||
async_pg_uri = pg_uri.replace("postgresql://", "postgresql+asyncpg://")
|
||||
else:
|
||||
# Handle other URI formats (e.g., postgresql+pg8000://)
|
||||
async_pg_uri = f"postgresql+asyncpg://{pg_uri.split('://', 1)[1]}" if "://" in pg_uri else pg_uri
|
||||
|
||||
# Replace sslmode with ssl for asyncpg
|
||||
async_pg_uri = async_pg_uri.replace("sslmode=", "ssl=")
|
||||
# Convert PostgreSQL URI to async format using common utility
|
||||
async_pg_uri = get_database_uri_for_context(settings.letta_pg_uri, "async")
|
||||
|
||||
# Build engine configuration based on settings
|
||||
engine_args = {
|
||||
|
||||
@@ -13,6 +13,7 @@ from letta.schemas.letta_request import RetrieveStreamRequest
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.run import Run
|
||||
from letta.schemas.run_metrics import RunMetrics
|
||||
from letta.schemas.step import Step
|
||||
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
|
||||
from letta.server.rest_api.redis_stream_manager import redis_sse_stream_generator
|
||||
@@ -224,6 +225,23 @@ async def retrieve_run_usage(
|
||||
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
|
||||
|
||||
|
||||
@router.get("/{run_id}/metrics", response_model=RunMetrics, operation_id="retrieve_metrics_for_run")
|
||||
async def retrieve_metrics_for_run(
|
||||
run_id: str,
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Get run metrics by run ID.
|
||||
"""
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
runs_manager = RunManager()
|
||||
return await runs_manager.get_run_metrics_async(run_id=run_id, actor=actor)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Run metrics not found")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{run_id}/steps",
|
||||
response_model=List[Step],
|
||||
@@ -247,8 +265,7 @@ async def list_run_steps(
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
runs_manager = RunManager()
|
||||
|
||||
try:
|
||||
steps = await runs_manager.get_run_steps(
|
||||
return await runs_manager.get_run_steps(
|
||||
run_id=run_id,
|
||||
actor=actor,
|
||||
limit=limit,
|
||||
@@ -256,9 +273,6 @@ async def list_run_steps(
|
||||
after=after,
|
||||
ascending=(order == "asc"),
|
||||
)
|
||||
return steps
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{run_id}", response_model=Run, operation_id="delete_run")
|
||||
@@ -272,12 +286,7 @@ async def delete_run(
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
runs_manager = RunManager()
|
||||
|
||||
try:
|
||||
run = await runs_manager.delete_run_by_id(run_id=run_id, actor=actor)
|
||||
return run
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
return await runs_manager.delete_run_by_id(run_id=run_id, actor=actor)
|
||||
|
||||
|
||||
@router.post(
|
||||
|
||||
@@ -74,7 +74,7 @@ class AnthropicTokenCounter(TokenCounter):
|
||||
return await self.client.count_tokens(model=self.model, tools=tools)
|
||||
|
||||
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
|
||||
return Message.to_anthropic_dicts_from_list(messages)
|
||||
return Message.to_anthropic_dicts_from_list(messages, current_model=self.model)
|
||||
|
||||
|
||||
class TiktokenCounter(TokenCounter):
|
||||
|
||||
@@ -2,14 +2,10 @@ from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import asc, desc, nulls_last, select
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
|
||||
from letta.orm.run import Run as RunModel
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
from sqlalchemy import asc, desc
|
||||
from typing import Optional
|
||||
|
||||
from letta.services.helpers.agent_manager_helper import _cursor_filter
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
|
||||
|
||||
async def _apply_pagination_async(
|
||||
@@ -29,17 +25,11 @@ async def _apply_pagination_async(
|
||||
sort_nulls_last = False
|
||||
|
||||
if after:
|
||||
result = (
|
||||
await session.execute(
|
||||
select(sort_column, RunModel.id).where(RunModel.id == after)
|
||||
)
|
||||
).first()
|
||||
result = (await session.execute(select(sort_column, RunModel.id).where(RunModel.id == after))).first()
|
||||
if result:
|
||||
after_sort_value, after_id = result
|
||||
# SQLite does not support as granular timestamping, so we need to round the timestamp
|
||||
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(
|
||||
after_sort_value, datetime
|
||||
):
|
||||
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(after_sort_value, datetime):
|
||||
after_sort_value = after_sort_value.strftime("%Y-%m-%d %H:%M:%S")
|
||||
query = query.where(
|
||||
_cursor_filter(
|
||||
@@ -53,17 +43,11 @@ async def _apply_pagination_async(
|
||||
)
|
||||
|
||||
if before:
|
||||
result = (
|
||||
await session.execute(
|
||||
select(sort_column, RunModel.id).where(RunModel.id == before)
|
||||
)
|
||||
).first()
|
||||
result = (await session.execute(select(sort_column, RunModel.id).where(RunModel.id == before))).first()
|
||||
if result:
|
||||
before_sort_value, before_id = result
|
||||
# SQLite does not support as granular timestamping, so we need to round the timestamp
|
||||
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(
|
||||
before_sort_value, datetime
|
||||
):
|
||||
if settings.database_engine is DatabaseChoice.SQLITE and isinstance(before_sort_value, datetime):
|
||||
before_sort_value = before_sort_value.strftime("%Y-%m-%d %H:%M:%S")
|
||||
query = query.where(
|
||||
_cursor_filter(
|
||||
|
||||
@@ -8,9 +8,11 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.log import get_logger
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.orm.run import Run as RunModel
|
||||
from letta.orm.run_metrics import RunMetrics as RunMetricsModel
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.orm.step import Step as StepModel
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
@@ -21,6 +23,7 @@ from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.run import Run as PydanticRun, RunUpdate
|
||||
from letta.schemas.run_metrics import RunMetrics as PydanticRunMetrics
|
||||
from letta.schemas.step import Step as PydanticStep
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
@@ -62,6 +65,23 @@ class RunManager:
|
||||
run = RunModel(**run_data)
|
||||
run.organization_id = organization_id
|
||||
run = await run.create_async(session, actor=actor, no_commit=True, no_refresh=True)
|
||||
|
||||
# Create run metrics with start timestamp
|
||||
import time
|
||||
|
||||
# Get the project_id from the agent
|
||||
agent = await session.get(AgentModel, agent_id)
|
||||
project_id = agent.project_id if agent else None
|
||||
|
||||
metrics = RunMetricsModel(
|
||||
id=run.id,
|
||||
organization_id=organization_id,
|
||||
agent_id=agent_id,
|
||||
project_id=project_id,
|
||||
run_start_ns=int(time.time() * 1e9), # Current time in nanoseconds
|
||||
num_steps=0, # Initialize to 0
|
||||
)
|
||||
await metrics.create_async(session)
|
||||
await session.commit()
|
||||
|
||||
return run.to_pydantic()
|
||||
@@ -178,6 +198,21 @@ class RunManager:
|
||||
await run.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||||
final_metadata = run.metadata_
|
||||
pydantic_run = run.to_pydantic()
|
||||
|
||||
await session.commit()
|
||||
|
||||
# update run metrics table
|
||||
num_steps = len(await self.step_manager.list_steps_async(run_id=run_id, actor=actor))
|
||||
async with db_registry.async_session() as session:
|
||||
metrics = await RunMetricsModel.read_async(db_session=session, identifier=run_id, actor=actor)
|
||||
# Calculate runtime if run is completing
|
||||
if is_terminal_update and metrics.run_start_ns:
|
||||
import time
|
||||
|
||||
current_ns = int(time.time() * 1e9)
|
||||
metrics.run_ns = current_ns - metrics.run_start_ns
|
||||
metrics.num_steps = num_steps
|
||||
await metrics.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||||
await session.commit()
|
||||
|
||||
# Dispatch callback outside of database session if needed
|
||||
@@ -299,3 +334,31 @@ class RunManager:
|
||||
raise NoResultFound(f"Run with id {run_id} not found")
|
||||
pydantic_run = run.to_pydantic()
|
||||
return pydantic_run.request_config
|
||||
|
||||
@enforce_types
|
||||
async def get_run_metrics_async(self, run_id: str, actor: PydanticUser) -> PydanticRunMetrics:
|
||||
"""Get metrics for a run."""
|
||||
async with db_registry.async_session() as session:
|
||||
metrics = await RunMetricsModel.read_async(db_session=session, identifier=run_id, actor=actor)
|
||||
return metrics.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def get_run_steps(
|
||||
self,
|
||||
run_id: str,
|
||||
actor: PydanticUser,
|
||||
limit: Optional[int] = 100,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
ascending: bool = False,
|
||||
) -> List[PydanticStep]:
|
||||
"""Get steps for a run."""
|
||||
async with db_registry.async_session() as session:
|
||||
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION)
|
||||
if not run:
|
||||
raise NoResultFound(f"Run with id {run_id} not found")
|
||||
|
||||
steps = await self.step_manager.list_steps_async(
|
||||
actor=actor, run_id=run_id, limit=limit, before=before, after=after, order="asc" if ascending else "desc"
|
||||
)
|
||||
return steps
|
||||
|
||||
@@ -248,7 +248,11 @@ def unpack_message(packed_message: str) -> str:
|
||||
warnings.warn(f"Was unable to find 'message' field in packed message object: '{packed_message}'")
|
||||
return packed_message
|
||||
else:
|
||||
try:
|
||||
message_type = message_json["type"]
|
||||
except:
|
||||
return packed_message
|
||||
|
||||
if message_type != "user_message":
|
||||
warnings.warn(f"Expected type to be 'user_message', but was '{message_type}', so not unpacking: '{packed_message}'")
|
||||
return packed_message
|
||||
|
||||
@@ -254,18 +254,33 @@ def test_send_message_after_turning_off_requires_approval(
|
||||
messages = accumulate_chunks(response)
|
||||
|
||||
assert messages is not None
|
||||
assert len(messages) == 6 or len(messages) == 8 or len(messages) == 9
|
||||
assert messages[0].message_type == "reasoning_message"
|
||||
assert messages[1].message_type == "assistant_message"
|
||||
assert messages[2].message_type == "tool_call_message"
|
||||
assert messages[3].message_type == "tool_return_message"
|
||||
if len(messages) == 8:
|
||||
assert messages[4].message_type == "reasoning_message"
|
||||
assert messages[5].message_type == "assistant_message"
|
||||
elif len(messages) == 9:
|
||||
assert messages[4].message_type == "reasoning_message"
|
||||
assert messages[5].message_type == "tool_call_message"
|
||||
assert messages[6].message_type == "tool_return_message"
|
||||
assert 6 <= len(messages) <= 9
|
||||
idx = 0
|
||||
|
||||
assert messages[idx].message_type == "reasoning_message"
|
||||
idx += 1
|
||||
|
||||
try:
|
||||
assert messages[idx].message_type == "assistant_message"
|
||||
idx += 1
|
||||
except:
|
||||
pass
|
||||
|
||||
assert messages[idx].message_type == "tool_call_message"
|
||||
idx += 1
|
||||
assert messages[idx].message_type == "tool_return_message"
|
||||
idx += 1
|
||||
|
||||
assert messages[idx].message_type == "reasoning_message"
|
||||
idx += 1
|
||||
try:
|
||||
assert messages[idx].message_type == "assistant_message"
|
||||
idx += 1
|
||||
except:
|
||||
assert messages[idx].message_type == "tool_call_message"
|
||||
idx += 1
|
||||
assert messages[idx].message_type == "tool_return_message"
|
||||
idx += 1
|
||||
|
||||
|
||||
# ------------------------------
|
||||
|
||||
@@ -1074,6 +1074,243 @@ async def test_get_run_request_config_nonexistent_run(server: SyncServer, defaul
|
||||
await server.run_manager.get_run_request_config("nonexistent_run", actor=default_user)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# RunManager Tests - Run Metrics
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_metrics_creation(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that run metrics are created when a run is created."""
|
||||
# Create a run
|
||||
run_data = PydanticRun(
|
||||
metadata={"type": "test_metrics"},
|
||||
agent_id=sarah_agent.id,
|
||||
)
|
||||
created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user)
|
||||
|
||||
# Get the run metrics
|
||||
metrics = await server.run_manager.get_run_metrics_async(run_id=created_run.id, actor=default_user)
|
||||
|
||||
# Assertions
|
||||
assert metrics is not None
|
||||
assert metrics.id == created_run.id
|
||||
assert metrics.agent_id == sarah_agent.id
|
||||
assert metrics.organization_id == default_user.organization_id
|
||||
# project_id may be None or set from the agent
|
||||
assert metrics.run_start_ns is not None
|
||||
assert metrics.run_start_ns > 0
|
||||
assert metrics.run_ns is None # Should be None until run completes
|
||||
assert metrics.num_steps is not None
|
||||
assert metrics.num_steps == 0 # Should be 0 initially
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_metrics_timestamp_tracking(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that run_start_ns is properly tracked."""
|
||||
import time
|
||||
|
||||
# Record time before creation
|
||||
before_ns = int(time.time() * 1e9)
|
||||
|
||||
# Create a run
|
||||
run_data = PydanticRun(
|
||||
metadata={"type": "test_timestamp"},
|
||||
agent_id=sarah_agent.id,
|
||||
)
|
||||
created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user)
|
||||
|
||||
# Record time after creation
|
||||
after_ns = int(time.time() * 1e9)
|
||||
|
||||
# Get the run metrics
|
||||
metrics = await server.run_manager.get_run_metrics_async(run_id=created_run.id, actor=default_user)
|
||||
|
||||
# Verify timestamp is within expected range
|
||||
assert metrics.run_start_ns is not None
|
||||
assert before_ns <= metrics.run_start_ns <= after_ns, f"Expected {before_ns} <= {metrics.run_start_ns} <= {after_ns}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_metrics_duration_calculation(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that run duration (run_ns) is calculated when run completes."""
|
||||
import asyncio
|
||||
|
||||
# Create a run
|
||||
run_data = PydanticRun(
|
||||
metadata={"type": "test_duration"},
|
||||
agent_id=sarah_agent.id,
|
||||
)
|
||||
created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user)
|
||||
|
||||
# Get initial metrics
|
||||
initial_metrics = await server.run_manager.get_run_metrics_async(run_id=created_run.id, actor=default_user)
|
||||
assert initial_metrics.run_ns is None # Should be None initially
|
||||
assert initial_metrics.run_start_ns is not None
|
||||
|
||||
# Wait a bit to ensure there's measurable duration
|
||||
await asyncio.sleep(0.1) # Wait 100ms
|
||||
|
||||
# Update the run to completed
|
||||
updated_run = await server.run_manager.update_run_by_id_async(
|
||||
created_run.id, RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn), actor=default_user
|
||||
)
|
||||
|
||||
# Get updated metrics
|
||||
final_metrics = await server.run_manager.get_run_metrics_async(run_id=created_run.id, actor=default_user)
|
||||
|
||||
# Assertions
|
||||
assert final_metrics.run_ns is not None
|
||||
assert final_metrics.run_ns > 0
|
||||
# Duration should be at least 100ms (100_000_000 nanoseconds)
|
||||
assert final_metrics.run_ns >= 100_000_000, f"Expected run_ns >= 100_000_000, got {final_metrics.run_ns}"
|
||||
# Duration should be reasonable (less than 10 seconds)
|
||||
assert final_metrics.run_ns < 10_000_000_000, f"Expected run_ns < 10_000_000_000, got {final_metrics.run_ns}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_metrics_num_steps_tracking(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that num_steps is properly tracked in run metrics."""
|
||||
# Create a run
|
||||
run_data = PydanticRun(
|
||||
metadata={"type": "test_num_steps"},
|
||||
agent_id=sarah_agent.id,
|
||||
)
|
||||
created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user)
|
||||
|
||||
# Initial metrics should have 0 steps
|
||||
initial_metrics = await server.run_manager.get_run_metrics_async(run_id=created_run.id, actor=default_user)
|
||||
assert initial_metrics.num_steps == 0
|
||||
|
||||
# Add some steps
|
||||
for i in range(3):
|
||||
await server.step_manager.log_step_async(
|
||||
agent_id=sarah_agent.id,
|
||||
provider_name="openai",
|
||||
provider_category="base",
|
||||
model="gpt-4o-mini",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window_limit=8192,
|
||||
usage=UsageStatistics(
|
||||
completion_tokens=100 + i * 10,
|
||||
prompt_tokens=50 + i * 5,
|
||||
total_tokens=150 + i * 15,
|
||||
),
|
||||
run_id=created_run.id,
|
||||
actor=default_user,
|
||||
project_id=sarah_agent.project_id,
|
||||
)
|
||||
|
||||
# Update the run to trigger metrics update
|
||||
await server.run_manager.update_run_by_id_async(
|
||||
created_run.id, RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn), actor=default_user
|
||||
)
|
||||
|
||||
# Get updated metrics
|
||||
final_metrics = await server.run_manager.get_run_metrics_async(run_id=created_run.id, actor=default_user)
|
||||
|
||||
# Verify num_steps was updated
|
||||
assert final_metrics.num_steps == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_metrics_not_found(server: SyncServer, default_user):
|
||||
"""Test getting metrics for non-existent run."""
|
||||
with pytest.raises(NoResultFound):
|
||||
await server.run_manager.get_run_metrics_async(run_id="nonexistent_run", actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_metrics_partial_update(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that non-terminal updates don't calculate run_ns."""
|
||||
# Create a run
|
||||
run_data = PydanticRun(
|
||||
metadata={"type": "test_partial"},
|
||||
agent_id=sarah_agent.id,
|
||||
)
|
||||
created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user)
|
||||
|
||||
# Add a step
|
||||
await server.step_manager.log_step_async(
|
||||
agent_id=sarah_agent.id,
|
||||
provider_name="openai",
|
||||
provider_category="base",
|
||||
model="gpt-4o-mini",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window_limit=8192,
|
||||
usage=UsageStatistics(
|
||||
completion_tokens=100,
|
||||
prompt_tokens=50,
|
||||
total_tokens=150,
|
||||
),
|
||||
run_id=created_run.id,
|
||||
actor=default_user,
|
||||
project_id=sarah_agent.project_id,
|
||||
)
|
||||
|
||||
# Update to running (non-terminal)
|
||||
await server.run_manager.update_run_by_id_async(created_run.id, RunUpdate(status=RunStatus.running), actor=default_user)
|
||||
|
||||
# Get metrics
|
||||
metrics = await server.run_manager.get_run_metrics_async(run_id=created_run.id, actor=default_user)
|
||||
|
||||
# Verify run_ns is still None (not calculated for non-terminal updates)
|
||||
assert metrics.run_ns is None
|
||||
# But num_steps should be updated
|
||||
assert metrics.num_steps == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_metrics_integration_with_run_steps(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test integration between run metrics and run steps."""
|
||||
# Create a run
|
||||
run_data = PydanticRun(
|
||||
metadata={"type": "test_integration"},
|
||||
agent_id=sarah_agent.id,
|
||||
)
|
||||
created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user)
|
||||
|
||||
# Add multiple steps
|
||||
step_ids = []
|
||||
for i in range(5):
|
||||
step = await server.step_manager.log_step_async(
|
||||
agent_id=sarah_agent.id,
|
||||
provider_name="openai",
|
||||
provider_category="base",
|
||||
model="gpt-4o-mini",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window_limit=8192,
|
||||
usage=UsageStatistics(
|
||||
completion_tokens=100,
|
||||
prompt_tokens=50,
|
||||
total_tokens=150,
|
||||
),
|
||||
run_id=created_run.id,
|
||||
actor=default_user,
|
||||
project_id=sarah_agent.project_id,
|
||||
)
|
||||
step_ids.append(step.id)
|
||||
|
||||
# Get run steps
|
||||
run_steps = await server.run_manager.get_run_steps(run_id=created_run.id, actor=default_user)
|
||||
|
||||
# Verify steps are returned correctly
|
||||
assert len(run_steps) == 5
|
||||
assert all(step.run_id == created_run.id for step in run_steps)
|
||||
|
||||
# Update run to completed
|
||||
await server.run_manager.update_run_by_id_async(
|
||||
created_run.id, RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn), actor=default_user
|
||||
)
|
||||
|
||||
# Get final metrics
|
||||
metrics = await server.run_manager.get_run_metrics_async(run_id=created_run.id, actor=default_user)
|
||||
|
||||
# Verify metrics reflect the steps
|
||||
assert metrics.num_steps == 5
|
||||
assert metrics.run_ns is not None
|
||||
|
||||
|
||||
# TODO: add back once metrics are added
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user