feat: add routes for steps (#839)

This commit is contained in:
cthomas
2025-01-29 20:57:09 -08:00
committed by GitHub
parent 6897794a68
commit 9f60ae65fd
4 changed files with 138 additions and 3 deletions

View File

@@ -13,18 +13,18 @@ repos:
hooks:
- id: autoflake
name: autoflake
entry: bash -c 'cd apps/core && poetry run autoflake --remove-all-unused-imports --remove-unused-variables --in-place --recursive --ignore-init-module-imports .'
entry: poetry run autoflake --remove-all-unused-imports --remove-unused-variables --in-place --recursive --ignore-init-module-imports .
language: system
types: [python]
- id: isort
name: isort
entry: bash -c 'cd apps/core && poetry run isort --profile black .'
entry: poetry run isort --profile black .
language: system
types: [python]
exclude: ^docs/
- id: black
name: black
entry: bash -c 'cd apps/core && poetry run black --line-length 140 --target-version py310 --target-version py311 .'
entry: poetry run black --line-length 140 --target-version py310 --target-version py311 .
language: system
types: [python]
exclude: ^docs/

View File

@@ -7,6 +7,7 @@ from letta.server.rest_api.routers.v1.providers import router as providers_route
from letta.server.rest_api.routers.v1.runs import router as runs_router
from letta.server.rest_api.routers.v1.sandbox_configs import router as sandbox_configs_router
from letta.server.rest_api.routers.v1.sources import router as sources_router
from letta.server.rest_api.routers.v1.steps import router as steps_router
from letta.server.rest_api.routers.v1.tags import router as tags_router
from letta.server.rest_api.routers.v1.tools import router as tools_router
@@ -21,5 +22,6 @@ ROUTERS = [
sandbox_configs_router,
providers_router,
runs_router,
steps_router,
tags_router,
]

View File

@@ -0,0 +1,78 @@
from datetime import datetime
from typing import List, Optional
from fastapi import APIRouter, Depends, Header, HTTPException, Query
from letta.orm.errors import NoResultFound
from letta.schemas.step import Step
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
router = APIRouter(prefix="/steps", tags=["steps"])
@router.get("", response_model=List[Step], operation_id="list_steps")
def list_steps(
before: Optional[str] = Query(None, description="Return steps before this step ID"),
after: Optional[str] = Query(None, description="Return steps after this step ID"),
limit: Optional[int] = Query(50, description="Maximum number of steps to return"),
order: Optional[str] = Query("desc", description="Sort order (asc or desc)"),
start_date: Optional[str] = Query(None, description='Return steps after this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'),
end_date: Optional[str] = Query(None, description='Return steps before this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'),
model: Optional[str] = Query(None, description="Filter by the name of the model used for the step"),
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
):
"""
List steps with optional pagination and date filters.
Dates should be provided in ISO 8601 format (e.g. 2025-01-29T15:01:19-08:00)
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
# Convert ISO strings to datetime objects if provided
start_dt = datetime.fromisoformat(start_date) if start_date else None
end_dt = datetime.fromisoformat(end_date) if end_date else None
return server.step_manager.list_steps(
actor=actor,
before=before,
after=after,
start_date=start_dt,
end_date=end_dt,
limit=limit,
order=order,
model=model,
)
@router.get("/{step_id}", response_model=Step, operation_id="retrieve_step")
def retrieve_step(
step_id: str,
user_id: Optional[str] = Header(None, alias="user_id"),
server: SyncServer = Depends(get_letta_server),
):
"""
Get a step by ID.
"""
try:
return server.step_manager.get_step(step_id=step_id)
except NoResultFound:
raise HTTPException(status_code=404, detail="Step not found")
@router.patch("/{step_id}/transaction/{transaction_id}", response_model=Step, operation_id="update_step_transaction_id")
def update_step_transaction_id(
step_id: str,
transaction_id: str,
user_id: Optional[str] = Header(None, alias="user_id"),
server: SyncServer = Depends(get_letta_server),
):
"""
Update the transaction ID for a step.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
try:
return server.step_manager.update_step_transaction_id(actor, step_id=step_id, transaction_id=transaction_id)
except NoResultFound:
raise HTTPException(status_code=404, detail="Step not found")

View File

@@ -1,3 +1,4 @@
import datetime
from typing import List, Literal, Optional
from sqlalchemy import select
@@ -20,6 +21,34 @@ class StepManager:
self.session_maker = db_context
@enforce_types
def list_steps(
self,
actor: PydanticUser,
before: Optional[str] = None,
after: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: Optional[int] = 50,
order: Optional[str] = None,
model: Optional[str] = None,
) -> List[PydanticStep]:
"""List all jobs with optional pagination and status filter."""
with self.session_maker() as session:
filter_kwargs = {"organization_id": actor.organization_id, "model": model}
steps = StepModel.list(
db_session=session,
before=before,
after=after,
start_date=start_date,
end_date=end_date,
limit=limit,
ascending=True if order == "asc" else False,
**filter_kwargs,
)
return [step.to_pydantic() for step in steps]
@enforce_types
def log_step(
self,
@@ -58,6 +87,32 @@ class StepManager:
step = StepModel.read(db_session=session, identifier=step_id)
return step.to_pydantic()
@enforce_types
def update_step_transaction_id(self, actor: PydanticUser, step_id: str, transaction_id: str) -> PydanticStep:
"""Update the transaction ID for a step.
Args:
actor: The user making the request
step_id: The ID of the step to update
transaction_id: The new transaction ID to set
Returns:
The updated step
Raises:
NoResultFound: If the step does not exist
"""
with self.session_maker() as session:
step = session.get(StepModel, step_id)
if not step:
raise NoResultFound(f"Step with id {step_id} does not exist")
if step.organization_id != actor.organization_id:
raise Exception("Unauthorized")
step.tid = transaction_id
session.commit()
return step.to_pydantic()
def _verify_job_access(
self,
session: Session,