feat: add step count filtering to internal runs [LET-5417] (#5547)
* feat: add tool_used field to run_metrics [LET-5419] * change to tool name * use tool ids over names * feat: add internal runs route with template_family filtering * feat: add step count filtering to internal runs [LET-5417] * remove import * add auto generated * add test * fix snippets
This commit is contained in:
committed by
Caren Thomas
parent
fc531ca6de
commit
77c797c752
@@ -152,13 +152,12 @@ class ToolType(str, Enum):
|
||||
LETTA_VOICE_SLEEPTIME_CORE = "letta_voice_sleeptime_core"
|
||||
LETTA_BUILTIN = "letta_builtin"
|
||||
LETTA_FILES_CORE = "letta_files_core"
|
||||
EXTERNAL_LANGCHAIN = "external_langchain" # DEPRECATED
|
||||
EXTERNAL_COMPOSIO = "external_composio" # DEPRECATED
|
||||
EXTERNAL_LANGCHAIN = "external_langchain" # DEPRECATED
|
||||
EXTERNAL_COMPOSIO = "external_composio" # DEPRECATED
|
||||
# TODO is "external" the right name here? Since as of now, MCP is local / doesn't support remote?
|
||||
EXTERNAL_MCP = "external_mcp"
|
||||
|
||||
|
||||
|
||||
class JobType(str, Enum):
|
||||
JOB = "job"
|
||||
RUN = "run"
|
||||
@@ -222,3 +221,11 @@ class TagMatchMode(str, Enum):
|
||||
|
||||
ANY = "any"
|
||||
ALL = "all"
|
||||
|
||||
|
||||
class ComparisonOperator(str, Enum):
|
||||
"""Comparison operators for filtering numeric values"""
|
||||
|
||||
EQ = "eq" # equals
|
||||
GTE = "gte" # greater than or equal
|
||||
LTE = "lte" # less than or equal
|
||||
|
||||
@@ -16,7 +16,7 @@ from letta.orm.run_metrics import RunMetrics as RunMetricsModel
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.orm.step import Step as StepModel
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
from letta.schemas.enums import AgentType, MessageRole, RunStatus
|
||||
from letta.schemas.enums import AgentType, ComparisonOperator, MessageRole, RunStatus
|
||||
from letta.schemas.job import LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessage, LettaMessageUnion
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
@@ -109,6 +109,8 @@ class RunManager:
|
||||
stop_reason: Optional[str] = None,
|
||||
background: Optional[bool] = None,
|
||||
template_family: Optional[str] = None,
|
||||
step_count: Optional[int] = None,
|
||||
step_count_operator: ComparisonOperator = ComparisonOperator.EQ,
|
||||
) -> List[PydanticRun]:
|
||||
"""List runs with filtering options."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -138,6 +140,18 @@ class RunManager:
|
||||
if template_family:
|
||||
query = query.filter(RunModel.base_template_id == template_family)
|
||||
|
||||
# Filter by step_count - join with run_metrics
|
||||
if step_count is not None:
|
||||
query = query.join(RunMetricsModel, RunModel.id == RunMetricsModel.id)
|
||||
|
||||
# Filter by step_count with the specified operator
|
||||
if step_count_operator == ComparisonOperator.EQ:
|
||||
query = query.filter(RunMetricsModel.num_steps == step_count)
|
||||
elif step_count_operator == ComparisonOperator.GTE:
|
||||
query = query.filter(RunMetricsModel.num_steps >= step_count)
|
||||
elif step_count_operator == ComparisonOperator.LTE:
|
||||
query = query.filter(RunMetricsModel.num_steps <= step_count)
|
||||
|
||||
# Apply pagination
|
||||
from letta.services.helpers.run_manager_helper import _apply_pagination_async
|
||||
|
||||
|
||||
@@ -382,6 +382,140 @@ async def test_list_runs_by_stop_reason(server: SyncServer, sarah_agent, default
|
||||
assert runs[0].id == run.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_runs_by_step_count(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test listing runs filtered by step count."""
|
||||
from letta.schemas.enums import ComparisonOperator
|
||||
|
||||
# Create runs with different numbers of steps
|
||||
runs_data = []
|
||||
|
||||
# Run with 0 steps
|
||||
run_0 = await server.run_manager.create_run(
|
||||
pydantic_run=PydanticRun(
|
||||
agent_id=sarah_agent.id,
|
||||
metadata={"steps": 0},
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
runs_data.append((run_0, 0))
|
||||
|
||||
# Run with 2 steps
|
||||
run_2 = await server.run_manager.create_run(
|
||||
pydantic_run=PydanticRun(
|
||||
agent_id=sarah_agent.id,
|
||||
metadata={"steps": 2},
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
for i in range(2):
|
||||
await server.step_manager.log_step_async(
|
||||
agent_id=sarah_agent.id,
|
||||
provider_name="openai",
|
||||
provider_category="base",
|
||||
model="gpt-4o-mini",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window_limit=8192,
|
||||
usage=UsageStatistics(
|
||||
completion_tokens=100,
|
||||
prompt_tokens=50,
|
||||
total_tokens=150,
|
||||
),
|
||||
run_id=run_2.id,
|
||||
actor=default_user,
|
||||
project_id=sarah_agent.project_id,
|
||||
)
|
||||
runs_data.append((run_2, 2))
|
||||
|
||||
# Run with 5 steps
|
||||
run_5 = await server.run_manager.create_run(
|
||||
pydantic_run=PydanticRun(
|
||||
agent_id=sarah_agent.id,
|
||||
metadata={"steps": 5},
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
for i in range(5):
|
||||
await server.step_manager.log_step_async(
|
||||
agent_id=sarah_agent.id,
|
||||
provider_name="openai",
|
||||
provider_category="base",
|
||||
model="gpt-4o-mini",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window_limit=8192,
|
||||
usage=UsageStatistics(
|
||||
completion_tokens=100,
|
||||
prompt_tokens=50,
|
||||
total_tokens=150,
|
||||
),
|
||||
run_id=run_5.id,
|
||||
actor=default_user,
|
||||
project_id=sarah_agent.project_id,
|
||||
)
|
||||
runs_data.append((run_5, 5))
|
||||
|
||||
# Update all runs to trigger metrics update
|
||||
for run, _ in runs_data:
|
||||
await server.run_manager.update_run_by_id_async(
|
||||
run.id,
|
||||
RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Test EQ operator - exact match
|
||||
runs_eq_2 = await server.run_manager.list_runs(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
step_count=2,
|
||||
step_count_operator=ComparisonOperator.EQ,
|
||||
)
|
||||
assert len(runs_eq_2) == 1
|
||||
assert runs_eq_2[0].id == run_2.id
|
||||
|
||||
# Test GTE operator - greater than or equal
|
||||
runs_gte_2 = await server.run_manager.list_runs(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
step_count=2,
|
||||
step_count_operator=ComparisonOperator.GTE,
|
||||
)
|
||||
assert len(runs_gte_2) == 2
|
||||
run_ids_gte = {run.id for run in runs_gte_2}
|
||||
assert run_2.id in run_ids_gte
|
||||
assert run_5.id in run_ids_gte
|
||||
|
||||
# Test LTE operator - less than or equal
|
||||
runs_lte_2 = await server.run_manager.list_runs(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
step_count=2,
|
||||
step_count_operator=ComparisonOperator.LTE,
|
||||
)
|
||||
assert len(runs_lte_2) == 2
|
||||
run_ids_lte = {run.id for run in runs_lte_2}
|
||||
assert run_0.id in run_ids_lte
|
||||
assert run_2.id in run_ids_lte
|
||||
|
||||
# Test GTE with 0 - should return all runs
|
||||
runs_gte_0 = await server.run_manager.list_runs(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
step_count=0,
|
||||
step_count_operator=ComparisonOperator.GTE,
|
||||
)
|
||||
assert len(runs_gte_0) == 3
|
||||
|
||||
# Test LTE with 0 - should return only run with 0 steps
|
||||
runs_lte_0 = await server.run_manager.list_runs(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
step_count=0,
|
||||
step_count_operator=ComparisonOperator.LTE,
|
||||
)
|
||||
assert len(runs_lte_0) == 1
|
||||
assert runs_lte_0[0].id == run_0.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_runs_by_base_template_id(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test listing runs by template family."""
|
||||
|
||||
Reference in New Issue
Block a user