231 lines
8.8 KiB
Python
231 lines
8.8 KiB
Python
import datetime
|
||
from typing import List, Optional
|
||
|
||
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
|
||
from sqlalchemy import tuple_
|
||
|
||
from letta.jobs.types import BatchPollingResult, ItemUpdateInfo
|
||
from letta.log import get_logger
|
||
from letta.orm.llm_batch_items import LLMBatchItem
|
||
from letta.orm.llm_batch_job import LLMBatchJob
|
||
from letta.schemas.agent import AgentStepState
|
||
from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
|
||
from letta.schemas.llm_batch_job import LLMBatchItem as PydanticLLMBatchItem
|
||
from letta.schemas.llm_batch_job import LLMBatchJob as PydanticLLMBatchJob
|
||
from letta.schemas.llm_config import LLMConfig
|
||
from letta.schemas.user import User as PydanticUser
|
||
from letta.utils import enforce_types
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class LLMBatchManager:
|
||
"""Manager for handling both LLMBatchJob and LLMBatchItem operations."""
|
||
|
||
def __init__(self):
|
||
from letta.server.db import db_context
|
||
|
||
self.session_maker = db_context
|
||
|
||
@enforce_types
|
||
def create_batch_job(
|
||
self,
|
||
llm_provider: ProviderType,
|
||
create_batch_response: BetaMessageBatch,
|
||
actor: PydanticUser,
|
||
status: JobStatus = JobStatus.created,
|
||
) -> PydanticLLMBatchJob:
|
||
"""Create a new LLM batch job."""
|
||
with self.session_maker() as session:
|
||
batch = LLMBatchJob(
|
||
status=status,
|
||
llm_provider=llm_provider,
|
||
create_batch_response=create_batch_response,
|
||
organization_id=actor.organization_id,
|
||
)
|
||
batch.create(session, actor=actor)
|
||
return batch.to_pydantic()
|
||
|
||
@enforce_types
|
||
def get_batch_job_by_id(self, batch_id: str, actor: Optional[PydanticUser] = None) -> PydanticLLMBatchJob:
|
||
"""Retrieve a single batch job by ID."""
|
||
with self.session_maker() as session:
|
||
batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor)
|
||
return batch.to_pydantic()
|
||
|
||
@enforce_types
|
||
def update_batch_status(
|
||
self,
|
||
batch_id: str,
|
||
status: JobStatus,
|
||
actor: Optional[PydanticUser] = None,
|
||
latest_polling_response: Optional[BetaMessageBatch] = None,
|
||
) -> PydanticLLMBatchJob:
|
||
"""Update a batch job’s status and optionally its polling response."""
|
||
with self.session_maker() as session:
|
||
batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor)
|
||
batch.status = status
|
||
batch.latest_polling_response = latest_polling_response
|
||
batch.last_polled_at = datetime.datetime.now(datetime.timezone.utc)
|
||
batch = batch.update(db_session=session, actor=actor)
|
||
return batch.to_pydantic()
|
||
|
||
def bulk_update_batch_statuses(
|
||
self,
|
||
updates: List[BatchPollingResult],
|
||
) -> None:
|
||
"""
|
||
Efficiently update many LLMBatchJob rows. This is used by the cron jobs.
|
||
|
||
`updates` = [(batch_id, new_status, polling_response_or_None), …]
|
||
"""
|
||
now = datetime.datetime.now(datetime.timezone.utc)
|
||
|
||
with self.session_maker() as session:
|
||
mappings = []
|
||
for batch_id, status, response in updates:
|
||
mappings.append(
|
||
{
|
||
"id": batch_id,
|
||
"status": status,
|
||
"latest_polling_response": response,
|
||
"last_polled_at": now,
|
||
}
|
||
)
|
||
|
||
session.bulk_update_mappings(LLMBatchJob, mappings)
|
||
session.commit()
|
||
|
||
@enforce_types
|
||
def delete_batch_request(self, batch_id: str, actor: PydanticUser) -> None:
|
||
"""Hard delete a batch job by ID."""
|
||
with self.session_maker() as session:
|
||
batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor)
|
||
batch.hard_delete(db_session=session, actor=actor)
|
||
|
||
@enforce_types
|
||
def list_running_batches(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]:
|
||
"""Return all running LLM batch jobs, optionally filtered by actor's organization."""
|
||
with self.session_maker() as session:
|
||
query = session.query(LLMBatchJob).filter(LLMBatchJob.status == JobStatus.running)
|
||
|
||
if actor is not None:
|
||
query = query.filter(LLMBatchJob.organization_id == actor.organization_id)
|
||
|
||
results = query.all()
|
||
return [batch.to_pydantic() for batch in results]
|
||
|
||
@enforce_types
|
||
def create_batch_item(
|
||
self,
|
||
batch_id: str,
|
||
agent_id: str,
|
||
llm_config: LLMConfig,
|
||
actor: PydanticUser,
|
||
request_status: JobStatus = JobStatus.created,
|
||
step_status: AgentStepStatus = AgentStepStatus.paused,
|
||
step_state: Optional[AgentStepState] = None,
|
||
) -> PydanticLLMBatchItem:
|
||
"""Create a new batch item."""
|
||
with self.session_maker() as session:
|
||
item = LLMBatchItem(
|
||
batch_id=batch_id,
|
||
agent_id=agent_id,
|
||
llm_config=llm_config,
|
||
request_status=request_status,
|
||
step_status=step_status,
|
||
step_state=step_state,
|
||
organization_id=actor.organization_id,
|
||
)
|
||
item.create(session, actor=actor)
|
||
return item.to_pydantic()
|
||
|
||
@enforce_types
|
||
def get_batch_item_by_id(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem:
|
||
"""Retrieve a single batch item by ID."""
|
||
with self.session_maker() as session:
|
||
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
|
||
return item.to_pydantic()
|
||
|
||
@enforce_types
|
||
def update_batch_item(
|
||
self,
|
||
item_id: str,
|
||
actor: PydanticUser,
|
||
request_status: Optional[JobStatus] = None,
|
||
step_status: Optional[AgentStepStatus] = None,
|
||
llm_request_response: Optional[BetaMessageBatchIndividualResponse] = None,
|
||
step_state: Optional[AgentStepState] = None,
|
||
) -> PydanticLLMBatchItem:
|
||
"""Update fields on a batch item."""
|
||
with self.session_maker() as session:
|
||
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
|
||
|
||
if request_status:
|
||
item.request_status = request_status
|
||
if step_status:
|
||
item.step_status = step_status
|
||
if llm_request_response:
|
||
item.batch_request_result = llm_request_response
|
||
if step_state:
|
||
item.step_state = step_state
|
||
|
||
return item.update(db_session=session, actor=actor).to_pydantic()
|
||
|
||
def bulk_update_batch_items_by_agent(
|
||
self,
|
||
updates: List[ItemUpdateInfo],
|
||
) -> None:
|
||
"""
|
||
Efficiently update LLMBatchItem rows by (batch_id, agent_id).
|
||
|
||
Args:
|
||
updates: List of tuples:
|
||
(batch_id, agent_id, new_request_status, batch_request_result)
|
||
"""
|
||
with self.session_maker() as session:
|
||
# For bulk_update_mappings, we need the primary key of each row
|
||
# So we must map (batch_id, agent_id) → actual PK (id)
|
||
# We'll do it in one DB query using the (batch_id, agent_id) sets
|
||
|
||
# 1. Gather the pairs
|
||
key_pairs = [(b_id, a_id) for (b_id, a_id, *_rest) in updates]
|
||
|
||
# 2. Query items in a single step
|
||
items = (
|
||
session.query(LLMBatchItem.id, LLMBatchItem.batch_id, LLMBatchItem.agent_id)
|
||
.filter(tuple_(LLMBatchItem.batch_id, LLMBatchItem.agent_id).in_(key_pairs))
|
||
.all()
|
||
)
|
||
|
||
# Build a map from (batch_id, agent_id) → PK id
|
||
pair_to_pk = {}
|
||
for row_id, row_batch_id, row_agent_id in items:
|
||
pair_to_pk[(row_batch_id, row_agent_id)] = row_id
|
||
|
||
# 3. Construct mappings for the PK-based bulk update
|
||
mappings = []
|
||
for batch_id, agent_id, new_status, new_result in updates:
|
||
pk_id = pair_to_pk.get((batch_id, agent_id))
|
||
if not pk_id:
|
||
# Nonexistent or mismatch → skip
|
||
continue
|
||
mappings.append(
|
||
{
|
||
"id": pk_id,
|
||
"request_status": new_status,
|
||
"batch_request_result": new_result,
|
||
}
|
||
)
|
||
|
||
if mappings:
|
||
session.bulk_update_mappings(LLMBatchItem, mappings)
|
||
session.commit()
|
||
|
||
@enforce_types
|
||
def delete_batch_item(self, item_id: str, actor: PydanticUser) -> None:
|
||
"""Hard delete a batch item by ID."""
|
||
with self.session_maker() as session:
|
||
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
|
||
item.hard_delete(db_session=session, actor=actor)
|