fix: batch jobs not being polled (#3034)
This commit is contained in:
@@ -492,30 +492,32 @@ class LettaAgentBatch(BaseAgent):
|
||||
msg_map: Dict[str, List[Message]],
|
||||
) -> Tuple[List[LettaBatchRequest], Dict[str, AgentStepState]]:
|
||||
# who continues?
|
||||
continues = [aid for aid, cont in ctx.should_continue_map.items() if cont]
|
||||
continues = [agent_id for agent_id, cont in ctx.should_continue_map.items() if cont]
|
||||
|
||||
success_flag_map = {aid: result.success_flag for aid, result in exec_results}
|
||||
|
||||
batch_reqs: List[LettaBatchRequest] = []
|
||||
for aid in continues:
|
||||
for agent_id in continues:
|
||||
heartbeat = create_heartbeat_system_message(
|
||||
agent_id=aid,
|
||||
model=ctx.agent_state_map[aid].llm_config.model,
|
||||
function_call_success=success_flag_map[aid],
|
||||
agent_id=agent_id,
|
||||
model=ctx.agent_state_map[agent_id].llm_config.model,
|
||||
function_call_success=success_flag_map[agent_id],
|
||||
timezone=ctx.agent_state_map[agent_id].timezone,
|
||||
actor=self.actor,
|
||||
)
|
||||
batch_reqs.append(
|
||||
LettaBatchRequest(
|
||||
agent_id=aid, messages=[MessageCreate.model_validate(heartbeat.model_dump(include={"role", "content", "name", "otid"}))]
|
||||
agent_id=agent_id,
|
||||
messages=[MessageCreate.model_validate(heartbeat.model_dump(include={"role", "content", "name", "otid"}))],
|
||||
)
|
||||
)
|
||||
|
||||
# extend in‑context ids when necessary
|
||||
for aid, new_msgs in msg_map.items():
|
||||
ast = ctx.agent_state_map[aid]
|
||||
for agent_id, new_msgs in msg_map.items():
|
||||
ast = ctx.agent_state_map[agent_id]
|
||||
if not ast.message_buffer_autoclear:
|
||||
await self.agent_manager.set_in_context_messages_async(
|
||||
agent_id=aid,
|
||||
agent_id=agent_id,
|
||||
message_ids=ast.message_ids + [m.id for m in new_msgs],
|
||||
actor=self.actor,
|
||||
)
|
||||
@@ -605,7 +607,8 @@ class LettaAgentBatch(BaseAgent):
|
||||
|
||||
return tool_call_name, tool_args, continue_stepping
|
||||
|
||||
def _prepare_tools_per_agent(self, agent_state: AgentState, tool_rules_solver: ToolRulesSolver) -> List[dict]:
|
||||
@staticmethod
|
||||
def _prepare_tools_per_agent(agent_state: AgentState, tool_rules_solver: ToolRulesSolver) -> List[dict]:
|
||||
tools = [t for t in agent_state.tools if t.tool_type in {ToolType.CUSTOM, ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE}]
|
||||
valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools]))
|
||||
return [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
|
||||
@@ -621,7 +624,9 @@ class LettaAgentBatch(BaseAgent):
|
||||
return in_context_messages
|
||||
|
||||
# Not used in batch.
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> LettaResponse:
|
||||
async def step(
|
||||
self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, run_id: str | None = None
|
||||
) -> LettaResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def step_stream(
|
||||
|
||||
@@ -36,25 +36,44 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool:
|
||||
# Use a temporary connection context for the attempt initially
|
||||
with db_context() as session:
|
||||
engine = session.get_bind()
|
||||
# Get raw connection - MUST be kept open if lock is acquired
|
||||
raw_conn = engine.raw_connection()
|
||||
cur = raw_conn.cursor()
|
||||
engine_name = engine.name
|
||||
logger.info(f"Database engine type: {engine_name}")
|
||||
|
||||
cur.execute("SELECT pg_try_advisory_lock(CAST(%s AS bigint))", (ADVISORY_LOCK_KEY,))
|
||||
acquired_lock = cur.fetchone()[0]
|
||||
if engine_name != "postgresql":
|
||||
logger.warning(f"Advisory locks not supported for {engine_name} database. Starting scheduler without leader election.")
|
||||
acquired_lock = True # For SQLite, assume we can start the scheduler
|
||||
else:
|
||||
# Get raw connection - MUST be kept open if lock is acquired
|
||||
raw_conn = engine.raw_connection()
|
||||
cur = raw_conn.cursor()
|
||||
|
||||
cur.execute("SELECT pg_try_advisory_lock(CAST(%s AS bigint))", (ADVISORY_LOCK_KEY,))
|
||||
acquired_lock = cur.fetchone()[0]
|
||||
|
||||
if not acquired_lock:
|
||||
cur.close()
|
||||
raw_conn.close()
|
||||
if cur:
|
||||
cur.close()
|
||||
if raw_conn:
|
||||
raw_conn.close()
|
||||
logger.info("Scheduler lock held by another instance.")
|
||||
return False
|
||||
|
||||
# --- Lock Acquired ---
|
||||
logger.info("Acquired scheduler lock.")
|
||||
_advisory_lock_conn = raw_conn # Keep connection for lock duration
|
||||
_advisory_lock_cur = cur # Keep cursor for lock duration
|
||||
raw_conn = None # Prevent closing in finally block
|
||||
cur = None # Prevent closing in finally block
|
||||
if engine_name == "postgresql":
|
||||
logger.info("Acquired PostgreSQL advisory lock.")
|
||||
_advisory_lock_conn = raw_conn # Keep connection for lock duration
|
||||
_advisory_lock_cur = cur # Keep cursor for lock duration
|
||||
raw_conn = None # Prevent closing in finally block
|
||||
cur = None # Prevent closing in finally block
|
||||
else:
|
||||
logger.info("Starting scheduler for non-PostgreSQL database.")
|
||||
# For SQLite, we don't need to keep the connection open
|
||||
if cur:
|
||||
cur.close()
|
||||
if raw_conn:
|
||||
raw_conn.close()
|
||||
raw_conn = None
|
||||
cur = None
|
||||
|
||||
trigger = IntervalTrigger(
|
||||
seconds=settings.poll_running_llm_batches_interval_seconds,
|
||||
@@ -157,35 +176,30 @@ async def _release_advisory_lock():
|
||||
_advisory_lock_conn = None # Clear global immediately
|
||||
|
||||
if lock_cur is not None and lock_conn is not None:
|
||||
logger.info(f"Attempting to release advisory lock {ADVISORY_LOCK_KEY}")
|
||||
logger.info(f"Attempting to release PostgreSQL advisory lock {ADVISORY_LOCK_KEY}")
|
||||
try:
|
||||
if not lock_conn.closed:
|
||||
if not lock_cur.closed:
|
||||
lock_cur.execute("SELECT pg_advisory_unlock(CAST(%s AS bigint))", (ADVISORY_LOCK_KEY,))
|
||||
lock_cur.fetchone() # Consume result
|
||||
lock_conn.commit()
|
||||
logger.info(f"Executed pg_advisory_unlock for lock {ADVISORY_LOCK_KEY}")
|
||||
else:
|
||||
logger.warning("Advisory lock cursor closed before unlock.")
|
||||
else:
|
||||
logger.warning("Advisory lock connection closed before unlock.")
|
||||
# Try to execute unlock - connection/cursor validity is checked by attempting the operation
|
||||
lock_cur.execute("SELECT pg_advisory_unlock(CAST(%s AS bigint))", (ADVISORY_LOCK_KEY,))
|
||||
lock_cur.fetchone() # Consume result
|
||||
lock_conn.commit()
|
||||
logger.info(f"Executed pg_advisory_unlock for lock {ADVISORY_LOCK_KEY}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing pg_advisory_unlock: {e}", exc_info=True)
|
||||
finally:
|
||||
# Ensure resources are closed regardless of unlock success
|
||||
try:
|
||||
if lock_cur and not lock_cur.closed:
|
||||
if lock_cur:
|
||||
lock_cur.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing advisory lock cursor: {e}", exc_info=True)
|
||||
try:
|
||||
if lock_conn and not lock_conn.closed:
|
||||
if lock_conn:
|
||||
lock_conn.close()
|
||||
logger.info("Closed database connection that held advisory lock.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing advisory lock connection: {e}", exc_info=True)
|
||||
else:
|
||||
logger.warning("Attempted to release lock, but connection/cursor not found.")
|
||||
logger.info("No PostgreSQL advisory lock to release (likely using SQLite or non-PostgreSQL database).")
|
||||
|
||||
|
||||
async def start_scheduler_with_leader_election(server: SyncServer):
|
||||
@@ -236,10 +250,18 @@ async def shutdown_scheduler_and_release_lock():
|
||||
logger.info("Shutting down: Leader instance stopping scheduler and releasing lock.")
|
||||
if scheduler.running:
|
||||
try:
|
||||
scheduler.shutdown() # wait=True by default
|
||||
# Force synchronous shutdown to prevent callback scheduling
|
||||
scheduler.shutdown(wait=True)
|
||||
|
||||
# wait for any internal cleanup to complete
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
logger.info("APScheduler shut down.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error shutting down APScheduler: {e}", exc_info=True)
|
||||
# Handle SchedulerNotRunningError and other shutdown exceptions
|
||||
logger.warning(f"Exception during APScheduler shutdown: {e}")
|
||||
if "not running" not in str(e).lower():
|
||||
logger.error(f"Unexpected error shutting down APScheduler: {e}", exc_info=True)
|
||||
|
||||
await _release_advisory_lock()
|
||||
_is_scheduler_leader = False # Update state after cleanup
|
||||
@@ -247,9 +269,11 @@ async def shutdown_scheduler_and_release_lock():
|
||||
logger.info("Shutting down: Non-leader instance.")
|
||||
|
||||
# Final cleanup check for scheduler state (belt and suspenders)
|
||||
if scheduler.running:
|
||||
logger.warning("Scheduler still running after shutdown logic completed? Forcing shutdown.")
|
||||
try:
|
||||
# This should rarely be needed if shutdown logic above worked correctly
|
||||
try:
|
||||
if scheduler.running:
|
||||
logger.warning("Scheduler still running after shutdown logic completed? Forcing shutdown.")
|
||||
scheduler.shutdown(wait=False)
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
# Catch SchedulerNotRunningError and other shutdown exceptions
|
||||
logger.debug(f"Expected exception during final scheduler cleanup: {e}")
|
||||
|
||||
@@ -3,6 +3,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -12,10 +13,11 @@ from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from letta.__init__ import __version__
|
||||
from letta.__init__ import __version__ as letta_version
|
||||
from letta.agents.exceptions import IncompatibleAgentType
|
||||
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
|
||||
from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaUserNotFoundError
|
||||
from letta.jobs.scheduler import start_scheduler_with_leader_election
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
|
||||
from letta.schemas.letta_message import create_letta_message_union_schema
|
||||
@@ -25,6 +27,7 @@ from letta.schemas.letta_message_content import (
|
||||
create_letta_user_message_content_union_schema,
|
||||
)
|
||||
from letta.server.constants import REST_DEFAULT_PORT
|
||||
from letta.server.db import db_registry
|
||||
|
||||
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
|
||||
from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right?
|
||||
@@ -94,9 +97,7 @@ random_password = os.getenv("LETTA_SERVER_PASSWORD") or generate_password()
|
||||
|
||||
|
||||
class CheckPasswordMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
async def dispatch(self, request, call_next):
|
||||
|
||||
# Exclude health check endpoint from password protection
|
||||
if request.url.path in {"/v1/health", "/v1/health/", "/latest/health/"}:
|
||||
return await call_next(request)
|
||||
@@ -113,11 +114,46 @@ class CheckPasswordMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app_: FastAPI):
|
||||
"""
|
||||
FastAPI lifespan context manager with setup before the app starts pre-yield and on shutdown after the yield.
|
||||
"""
|
||||
worker_id = os.getpid()
|
||||
|
||||
logger.info(f"[Worker {worker_id}] Starting lifespan initialization")
|
||||
logger.info(f"[Worker {worker_id}] Initializing database connections")
|
||||
db_registry.initialize_sync()
|
||||
db_registry.initialize_async()
|
||||
logger.info(f"[Worker {worker_id}] Database connections initialized")
|
||||
|
||||
logger.info(f"[Worker {worker_id}] Starting scheduler with leader election")
|
||||
global server
|
||||
try:
|
||||
await start_scheduler_with_leader_election(server)
|
||||
logger.info(f"[Worker {worker_id}] Scheduler initialization completed")
|
||||
except Exception as e:
|
||||
logger.error(f"[Worker {worker_id}] Scheduler initialization failed: {e}", exc_info=True)
|
||||
logger.info(f"[Worker {worker_id}] Lifespan startup completed")
|
||||
yield
|
||||
|
||||
# Cleanup on shutdown
|
||||
logger.info(f"[Worker {worker_id}] Starting lifespan shutdown")
|
||||
try:
|
||||
from letta.jobs.scheduler import shutdown_scheduler_and_release_lock
|
||||
|
||||
await shutdown_scheduler_and_release_lock()
|
||||
logger.info(f"[Worker {worker_id}] Scheduler shutdown completed")
|
||||
except Exception as e:
|
||||
logger.error(f"[Worker {worker_id}] Scheduler shutdown failed: {e}", exc_info=True)
|
||||
logger.info(f"[Worker {worker_id}] Lifespan shutdown completed")
|
||||
|
||||
|
||||
def create_application() -> "FastAPI":
|
||||
"""the application start routine"""
|
||||
# global server
|
||||
# server = SyncServer(default_interface_factory=lambda: interface())
|
||||
print(f"\n[[ Letta server // v{__version__} ]]")
|
||||
print(f"\n[[ Letta server // v{letta_version} ]]")
|
||||
|
||||
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
|
||||
import sentry_sdk
|
||||
@@ -136,8 +172,9 @@ def create_application() -> "FastAPI":
|
||||
# openapi_tags=TAGS_METADATA,
|
||||
title="Letta",
|
||||
summary="Create LLM agents with long-term memory and custom tools 📚🦙",
|
||||
version="1.0.0", # TODO wire this up to the version in the package
|
||||
version=letta_version,
|
||||
debug=debug_mode, # if True, the stack trace will be printed in the response
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
@app.exception_handler(IncompatibleAgentType)
|
||||
|
||||
Reference in New Issue
Block a user