From 77c797c752d9cd4fb4162237ceb188e6109d16fe Mon Sep 17 00:00:00 2001 From: Christina Tong Date: Wed, 22 Oct 2025 11:22:23 -0700 Subject: [PATCH] feat: add step count filtering to internal runs [LET-5417] (#5547) * feat: add tool_used field to run_metrics [LET-5419] * change to tool name * use tool ids over names * feat: add internal runs route with template_family filtering * feat: add step count filtering to internal runs [LET-5417] * remove import * add auto generated * add test * fix snippets --- letta/schemas/enums.py | 13 ++- letta/services/run_manager.py | 16 +++- tests/managers/test_run_manager.py | 134 +++++++++++++++++++++++++++++ 3 files changed, 159 insertions(+), 4 deletions(-) diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index a5d48142..4120b81d 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -152,13 +152,12 @@ class ToolType(str, Enum): LETTA_VOICE_SLEEPTIME_CORE = "letta_voice_sleeptime_core" LETTA_BUILTIN = "letta_builtin" LETTA_FILES_CORE = "letta_files_core" - EXTERNAL_LANGCHAIN = "external_langchain" # DEPRECATED - EXTERNAL_COMPOSIO = "external_composio" # DEPRECATED + EXTERNAL_LANGCHAIN = "external_langchain" # DEPRECATED + EXTERNAL_COMPOSIO = "external_composio" # DEPRECATED # TODO is "external" the right name here? Since as of now, MCP is local / doesn't support remote? EXTERNAL_MCP = "external_mcp" - class JobType(str, Enum): JOB = "job" RUN = "run" @@ -222,3 +221,11 @@ class TagMatchMode(str, Enum): ANY = "any" ALL = "all" + + +class ComparisonOperator(str, Enum): + """Comparison operators for filtering numeric values""" + + EQ = "eq" # equals + GTE = "gte" # greater than or equal + LTE = "lte" # less than or equal diff --git a/letta/services/run_manager.py b/letta/services/run_manager.py index fa4671eb..8a2b3a4b 100644 --- a/letta/services/run_manager.py +++ b/letta/services/run_manager.py @@ -16,7 +16,7 @@ from letta.orm.run_metrics import RunMetrics as RunMetricsModel from letta.orm.sqlalchemy_base import AccessType from letta.orm.step import Step as StepModel from letta.otel.tracing import log_event, trace_method -from letta.schemas.enums import AgentType, MessageRole, RunStatus +from letta.schemas.enums import AgentType, ComparisonOperator, MessageRole, RunStatus from letta.schemas.job import LettaRequestConfig from letta.schemas.letta_message import LettaMessage, LettaMessageUnion from letta.schemas.letta_response import LettaResponse @@ -109,6 +109,8 @@ class RunManager: stop_reason: Optional[str] = None, background: Optional[bool] = None, template_family: Optional[str] = None, + step_count: Optional[int] = None, + step_count_operator: ComparisonOperator = ComparisonOperator.EQ, ) -> List[PydanticRun]: """List runs with filtering options.""" async with db_registry.async_session() as session: @@ -138,6 +140,18 @@ class RunManager: if template_family: query = query.filter(RunModel.base_template_id == template_family) + # Filter by step_count - join with run_metrics + if step_count is not None: + query = query.join(RunMetricsModel, RunModel.id == RunMetricsModel.id) + + # Filter by step_count with the specified operator + if step_count_operator == ComparisonOperator.EQ: + query = query.filter(RunMetricsModel.num_steps == step_count) + elif step_count_operator == ComparisonOperator.GTE: + query = query.filter(RunMetricsModel.num_steps >= step_count) + elif step_count_operator == ComparisonOperator.LTE: + query = query.filter(RunMetricsModel.num_steps <= step_count) + # Apply pagination from letta.services.helpers.run_manager_helper import _apply_pagination_async diff --git a/tests/managers/test_run_manager.py b/tests/managers/test_run_manager.py index 9428233f..2c2b17a5 100644 --- a/tests/managers/test_run_manager.py +++ b/tests/managers/test_run_manager.py @@ -382,6 +382,140 @@ async def test_list_runs_by_stop_reason(server: SyncServer, sarah_agent, default assert runs[0].id == run.id +@pytest.mark.asyncio +async def test_list_runs_by_step_count(server: SyncServer, sarah_agent, default_user): + """Test listing runs filtered by step count.""" + from letta.schemas.enums import ComparisonOperator + + # Create runs with different numbers of steps + runs_data = [] + + # Run with 0 steps + run_0 = await server.run_manager.create_run( + pydantic_run=PydanticRun( + agent_id=sarah_agent.id, + metadata={"steps": 0}, + ), + actor=default_user, + ) + runs_data.append((run_0, 0)) + + # Run with 2 steps + run_2 = await server.run_manager.create_run( + pydantic_run=PydanticRun( + agent_id=sarah_agent.id, + metadata={"steps": 2}, + ), + actor=default_user, + ) + for i in range(2): + await server.step_manager.log_step_async( + agent_id=sarah_agent.id, + provider_name="openai", + provider_category="base", + model="gpt-4o-mini", + model_endpoint="https://api.openai.com/v1", + context_window_limit=8192, + usage=UsageStatistics( + completion_tokens=100, + prompt_tokens=50, + total_tokens=150, + ), + run_id=run_2.id, + actor=default_user, + project_id=sarah_agent.project_id, + ) + runs_data.append((run_2, 2)) + + # Run with 5 steps + run_5 = await server.run_manager.create_run( + pydantic_run=PydanticRun( + agent_id=sarah_agent.id, + metadata={"steps": 5}, + ), + actor=default_user, + ) + for i in range(5): + await server.step_manager.log_step_async( + agent_id=sarah_agent.id, + provider_name="openai", + provider_category="base", + model="gpt-4o-mini", + model_endpoint="https://api.openai.com/v1", + context_window_limit=8192, + usage=UsageStatistics( + completion_tokens=100, + prompt_tokens=50, + total_tokens=150, + ), + run_id=run_5.id, + actor=default_user, + project_id=sarah_agent.project_id, + ) + runs_data.append((run_5, 5)) + + # Update all runs to trigger metrics update + for run, _ in runs_data: + await server.run_manager.update_run_by_id_async( + run.id, + RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn), + actor=default_user, + ) + + # Test EQ operator - exact match + runs_eq_2 = await server.run_manager.list_runs( + actor=default_user, + agent_id=sarah_agent.id, + step_count=2, + step_count_operator=ComparisonOperator.EQ, + ) + assert len(runs_eq_2) == 1 + assert runs_eq_2[0].id == run_2.id + + # Test GTE operator - greater than or equal + runs_gte_2 = await server.run_manager.list_runs( + actor=default_user, + agent_id=sarah_agent.id, + step_count=2, + step_count_operator=ComparisonOperator.GTE, + ) + assert len(runs_gte_2) == 2 + run_ids_gte = {run.id for run in runs_gte_2} + assert run_2.id in run_ids_gte + assert run_5.id in run_ids_gte + + # Test LTE operator - less than or equal + runs_lte_2 = await server.run_manager.list_runs( + actor=default_user, + agent_id=sarah_agent.id, + step_count=2, + step_count_operator=ComparisonOperator.LTE, + ) + assert len(runs_lte_2) == 2 + run_ids_lte = {run.id for run in runs_lte_2} + assert run_0.id in run_ids_lte + assert run_2.id in run_ids_lte + + # Test GTE with 0 - should return all runs + runs_gte_0 = await server.run_manager.list_runs( + actor=default_user, + agent_id=sarah_agent.id, + step_count=0, + step_count_operator=ComparisonOperator.GTE, + ) + assert len(runs_gte_0) == 3 + + # Test LTE with 0 - should return only run with 0 steps + runs_lte_0 = await server.run_manager.list_runs( + actor=default_user, + agent_id=sarah_agent.id, + step_count=0, + step_count_operator=ComparisonOperator.LTE, + ) + assert len(runs_lte_0) == 1 + assert runs_lte_0[0].id == run_0.id + + @pytest.mark.asyncio async def test_list_runs_by_base_template_id(server: SyncServer, sarah_agent, default_user): """Test listing runs by template family."""