diff --git a/letta/server/rest_api/routers/v1/steps.py b/letta/server/rest_api/routers/v1/steps.py index 0b0e0000..f149c851 100644 --- a/letta/server/rest_api/routers/v1/steps.py +++ b/letta/server/rest_api/routers/v1/steps.py @@ -7,6 +7,7 @@ from letta.orm.errors import NoResultFound from letta.schemas.step import Step from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer +from letta.services.step_manager import FeedbackType router = APIRouter(prefix="/steps", tags=["steps"]) @@ -72,7 +73,7 @@ async def retrieve_step( @router.patch("/{step_id}/feedback", response_model=Step, operation_id="add_feedback") async def add_feedback( step_id: str, - feedback: Optional[Literal["positive", "negative"]], + feedback: Optional[FeedbackType], actor_id: Optional[str] = Header(None, alias="user_id"), server: SyncServer = Depends(get_letta_server), ): diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index 8e907f91..3f5ca389 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -1,4 +1,5 @@ from datetime import datetime +from enum import Enum from typing import List, Literal, Optional from sqlalchemy import select @@ -18,6 +19,11 @@ from letta.server.db import db_registry from letta.utils import enforce_types +class FeedbackType(str, Enum): + POSITIVE = "positive" + NEGATIVE = "negative" + + class StepManager: @enforce_types @@ -154,9 +160,7 @@ class StepManager: @enforce_types @trace_method - async def add_feedback_async( - self, step_id: str, feedback: Optional[Literal["positive", "negative"]], actor: PydanticUser - ) -> PydanticStep: + async def add_feedback_async(self, step_id: str, feedback: Optional[FeedbackType], actor: PydanticUser) -> PydanticStep: async with db_registry.async_session() as session: step = await StepModel.read_async(db_session=session, identifier=step_id, actor=actor) if not step: