From 2c9647424d3262f6b90c840860c31cd1223b98d0 Mon Sep 17 00:00:00 2001 From: Andy Li <55300002+cliandy@users.noreply.github.com> Date: Wed, 25 Jun 2025 17:07:21 -0700 Subject: [PATCH] fix: batch jobs not being polled (#3034) --- letta/agents/letta_agent_batch.py | 27 ++++++---- letta/jobs/scheduler.py | 90 +++++++++++++++++++------------ letta/server/rest_api/app.py | 47 ++++++++++++++-- 3 files changed, 115 insertions(+), 49 deletions(-) diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index da94511f..456633d4 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -492,30 +492,32 @@ class LettaAgentBatch(BaseAgent): msg_map: Dict[str, List[Message]], ) -> Tuple[List[LettaBatchRequest], Dict[str, AgentStepState]]: # who continues? - continues = [aid for aid, cont in ctx.should_continue_map.items() if cont] + continues = [agent_id for agent_id, cont in ctx.should_continue_map.items() if cont] success_flag_map = {aid: result.success_flag for aid, result in exec_results} batch_reqs: List[LettaBatchRequest] = [] - for aid in continues: + for agent_id in continues: heartbeat = create_heartbeat_system_message( - agent_id=aid, - model=ctx.agent_state_map[aid].llm_config.model, - function_call_success=success_flag_map[aid], + agent_id=agent_id, + model=ctx.agent_state_map[agent_id].llm_config.model, + function_call_success=success_flag_map[agent_id], + timezone=ctx.agent_state_map[agent_id].timezone, actor=self.actor, ) batch_reqs.append( LettaBatchRequest( - agent_id=aid, messages=[MessageCreate.model_validate(heartbeat.model_dump(include={"role", "content", "name", "otid"}))] + agent_id=agent_id, + messages=[MessageCreate.model_validate(heartbeat.model_dump(include={"role", "content", "name", "otid"}))], ) ) # extend in‑context ids when necessary - for aid, new_msgs in msg_map.items(): - ast = ctx.agent_state_map[aid] + for agent_id, new_msgs in msg_map.items(): + ast = ctx.agent_state_map[agent_id] if not ast.message_buffer_autoclear: await self.agent_manager.set_in_context_messages_async( - agent_id=aid, + agent_id=agent_id, message_ids=ast.message_ids + [m.id for m in new_msgs], actor=self.actor, ) @@ -605,7 +607,8 @@ class LettaAgentBatch(BaseAgent): return tool_call_name, tool_args, continue_stepping - def _prepare_tools_per_agent(self, agent_state: AgentState, tool_rules_solver: ToolRulesSolver) -> List[dict]: + @staticmethod + def _prepare_tools_per_agent(agent_state: AgentState, tool_rules_solver: ToolRulesSolver) -> List[dict]: tools = [t for t in agent_state.tools if t.tool_type in {ToolType.CUSTOM, ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE}] valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools])) return [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)] @@ -621,7 +624,9 @@ class LettaAgentBatch(BaseAgent): return in_context_messages # Not used in batch. - async def step(self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS) -> LettaResponse: + async def step( + self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, run_id: str | None = None + ) -> LettaResponse: raise NotImplementedError async def step_stream( diff --git a/letta/jobs/scheduler.py b/letta/jobs/scheduler.py index 6e7dad00..b453aea9 100644 --- a/letta/jobs/scheduler.py +++ b/letta/jobs/scheduler.py @@ -36,25 +36,44 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool: # Use a temporary connection context for the attempt initially with db_context() as session: engine = session.get_bind() - # Get raw connection - MUST be kept open if lock is acquired - raw_conn = engine.raw_connection() - cur = raw_conn.cursor() + engine_name = engine.name + logger.info(f"Database engine type: {engine_name}") - cur.execute("SELECT pg_try_advisory_lock(CAST(%s AS bigint))", (ADVISORY_LOCK_KEY,)) - acquired_lock = cur.fetchone()[0] + if engine_name != "postgresql": + logger.warning(f"Advisory locks not supported for {engine_name} database. Starting scheduler without leader election.") + acquired_lock = True # For SQLite, assume we can start the scheduler + else: + # Get raw connection - MUST be kept open if lock is acquired + raw_conn = engine.raw_connection() + cur = raw_conn.cursor() + + cur.execute("SELECT pg_try_advisory_lock(CAST(%s AS bigint))", (ADVISORY_LOCK_KEY,)) + acquired_lock = cur.fetchone()[0] if not acquired_lock: - cur.close() - raw_conn.close() + if cur: + cur.close() + if raw_conn: + raw_conn.close() logger.info("Scheduler lock held by another instance.") return False # --- Lock Acquired --- - logger.info("Acquired scheduler lock.") - _advisory_lock_conn = raw_conn # Keep connection for lock duration - _advisory_lock_cur = cur # Keep cursor for lock duration - raw_conn = None # Prevent closing in finally block - cur = None # Prevent closing in finally block + if engine_name == "postgresql": + logger.info("Acquired PostgreSQL advisory lock.") + _advisory_lock_conn = raw_conn # Keep connection for lock duration + _advisory_lock_cur = cur # Keep cursor for lock duration + raw_conn = None # Prevent closing in finally block + cur = None # Prevent closing in finally block + else: + logger.info("Starting scheduler for non-PostgreSQL database.") + # For SQLite, we don't need to keep the connection open + if cur: + cur.close() + if raw_conn: + raw_conn.close() + raw_conn = None + cur = None trigger = IntervalTrigger( seconds=settings.poll_running_llm_batches_interval_seconds, @@ -157,35 +176,30 @@ async def _release_advisory_lock(): _advisory_lock_conn = None # Clear global immediately if lock_cur is not None and lock_conn is not None: - logger.info(f"Attempting to release advisory lock {ADVISORY_LOCK_KEY}") + logger.info(f"Attempting to release PostgreSQL advisory lock {ADVISORY_LOCK_KEY}") try: - if not lock_conn.closed: - if not lock_cur.closed: - lock_cur.execute("SELECT pg_advisory_unlock(CAST(%s AS bigint))", (ADVISORY_LOCK_KEY,)) - lock_cur.fetchone() # Consume result - lock_conn.commit() - logger.info(f"Executed pg_advisory_unlock for lock {ADVISORY_LOCK_KEY}") - else: - logger.warning("Advisory lock cursor closed before unlock.") - else: - logger.warning("Advisory lock connection closed before unlock.") + # Try to execute unlock - connection/cursor validity is checked by attempting the operation + lock_cur.execute("SELECT pg_advisory_unlock(CAST(%s AS bigint))", (ADVISORY_LOCK_KEY,)) + lock_cur.fetchone() # Consume result + lock_conn.commit() + logger.info(f"Executed pg_advisory_unlock for lock {ADVISORY_LOCK_KEY}") except Exception as e: logger.error(f"Error executing pg_advisory_unlock: {e}", exc_info=True) finally: # Ensure resources are closed regardless of unlock success try: - if lock_cur and not lock_cur.closed: + if lock_cur: lock_cur.close() except Exception as e: logger.error(f"Error closing advisory lock cursor: {e}", exc_info=True) try: - if lock_conn and not lock_conn.closed: + if lock_conn: lock_conn.close() logger.info("Closed database connection that held advisory lock.") except Exception as e: logger.error(f"Error closing advisory lock connection: {e}", exc_info=True) else: - logger.warning("Attempted to release lock, but connection/cursor not found.") + logger.info("No PostgreSQL advisory lock to release (likely using SQLite or non-PostgreSQL database).") async def start_scheduler_with_leader_election(server: SyncServer): @@ -236,10 +250,18 @@ async def shutdown_scheduler_and_release_lock(): logger.info("Shutting down: Leader instance stopping scheduler and releasing lock.") if scheduler.running: try: - scheduler.shutdown() # wait=True by default + # Force synchronous shutdown to prevent callback scheduling + scheduler.shutdown(wait=True) + + # wait for any internal cleanup to complete + await asyncio.sleep(0.1) + logger.info("APScheduler shut down.") except Exception as e: - logger.error(f"Error shutting down APScheduler: {e}", exc_info=True) + # Handle SchedulerNotRunningError and other shutdown exceptions + logger.warning(f"Exception during APScheduler shutdown: {e}") + if "not running" not in str(e).lower(): + logger.error(f"Unexpected error shutting down APScheduler: {e}", exc_info=True) await _release_advisory_lock() _is_scheduler_leader = False # Update state after cleanup @@ -247,9 +269,11 @@ async def shutdown_scheduler_and_release_lock(): logger.info("Shutting down: Non-leader instance.") # Final cleanup check for scheduler state (belt and suspenders) - if scheduler.running: - logger.warning("Scheduler still running after shutdown logic completed? Forcing shutdown.") - try: + # This should rarely be needed if shutdown logic above worked correctly + try: + if scheduler.running: + logger.warning("Scheduler still running after shutdown logic completed? Forcing shutdown.") scheduler.shutdown(wait=False) - except: - pass + except Exception as e: + # Catch SchedulerNotRunningError and other shutdown exceptions + logger.debug(f"Expected exception during final scheduler cleanup: {e}") diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index 581aa20e..a3075a8e 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -3,6 +3,7 @@ import json import logging import os import sys +from contextlib import asynccontextmanager from pathlib import Path from typing import Optional @@ -12,10 +13,11 @@ from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.cors import CORSMiddleware -from letta.__init__ import __version__ +from letta.__init__ import __version__ as letta_version from letta.agents.exceptions import IncompatibleAgentType from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaUserNotFoundError +from letta.jobs.scheduler import start_scheduler_with_leader_election from letta.log import get_logger from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError from letta.schemas.letta_message import create_letta_message_union_schema @@ -25,6 +27,7 @@ from letta.schemas.letta_message_content import ( create_letta_user_message_content_union_schema, ) from letta.server.constants import REST_DEFAULT_PORT +from letta.server.db import db_registry # NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right? @@ -94,9 +97,7 @@ random_password = os.getenv("LETTA_SERVER_PASSWORD") or generate_password() class CheckPasswordMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): - # Exclude health check endpoint from password protection if request.url.path in {"/v1/health", "/v1/health/", "/latest/health/"}: return await call_next(request) @@ -113,11 +114,46 @@ class CheckPasswordMiddleware(BaseHTTPMiddleware): ) +@asynccontextmanager +async def lifespan(app_: FastAPI): + """ + FastAPI lifespan context manager with setup before the app starts pre-yield and on shutdown after the yield. + """ + worker_id = os.getpid() + + logger.info(f"[Worker {worker_id}] Starting lifespan initialization") + logger.info(f"[Worker {worker_id}] Initializing database connections") + db_registry.initialize_sync() + db_registry.initialize_async() + logger.info(f"[Worker {worker_id}] Database connections initialized") + + logger.info(f"[Worker {worker_id}] Starting scheduler with leader election") + global server + try: + await start_scheduler_with_leader_election(server) + logger.info(f"[Worker {worker_id}] Scheduler initialization completed") + except Exception as e: + logger.error(f"[Worker {worker_id}] Scheduler initialization failed: {e}", exc_info=True) + logger.info(f"[Worker {worker_id}] Lifespan startup completed") + yield + + # Cleanup on shutdown + logger.info(f"[Worker {worker_id}] Starting lifespan shutdown") + try: + from letta.jobs.scheduler import shutdown_scheduler_and_release_lock + + await shutdown_scheduler_and_release_lock() + logger.info(f"[Worker {worker_id}] Scheduler shutdown completed") + except Exception as e: + logger.error(f"[Worker {worker_id}] Scheduler shutdown failed: {e}", exc_info=True) + logger.info(f"[Worker {worker_id}] Lifespan shutdown completed") + + def create_application() -> "FastAPI": """the application start routine""" # global server # server = SyncServer(default_interface_factory=lambda: interface()) - print(f"\n[[ Letta server // v{__version__} ]]") + print(f"\n[[ Letta server // v{letta_version} ]]") if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""): import sentry_sdk @@ -136,8 +172,9 @@ def create_application() -> "FastAPI": # openapi_tags=TAGS_METADATA, title="Letta", summary="Create LLM agents with long-term memory and custom tools 📚🦙", - version="1.0.0", # TODO wire this up to the version in the package + version=letta_version, debug=debug_mode, # if True, the stack trace will be printed in the response + lifespan=lifespan, ) @app.exception_handler(IncompatibleAgentType)