feat: add step count filtering to internal runs [LET-5417] (#5547)

* feat: add tool_used field to run_metrics [LET-5419]

* change to tool name

* use tool ids over names

* feat: add internal runs route with template_family filtering

* feat: add step count filtering to internal runs [LET-5417]

* remove import

* add auto generated

* add test

* fix snippets
This commit is contained in:
Christina Tong
2025-10-22 11:22:23 -07:00
committed by Caren Thomas
parent fc531ca6de
commit 77c797c752
3 changed files with 159 additions and 4 deletions

View File

@@ -152,13 +152,12 @@ class ToolType(str, Enum):
LETTA_VOICE_SLEEPTIME_CORE = "letta_voice_sleeptime_core"
LETTA_BUILTIN = "letta_builtin"
LETTA_FILES_CORE = "letta_files_core"
EXTERNAL_LANGCHAIN = "external_langchain" # DEPRECATED
EXTERNAL_COMPOSIO = "external_composio" # DEPRECATED
EXTERNAL_LANGCHAIN = "external_langchain" # DEPRECATED
EXTERNAL_COMPOSIO = "external_composio" # DEPRECATED
# TODO is "external" the right name here? Since as of now, MCP is local / doesn't support remote?
EXTERNAL_MCP = "external_mcp"
class JobType(str, Enum):
JOB = "job"
RUN = "run"
@@ -222,3 +221,11 @@ class TagMatchMode(str, Enum):
ANY = "any"
ALL = "all"
class ComparisonOperator(str, Enum):
"""Comparison operators for filtering numeric values"""
EQ = "eq" # equals
GTE = "gte" # greater than or equal
LTE = "lte" # less than or equal

View File

@@ -16,7 +16,7 @@ from letta.orm.run_metrics import RunMetrics as RunMetricsModel
from letta.orm.sqlalchemy_base import AccessType
from letta.orm.step import Step as StepModel
from letta.otel.tracing import log_event, trace_method
from letta.schemas.enums import AgentType, MessageRole, RunStatus
from letta.schemas.enums import AgentType, ComparisonOperator, MessageRole, RunStatus
from letta.schemas.job import LettaRequestConfig
from letta.schemas.letta_message import LettaMessage, LettaMessageUnion
from letta.schemas.letta_response import LettaResponse
@@ -109,6 +109,8 @@ class RunManager:
stop_reason: Optional[str] = None,
background: Optional[bool] = None,
template_family: Optional[str] = None,
step_count: Optional[int] = None,
step_count_operator: ComparisonOperator = ComparisonOperator.EQ,
) -> List[PydanticRun]:
"""List runs with filtering options."""
async with db_registry.async_session() as session:
@@ -138,6 +140,18 @@ class RunManager:
if template_family:
query = query.filter(RunModel.base_template_id == template_family)
# Filter by step_count - join with run_metrics
if step_count is not None:
query = query.join(RunMetricsModel, RunModel.id == RunMetricsModel.id)
# Filter by step_count with the specified operator
if step_count_operator == ComparisonOperator.EQ:
query = query.filter(RunMetricsModel.num_steps == step_count)
elif step_count_operator == ComparisonOperator.GTE:
query = query.filter(RunMetricsModel.num_steps >= step_count)
elif step_count_operator == ComparisonOperator.LTE:
query = query.filter(RunMetricsModel.num_steps <= step_count)
# Apply pagination
from letta.services.helpers.run_manager_helper import _apply_pagination_async

View File

@@ -382,6 +382,140 @@ async def test_list_runs_by_stop_reason(server: SyncServer, sarah_agent, default
assert runs[0].id == run.id
@pytest.mark.asyncio
async def test_list_runs_by_step_count(server: SyncServer, sarah_agent, default_user):
"""Test listing runs filtered by step count."""
from letta.schemas.enums import ComparisonOperator
# Create runs with different numbers of steps
runs_data = []
# Run with 0 steps
run_0 = await server.run_manager.create_run(
pydantic_run=PydanticRun(
agent_id=sarah_agent.id,
metadata={"steps": 0},
),
actor=default_user,
)
runs_data.append((run_0, 0))
# Run with 2 steps
run_2 = await server.run_manager.create_run(
pydantic_run=PydanticRun(
agent_id=sarah_agent.id,
metadata={"steps": 2},
),
actor=default_user,
)
for i in range(2):
await server.step_manager.log_step_async(
agent_id=sarah_agent.id,
provider_name="openai",
provider_category="base",
model="gpt-4o-mini",
model_endpoint="https://api.openai.com/v1",
context_window_limit=8192,
usage=UsageStatistics(
completion_tokens=100,
prompt_tokens=50,
total_tokens=150,
),
run_id=run_2.id,
actor=default_user,
project_id=sarah_agent.project_id,
)
runs_data.append((run_2, 2))
# Run with 5 steps
run_5 = await server.run_manager.create_run(
pydantic_run=PydanticRun(
agent_id=sarah_agent.id,
metadata={"steps": 5},
),
actor=default_user,
)
for i in range(5):
await server.step_manager.log_step_async(
agent_id=sarah_agent.id,
provider_name="openai",
provider_category="base",
model="gpt-4o-mini",
model_endpoint="https://api.openai.com/v1",
context_window_limit=8192,
usage=UsageStatistics(
completion_tokens=100,
prompt_tokens=50,
total_tokens=150,
),
run_id=run_5.id,
actor=default_user,
project_id=sarah_agent.project_id,
)
runs_data.append((run_5, 5))
# Update all runs to trigger metrics update
for run, _ in runs_data:
await server.run_manager.update_run_by_id_async(
run.id,
RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn),
actor=default_user,
)
# Test EQ operator - exact match
runs_eq_2 = await server.run_manager.list_runs(
actor=default_user,
agent_id=sarah_agent.id,
step_count=2,
step_count_operator=ComparisonOperator.EQ,
)
assert len(runs_eq_2) == 1
assert runs_eq_2[0].id == run_2.id
# Test GTE operator - greater than or equal
runs_gte_2 = await server.run_manager.list_runs(
actor=default_user,
agent_id=sarah_agent.id,
step_count=2,
step_count_operator=ComparisonOperator.GTE,
)
assert len(runs_gte_2) == 2
run_ids_gte = {run.id for run in runs_gte_2}
assert run_2.id in run_ids_gte
assert run_5.id in run_ids_gte
# Test LTE operator - less than or equal
runs_lte_2 = await server.run_manager.list_runs(
actor=default_user,
agent_id=sarah_agent.id,
step_count=2,
step_count_operator=ComparisonOperator.LTE,
)
assert len(runs_lte_2) == 2
run_ids_lte = {run.id for run in runs_lte_2}
assert run_0.id in run_ids_lte
assert run_2.id in run_ids_lte
# Test GTE with 0 - should return all runs
runs_gte_0 = await server.run_manager.list_runs(
actor=default_user,
agent_id=sarah_agent.id,
step_count=0,
step_count_operator=ComparisonOperator.GTE,
)
assert len(runs_gte_0) == 3
# Test LTE with 0 - should return only run with 0 steps
runs_lte_0 = await server.run_manager.list_runs(
actor=default_user,
agent_id=sarah_agent.id,
step_count=0,
step_count_operator=ComparisonOperator.LTE,
)
assert len(runs_lte_0) == 1
assert runs_lte_0[0].id == run_0.id
@pytest.mark.asyncio
async def test_list_runs_by_base_template_id(server: SyncServer, sarah_agent, default_user):
"""Test listing runs by template family."""