diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6d9b7491..53f5d73a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,18 +13,18 @@ repos: hooks: - id: autoflake name: autoflake - entry: bash -c 'cd apps/core && poetry run autoflake --remove-all-unused-imports --remove-unused-variables --in-place --recursive --ignore-init-module-imports .' + entry: poetry run autoflake --remove-all-unused-imports --remove-unused-variables --in-place --recursive --ignore-init-module-imports . language: system types: [python] - id: isort name: isort - entry: bash -c 'cd apps/core && poetry run isort --profile black .' + entry: poetry run isort --profile black . language: system types: [python] exclude: ^docs/ - id: black name: black - entry: bash -c 'cd apps/core && poetry run black --line-length 140 --target-version py310 --target-version py311 .' + entry: poetry run black --line-length 140 --target-version py310 --target-version py311 . language: system types: [python] exclude: ^docs/ diff --git a/letta/server/rest_api/routers/v1/__init__.py b/letta/server/rest_api/routers/v1/__init__.py index 5611c055..6a683ac8 100644 --- a/letta/server/rest_api/routers/v1/__init__.py +++ b/letta/server/rest_api/routers/v1/__init__.py @@ -7,6 +7,7 @@ from letta.server.rest_api.routers.v1.providers import router as providers_route from letta.server.rest_api.routers.v1.runs import router as runs_router from letta.server.rest_api.routers.v1.sandbox_configs import router as sandbox_configs_router from letta.server.rest_api.routers.v1.sources import router as sources_router +from letta.server.rest_api.routers.v1.steps import router as steps_router from letta.server.rest_api.routers.v1.tags import router as tags_router from letta.server.rest_api.routers.v1.tools import router as tools_router @@ -21,5 +22,6 @@ ROUTERS = [ sandbox_configs_router, providers_router, runs_router, + steps_router, tags_router, ] diff --git a/letta/server/rest_api/routers/v1/steps.py b/letta/server/rest_api/routers/v1/steps.py new file mode 100644 index 00000000..cb82bf59 --- /dev/null +++ b/letta/server/rest_api/routers/v1/steps.py @@ -0,0 +1,78 @@ +from datetime import datetime +from typing import List, Optional + +from fastapi import APIRouter, Depends, Header, HTTPException, Query + +from letta.orm.errors import NoResultFound +from letta.schemas.step import Step +from letta.server.rest_api.utils import get_letta_server +from letta.server.server import SyncServer + +router = APIRouter(prefix="/steps", tags=["steps"]) + + +@router.get("", response_model=List[Step], operation_id="list_steps") +def list_steps( + before: Optional[str] = Query(None, description="Return steps before this step ID"), + after: Optional[str] = Query(None, description="Return steps after this step ID"), + limit: Optional[int] = Query(50, description="Maximum number of steps to return"), + order: Optional[str] = Query("desc", description="Sort order (asc or desc)"), + start_date: Optional[str] = Query(None, description='Return steps after this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'), + end_date: Optional[str] = Query(None, description='Return steps before this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'), + model: Optional[str] = Query(None, description="Filter by the name of the model used for the step"), + server: SyncServer = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), +): + """ + List steps with optional pagination and date filters. + Dates should be provided in ISO 8601 format (e.g. 2025-01-29T15:01:19-08:00) + """ + actor = server.user_manager.get_user_or_default(user_id=user_id) + + # Convert ISO strings to datetime objects if provided + start_dt = datetime.fromisoformat(start_date) if start_date else None + end_dt = datetime.fromisoformat(end_date) if end_date else None + + return server.step_manager.list_steps( + actor=actor, + before=before, + after=after, + start_date=start_dt, + end_date=end_dt, + limit=limit, + order=order, + model=model, + ) + + +@router.get("/{step_id}", response_model=Step, operation_id="retrieve_step") +def retrieve_step( + step_id: str, + user_id: Optional[str] = Header(None, alias="user_id"), + server: SyncServer = Depends(get_letta_server), +): + """ + Get a step by ID. + """ + try: + return server.step_manager.get_step(step_id=step_id) + except NoResultFound: + raise HTTPException(status_code=404, detail="Step not found") + + +@router.patch("/{step_id}/transaction/{transaction_id}", response_model=Step, operation_id="update_step_transaction_id") +def update_step_transaction_id( + step_id: str, + transaction_id: str, + user_id: Optional[str] = Header(None, alias="user_id"), + server: SyncServer = Depends(get_letta_server), +): + """ + Update the transaction ID for a step. + """ + actor = server.user_manager.get_user_or_default(user_id=user_id) + + try: + return server.step_manager.update_step_transaction_id(actor, step_id=step_id, transaction_id=transaction_id) + except NoResultFound: + raise HTTPException(status_code=404, detail="Step not found") diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index cbeee458..5e6fbce5 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -1,3 +1,4 @@ +import datetime from typing import List, Literal, Optional from sqlalchemy import select @@ -20,6 +21,34 @@ class StepManager: self.session_maker = db_context + @enforce_types + def list_steps( + self, + actor: PydanticUser, + before: Optional[str] = None, + after: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + limit: Optional[int] = 50, + order: Optional[str] = None, + model: Optional[str] = None, + ) -> List[PydanticStep]: + """List all jobs with optional pagination and status filter.""" + with self.session_maker() as session: + filter_kwargs = {"organization_id": actor.organization_id, "model": model} + + steps = StepModel.list( + db_session=session, + before=before, + after=after, + start_date=start_date, + end_date=end_date, + limit=limit, + ascending=True if order == "asc" else False, + **filter_kwargs, + ) + return [step.to_pydantic() for step in steps] + @enforce_types def log_step( self, @@ -58,6 +87,32 @@ class StepManager: step = StepModel.read(db_session=session, identifier=step_id) return step.to_pydantic() + @enforce_types + def update_step_transaction_id(self, actor: PydanticUser, step_id: str, transaction_id: str) -> PydanticStep: + """Update the transaction ID for a step. + + Args: + actor: The user making the request + step_id: The ID of the step to update + transaction_id: The new transaction ID to set + + Returns: + The updated step + + Raises: + NoResultFound: If the step does not exist + """ + with self.session_maker() as session: + step = session.get(StepModel, step_id) + if not step: + raise NoResultFound(f"Step with id {step_id} does not exist") + if step.organization_id != actor.organization_id: + raise Exception("Unauthorized") + + step.tid = transaction_id + session.commit() + return step.to_pydantic() + def _verify_job_access( self, session: Session,