feat: add agents_runs table (#4768)
This commit is contained in:
committed by
Caren Thomas
parent
00292363c4
commit
c85bfefa52
51
alembic/versions/5973fd8b8c60_add_agents_runs_table.py
Normal file
51
alembic/versions/5973fd8b8c60_add_agents_runs_table.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""add agents_runs table
|
||||
|
||||
Revision ID: 5973fd8b8c60
|
||||
Revises: eff256d296cb
|
||||
Create Date: 2025-09-18 10:52:46.270241
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "5973fd8b8c60"
|
||||
down_revision: Union[str, None] = "eff256d296cb"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"agents_runs",
|
||||
sa.Column("agent_id", sa.String(), nullable=False),
|
||||
sa.Column("run_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["agent_id"],
|
||||
["agents.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["run_id"],
|
||||
["jobs.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("agent_id", "run_id"),
|
||||
sa.UniqueConstraint("agent_id", "run_id", name="unique_agent_run"),
|
||||
)
|
||||
op.create_index("ix_agents_runs_agent_id_run_id", "agents_runs", ["agent_id", "run_id"], unique=False)
|
||||
op.create_index("ix_agents_runs_run_id_agent_id", "agents_runs", ["run_id", "agent_id"], unique=False)
|
||||
op.add_column("jobs", sa.Column("background", sa.Boolean(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("jobs", "background")
|
||||
op.drop_index("ix_agents_runs_run_id_agent_id", table_name="agents_runs")
|
||||
op.drop_index("ix_agents_runs_agent_id_run_id", table_name="agents_runs")
|
||||
op.drop_table("agents_runs")
|
||||
# ### end Alembic commands ###
|
||||
@@ -9681,6 +9681,24 @@
|
||||
"description": "List all runs.",
|
||||
"operationId": "list_runs",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "agent_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "The unique identifier of the agent associated with the run.",
|
||||
"title": "Agent Id"
|
||||
},
|
||||
"description": "The unique identifier of the agent associated with the run."
|
||||
},
|
||||
{
|
||||
"name": "agent_ids",
|
||||
"in": "query",
|
||||
@@ -9697,10 +9715,10 @@
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "The unique identifier of the agent associated with the run.",
|
||||
"description": "(DEPRECATED) The unique identifiers of the agents associated with the run.",
|
||||
"title": "Agent Ids"
|
||||
},
|
||||
"description": "The unique identifier of the agent associated with the run."
|
||||
"description": "(DEPRECATED) The unique identifiers of the agents associated with the run."
|
||||
},
|
||||
{
|
||||
"name": "background",
|
||||
@@ -9855,23 +9873,20 @@
|
||||
"deprecated": true,
|
||||
"parameters": [
|
||||
{
|
||||
"name": "agent_ids",
|
||||
"name": "agent_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "The unique identifier of the agent associated with the run.",
|
||||
"title": "Agent Ids"
|
||||
"title": "Agent Id"
|
||||
},
|
||||
"description": "The unique identifier of the agent associated with the run."
|
||||
},
|
||||
@@ -16696,6 +16711,30 @@
|
||||
"$ref": "#/components/schemas/JobType",
|
||||
"default": "batch"
|
||||
},
|
||||
"background": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Background",
|
||||
"description": "Whether the job was created in background mode."
|
||||
},
|
||||
"agent_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Agent Id",
|
||||
"description": "The agent associated with this job/run."
|
||||
},
|
||||
"callback_url": {
|
||||
"anyOf": [
|
||||
{
|
||||
@@ -23596,6 +23635,30 @@
|
||||
"description": "The type of the job.",
|
||||
"default": "job"
|
||||
},
|
||||
"background": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Background",
|
||||
"description": "Whether the job was created in background mode."
|
||||
},
|
||||
"agent_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Agent Id",
|
||||
"description": "The agent associated with this job/run."
|
||||
},
|
||||
"callback_url": {
|
||||
"anyOf": [
|
||||
{
|
||||
@@ -26860,6 +26923,30 @@
|
||||
"$ref": "#/components/schemas/JobType",
|
||||
"default": "run"
|
||||
},
|
||||
"background": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Background",
|
||||
"description": "Whether the job was created in background mode."
|
||||
},
|
||||
"agent_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Agent Id",
|
||||
"description": "The agent associated with this job/run."
|
||||
},
|
||||
"callback_url": {
|
||||
"anyOf": [
|
||||
{
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"name": "docs",
|
||||
"$schema": "../../node_modules/nx/schemas/project-schema.json",
|
||||
"sourceRoot": "apps/fern",
|
||||
"sourceRoot": "apps/core/fern",
|
||||
"projectType": "application",
|
||||
"tags": [],
|
||||
"targets": {
|
||||
|
||||
@@ -1214,8 +1214,8 @@ class LettaAgentV2(BaseAgentV2):
|
||||
job_id=run_id,
|
||||
new_status=JobStatus.failed if is_error else JobStatus.completed,
|
||||
actor=self.actor,
|
||||
metadata=job_update_metadata,
|
||||
stop_reason=self.stop_reason.stop_reason if self.stop_reason else StopReasonType.error,
|
||||
metadata=job_update_metadata,
|
||||
)
|
||||
if request_span:
|
||||
request_span.end()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.agents_runs import AgentsRuns
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
from letta.orm.archive import Archive
|
||||
from letta.orm.archives_agents import ArchivesAgents
|
||||
|
||||
@@ -132,6 +132,13 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
|
||||
lazy="selectin",
|
||||
doc="Tags associated with the agent.",
|
||||
)
|
||||
runs: Mapped[List["AgentsRuns"]] = relationship(
|
||||
"AgentsRuns",
|
||||
back_populates="agent",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="selectin",
|
||||
doc="Runs associated with the agent.",
|
||||
)
|
||||
identities: Mapped[List["Identity"]] = relationship(
|
||||
"Identity",
|
||||
secondary="identities_agents",
|
||||
|
||||
26
letta/orm/agents_runs.py
Normal file
26
letta/orm/agents_runs.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import ForeignKey, Index, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.base import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.job import Job
|
||||
|
||||
|
||||
class AgentsRuns(Base):
|
||||
__tablename__ = "agents_runs"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("agent_id", "run_id", name="unique_agent_run"),
|
||||
Index("ix_agents_runs_agent_id_run_id", "agent_id", "run_id"),
|
||||
Index("ix_agents_runs_run_id_agent_id", "run_id", "agent_id"),
|
||||
)
|
||||
|
||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
|
||||
run_id: Mapped[str] = mapped_column(String, ForeignKey("jobs.id"), primary_key=True)
|
||||
|
||||
# relationships
|
||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="runs")
|
||||
run: Mapped["Job"] = relationship("Job", back_populates="agent")
|
||||
@@ -11,6 +11,7 @@ from letta.schemas.job import Job as PydanticJob, LettaRequestConfig
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.agents_runs import AgentsRuns
|
||||
from letta.orm.job_messages import JobMessage
|
||||
from letta.orm.message import Message
|
||||
from letta.orm.organization import Organization
|
||||
@@ -30,6 +31,9 @@ class Job(SqlalchemyBase, UserMixin):
|
||||
status: Mapped[JobStatus] = mapped_column(String, default=JobStatus.created, doc="The current status of the job.")
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(nullable=True, doc="The unix timestamp of when the job was completed.")
|
||||
stop_reason: Mapped[Optional[StopReasonType]] = mapped_column(String, nullable=True, doc="The reason why the job was stopped.")
|
||||
background: Mapped[Optional[bool]] = mapped_column(
|
||||
Boolean, nullable=True, default=False, doc="Whether the job was created in background mode."
|
||||
)
|
||||
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, doc="The metadata of the job.")
|
||||
job_type: Mapped[JobType] = mapped_column(
|
||||
String,
|
||||
@@ -59,6 +63,7 @@ class Job(SqlalchemyBase, UserMixin):
|
||||
steps: Mapped[List["Step"]] = relationship("Step", back_populates="job", cascade="save-update")
|
||||
# organization relationship (nullable for backward compatibility)
|
||||
organization: Mapped[Optional["Organization"]] = relationship("Organization", back_populates="jobs")
|
||||
agent: Mapped[List["AgentsRuns"]] = relationship("AgentsRuns", back_populates="run", cascade="all, delete-orphan")
|
||||
|
||||
@property
|
||||
def messages(self) -> List["Message"]:
|
||||
|
||||
@@ -24,9 +24,9 @@ class JobBase(OrmMetadataBase):
|
||||
metadata: Optional[dict] = Field(None, validation_alias="metadata_", description="The metadata of the job.")
|
||||
job_type: JobType = Field(default=JobType.JOB, description="The type of the job.")
|
||||
|
||||
## TODO: Run-specific fields
|
||||
# background: Optional[bool] = Field(None, description="Whether the job was created in background mode.")
|
||||
# agent_id: Optional[str] = Field(None, description="The agent associated with this job/run.")
|
||||
# Run-specific fields
|
||||
background: Optional[bool] = Field(None, description="Whether the job was created in background mode.")
|
||||
agent_id: Optional[str] = Field(None, description="The agent associated with this job/run.")
|
||||
|
||||
callback_url: Optional[str] = Field(None, description="If set, POST to this URL when the job completes.")
|
||||
callback_sent_at: Optional[datetime] = Field(None, description="Timestamp when the callback was last attempted.")
|
||||
|
||||
@@ -241,4 +241,5 @@ class MarshmallowAgentSchema(BaseSchema):
|
||||
"groups",
|
||||
"batch_items",
|
||||
"organization",
|
||||
"runs", # Exclude the runs relationship (agents_runs association table)
|
||||
)
|
||||
|
||||
@@ -1171,9 +1171,10 @@ async def send_message(
|
||||
pydantic_job=Run(
|
||||
user_id=actor.id,
|
||||
status=job_status,
|
||||
agent_id=agent_id,
|
||||
background=False,
|
||||
metadata={
|
||||
"job_type": "send_message",
|
||||
"agent_id": agent_id,
|
||||
},
|
||||
request_config=LettaRequestConfig(
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
@@ -1305,10 +1306,10 @@ async def send_message_streaming(
|
||||
pydantic_job=Run(
|
||||
user_id=actor.id,
|
||||
status=job_status,
|
||||
agent_id=agent_id,
|
||||
background=request.background or False,
|
||||
metadata={
|
||||
"job_type": "send_message_streaming",
|
||||
"agent_id": agent_id,
|
||||
"background": request.background or False,
|
||||
},
|
||||
request_config=LettaRequestConfig(
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
@@ -1482,6 +1483,7 @@ async def cancel_agent_run(
|
||||
statuses=[JobStatus.created, JobStatus.running],
|
||||
job_type=JobType.RUN,
|
||||
ascending=False,
|
||||
agent_ids=[agent_id],
|
||||
)
|
||||
run_ids = [Run.from_job(job).id for job in job_ids]
|
||||
else:
|
||||
@@ -1651,6 +1653,8 @@ async def send_message_async(
|
||||
user_id=actor.id,
|
||||
status=JobStatus.created,
|
||||
callback_url=request.callback_url,
|
||||
agent_id=agent_id,
|
||||
background=True, # Async endpoints are always background
|
||||
metadata={
|
||||
"job_type": "send_message_async",
|
||||
"agent_id": agent_id,
|
||||
|
||||
@@ -30,7 +30,8 @@ router = APIRouter(prefix="/runs", tags=["runs"])
|
||||
@router.get("/", response_model=List[Run], operation_id="list_runs")
|
||||
def list_runs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
agent_ids: Optional[List[str]] = Query(None, description="The unique identifier of the agent associated with the run."),
|
||||
agent_id: Optional[str] = Query(None, description="The unique identifier of the agent associated with the run."),
|
||||
agent_ids: Optional[List[str]] = Query(None, description="(DEPRECATED) The unique identifiers of the agents associated with the run."),
|
||||
background: Optional[bool] = Query(None, description="If True, filters for runs that were created in background mode."),
|
||||
stop_reason: Optional[StopReasonType] = Query(None, description="Filter runs by stop reason."),
|
||||
after: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
@@ -62,19 +63,18 @@ def list_runs(
|
||||
after=after,
|
||||
ascending=False,
|
||||
stop_reason=stop_reason,
|
||||
# agent_id=agent_id,
|
||||
agent_ids=agent_ids if agent_ids else [agent_id],
|
||||
background=background,
|
||||
)
|
||||
]
|
||||
if agent_ids:
|
||||
runs = [run for run in runs if "agent_id" in run.metadata and run.metadata["agent_id"] in agent_ids]
|
||||
if background is not None:
|
||||
runs = [run for run in runs if "background" in run.metadata and run.metadata["background"] == background]
|
||||
return runs
|
||||
|
||||
|
||||
@router.get("/active", response_model=List[Run], operation_id="list_active_runs", deprecated=True)
|
||||
def list_active_runs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
agent_ids: Optional[List[str]] = Query(None, description="The unique identifier of the agent associated with the run."),
|
||||
agent_id: Optional[str] = Query(None, description="The unique identifier of the agent associated with the run."),
|
||||
background: Optional[bool] = Query(None, description="If True, filters for runs that were created in background mode."),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
@@ -83,15 +83,11 @@ def list_active_runs(
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=headers.actor_id)
|
||||
|
||||
active_runs = server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.RUN)
|
||||
active_runs = server.job_manager.list_jobs(
|
||||
actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.RUN, agent_ids=[agent_id], background=background
|
||||
)
|
||||
active_runs = [Run.from_job(job) for job in active_runs]
|
||||
|
||||
if agent_ids:
|
||||
active_runs = [run for run in active_runs if "agent_id" in run.metadata and run.metadata["agent_id"] in agent_ids]
|
||||
|
||||
if background is not None:
|
||||
active_runs = [run for run in active_runs if "background" in run.metadata and run.metadata["background"] == background]
|
||||
|
||||
return active_runs
|
||||
|
||||
|
||||
@@ -104,7 +100,7 @@ async def retrieve_run(
|
||||
"""
|
||||
Get the status of a run.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(user_id=headers.actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
try:
|
||||
job = await server.job_manager.get_job_by_id_async(job_id=run_id, actor=actor)
|
||||
@@ -317,7 +313,7 @@ async def retrieve_stream(
|
||||
|
||||
run = Run.from_job(job)
|
||||
|
||||
if "background" not in run.metadata or not run.metadata["background"]:
|
||||
if not run.background:
|
||||
raise HTTPException(status_code=400, detail="Run was not created in background mode, so it cannot be retrieved.")
|
||||
|
||||
if run.created_at < get_utc_time() - timedelta(hours=3):
|
||||
|
||||
@@ -39,13 +39,29 @@ class JobManager:
|
||||
self, pydantic_job: Union[PydanticJob, PydanticRun, PydanticBatchJob], actor: PydanticUser
|
||||
) -> Union[PydanticJob, PydanticRun, PydanticBatchJob]:
|
||||
"""Create a new job based on the JobCreate schema."""
|
||||
from letta.orm.agents_runs import AgentsRuns
|
||||
|
||||
with db_registry.session() as session:
|
||||
# Associate the job with the user
|
||||
pydantic_job.user_id = actor.id
|
||||
|
||||
# Get agent_id if present
|
||||
agent_id = getattr(pydantic_job, "agent_id", None)
|
||||
|
||||
job_data = pydantic_job.model_dump(to_orm=True)
|
||||
# Remove agent_id from job_data as it's not a field in the Job ORM model
|
||||
# The relationship is handled through the AgentsRuns association table
|
||||
job_data.pop("agent_id", None)
|
||||
job = JobModel(**job_data)
|
||||
job.organization_id = actor.organization_id
|
||||
job.create(session, actor=actor) # Save job in the database
|
||||
|
||||
# If this is a Run with an agent_id, create the agents_runs association
|
||||
if agent_id and isinstance(pydantic_job, PydanticRun):
|
||||
agents_run = AgentsRuns(agent_id=agent_id, run_id=job.id)
|
||||
session.add(agents_run)
|
||||
session.commit()
|
||||
|
||||
return job.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@@ -54,15 +70,37 @@ class JobManager:
|
||||
self, pydantic_job: Union[PydanticJob, PydanticRun, PydanticBatchJob], actor: PydanticUser
|
||||
) -> Union[PydanticJob, PydanticRun, PydanticBatchJob]:
|
||||
"""Create a new job based on the JobCreate schema."""
|
||||
from letta.orm.agents_runs import AgentsRuns
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Associate the job with the user
|
||||
pydantic_job.user_id = actor.id
|
||||
|
||||
# Get agent_id if present
|
||||
agent_id = getattr(pydantic_job, "agent_id", None)
|
||||
|
||||
job_data = pydantic_job.model_dump(to_orm=True)
|
||||
# Remove agent_id from job_data as it's not a field in the Job ORM model
|
||||
# The relationship is handled through the AgentsRuns association table
|
||||
job_data.pop("agent_id", None)
|
||||
job = JobModel(**job_data)
|
||||
job.organization_id = actor.organization_id
|
||||
job = await job.create_async(session, actor=actor, no_commit=True, no_refresh=True) # Save job in the database
|
||||
result = job.to_pydantic()
|
||||
|
||||
# If this is a Run with an agent_id, create the agents_runs association
|
||||
if agent_id and isinstance(pydantic_job, PydanticRun):
|
||||
agents_run = AgentsRuns(agent_id=agent_id, run_id=job.id)
|
||||
session.add(agents_run)
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Convert to pydantic first, then add agent_id if needed
|
||||
result = super(JobModel, job).to_pydantic()
|
||||
|
||||
# Add back the agent_id field to the result if it was present
|
||||
if agent_id and isinstance(pydantic_job, PydanticRun):
|
||||
result.agent_id = agent_id
|
||||
|
||||
return result
|
||||
|
||||
@enforce_types
|
||||
@@ -275,8 +313,15 @@ class JobManager:
|
||||
job_type: JobType = JobType.JOB,
|
||||
ascending: bool = True,
|
||||
stop_reason: Optional[StopReasonType] = None,
|
||||
# agent_id: Optional[str] = None,
|
||||
agent_ids: Optional[List[str]] = None,
|
||||
background: Optional[bool] = None,
|
||||
) -> List[PydanticJob]:
|
||||
"""List all jobs with optional pagination and status filter."""
|
||||
from sqlalchemy import and_, select
|
||||
|
||||
from letta.orm.agents_runs import AgentsRuns
|
||||
|
||||
with db_registry.session() as session:
|
||||
filter_kwargs = {"user_id": actor.id, "job_type": job_type}
|
||||
|
||||
@@ -288,14 +333,66 @@ class JobManager:
|
||||
if stop_reason is not None:
|
||||
filter_kwargs["stop_reason"] = stop_reason
|
||||
|
||||
jobs = JobModel.list(
|
||||
db_session=session,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
ascending=ascending,
|
||||
**filter_kwargs,
|
||||
)
|
||||
# Add background filter if provided
|
||||
if background is not None:
|
||||
filter_kwargs["background"] = background
|
||||
|
||||
# Build query
|
||||
query = select(JobModel)
|
||||
|
||||
# Apply basic filters
|
||||
for key, value in filter_kwargs.items():
|
||||
if isinstance(value, list):
|
||||
query = query.where(getattr(JobModel, key).in_(value))
|
||||
else:
|
||||
query = query.where(getattr(JobModel, key) == value)
|
||||
|
||||
# If agent_id filter is provided, join with agents_runs table
|
||||
if agent_ids:
|
||||
query = query.join(AgentsRuns, JobModel.id == AgentsRuns.run_id)
|
||||
query = query.where(AgentsRuns.agent_id.in_(agent_ids))
|
||||
|
||||
# Apply pagination and ordering
|
||||
if ascending:
|
||||
query = query.order_by(JobModel.created_at.asc(), JobModel.id.asc())
|
||||
else:
|
||||
query = query.order_by(JobModel.created_at.desc(), JobModel.id.desc())
|
||||
|
||||
# Apply cursor-based pagination
|
||||
if before:
|
||||
before_job = session.get(JobModel, before)
|
||||
if before_job:
|
||||
if ascending:
|
||||
query = query.where(
|
||||
(JobModel.created_at < before_job.created_at)
|
||||
| ((JobModel.created_at == before_job.created_at) & (JobModel.id < before_job.id))
|
||||
)
|
||||
else:
|
||||
query = query.where(
|
||||
(JobModel.created_at > before_job.created_at)
|
||||
| ((JobModel.created_at == before_job.created_at) & (JobModel.id > before_job.id))
|
||||
)
|
||||
|
||||
if after:
|
||||
after_job = session.get(JobModel, after)
|
||||
if after_job:
|
||||
if ascending:
|
||||
query = query.where(
|
||||
(JobModel.created_at > after_job.created_at)
|
||||
| ((JobModel.created_at == after_job.created_at) & (JobModel.id > after_job.id))
|
||||
)
|
||||
else:
|
||||
query = query.where(
|
||||
(JobModel.created_at < after_job.created_at)
|
||||
| ((JobModel.created_at == after_job.created_at) & (JobModel.id < after_job.id))
|
||||
)
|
||||
|
||||
# Apply limit
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
# Execute query
|
||||
jobs = session.execute(query).scalars().all()
|
||||
return [job.to_pydantic() for job in jobs]
|
||||
|
||||
@enforce_types
|
||||
@@ -311,10 +408,15 @@ class JobManager:
|
||||
ascending: bool = True,
|
||||
source_id: Optional[str] = None,
|
||||
stop_reason: Optional[StopReasonType] = None,
|
||||
# agent_id: Optional[str] = None,
|
||||
agent_ids: Optional[List[str]] = None,
|
||||
background: Optional[bool] = None,
|
||||
) -> List[PydanticJob]:
|
||||
"""List all jobs with optional pagination and status filter."""
|
||||
from sqlalchemy import and_, or_, select
|
||||
|
||||
from letta.orm.agents_runs import AgentsRuns
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# build base query
|
||||
query = select(JobModel).where(JobModel.user_id == actor.id).where(JobModel.job_type == job_type)
|
||||
@@ -323,15 +425,24 @@ class JobManager:
|
||||
if statuses:
|
||||
query = query.where(JobModel.status.in_(statuses))
|
||||
|
||||
# add stop_reason filter if provided
|
||||
if stop_reason is not None:
|
||||
query = query.where(JobModel.stop_reason == stop_reason)
|
||||
|
||||
# add background filter if provided
|
||||
if background is not None:
|
||||
query = query.where(JobModel.background == background)
|
||||
|
||||
# add source_id filter if provided
|
||||
if source_id:
|
||||
column = getattr(JobModel, "metadata_")
|
||||
column = column.op("->>")("source_id")
|
||||
query = query.where(column == source_id)
|
||||
|
||||
# add stop_reason filter if provided
|
||||
if stop_reason is not None:
|
||||
query = query.where(JobModel.stop_reason == stop_reason)
|
||||
# If agent_id filter is provided, join with agents_runs table
|
||||
if agent_ids:
|
||||
query = query.join(AgentsRuns, JobModel.id == AgentsRuns.run_id)
|
||||
query = query.where(AgentsRuns.agent_id.in_(agent_ids))
|
||||
|
||||
# handle cursor-based pagination
|
||||
if before or after:
|
||||
|
||||
@@ -9020,15 +9020,29 @@ async def test_list_jobs_by_stop_reason(server: SyncServer, sarah_agent, default
|
||||
status=JobStatus.pending,
|
||||
job_type=JobType.RUN,
|
||||
stop_reason=StopReasonType.requires_approval,
|
||||
agent_id=sarah_agent.id,
|
||||
background=True,
|
||||
)
|
||||
run = await server.job_manager.create_job_async(pydantic_job=run_pydantic, actor=default_user)
|
||||
assert run.stop_reason == StopReasonType.requires_approval
|
||||
assert run.background == True
|
||||
assert run.agent_id == sarah_agent.id
|
||||
|
||||
# list jobs by stop reason
|
||||
jobs = await server.job_manager.list_jobs_async(actor=default_user, job_type=JobType.RUN, stop_reason=StopReasonType.requires_approval)
|
||||
assert len(jobs) == 1
|
||||
assert jobs[0].id == run.id
|
||||
|
||||
# list jobs by background
|
||||
jobs = await server.job_manager.list_jobs_async(actor=default_user, job_type=JobType.RUN, background=True)
|
||||
assert len(jobs) == 1
|
||||
assert jobs[0].id == run.id
|
||||
|
||||
# list jobs by agent_id
|
||||
jobs = await server.job_manager.list_jobs_async(actor=default_user, job_type=JobType.RUN, agent_ids=[sarah_agent.id])
|
||||
assert len(jobs) == 1
|
||||
assert jobs[0].id == run.id
|
||||
|
||||
|
||||
async def test_e2e_job_callback(monkeypatch, server: SyncServer, default_user):
|
||||
"""Test that job callbacks are properly dispatched when a job is completed."""
|
||||
|
||||
Reference in New Issue
Block a user