feat: add routes for steps (#839)
This commit is contained in:
@@ -13,18 +13,18 @@ repos:
|
||||
hooks:
|
||||
- id: autoflake
|
||||
name: autoflake
|
||||
entry: bash -c 'cd apps/core && poetry run autoflake --remove-all-unused-imports --remove-unused-variables --in-place --recursive --ignore-init-module-imports .'
|
||||
entry: poetry run autoflake --remove-all-unused-imports --remove-unused-variables --in-place --recursive --ignore-init-module-imports .
|
||||
language: system
|
||||
types: [python]
|
||||
- id: isort
|
||||
name: isort
|
||||
entry: bash -c 'cd apps/core && poetry run isort --profile black .'
|
||||
entry: poetry run isort --profile black .
|
||||
language: system
|
||||
types: [python]
|
||||
exclude: ^docs/
|
||||
- id: black
|
||||
name: black
|
||||
entry: bash -c 'cd apps/core && poetry run black --line-length 140 --target-version py310 --target-version py311 .'
|
||||
entry: poetry run black --line-length 140 --target-version py310 --target-version py311 .
|
||||
language: system
|
||||
types: [python]
|
||||
exclude: ^docs/
|
||||
|
||||
@@ -7,6 +7,7 @@ from letta.server.rest_api.routers.v1.providers import router as providers_route
|
||||
from letta.server.rest_api.routers.v1.runs import router as runs_router
|
||||
from letta.server.rest_api.routers.v1.sandbox_configs import router as sandbox_configs_router
|
||||
from letta.server.rest_api.routers.v1.sources import router as sources_router
|
||||
from letta.server.rest_api.routers.v1.steps import router as steps_router
|
||||
from letta.server.rest_api.routers.v1.tags import router as tags_router
|
||||
from letta.server.rest_api.routers.v1.tools import router as tools_router
|
||||
|
||||
@@ -21,5 +22,6 @@ ROUTERS = [
|
||||
sandbox_configs_router,
|
||||
providers_router,
|
||||
runs_router,
|
||||
steps_router,
|
||||
tags_router,
|
||||
]
|
||||
|
||||
78
letta/server/rest_api/routers/v1/steps.py
Normal file
78
letta/server/rest_api/routers/v1/steps.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.step import Step
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
router = APIRouter(prefix="/steps", tags=["steps"])
|
||||
|
||||
|
||||
@router.get("", response_model=List[Step], operation_id="list_steps")
|
||||
def list_steps(
|
||||
before: Optional[str] = Query(None, description="Return steps before this step ID"),
|
||||
after: Optional[str] = Query(None, description="Return steps after this step ID"),
|
||||
limit: Optional[int] = Query(50, description="Maximum number of steps to return"),
|
||||
order: Optional[str] = Query("desc", description="Sort order (asc or desc)"),
|
||||
start_date: Optional[str] = Query(None, description='Return steps after this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'),
|
||||
end_date: Optional[str] = Query(None, description='Return steps before this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'),
|
||||
model: Optional[str] = Query(None, description="Filter by the name of the model used for the step"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
List steps with optional pagination and date filters.
|
||||
Dates should be provided in ISO 8601 format (e.g. 2025-01-29T15:01:19-08:00)
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# Convert ISO strings to datetime objects if provided
|
||||
start_dt = datetime.fromisoformat(start_date) if start_date else None
|
||||
end_dt = datetime.fromisoformat(end_date) if end_date else None
|
||||
|
||||
return server.step_manager.list_steps(
|
||||
actor=actor,
|
||||
before=before,
|
||||
after=after,
|
||||
start_date=start_dt,
|
||||
end_date=end_dt,
|
||||
limit=limit,
|
||||
order=order,
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{step_id}", response_model=Step, operation_id="retrieve_step")
|
||||
def retrieve_step(
|
||||
step_id: str,
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Get a step by ID.
|
||||
"""
|
||||
try:
|
||||
return server.step_manager.get_step(step_id=step_id)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Step not found")
|
||||
|
||||
|
||||
@router.patch("/{step_id}/transaction/{transaction_id}", response_model=Step, operation_id="update_step_transaction_id")
|
||||
def update_step_transaction_id(
|
||||
step_id: str,
|
||||
transaction_id: str,
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Update the transaction ID for a step.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
try:
|
||||
return server.step_manager.update_step_transaction_id(actor, step_id=step_id, transaction_id=transaction_id)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Step not found")
|
||||
@@ -1,3 +1,4 @@
|
||||
import datetime
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
@@ -20,6 +21,34 @@ class StepManager:
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def list_steps(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: Optional[int] = 50,
|
||||
order: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> List[PydanticStep]:
|
||||
"""List all jobs with optional pagination and status filter."""
|
||||
with self.session_maker() as session:
|
||||
filter_kwargs = {"organization_id": actor.organization_id, "model": model}
|
||||
|
||||
steps = StepModel.list(
|
||||
db_session=session,
|
||||
before=before,
|
||||
after=after,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
ascending=True if order == "asc" else False,
|
||||
**filter_kwargs,
|
||||
)
|
||||
return [step.to_pydantic() for step in steps]
|
||||
|
||||
@enforce_types
|
||||
def log_step(
|
||||
self,
|
||||
@@ -58,6 +87,32 @@ class StepManager:
|
||||
step = StepModel.read(db_session=session, identifier=step_id)
|
||||
return step.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def update_step_transaction_id(self, actor: PydanticUser, step_id: str, transaction_id: str) -> PydanticStep:
|
||||
"""Update the transaction ID for a step.
|
||||
|
||||
Args:
|
||||
actor: The user making the request
|
||||
step_id: The ID of the step to update
|
||||
transaction_id: The new transaction ID to set
|
||||
|
||||
Returns:
|
||||
The updated step
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the step does not exist
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
step = session.get(StepModel, step_id)
|
||||
if not step:
|
||||
raise NoResultFound(f"Step with id {step_id} does not exist")
|
||||
if step.organization_id != actor.organization_id:
|
||||
raise Exception("Unauthorized")
|
||||
|
||||
step.tid = transaction_id
|
||||
session.commit()
|
||||
return step.to_pydantic()
|
||||
|
||||
def _verify_job_access(
|
||||
self,
|
||||
session: Session,
|
||||
|
||||
Reference in New Issue
Block a user