From d5e3e22ed1175a73139055cf5bcf9ca79efbe5f6 Mon Sep 17 00:00:00 2001 From: cthomas Date: Wed, 25 Jun 2025 16:15:30 -0700 Subject: [PATCH] feat: add feedback for steps (#2946) --- .../51999513bcf1_steps_feedback_field.py | 31 +++++++++++++++++++ letta/orm/step.py | 3 ++ letta/schemas/step.py | 5 ++- letta/server/rest_api/routers/v1/steps.py | 23 +++++++++++++- letta/services/step_manager.py | 17 +++++++++- 5 files changed, 76 insertions(+), 3 deletions(-) create mode 100644 alembic/versions/51999513bcf1_steps_feedback_field.py diff --git a/alembic/versions/51999513bcf1_steps_feedback_field.py b/alembic/versions/51999513bcf1_steps_feedback_field.py new file mode 100644 index 00000000..cd30a911 --- /dev/null +++ b/alembic/versions/51999513bcf1_steps_feedback_field.py @@ -0,0 +1,31 @@ +"""steps feedback field + +Revision ID: 51999513bcf1 +Revises: 61ee53ec45a5 +Create Date: 2025-06-20 14:09:22.993263 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "51999513bcf1" +down_revision: Union[str, None] = "61ee53ec45a5" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("steps", sa.Column("feedback", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("steps", "feedback") + # ### end Alembic commands ### diff --git a/letta/orm/step.py b/letta/orm/step.py index bd03d935..752c492e 100644 --- a/letta/orm/step.py +++ b/letta/orm/step.py @@ -48,6 +48,9 @@ class Step(SqlalchemyBase): tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.") tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.") trace_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The trace id of the agent step.") + feedback: Mapped[Optional[str]] = mapped_column( + None, nullable=True, doc="The feedback for this step. Must be either 'positive' or 'negative'." + ) # Relationships (foreign keys) organization: Mapped[Optional["Organization"]] = relationship("Organization") diff --git a/letta/schemas/step.py b/letta/schemas/step.py index 2e0604d8..398199b5 100644 --- a/letta/schemas/step.py +++ b/letta/schemas/step.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Literal, Optional from pydantic import Field @@ -32,3 +32,6 @@ class Step(StepBase): tid: Optional[str] = Field(None, description="The unique identifier of the transaction that processed this step.") trace_id: Optional[str] = Field(None, description="The trace id of the agent step.") messages: List[Message] = Field([], description="The messages generated during this step.") + feedback: Optional[Literal["positive", "negative"]] = Field( + None, description="The feedback for this step. Must be either 'positive' or 'negative'." + ) diff --git a/letta/server/rest_api/routers/v1/steps.py b/letta/server/rest_api/routers/v1/steps.py index cf6c9565..0b0e0000 100644 --- a/letta/server/rest_api/routers/v1/steps.py +++ b/letta/server/rest_api/routers/v1/steps.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import List, Optional +from typing import List, Literal, Optional from fastapi import APIRouter, Depends, Header, HTTPException, Query @@ -22,6 +22,8 @@ async def list_steps( model: Optional[str] = Query(None, description="Filter by the name of the model used for the step"), agent_id: Optional[str] = Query(None, description="Filter by the ID of the agent that performed the step"), trace_ids: Optional[list[str]] = Query(None, description="Filter by trace ids returned by the server"), + feedback: Optional[Literal["positive", "negative"]] = Query(None, description="Filter by feedback"), + tags: Optional[list[str]] = Query(None, description="Filter by tags"), server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), ): @@ -46,6 +48,8 @@ async def list_steps( model=model, agent_id=agent_id, trace_ids=trace_ids, + feedback=feedback, + tags=tags, ) @@ -65,6 +69,23 @@ async def retrieve_step( raise HTTPException(status_code=404, detail="Step not found") +@router.patch("/{step_id}/feedback", response_model=Step, operation_id="add_feedback") +async def add_feedback( + step_id: str, + feedback: Optional[Literal["positive", "negative"]], + actor_id: Optional[str] = Header(None, alias="user_id"), + server: SyncServer = Depends(get_letta_server), +): + """ + Add feedback to a step. + """ + try: + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.step_manager.add_feedback_async(step_id=step_id, feedback=feedback, actor=actor) + except NoResultFound: + raise HTTPException(status_code=404, detail="Step not found") + + @router.patch("/{step_id}/transaction/{transaction_id}", response_model=Step, operation_id="update_step_transaction_id") def update_step_transaction_id( step_id: str, diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index 9401af48..8e907f91 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -34,6 +34,7 @@ class StepManager: model: Optional[str] = None, agent_id: Optional[str] = None, trace_ids: Optional[list[str]] = None, + feedback: Optional[Literal["positive", "negative"]] = None, ) -> List[PydanticStep]: """List all jobs with optional pagination and status filter.""" async with db_registry.async_session() as session: @@ -44,7 +45,8 @@ class StepManager: filter_kwargs["agent_id"] = agent_id if trace_ids: filter_kwargs["trace_id"] = trace_ids - + if feedback: + filter_kwargs["feedback"] = feedback steps = await StepModel.list_async( db_session=session, before=before, @@ -150,6 +152,19 @@ class StepManager: step = await StepModel.read_async(db_session=session, identifier=step_id, actor=actor) return step.to_pydantic() + @enforce_types + @trace_method + async def add_feedback_async( + self, step_id: str, feedback: Optional[Literal["positive", "negative"]], actor: PydanticUser + ) -> PydanticStep: + async with db_registry.async_session() as session: + step = await StepModel.read_async(db_session=session, identifier=step_id, actor=actor) + if not step: + raise NoResultFound(f"Step with id {step_id} does not exist") + step.feedback = feedback + step = await step.update_async(session) + return step.to_pydantic() + @enforce_types @trace_method def update_step_transaction_id(self, actor: PydanticUser, step_id: str, transaction_id: str) -> PydanticStep: