fix: batch jobs not being polled (#3034)

This commit is contained in:
Andy Li
2025-06-25 17:07:21 -07:00
committed by GitHub
parent be199b15a4
commit 2c9647424d
3 changed files with 115 additions and 49 deletions

View File

@@ -492,30 +492,32 @@ class LettaAgentBatch(BaseAgent):
msg_map: Dict[str, List[Message]],
) -> Tuple[List[LettaBatchRequest], Dict[str, AgentStepState]]:
# who continues?
continues = [aid for aid, cont in ctx.should_continue_map.items() if cont]
continues = [agent_id for agent_id, cont in ctx.should_continue_map.items() if cont]
success_flag_map = {aid: result.success_flag for aid, result in exec_results}
batch_reqs: List[LettaBatchRequest] = []
for aid in continues:
for agent_id in continues:
heartbeat = create_heartbeat_system_message(
agent_id=aid,
model=ctx.agent_state_map[aid].llm_config.model,
function_call_success=success_flag_map[aid],
agent_id=agent_id,
model=ctx.agent_state_map[agent_id].llm_config.model,
function_call_success=success_flag_map[agent_id],
timezone=ctx.agent_state_map[agent_id].timezone,
actor=self.actor,
)
batch_reqs.append(
LettaBatchRequest(
agent_id=aid, messages=[MessageCreate.model_validate(heartbeat.model_dump(include={"role", "content", "name", "otid"}))]
agent_id=agent_id,
messages=[MessageCreate.model_validate(heartbeat.model_dump(include={"role", "content", "name", "otid"}))],
)
)
# extend incontext ids when necessary
for aid, new_msgs in msg_map.items():
ast = ctx.agent_state_map[aid]
for agent_id, new_msgs in msg_map.items():
ast = ctx.agent_state_map[agent_id]
if not ast.message_buffer_autoclear:
await self.agent_manager.set_in_context_messages_async(
agent_id=aid,
agent_id=agent_id,
message_ids=ast.message_ids + [m.id for m in new_msgs],
actor=self.actor,
)
@@ -605,7 +607,8 @@ class LettaAgentBatch(BaseAgent):
return tool_call_name, tool_args, continue_stepping
def _prepare_tools_per_agent(self, agent_state: AgentState, tool_rules_solver: ToolRulesSolver) -> List[dict]:
@staticmethod
def _prepare_tools_per_agent(agent_state: AgentState, tool_rules_solver: ToolRulesSolver) -> List[dict]:
tools = [t for t in agent_state.tools if t.tool_type in {ToolType.CUSTOM, ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE}]
valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools]))
return [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
@@ -621,7 +624,9 @@ class LettaAgentBatch(BaseAgent):
return in_context_messages
# Not used in batch.
async def step(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> LettaResponse:
async def step(
self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, run_id: str | None = None
) -> LettaResponse:
raise NotImplementedError
async def step_stream(

View File

@@ -36,25 +36,44 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool:
# Use a temporary connection context for the attempt initially
with db_context() as session:
engine = session.get_bind()
# Get raw connection - MUST be kept open if lock is acquired
raw_conn = engine.raw_connection()
cur = raw_conn.cursor()
engine_name = engine.name
logger.info(f"Database engine type: {engine_name}")
cur.execute("SELECT pg_try_advisory_lock(CAST(%s AS bigint))", (ADVISORY_LOCK_KEY,))
acquired_lock = cur.fetchone()[0]
if engine_name != "postgresql":
logger.warning(f"Advisory locks not supported for {engine_name} database. Starting scheduler without leader election.")
acquired_lock = True # For SQLite, assume we can start the scheduler
else:
# Get raw connection - MUST be kept open if lock is acquired
raw_conn = engine.raw_connection()
cur = raw_conn.cursor()
cur.execute("SELECT pg_try_advisory_lock(CAST(%s AS bigint))", (ADVISORY_LOCK_KEY,))
acquired_lock = cur.fetchone()[0]
if not acquired_lock:
cur.close()
raw_conn.close()
if cur:
cur.close()
if raw_conn:
raw_conn.close()
logger.info("Scheduler lock held by another instance.")
return False
# --- Lock Acquired ---
logger.info("Acquired scheduler lock.")
_advisory_lock_conn = raw_conn # Keep connection for lock duration
_advisory_lock_cur = cur # Keep cursor for lock duration
raw_conn = None # Prevent closing in finally block
cur = None # Prevent closing in finally block
if engine_name == "postgresql":
logger.info("Acquired PostgreSQL advisory lock.")
_advisory_lock_conn = raw_conn # Keep connection for lock duration
_advisory_lock_cur = cur # Keep cursor for lock duration
raw_conn = None # Prevent closing in finally block
cur = None # Prevent closing in finally block
else:
logger.info("Starting scheduler for non-PostgreSQL database.")
# For SQLite, we don't need to keep the connection open
if cur:
cur.close()
if raw_conn:
raw_conn.close()
raw_conn = None
cur = None
trigger = IntervalTrigger(
seconds=settings.poll_running_llm_batches_interval_seconds,
@@ -157,35 +176,30 @@ async def _release_advisory_lock():
_advisory_lock_conn = None # Clear global immediately
if lock_cur is not None and lock_conn is not None:
logger.info(f"Attempting to release advisory lock {ADVISORY_LOCK_KEY}")
logger.info(f"Attempting to release PostgreSQL advisory lock {ADVISORY_LOCK_KEY}")
try:
if not lock_conn.closed:
if not lock_cur.closed:
lock_cur.execute("SELECT pg_advisory_unlock(CAST(%s AS bigint))", (ADVISORY_LOCK_KEY,))
lock_cur.fetchone() # Consume result
lock_conn.commit()
logger.info(f"Executed pg_advisory_unlock for lock {ADVISORY_LOCK_KEY}")
else:
logger.warning("Advisory lock cursor closed before unlock.")
else:
logger.warning("Advisory lock connection closed before unlock.")
# Try to execute unlock - connection/cursor validity is checked by attempting the operation
lock_cur.execute("SELECT pg_advisory_unlock(CAST(%s AS bigint))", (ADVISORY_LOCK_KEY,))
lock_cur.fetchone() # Consume result
lock_conn.commit()
logger.info(f"Executed pg_advisory_unlock for lock {ADVISORY_LOCK_KEY}")
except Exception as e:
logger.error(f"Error executing pg_advisory_unlock: {e}", exc_info=True)
finally:
# Ensure resources are closed regardless of unlock success
try:
if lock_cur and not lock_cur.closed:
if lock_cur:
lock_cur.close()
except Exception as e:
logger.error(f"Error closing advisory lock cursor: {e}", exc_info=True)
try:
if lock_conn and not lock_conn.closed:
if lock_conn:
lock_conn.close()
logger.info("Closed database connection that held advisory lock.")
except Exception as e:
logger.error(f"Error closing advisory lock connection: {e}", exc_info=True)
else:
logger.warning("Attempted to release lock, but connection/cursor not found.")
logger.info("No PostgreSQL advisory lock to release (likely using SQLite or non-PostgreSQL database).")
async def start_scheduler_with_leader_election(server: SyncServer):
@@ -236,10 +250,18 @@ async def shutdown_scheduler_and_release_lock():
logger.info("Shutting down: Leader instance stopping scheduler and releasing lock.")
if scheduler.running:
try:
scheduler.shutdown() # wait=True by default
# Force synchronous shutdown to prevent callback scheduling
scheduler.shutdown(wait=True)
# wait for any internal cleanup to complete
await asyncio.sleep(0.1)
logger.info("APScheduler shut down.")
except Exception as e:
logger.error(f"Error shutting down APScheduler: {e}", exc_info=True)
# Handle SchedulerNotRunningError and other shutdown exceptions
logger.warning(f"Exception during APScheduler shutdown: {e}")
if "not running" not in str(e).lower():
logger.error(f"Unexpected error shutting down APScheduler: {e}", exc_info=True)
await _release_advisory_lock()
_is_scheduler_leader = False # Update state after cleanup
@@ -247,9 +269,11 @@ async def shutdown_scheduler_and_release_lock():
logger.info("Shutting down: Non-leader instance.")
# Final cleanup check for scheduler state (belt and suspenders)
if scheduler.running:
logger.warning("Scheduler still running after shutdown logic completed? Forcing shutdown.")
try:
# This should rarely be needed if shutdown logic above worked correctly
try:
if scheduler.running:
logger.warning("Scheduler still running after shutdown logic completed? Forcing shutdown.")
scheduler.shutdown(wait=False)
except:
pass
except Exception as e:
# Catch SchedulerNotRunningError and other shutdown exceptions
logger.debug(f"Expected exception during final scheduler cleanup: {e}")

View File

@@ -3,6 +3,7 @@ import json
import logging
import os
import sys
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional
@@ -12,10 +13,11 @@ from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from letta.__init__ import __version__
from letta.__init__ import __version__ as letta_version
from letta.agents.exceptions import IncompatibleAgentType
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaUserNotFoundError
from letta.jobs.scheduler import start_scheduler_with_leader_election
from letta.log import get_logger
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
from letta.schemas.letta_message import create_letta_message_union_schema
@@ -25,6 +27,7 @@ from letta.schemas.letta_message_content import (
create_letta_user_message_content_union_schema,
)
from letta.server.constants import REST_DEFAULT_PORT
from letta.server.db import db_registry
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right?
@@ -94,9 +97,7 @@ random_password = os.getenv("LETTA_SERVER_PASSWORD") or generate_password()
class CheckPasswordMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
# Exclude health check endpoint from password protection
if request.url.path in {"/v1/health", "/v1/health/", "/latest/health/"}:
return await call_next(request)
@@ -113,11 +114,46 @@ class CheckPasswordMiddleware(BaseHTTPMiddleware):
)
@asynccontextmanager
async def lifespan(app_: FastAPI):
"""
FastAPI lifespan context manager with setup before the app starts pre-yield and on shutdown after the yield.
"""
worker_id = os.getpid()
logger.info(f"[Worker {worker_id}] Starting lifespan initialization")
logger.info(f"[Worker {worker_id}] Initializing database connections")
db_registry.initialize_sync()
db_registry.initialize_async()
logger.info(f"[Worker {worker_id}] Database connections initialized")
logger.info(f"[Worker {worker_id}] Starting scheduler with leader election")
global server
try:
await start_scheduler_with_leader_election(server)
logger.info(f"[Worker {worker_id}] Scheduler initialization completed")
except Exception as e:
logger.error(f"[Worker {worker_id}] Scheduler initialization failed: {e}", exc_info=True)
logger.info(f"[Worker {worker_id}] Lifespan startup completed")
yield
# Cleanup on shutdown
logger.info(f"[Worker {worker_id}] Starting lifespan shutdown")
try:
from letta.jobs.scheduler import shutdown_scheduler_and_release_lock
await shutdown_scheduler_and_release_lock()
logger.info(f"[Worker {worker_id}] Scheduler shutdown completed")
except Exception as e:
logger.error(f"[Worker {worker_id}] Scheduler shutdown failed: {e}", exc_info=True)
logger.info(f"[Worker {worker_id}] Lifespan shutdown completed")
def create_application() -> "FastAPI":
"""the application start routine"""
# global server
# server = SyncServer(default_interface_factory=lambda: interface())
print(f"\n[[ Letta server // v{__version__} ]]")
print(f"\n[[ Letta server // v{letta_version} ]]")
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
import sentry_sdk
@@ -136,8 +172,9 @@ def create_application() -> "FastAPI":
# openapi_tags=TAGS_METADATA,
title="Letta",
summary="Create LLM agents with long-term memory and custom tools 📚🦙",
version="1.0.0", # TODO wire this up to the version in the package
version=letta_version,
debug=debug_mode, # if True, the stack trace will be printed in the response
lifespan=lifespan,
)
@app.exception_handler(IncompatibleAgentType)