diff --git a/letta/server/db.py b/letta/server/db.py index 11365522..dd484bca 100644 --- a/letta/server/db.py +++ b/letta/server/db.py @@ -1,3 +1,4 @@ +import asyncio import os import threading import time @@ -116,6 +117,13 @@ class DatabaseRegistry: self.config = LettaConfig.load() self.logger = get_logger(__name__) + if settings.db_max_concurrent_sessions: + self._db_semaphore = asyncio.Semaphore(settings.db_max_concurrent_sessions) + self.logger.info(f"Initialized database throttling with max {settings.db_max_concurrent_sessions} concurrent sessions") + else: + self.logger.info("Database throttling is disabled") + self._db_semaphore = None + def initialize_sync(self, force: bool = False) -> None: """Initialize the synchronous database engine if not already initialized.""" with self._lock: @@ -364,16 +372,28 @@ class DatabaseRegistry: @trace_method @asynccontextmanager async def async_session(self, name: str = "default") -> AsyncGenerator[AsyncSession, None]: - """Async context manager for database sessions.""" - session_factory = self.get_async_session_factory(name) - if not session_factory: - raise ValueError(f"No async session factory found for '{name}' or async database is not configured") + """Async context manager for database sessions with throttling.""" + if self._db_semaphore: + async with self._db_semaphore: + session_factory = self.get_async_session_factory(name) + if not session_factory: + raise ValueError(f"No async session factory found for '{name}' or async database is not configured") - session = session_factory() - try: - yield session - finally: - await session.close() + session = session_factory() + try: + yield session + finally: + await session.close() + else: + session_factory = self.get_async_session_factory(name) + if not session_factory: + raise ValueError(f"No async session factory found for '{name}' or async database is not configured") + + session = session_factory() + try: + yield session + finally: + await session.close() @trace_method def session_caller_trace(self, caller_info: str): diff --git a/letta/settings.py b/letta/settings.py index e438faa2..619fd466 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -214,6 +214,7 @@ class Settings(BaseSettings): pool_pre_ping: bool = True # Pre ping to check for dead connections pool_use_lifo: bool = True disable_sqlalchemy_pooling: bool = False + db_max_concurrent_sessions: Optional[int] = None redis_host: Optional[str] = Field(default=None, description="Host for Redis instance") redis_port: Optional[int] = Field(default=6379, description="Port for Redis instance")