From d4b510e358cf88cd8fc74a055532b4c3f0f65964 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Fri, 27 Jun 2025 14:35:29 -0700 Subject: [PATCH] feat: allow filtering steps by feedback (#3061) --- letta/orm/sqlalchemy_base.py | 9 +++++++++ letta/server/rest_api/routers/v1/steps.py | 2 ++ letta/services/step_manager.py | 2 ++ tests/test_managers.py | 13 +++++++++++++ 4 files changed, 26 insertions(+) diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 8c33b4b5..af8a1077 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -183,6 +183,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): identifier_keys: Optional[List[str]] = None, identity_id: Optional[str] = None, query_options: Sequence[ORMOption] | None = None, # ← new + has_feedback: Optional[bool] = None, **kwargs, ) -> List["SqlalchemyBase"]: """ @@ -281,6 +282,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): identifier_keys: Optional[List[str]] = None, identity_id: Optional[str] = None, check_is_deleted: bool = False, + has_feedback: Optional[bool] = None, **kwargs, ): """ @@ -337,6 +339,13 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): if end_date: query = query.filter(cls.created_at < end_date) + # Feedback filtering + if has_feedback is not None and hasattr(cls, "feedback"): + if has_feedback: + query = query.filter(cls.feedback.isnot(None)) + else: + query = query.filter(cls.feedback.is_(None)) + # Handle pagination based on before/after if before_obj or after_obj: conditions = [] diff --git a/letta/server/rest_api/routers/v1/steps.py b/letta/server/rest_api/routers/v1/steps.py index f149c851..9fe11ee8 100644 --- a/letta/server/rest_api/routers/v1/steps.py +++ b/letta/server/rest_api/routers/v1/steps.py @@ -24,6 +24,7 @@ async def list_steps( agent_id: Optional[str] = Query(None, description="Filter by the ID of the agent that performed the step"), trace_ids: Optional[list[str]] = Query(None, description="Filter by trace ids returned by the server"), feedback: Optional[Literal["positive", "negative"]] = Query(None, description="Filter by feedback"), + has_feedback: Optional[bool] = Query(None, description="Filter by whether steps have feedback (true) or not (false)"), tags: Optional[list[str]] = Query(None, description="Filter by tags"), server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), @@ -50,6 +51,7 @@ async def list_steps( agent_id=agent_id, trace_ids=trace_ids, feedback=feedback, + has_feedback=has_feedback, tags=tags, ) diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index a5fed8bf..d4d5c7b3 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -41,6 +41,7 @@ class StepManager: agent_id: Optional[str] = None, trace_ids: Optional[list[str]] = None, feedback: Optional[Literal["positive", "negative"]] = None, + has_feedback: Optional[bool] = None, ) -> List[PydanticStep]: """List all jobs with optional pagination and status filter.""" async with db_registry.async_session() as session: @@ -61,6 +62,7 @@ class StepManager: end_date=end_date, limit=limit, ascending=True if order == "asc" else False, + has_feedback=has_feedback, **filter_kwargs, ) return [step.to_pydantic() for step in steps] diff --git a/tests/test_managers.py b/tests/test_managers.py index 4ba3c053..c7031145 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -6200,6 +6200,19 @@ async def test_job_usage_stats_add_multiple(server: SyncServer, sarah_agent, def steps = await step_manager.list_steps_async(agent_id=sarah_agent.id, actor=default_user) assert len(steps) == 2 + # add step feedback + step_manager = server.step_manager + + # Add feedback to first step + await step_manager.add_feedback_async(step_id=steps[0].id, feedback="positive", actor=default_user) + + # Test has_feedback filtering + steps_with_feedback = await step_manager.list_steps_async(agent_id=sarah_agent.id, has_feedback=True, actor=default_user) + assert len(steps_with_feedback) == 1 + + steps_without_feedback = await step_manager.list_steps_async(agent_id=sarah_agent.id, actor=default_user) + assert len(steps_without_feedback) == 2 + def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user): """Test getting usage statistics for a nonexistent job."""