diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index 3a1c3659..bae5d07e 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -3,14 +3,21 @@ from typing import Any, AsyncGenerator, List, Optional, Union import openai +from letta.helpers.datetime_helpers import get_utc_time +from letta.log import get_logger +from letta.schemas.agent import AgentState from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage from letta.schemas.letta_message_content import TextContent from letta.schemas.letta_response import LettaResponse -from letta.schemas.message import MessageCreate +from letta.schemas.message import Message, MessageCreate, MessageUpdate from letta.schemas.user import User from letta.services.agent_manager import AgentManager +from letta.services.helpers.agent_manager_helper import compile_system_message from letta.services.message_manager import MessageManager +from letta.utils import united_diff + +logger = get_logger(__name__) class BaseAgent(ABC): @@ -64,3 +71,107 @@ class BaseAgent(ABC): return "" return [{"role": input_message.role.value, "content": get_content(input_message)} for input_message in input_messages] + + def _rebuild_memory( + self, + in_context_messages: List[Message], + agent_state: AgentState, + num_messages: int | None = None, # storing these calculations is specific to the voice agent + num_archival_memories: int | None = None, + ) -> List[Message]: + try: + # Refresh Memory + # TODO: This only happens for the summary block (voice?) + # [DB Call] loading blocks (modifies: agent_state.memory.blocks) + self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor) + + # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this + curr_system_message = in_context_messages[0] + curr_memory_str = agent_state.memory.compile() + curr_system_message_text = curr_system_message.content[0].text + if curr_memory_str in curr_system_message_text: + # NOTE: could this cause issues if a block is removed? (substring match would still work) + logger.debug( + f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" + ) + return in_context_messages + + memory_edit_timestamp = get_utc_time() + + # [DB Call] size of messages and archival memories + num_messages = num_messages or self.message_manager.size(actor=self.actor, agent_id=agent_state.id) + num_archival_memories = num_archival_memories or self.passage_manager.size(actor=self.actor, agent_id=agent_state.id) + + new_system_message_str = compile_system_message( + system_prompt=agent_state.system, + in_context_memory=agent_state.memory, + in_context_memory_last_edit=memory_edit_timestamp, + previous_message_count=num_messages, + archival_memory_size=num_archival_memories, + ) + + diff = united_diff(curr_system_message_text, new_system_message_str) + if len(diff) > 0: + logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}") + + # [DB Call] Update Messages + new_system_message = self.message_manager.update_message_by_id( + curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor + ) + # Skip pulling down the agent's memory again to save on a db call + return [new_system_message] + in_context_messages[1:] + + else: + return in_context_messages + except: + logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})") + raise + + async def _rebuild_memory_async(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]: + """ + Async version of function above. For now before breaking up components, changes should be made in both places. + """ + try: + # [DB Call] loading blocks (modifies: agent_state.memory.blocks) + await self.agent_manager.refresh_memory_async(agent_state=agent_state, actor=self.actor) + + # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this + curr_system_message = in_context_messages[0] + curr_memory_str = agent_state.memory.compile() + curr_system_message_text = curr_system_message.content[0].text + if curr_memory_str in curr_system_message_text: + logger.debug( + f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" + ) + return in_context_messages + + memory_edit_timestamp = get_utc_time() + + # [DB Call] size of messages and archival memories + # todo: blocking for now + num_messages = num_messages or self.message_manager.size(actor=self.actor, agent_id=agent_state.id) + num_archival_memories = num_archival_memories or self.passage_manager.size(actor=self.actor, agent_id=agent_state.id) + + new_system_message_str = compile_system_message( + system_prompt=agent_state.system, + in_context_memory=agent_state.memory, + in_context_memory_last_edit=memory_edit_timestamp, + previous_message_count=num_messages, + archival_memory_size=num_archival_memories, + ) + + diff = united_diff(curr_system_message_text, new_system_message_str) + if len(diff) > 0: + logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}") + + # [DB Call] Update Messages + new_system_message = self.message_manager.update_message_by_id_async( + curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor + ) + return [new_system_message] + in_context_messages[1:] + + else: + return in_context_messages + except: + logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})") + raise diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index ca7e1fb7..a1c6952c 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -32,6 +32,7 @@ from letta.services.helpers.agent_manager_helper import compile_system_message from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager +from letta.settings import settings from letta.system import package_function_response from letta.tracing import log_event, trace_method from letta.utils import united_diff @@ -171,6 +172,7 @@ class LettaAgent(BaseAgent): yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n" @trace_method + # When raising an error this doesn't show up async def _get_ai_reply( self, llm_client: LLMClientBase, @@ -179,7 +181,10 @@ class LettaAgent(BaseAgent): tool_rules_solver: ToolRulesSolver, stream: bool, ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: - in_context_messages = self._rebuild_memory(in_context_messages, agent_state) + if settings.experimental_enable_async_db_engine: + in_context_messages = await self._rebuild_memory_async(in_context_messages, agent_state) + else: + in_context_messages = self._rebuild_memory(in_context_messages, agent_state) tools = [ t @@ -296,51 +301,6 @@ class LettaAgent(BaseAgent): return persisted_messages, continue_stepping - def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]: - try: - self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor) - - # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this - curr_system_message = in_context_messages[0] - curr_memory_str = agent_state.memory.compile() - curr_system_message_text = curr_system_message.content[0].text - if curr_memory_str in curr_system_message_text: - # NOTE: could this cause issues if a block is removed? (substring match would still work) - logger.debug( - f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" - ) - return in_context_messages - - memory_edit_timestamp = get_utc_time() - - num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id) - num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id) - - new_system_message_str = compile_system_message( - system_prompt=agent_state.system, - in_context_memory=agent_state.memory, - in_context_memory_last_edit=memory_edit_timestamp, - previous_message_count=num_messages, - archival_memory_size=num_archival_memories, - ) - - diff = united_diff(curr_system_message_text, new_system_message_str) - if len(diff) > 0: - logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}") - - new_system_message = self.message_manager.update_message_by_id( - curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor - ) - - # Skip pulling down the agent's memory again to save on a db call - return [new_system_message] + in_context_messages[1:] - - else: - return in_context_messages - except: - logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})") - raise - @trace_method async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]: """ diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index 58cb5be7..2c134cb7 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -1,11 +1,12 @@ import json import uuid from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence, Tuple, Union from aiomultiprocess import Pool from anthropic.types.beta.messages import BetaMessageBatchCanceledResult, BetaMessageBatchErroredResult, BetaMessageBatchSucceededResult +from letta.agents.base_agent import BaseAgent from letta.agents.helpers import _prepare_in_context_messages from letta.helpers import ToolRulesSolver from letta.helpers.datetime_helpers import get_utc_time @@ -16,11 +17,12 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.log import get_logger from letta.orm.enums import ToolType from letta.schemas.agent import AgentState, AgentStepState -from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType +from letta.schemas.enums import AgentStepStatus, JobStatus, MessageStreamStatus, ProviderType from letta.schemas.job import JobUpdate +from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.letta_request import LettaBatchRequest -from letta.schemas.letta_response import LettaBatchResponse +from letta.schemas.letta_response import LettaBatchResponse, LettaResponse from letta.schemas.llm_batch_job import LLMBatchItem from letta.schemas.message import Message, MessageCreate, MessageUpdate from letta.schemas.openai.chat_completion_response import ToolCall as OpenAIToolCall @@ -95,7 +97,7 @@ async def execute_tool_wrapper(params: ToolExecutionParams) -> Tuple[str, Tuple[ # TODO: Limitations -> # TODO: Only works with anthropic for now -class LettaAgentBatch: +class LettaAgentBatch(BaseAgent): def __init__( self, @@ -539,43 +541,20 @@ class LettaAgentBatch: return in_context_messages # TODO: Make this a bullk function - def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]: - agent_state = self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor) + def _rebuild_memory( + self, + in_context_messages: List[Message], + agent_state: AgentState, + num_messages: int | None = None, + num_archival_memories: int | None = None, + ) -> List[Message]: + return super()._rebuild_memory(in_context_messages, agent_state) - # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this - curr_system_message = in_context_messages[0] - curr_memory_str = agent_state.memory.compile() - curr_system_message_text = curr_system_message.content[0].text - if curr_memory_str in curr_system_message_text: - # NOTE: could this cause issues if a block is removed? (substring match would still work) - logger.debug( - f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" - ) - return in_context_messages + # Not used in batch. + async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse: + raise NotImplementedError - memory_edit_timestamp = get_utc_time() - - num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id) - num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id) - - new_system_message_str = compile_system_message( - system_prompt=agent_state.system, - in_context_memory=agent_state.memory, - in_context_memory_last_edit=memory_edit_timestamp, - previous_message_count=num_messages, - archival_memory_size=num_archival_memories, - ) - - diff = united_diff(curr_system_message_text, new_system_message_str) - if len(diff) > 0: - logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}") - - new_system_message = self.message_manager.update_message_by_id( - curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor - ) - - # Skip pulling down the agent's memory again to save on a db call - return [new_system_message] + in_context_messages[1:] - - else: - return in_context_messages + async def step_stream( + self, input_messages: List[MessageCreate], max_steps: int = 10 + ) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]: + raise NotImplementedError diff --git a/letta/agents/voice_agent.py b/letta/agents/voice_agent.py index 27c3c7e3..294b0ad8 100644 --- a/letta/agents/voice_agent.py +++ b/letta/agents/voice_agent.py @@ -293,48 +293,17 @@ class VoiceAgent(BaseAgent): agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor ) - def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]: - # Refresh memory - # TODO: This only happens for the summary block - # TODO: We want to extend this refresh to be general, and stick it in agent_manager - block_ids = [block.id for block in agent_state.memory.blocks] - agent_state.memory.blocks = self.block_manager.get_all_blocks_by_ids(block_ids=block_ids, actor=self.actor) - - # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this - curr_system_message = in_context_messages[0] - curr_memory_str = agent_state.memory.compile() - curr_system_message_text = curr_system_message.content[0].text - if curr_memory_str in curr_system_message_text: - # NOTE: could this cause issues if a block is removed? (substring match would still work) - logger.debug( - f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" - ) - return in_context_messages - - memory_edit_timestamp = get_utc_time() - - new_system_message_str = compile_system_message( - system_prompt=agent_state.system, - in_context_memory=agent_state.memory, - in_context_memory_last_edit=memory_edit_timestamp, - previous_message_count=self.num_messages, - archival_memory_size=self.num_archival_memories, + def _rebuild_memory( + self, + in_context_messages: List[Message], + agent_state: AgentState, + num_messages: int | None = None, + num_archival_memories: int | None = None, + ) -> List[Message]: + return super()._rebuild_memory( + in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories ) - diff = united_diff(curr_system_message_text, new_system_message_str) - if len(diff) > 0: - logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}") - - new_system_message = self.message_manager.update_message_by_id( - curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor - ) - - # Skip pulling down the agent's memory again to save on a db call - return [new_system_message] + in_context_messages[1:] - - else: - return in_context_messages - def _build_openai_request(self, openai_messages: List[Dict], agent_state: AgentState) -> ChatCompletionRequest: tool_schemas = self._build_tool_schemas(agent_state) tool_choice = "auto" if tool_schemas else None diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index ca2e19b4..dcb4cebf 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union from sqlalchemy import String, and_, func, or_, select from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Mapped, Session, mapped_column from letta.log import get_logger @@ -300,6 +301,44 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): raise NoResultFound(f"{cls.__name__} not found with {', '.join(conditions if conditions else ['no conditions'])}") return found[0] + @classmethod + @handle_db_timeout + async def read_async( + cls, + db_session: "Session", + identifier: Optional[str] = None, + actor: Optional["User"] = None, + access: Optional[List[Literal["read", "write", "admin"]]] = ["read"], + access_type: AccessType = AccessType.ORGANIZATION, + **kwargs, + ) -> "SqlalchemyBase": + """The primary accessor for an ORM record. Async version of read method. + Args: + db_session: the database session to use when retrieving the record + identifier: the identifier of the record to read, can be the id string or the UUID object for backwards compatibility + actor: if specified, results will be scoped only to records the user is able to access + access: if actor is specified, records will be filtered to the minimum permission level for the actor + kwargs: additional arguments to pass to the read, used for more complex objects + Returns: + The matching object + Raises: + NoResultFound: if the object is not found + """ + # this is ok because read_multiple will check if the + identifiers = [] if identifier is None else [identifier] + found = await cls.read_multiple_async(db_session, identifiers, actor, access, access_type, **kwargs) + if len(found) == 0: + # for backwards compatibility. + conditions = [] + if identifier: + conditions.append(f"id={identifier}") + if actor: + conditions.append(f"access level in {access} for {actor}") + if hasattr(cls, "is_deleted"): + conditions.append("is_deleted=False") + raise NoResultFound(f"{cls.__name__} not found with {', '.join(conditions if conditions else ['no conditions'])}") + return found[0] + @classmethod @handle_db_timeout def read_multiple( @@ -323,6 +362,38 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): Raises: NoResultFound: if the object is not found """ + query, query_conditions = cls._read_multiple_preprocess(identifiers, actor, access, access_type, **kwargs) + results = db_session.execute(query).scalars().all() + return cls._read_multiple_postprocess(results, identifiers, query_conditions) + + @classmethod + @handle_db_timeout + async def read_multiple_async( + cls, + db_session: "AsyncSession", + identifiers: List[str] = [], + actor: Optional["User"] = None, + access: Optional[List[Literal["read", "write", "admin"]]] = ["read"], + access_type: AccessType = AccessType.ORGANIZATION, + **kwargs, + ) -> List["SqlalchemyBase"]: + """ + Async version of read_multiple(...) + The primary accessor for ORM record(s) + """ + query, query_conditions = cls._read_multiple_preprocess(identifiers, actor, access, access_type, **kwargs) + results = await db_session.execute(query) + return cls._read_multiple_postprocess(results.scalars().all(), identifiers, query_conditions) + + @classmethod + def _read_multiple_preprocess( + cls, + identifiers: List[str], + actor: Optional["User"], + access: Optional[List[Literal["read", "write", "admin"]]], + access_type: AccessType, + **kwargs, + ): logger.debug(f"Reading {cls.__name__} with ID(s): {identifiers} with actor={actor}") # Start the query @@ -350,7 +421,10 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): query = query.where(cls.is_deleted == False) query_conditions.append("is_deleted=False") - results = db_session.execute(query).scalars().all() + return query, query_conditions + + @classmethod + def _read_multiple_postprocess(cls, results, identifiers: List[str], query_conditions) -> List["SqlalchemyBase"]: if results: # if empty list a.k.a. no results if len(identifiers) > 0: # find which identifiers were not found @@ -471,6 +545,22 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): db_session.refresh(self) return self + @handle_db_timeout + async def update_async(self, db_session: AsyncSession, actor: "User | None" = None, no_commit: bool = False) -> "SqlalchemyBase": + """Async version of update function""" + logger.debug(...) + if actor: + self._set_created_and_updated_by_fields(actor.id) + self.set_updated_at() + + db_session.add(self) + if no_commit: + await db_session.flush() + else: + await db_session.commit() + await db_session.refresh(self) + return self + @classmethod @handle_db_timeout def size( diff --git a/letta/serialize_schemas/marshmallow_agent.py b/letta/serialize_schemas/marshmallow_agent.py index 8938586c..55003db1 100644 --- a/letta/serialize_schemas/marshmallow_agent.py +++ b/letta/serialize_schemas/marshmallow_agent.py @@ -1,6 +1,7 @@ from typing import Dict from marshmallow import fields, post_dump, pre_load +from sqlalchemy.orm import sessionmaker import letta from letta.orm import Agent @@ -14,7 +15,6 @@ from letta.serialize_schemas.marshmallow_custom_fields import EmbeddingConfigFie from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema from letta.serialize_schemas.marshmallow_tag import SerializedAgentTagSchema from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema -from letta.server.db import SessionLocal class MarshmallowAgentSchema(BaseSchema): @@ -41,7 +41,7 @@ class MarshmallowAgentSchema(BaseSchema): tool_exec_environment_variables = fields.List(fields.Nested(SerializedAgentEnvironmentVariableSchema)) tags = fields.List(fields.Nested(SerializedAgentTagSchema)) - def __init__(self, *args, session: SessionLocal, actor: User, **kwargs): + def __init__(self, *args, session: sessionmaker, actor: User, **kwargs): super().__init__(*args, actor=actor, **kwargs) self.session = session @@ -60,9 +60,9 @@ class MarshmallowAgentSchema(BaseSchema): After dumping the agent, load all its Message rows and serialize them here. """ # TODO: This is hacky, but want to move fast, please refactor moving forward - from letta.server.db import db_context as session_maker + from letta.server.db import db_registry - with session_maker() as session: + with db_registry.session() as session: agent_id = data.get("id") msgs = ( session.query(MessageModel) diff --git a/letta/server/db.py b/letta/server/db.py index 87d87300..57fbdd53 100644 --- a/letta/server/db.py +++ b/letta/server/db.py @@ -1,28 +1,19 @@ import os import threading -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager +from typing import Any, AsyncGenerator, Generator from rich.console import Console from rich.panel import Panel from rich.text import Text -from sqlalchemy import create_engine +from sqlalchemy import Engine, create_engine +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import sessionmaker from letta.config import LettaConfig from letta.log import get_logger -from letta.orm import Base from letta.settings import settings -# Use globals for the lock and initialization flag -_engine_lock = threading.Lock() -_engine_initialized = False - -# Create variables in global scope but don't initialize them yet -config = LettaConfig.load() -logger = get_logger(__name__) -engine = None -SessionLocal = None - def print_sqlite_schema_error(): """Print a formatted error message for SQLite schema issues""" @@ -54,86 +45,187 @@ def db_error_handler(): exit(1) -def initialize_engine(): - """Initialize the database engine only when needed.""" - global engine, SessionLocal, _engine_initialized +class DatabaseRegistry: + """Registry for database connections and sessions. - with _engine_lock: - # Check again inside the lock to prevent race conditions - if _engine_initialized: - return + This class manages both synchronous and asynchronous database connections + and provides context managers for session handling. + """ - if settings.letta_pg_uri_no_default: - logger.info("Creating postgres engine") - config.recall_storage_type = "postgres" - config.recall_storage_uri = settings.letta_pg_uri_no_default - config.archival_storage_type = "postgres" - config.archival_storage_uri = settings.letta_pg_uri_no_default + def __init__(self): + self._engines: dict[str, Engine] = {} + self._async_engines: dict[str, AsyncEngine] = {} + self._session_factories: dict[str, sessionmaker] = {} + self._async_session_factories: dict[str, async_sessionmaker] = {} + self._initialized: dict[str, bool] = {"sync": False, "async": False} + self._lock = threading.Lock() + self.config = LettaConfig.load() + self.logger = get_logger(__name__) - # create engine - engine = create_engine( - settings.letta_pg_uri, - # f"{settings.letta_pg_uri}?options=-c%20client_encoding=UTF8", - pool_size=settings.pg_pool_size, - max_overflow=settings.pg_max_overflow, - pool_timeout=settings.pg_pool_timeout, - pool_recycle=settings.pg_pool_recycle, - echo=settings.pg_echo, - # connect_args={"client_encoding": "utf8"}, - ) - else: - # TODO: don't rely on config storage - engine_path = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db") - logger.info("Creating sqlite engine " + engine_path) + def initialize_sync(self, force: bool = False) -> None: + """Initialize the synchronous database engine if not already initialized.""" + with self._lock: + if self._initialized.get("sync") and not force: + return - engine = create_engine(engine_path) + # Postgres engine + if settings.letta_pg_uri_no_default: + self.logger.info("Creating postgres engine") + self.config.recall_storage_type = "postgres" + self.config.recall_storage_uri = settings.letta_pg_uri_no_default + self.config.archival_storage_type = "postgres" + self.config.archival_storage_uri = settings.letta_pg_uri_no_default - # Store the original connect method - original_connect = engine.connect + engine = create_engine( + settings.letta_pg_uri, + # f"{settings.letta_pg_uri}?options=-c%20client_encoding=UTF8", + pool_size=settings.pg_pool_size, + max_overflow=settings.pg_max_overflow, + pool_timeout=settings.pg_pool_timeout, + pool_recycle=settings.pg_pool_recycle, + echo=settings.pg_echo, + # connect_args={"client_encoding": "utf8"}, + ) - def wrapped_connect(*args, **kwargs): - with db_error_handler(): - # Get the connection - connection = original_connect(*args, **kwargs) + self._engines["default"] = engine + # SQLite engine + else: + from letta.orm import Base - # Store the original execution method - original_execute = connection.execute + # TODO: don't rely on config storage + engine_path = "sqlite:///" + os.path.join(self.config.recall_storage_path, "sqlite.db") + self.logger.info("Creating sqlite engine " + engine_path) - # Wrap the execute method of the connection - def wrapped_execute(*args, **kwargs): - with db_error_handler(): - return original_execute(*args, **kwargs) + engine = create_engine(engine_path) - # Replace the connection's execute method - connection.execute = wrapped_execute + # Wrap the engine with error handling + self._wrap_sqlite_engine(engine) - return connection + Base.metadata.create_all(bind=engine) + self._engines["default"] = engine - # Replace the engine's connect method - engine.connect = wrapped_connect + # Create session factory + self._session_factories["default"] = sessionmaker(autocommit=False, autoflush=False, bind=self._engines["default"]) + self._initialized["sync"] = True - Base.metadata.create_all(bind=engine) + def initialize_async(self, force: bool = False) -> None: + """Initialize the asynchronous database engine if not already initialized.""" + with self._lock: + if self._initialized.get("async") and not force: + return - # Create the session factory - SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - _engine_initialized = True + if settings.letta_pg_uri_no_default: + self.logger.info("Creating async postgres engine") + + # Create async engine - convert URI to async format + pg_uri = settings.letta_pg_uri + if pg_uri.startswith("postgresql://"): + async_pg_uri = pg_uri.replace("postgresql://", "postgresql+asyncpg://") + else: + async_pg_uri = f"postgresql+asyncpg://{pg_uri.split('://', 1)[1]}" if "://" in pg_uri else pg_uri + + async_engine = create_async_engine( + async_pg_uri, + pool_size=settings.pg_pool_size, + max_overflow=settings.pg_max_overflow, + pool_timeout=settings.pg_pool_timeout, + pool_recycle=settings.pg_pool_recycle, + echo=settings.pg_echo, + ) + + self._async_engines["default"] = async_engine + + # Create async session factory + self._async_session_factories["default"] = async_sessionmaker( + autocommit=False, autoflush=False, bind=self._async_engines["default"], class_=AsyncSession + ) + self._initialized["async"] = True + else: + self.logger.warning("Async SQLite is currently not supported. Please use PostgreSQL for async database operations.") + # TODO (cliandy): unclear around async sqlite support in sqlalchemy, we will not currently support this + self._initialized["async"] = False + + def _wrap_sqlite_engine(self, engine: Engine) -> None: + """Wrap SQLite engine with error handling.""" + original_connect = engine.connect + + def wrapped_connect(*args, **kwargs): + with db_error_handler(): + connection = original_connect(*args, **kwargs) + original_execute = connection.execute + + def wrapped_execute(*args, **kwargs): + with db_error_handler(): + return original_execute(*args, **kwargs) + + connection.execute = wrapped_execute + return connection + + engine.connect = wrapped_connect + + def get_engine(self, name: str = "default") -> Engine: + """Get a database engine by name.""" + self.initialize_sync() + return self._engines.get(name) + + def get_async_engine(self, name: str = "default") -> AsyncEngine: + """Get an async database engine by name.""" + self.initialize_async() + return self._async_engines.get(name) + + def get_session_factory(self, name: str = "default") -> sessionmaker: + """Get a session factory by name.""" + self.initialize_sync() + return self._session_factories.get(name) + + def get_async_session_factory(self, name: str = "default") -> async_sessionmaker: + """Get an async session factory by name.""" + self.initialize_async() + return self._async_session_factories.get(name) + + @contextmanager + def session(self, name: str = "default") -> Generator[Any, None, None]: + """Context manager for database sessions.""" + session_factory = self.get_session_factory(name) + if not session_factory: + raise ValueError(f"No session factory found for '{name}'") + + session = session_factory() + try: + yield session + finally: + session.close() + + @asynccontextmanager + async def async_session(self, name: str = "default") -> AsyncGenerator[AsyncSession, None]: + """Async context manager for database sessions.""" + session_factory = self.get_async_session_factory(name) + if not session_factory: + raise ValueError(f"No async session factory found for '{name}' or async database is not configured") + + session = session_factory() + try: + yield session + finally: + await session.close() + + +# Create a singleton instance +db_registry = DatabaseRegistry() def get_db(): - """Get a database session, initializing the engine if needed.""" - global engine, SessionLocal - - # Make sure engine is initialized - if not _engine_initialized: - initialize_engine() - - # Now SessionLocal should be defined and callable - db = SessionLocal() - try: - yield db - finally: - db.close() + """Get a database session.""" + with db_registry.session() as session: + yield session -# Define db_context as a context manager that uses get_db +async def get_db_async(): + """Get an async database session.""" + async with db_registry.async_session() as session: + yield session + + +# Prefer calling db_registry.session() or db_registry.async_session() directly +# This is for backwards compatibility db_context = contextmanager(get_db) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 190ca8a8..0ff701f4 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -56,6 +56,7 @@ from letta.serialize_schemas import MarshmallowAgentSchema from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema from letta.serialize_schemas.pydantic_agent_schema import AgentSchema +from letta.server.db import db_registry from letta.services.block_manager import BlockManager from letta.services.helpers.agent_manager_helper import ( _apply_filters, @@ -85,9 +86,6 @@ class AgentManager: """Manager class to handle business logic related to Agents.""" def __init__(self): - from letta.server.db import db_context - - self.session_maker = db_context self.block_manager = BlockManager() self.tool_manager = ToolManager() self.source_manager = SourceManager() @@ -200,7 +198,7 @@ class AgentManager: identity_ids = agent_create.identity_ids or [] tag_values = agent_create.tags or [] - with self.session_maker() as session: + with db_registry.session() as session: with session.begin(): name_to_id, id_to_name = self._resolve_tools( session, @@ -356,7 +354,7 @@ class AgentManager: new_idents = set(agent_update.identity_ids or []) new_tags = set(agent_update.tags or []) - with self.session_maker() as session, session.begin(): + with db_registry.session() as session, session.begin(): agent: AgentModel = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) agent.updated_at = datetime.now(timezone.utc) @@ -503,7 +501,7 @@ class AgentManager: Returns: List[PydanticAgentState]: The filtered list of matching agents. """ - with self.session_maker() as session: + with db_registry.session() as session: query = select(AgentModel).distinct(AgentModel.created_at, AgentModel.id) query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION) @@ -541,7 +539,7 @@ class AgentManager: Returns: List[PydanticAgentState: The filtered list of matching agents. """ - with self.session_maker() as session: + with db_registry.session() as session: query = select(AgentModel).where(AgentModel.organization_id == actor.organization_id) if match_all: @@ -569,20 +567,20 @@ class AgentManager: """ Get the total count of agents for the given user. """ - with self.session_maker() as session: + with db_registry.session() as session: return AgentModel.size(db_session=session, actor=actor) @enforce_types def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState: """Fetch an agent by its ID.""" - with self.session_maker() as session: + with db_registry.session() as session: agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) return agent.to_pydantic() @enforce_types def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState: """Fetch an agent by its ID.""" - with self.session_maker() as session: + with db_registry.session() as session: agent = AgentModel.read(db_session=session, name=agent_name, actor=actor) return agent.to_pydantic() @@ -599,7 +597,7 @@ class AgentManager: Raises: NoResultFound: If agent doesn't exist """ - with self.session_maker() as session: + with db_registry.session() as session: # Retrieve the agent logger.debug(f"Hard deleting Agent with ID: {agent_id} with actor={actor}") agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) @@ -635,7 +633,7 @@ class AgentManager: @enforce_types def serialize(self, agent_id: str, actor: PydanticUser) -> AgentSchema: - with self.session_maker() as session: + with db_registry.session() as session: agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) schema = MarshmallowAgentSchema(session=session, actor=actor) data = schema.dump(agent) @@ -665,7 +663,7 @@ class AgentManager: serialized_agent_dict[MarshmallowAgentSchema.FIELD_MESSAGE_IDS] = message_ids - with self.session_maker() as session: + with db_registry.session() as session: schema = MarshmallowAgentSchema(session=session, actor=actor) agent = schema.load(serialized_agent_dict, session=session) @@ -728,7 +726,7 @@ class AgentManager: Returns: PydanticAgentState: The updated agent as a Pydantic model. """ - with self.session_maker() as session: + with db_registry.session() as session: # Retrieve the agent agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) @@ -767,7 +765,7 @@ class AgentManager: @enforce_types def list_groups(self, agent_id: str, actor: PydanticUser, manager_type: Optional[str] = None) -> List[PydanticGroup]: - with self.session_maker() as session: + with db_registry.session() as session: agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) if manager_type: return [group.to_pydantic() for group in agent.groups if group.manager_type == manager_type] @@ -908,7 +906,7 @@ class AgentManager: Returns: PydanticAgentState: The updated agent state with no linked messages. """ - with self.session_maker() as session: + with db_registry.session() as session: # Retrieve the existing agent (will raise NoResultFound if invalid) agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) @@ -985,6 +983,17 @@ class AgentManager: ) return agent_state + @enforce_types + async def refresh_memory_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState: + block_ids = [b.id for b in agent_state.memory.blocks] + if not block_ids: + return agent_state + + agent_state.memory.blocks = await self.block_manager.get_all_blocks_by_ids_async( + block_ids=[b.id for b in agent_state.memory.blocks], actor=actor + ) + return agent_state + # ====================================================================================================================== # Source Management # ====================================================================================================================== @@ -1003,7 +1012,7 @@ class AgentManager: IntegrityError: If the source is already attached to the agent """ - with self.session_maker() as session: + with db_registry.session() as session: # Verify both agent and source exist and user has permission to access them agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) @@ -1056,7 +1065,7 @@ class AgentManager: Returns: List[str]: List of source IDs attached to the agent """ - with self.session_maker() as session: + with db_registry.session() as session: # Verify agent exists and user has permission to access it agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) @@ -1073,7 +1082,7 @@ class AgentManager: source_id: ID of the source to detach actor: User performing the action """ - with self.session_maker() as session: + with db_registry.session() as session: # Verify agent exists and user has permission to access it agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) @@ -1101,7 +1110,7 @@ class AgentManager: actor: PydanticUser, ) -> PydanticBlock: """Gets a block attached to an agent by its label.""" - with self.session_maker() as session: + with db_registry.session() as session: agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) for block in agent.core_memory: if block.label == block_label: @@ -1117,7 +1126,7 @@ class AgentManager: actor: PydanticUser, ) -> PydanticAgentState: """Updates which block is assigned to a specific label for an agent.""" - with self.session_maker() as session: + with db_registry.session() as session: agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) new_block = BlockModel.read(db_session=session, identifier=new_block_id, actor=actor) @@ -1135,7 +1144,7 @@ class AgentManager: @enforce_types def attach_block(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState: """Attaches a block to an agent.""" - with self.session_maker() as session: + with db_registry.session() as session: agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) @@ -1151,7 +1160,7 @@ class AgentManager: actor: PydanticUser, ) -> PydanticAgentState: """Detaches a block from an agent.""" - with self.session_maker() as session: + with db_registry.session() as session: agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) original_length = len(agent.core_memory) @@ -1171,7 +1180,7 @@ class AgentManager: actor: PydanticUser, ) -> PydanticAgentState: """Detaches a block with the specified label from an agent.""" - with self.session_maker() as session: + with db_registry.session() as session: agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) original_length = len(agent.core_memory) @@ -1215,7 +1224,7 @@ class AgentManager: embedded_text = np.array(embedded_text) embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist() - with self.session_maker() as session: + with db_registry.session() as session: # Start with base query for source passages source_passages = None if not agent_only: # Include source passages @@ -1389,7 +1398,7 @@ class AgentManager: agent_only: bool = False, ) -> List[PydanticPassage]: """Lists all passages attached to an agent.""" - with self.session_maker() as session: + with db_registry.session() as session: main_query = self._build_passage_query( actor=actor, agent_id=agent_id, @@ -1447,7 +1456,7 @@ class AgentManager: agent_only: bool = False, ) -> int: """Returns the count of passages matching the given criteria.""" - with self.session_maker() as session: + with db_registry.session() as session: main_query = self._build_passage_query( actor=actor, agent_id=agent_id, @@ -1487,7 +1496,7 @@ class AgentManager: Returns: PydanticAgentState: The updated agent state. """ - with self.session_maker() as session: + with db_registry.session() as session: # Verify the agent exists and user has permission to access it agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) @@ -1522,7 +1531,7 @@ class AgentManager: Returns: PydanticAgentState: The updated agent state. """ - with self.session_maker() as session: + with db_registry.session() as session: # Verify the agent exists and user has permission to access it agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) @@ -1551,7 +1560,7 @@ class AgentManager: Returns: List[PydanticTool]: List of tools attached to the agent. """ - with self.session_maker() as session: + with db_registry.session() as session: agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) return [tool.to_pydantic() for tool in agent.tools] @@ -1574,7 +1583,7 @@ class AgentManager: Returns: List[str]: List of all tags. """ - with self.session_maker() as session: + with db_registry.session() as session: query = ( session.query(AgentsTags.tag) .join(AgentModel, AgentModel.id == AgentsTags.agent_id) diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index e05b7f7b..8f07f380 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -12,6 +12,7 @@ from letta.schemas.agent import AgentState as PydanticAgentState from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import BlockUpdate, Human, Persona from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry from letta.utils import enforce_types, list_human_files, list_persona_files logger = get_logger(__name__) @@ -20,12 +21,6 @@ logger = get_logger(__name__) class BlockManager: """Manager class to handle business logic related to Blocks.""" - def __init__(self): - # Fetching the db_context similarly as in ToolManager - from letta.server.db import db_context - - self.session_maker = db_context - @enforce_types def create_or_update_block(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock: """Create a new block based on the Block schema.""" @@ -34,7 +29,7 @@ class BlockManager: update_data = BlockUpdate(**block.model_dump(to_orm=True, exclude_none=True)) self.update_block(block.id, update_data, actor) else: - with self.session_maker() as session: + with db_registry.session() as session: data = block.model_dump(to_orm=True, exclude_none=True) block = BlockModel(**data, organization_id=actor.organization_id) block.create(session, actor=actor) @@ -53,7 +48,7 @@ class BlockManager: if not blocks: return [] - with self.session_maker() as session: + with db_registry.session() as session: block_models = [ BlockModel(**block.model_dump(to_orm=True, exclude_none=True), organization_id=actor.organization_id) for block in blocks ] @@ -68,7 +63,7 @@ class BlockManager: """Update a block by its ID with the given BlockUpdate object.""" # Safety check for block - with self.session_maker() as session: + with db_registry.session() as session: block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) @@ -81,7 +76,7 @@ class BlockManager: @enforce_types def delete_block(self, block_id: str, actor: PydanticUser) -> PydanticBlock: """Delete a block by its ID.""" - with self.session_maker() as session: + with db_registry.session() as session: block = BlockModel.read(db_session=session, identifier=block_id) block.hard_delete(db_session=session, actor=actor) return block.to_pydantic() @@ -100,7 +95,7 @@ class BlockManager: limit: Optional[int] = 50, ) -> List[PydanticBlock]: """Retrieve blocks based on various optional filters.""" - with self.session_maker() as session: + with db_registry.session() as session: # Prepare filters filters = {"organization_id": actor.organization_id} if label: @@ -126,7 +121,7 @@ class BlockManager: @enforce_types def get_block_by_id(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]: """Retrieve a block by its name.""" - with self.session_maker() as session: + with db_registry.session() as session: try: block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) return block.to_pydantic() @@ -136,12 +131,24 @@ class BlockManager: @enforce_types def get_all_blocks_by_ids(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]: """Retrieve blocks by their ids.""" - with self.session_maker() as session: + with db_registry.session() as session: blocks = [block.to_pydantic() for block in BlockModel.read_multiple(db_session=session, identifiers=block_ids, actor=actor)] # backwards compatibility. previous implementation added None for every block not found. blocks.extend([None for _ in range(len(block_ids) - len(blocks))]) return blocks + @enforce_types + async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]: + """Retrieve blocks by their ids. Async implementation.""" + async with db_registry.async_session() as session: + blocks = [ + block.to_pydantic() + for block in await BlockModel.read_multiple_async(db_session=session, identifiers=block_ids, actor=actor) + ] + # backwards compatibility. previous implementation added None for every block not found. + blocks.extend([None for _ in range(len(block_ids) - len(blocks))]) + return blocks + @enforce_types def add_default_blocks(self, actor: PydanticUser): for persona_file in list_persona_files(): @@ -161,7 +168,7 @@ class BlockManager: """ Retrieve all agents associated with a given block. """ - with self.session_maker() as session: + with db_registry.session() as session: block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) agents_orm = block.agents agents_pydantic = [agent.to_pydantic() for agent in agents_orm] @@ -176,7 +183,7 @@ class BlockManager: """ Get the total count of blocks for the given user. """ - with self.session_maker() as session: + with db_registry.session() as session: return BlockModel.size(db_session=session, actor=actor) # Block History Functions @@ -199,7 +206,7 @@ class BlockManager: strictly linear history. - A single commit at the end ensures atomicity. """ - with self.session_maker() as session: + with db_registry.session() as session: # 1) Load the Block if use_preloaded_block is not None: block = session.merge(use_preloaded_block) @@ -291,7 +298,7 @@ class BlockManager: If older sequences have been pruned, we jump to the largest sequence number that is still < current_seq. """ - with self.session_maker() as session: + with db_registry.session() as session: # 1) Load the current block block = ( session.merge(use_preloaded_block) @@ -333,7 +340,7 @@ class BlockManager: If some middle checkpoints have been pruned, we jump to the smallest sequence > current_seq that remains. """ - with self.session_maker() as session: + with db_registry.session() as session: block = ( session.merge(use_preloaded_block) if use_preloaded_block @@ -383,7 +390,7 @@ class BlockManager: NoResultFound if any block_id doesn’t exist or isn’t visible to this actor ValueError if any new value exceeds its block’s limit """ - with self.session_maker() as session: + with db_registry.session() as session: q = session.query(BlockModel).filter(BlockModel.id.in_(updates.keys()), BlockModel.organization_id == actor.organization_id) blocks = q.all() diff --git a/letta/services/group_manager.py b/letta/services/group_manager.py index 6c49cfcb..7adf49ec 100644 --- a/letta/services/group_manager.py +++ b/letta/services/group_manager.py @@ -11,16 +11,12 @@ from letta.schemas.group import GroupCreate, GroupUpdate, ManagerType from letta.schemas.letta_message import LettaMessage from letta.schemas.message import Message as PydanticMessage from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry from letta.utils import enforce_types class GroupManager: - def __init__(self): - from letta.server.db import db_context - - self.session_maker = db_context - @enforce_types def list_groups( self, @@ -31,7 +27,7 @@ class GroupManager: after: Optional[str] = None, limit: Optional[int] = 50, ) -> list[PydanticGroup]: - with self.session_maker() as session: + with db_registry.session() as session: filters = {"organization_id": actor.organization_id} if project_id: filters["project_id"] = project_id @@ -48,13 +44,13 @@ class GroupManager: @enforce_types def retrieve_group(self, group_id: str, actor: PydanticUser) -> PydanticGroup: - with self.session_maker() as session: + with db_registry.session() as session: group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) return group.to_pydantic() @enforce_types def create_group(self, group: GroupCreate, actor: PydanticUser) -> PydanticGroup: - with self.session_maker() as session: + with db_registry.session() as session: new_group = GroupModel() new_group.organization_id = actor.organization_id new_group.description = group.description @@ -99,7 +95,7 @@ class GroupManager: @enforce_types def modify_group(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup: - with self.session_maker() as session: + with db_registry.session() as session: group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) sleeptime_agent_frequency = None @@ -161,7 +157,7 @@ class GroupManager: @enforce_types def delete_group(self, group_id: str, actor: PydanticUser) -> None: - with self.session_maker() as session: + with db_registry.session() as session: # Retrieve the agent group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) group.hard_delete(session) @@ -178,7 +174,7 @@ class GroupManager: assistant_message_tool_name: str = "send_message", assistant_message_tool_kwarg: str = "message", ) -> list[LettaMessage]: - with self.session_maker() as session: + with db_registry.session() as session: filters = { "organization_id": actor.organization_id, "group_id": group_id, @@ -204,7 +200,7 @@ class GroupManager: @enforce_types def reset_messages(self, group_id: str, actor: PydanticUser) -> None: - with self.session_maker() as session: + with db_registry.session() as session: # Ensure group is loadable by user group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) @@ -217,7 +213,7 @@ class GroupManager: @enforce_types def bump_turns_counter(self, group_id: str, actor: PydanticUser) -> int: - with self.session_maker() as session: + with db_registry.session() as session: # Ensure group is loadable by user group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) @@ -228,7 +224,7 @@ class GroupManager: @enforce_types def get_last_processed_message_id_and_update(self, group_id: str, last_processed_message_id: str, actor: PydanticUser) -> str: - with self.session_maker() as session: + with db_registry.session() as session: # Ensure group is loadable by user group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) @@ -247,7 +243,7 @@ class GroupManager: """ Get the total count of groups for the given user. """ - with self.session_maker() as session: + with db_registry.session() as session: return GroupModel.size(db_session=session, actor=actor) def _process_agent_relationship(self, session: Session, group: GroupModel, agent_ids: List[str], allow_partial=False, replace=True): diff --git a/letta/services/identity_manager.py b/letta/services/identity_manager.py index 798b01a0..3ca05793 100644 --- a/letta/services/identity_manager.py +++ b/letta/services/identity_manager.py @@ -10,16 +10,12 @@ from letta.orm.identity import Identity as IdentityModel from letta.schemas.identity import Identity as PydanticIdentity from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityType, IdentityUpdate, IdentityUpsert from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry from letta.utils import enforce_types class IdentityManager: - def __init__(self): - from letta.server.db import db_context - - self.session_maker = db_context - @enforce_types def list_identities( self, @@ -32,7 +28,7 @@ class IdentityManager: limit: Optional[int] = 50, actor: PydanticUser = None, ) -> list[PydanticIdentity]: - with self.session_maker() as session: + with db_registry.session() as session: filters = {"organization_id": actor.organization_id} if project_id: filters["project_id"] = project_id @@ -52,13 +48,13 @@ class IdentityManager: @enforce_types def get_identity(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity: - with self.session_maker() as session: + with db_registry.session() as session: identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor) return identity.to_pydantic() @enforce_types def create_identity(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity: - with self.session_maker() as session: + with db_registry.session() as session: new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids", "block_ids"}, exclude_unset=True)) new_identity.organization_id = actor.organization_id self._process_relationship( @@ -82,7 +78,7 @@ class IdentityManager: @enforce_types def upsert_identity(self, identity: IdentityUpsert, actor: PydanticUser) -> PydanticIdentity: - with self.session_maker() as session: + with db_registry.session() as session: existing_identity = IdentityModel.read( db_session=session, identifier_key=identity.identifier_key, @@ -107,7 +103,7 @@ class IdentityManager: @enforce_types def update_identity(self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False) -> PydanticIdentity: - with self.session_maker() as session: + with db_registry.session() as session: try: existing_identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor) except NoResultFound: @@ -167,7 +163,7 @@ class IdentityManager: @enforce_types def upsert_identity_properties(self, identity_id: str, properties: List[IdentityProperty], actor: PydanticUser) -> PydanticIdentity: - with self.session_maker() as session: + with db_registry.session() as session: existing_identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor) if existing_identity is None: raise HTTPException(status_code=404, detail="Identity not found") @@ -181,7 +177,7 @@ class IdentityManager: @enforce_types def delete_identity(self, identity_id: str, actor: PydanticUser) -> None: - with self.session_maker() as session: + with db_registry.session() as session: identity = IdentityModel.read(db_session=session, identifier=identity_id) if identity is None: raise HTTPException(status_code=404, detail="Identity not found") @@ -198,7 +194,7 @@ class IdentityManager: """ Get the total count of identities for the given user. """ - with self.session_maker() as session: + with db_registry.session() as session: return IdentityModel.size(db_session=session, actor=actor) def _process_relationship( diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index 153c5fab..74576d4d 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -24,24 +24,19 @@ from letta.schemas.run import Run as PydanticRun from letta.schemas.step import Step as PydanticStep from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry from letta.utils import enforce_types class JobManager: """Manager class to handle business logic related to Jobs.""" - def __init__(self): - # Fetching the db_context similarly as in OrganizationManager - from letta.server.db import db_context - - self.session_maker = db_context - @enforce_types def create_job( self, pydantic_job: Union[PydanticJob, PydanticRun, PydanticBatchJob], actor: PydanticUser ) -> Union[PydanticJob, PydanticRun, PydanticBatchJob]: """Create a new job based on the JobCreate schema.""" - with self.session_maker() as session: + with db_registry.session() as session: # Associate the job with the user pydantic_job.user_id = actor.id job_data = pydantic_job.model_dump(to_orm=True) @@ -52,7 +47,7 @@ class JobManager: @enforce_types def update_job_by_id(self, job_id: str, job_update: JobUpdate, actor: PydanticUser) -> PydanticJob: """Update a job by its ID with the given JobUpdate object.""" - with self.session_maker() as session: + with db_registry.session() as session: # Fetch the job by ID job = self._verify_job_access(session=session, job_id=job_id, actor=actor, access=["write"]) @@ -76,7 +71,7 @@ class JobManager: @enforce_types def get_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob: """Fetch a job by its ID.""" - with self.session_maker() as session: + with db_registry.session() as session: # Retrieve job by ID using the Job model's read method job = JobModel.read(db_session=session, identifier=job_id, actor=actor, access_type=AccessType.USER) return job.to_pydantic() @@ -93,7 +88,7 @@ class JobManager: ascending: bool = True, ) -> List[PydanticJob]: """List all jobs with optional pagination and status filter.""" - with self.session_maker() as session: + with db_registry.session() as session: filter_kwargs = {"user_id": actor.id, "job_type": job_type} # Add status filter if provided @@ -113,7 +108,7 @@ class JobManager: @enforce_types def delete_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob: """Delete a job by its ID.""" - with self.session_maker() as session: + with db_registry.session() as session: job = self._verify_job_access(session=session, job_id=job_id, actor=actor) job.hard_delete(db_session=session, actor=actor) return job.to_pydantic() @@ -147,7 +142,7 @@ class JobManager: Raises: NoResultFound: If the job does not exist or user does not have access """ - with self.session_maker() as session: + with db_registry.session() as session: # Build filters filters = {} if role is not None: @@ -195,7 +190,7 @@ class JobManager: Raises: NoResultFound: If the job does not exist or user does not have access """ - with self.session_maker() as session: + with db_registry.session() as session: # Build filters filters = {} filters["job_id"] = job_id @@ -227,7 +222,7 @@ class JobManager: Raises: NoResultFound: If the job does not exist or user does not have access """ - with self.session_maker() as session: + with db_registry.session() as session: # First verify job exists and user has access self._verify_job_access(session, job_id, actor, access=["write"]) @@ -251,7 +246,7 @@ class JobManager: Raises: NoResultFound: If the job does not exist or user does not have access """ - with self.session_maker() as session: + with db_registry.session() as session: # First verify job exists and user has access self._verify_job_access(session, job_id, actor) @@ -293,7 +288,7 @@ class JobManager: Raises: NoResultFound: If the job does not exist or user does not have access """ - with self.session_maker() as session: + with db_registry.session() as session: # First verify job exists and user has access self._verify_job_access(session, job_id, actor, access=["write"]) @@ -453,7 +448,7 @@ class JobManager: Returns: The request config for the job """ - with self.session_maker() as session: + with db_registry.session() as session: job = session.query(JobModel).filter(JobModel.id == run_id).first() request_config = job.request_config or LettaRequestConfig() return request_config diff --git a/letta/services/llm_batch_manager.py b/letta/services/llm_batch_manager.py index caebaaf0..7d7b4b54 100644 --- a/letta/services/llm_batch_manager.py +++ b/letta/services/llm_batch_manager.py @@ -16,6 +16,7 @@ from letta.schemas.llm_batch_job import LLMBatchJob as PydanticLLMBatchJob from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry from letta.utils import enforce_types logger = get_logger(__name__) @@ -24,11 +25,6 @@ logger = get_logger(__name__) class LLMBatchManager: """Manager for handling both LLMBatchJob and LLMBatchItem operations.""" - def __init__(self): - from letta.server.db import db_context - - self.session_maker = db_context - @enforce_types def create_llm_batch_job( self, @@ -39,7 +35,7 @@ class LLMBatchManager: status: JobStatus = JobStatus.created, ) -> PydanticLLMBatchJob: """Create a new LLM batch job.""" - with self.session_maker() as session: + with db_registry.session() as session: batch = LLMBatchJob( status=status, llm_provider=llm_provider, @@ -53,7 +49,7 @@ class LLMBatchManager: @enforce_types def get_llm_batch_job_by_id(self, llm_batch_id: str, actor: Optional[PydanticUser] = None) -> PydanticLLMBatchJob: """Retrieve a single batch job by ID.""" - with self.session_maker() as session: + with db_registry.session() as session: batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor) return batch.to_pydantic() @@ -66,7 +62,7 @@ class LLMBatchManager: latest_polling_response: Optional[BetaMessageBatch] = None, ) -> PydanticLLMBatchJob: """Update a batch job’s status and optionally its polling response.""" - with self.session_maker() as session: + with db_registry.session() as session: batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor) batch.status = status batch.latest_polling_response = latest_polling_response @@ -85,7 +81,7 @@ class LLMBatchManager: """ now = datetime.datetime.now(datetime.timezone.utc) - with self.session_maker() as session: + with db_registry.session() as session: mappings = [] for llm_batch_id, status, response in updates: mappings.append( @@ -119,7 +115,7 @@ class LLMBatchManager: The results are ordered by their id in ascending order. """ - with self.session_maker() as session: + with db_registry.session() as session: query = session.query(LLMBatchJob).filter(LLMBatchJob.letta_batch_job_id == letta_batch_id) if actor is not None: @@ -140,7 +136,7 @@ class LLMBatchManager: @enforce_types def delete_llm_batch_request(self, llm_batch_id: str, actor: PydanticUser) -> None: """Hard delete a batch job by ID.""" - with self.session_maker() as session: + with db_registry.session() as session: batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor) batch.hard_delete(db_session=session, actor=actor) @@ -158,7 +154,7 @@ class LLMBatchManager: Retrieve messages across all LLM batch jobs associated with a Letta batch job. Optimized for PostgreSQL performance using ID-based keyset pagination. """ - with self.session_maker() as session: + with db_registry.session() as session: # If cursor is provided, get sequence_id for that message cursor_sequence_id = None if cursor: @@ -203,7 +199,7 @@ class LLMBatchManager: @enforce_types def list_running_llm_batches(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]: """Return all running LLM batch jobs, optionally filtered by actor's organization.""" - with self.session_maker() as session: + with db_registry.session() as session: query = session.query(LLMBatchJob).filter(LLMBatchJob.status == JobStatus.running) if actor is not None: @@ -224,7 +220,7 @@ class LLMBatchManager: step_state: Optional[AgentStepState] = None, ) -> PydanticLLMBatchItem: """Create a new batch item.""" - with self.session_maker() as session: + with db_registry.session() as session: item = LLMBatchItem( llm_batch_id=llm_batch_id, agent_id=agent_id, @@ -249,7 +245,7 @@ class LLMBatchManager: Returns: List of created batch items as Pydantic models """ - with self.session_maker() as session: + with db_registry.session() as session: # Convert Pydantic models to ORM objects orm_items = [] for item in llm_batch_items: @@ -274,7 +270,7 @@ class LLMBatchManager: @enforce_types def get_llm_batch_item_by_id(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem: """Retrieve a single batch item by ID.""" - with self.session_maker() as session: + with db_registry.session() as session: item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor) return item.to_pydantic() @@ -289,7 +285,7 @@ class LLMBatchManager: step_state: Optional[AgentStepState] = None, ) -> PydanticLLMBatchItem: """Update fields on a batch item.""" - with self.session_maker() as session: + with db_registry.session() as session: item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor) if request_status: @@ -325,7 +321,7 @@ class LLMBatchManager: The results are ordered by their id in ascending order. """ - with self.session_maker() as session: + with db_registry.session() as session: query = session.query(LLMBatchItem).filter(LLMBatchItem.llm_batch_id == llm_batch_id) if actor is not None: @@ -367,7 +363,7 @@ class LLMBatchManager: if len(llm_batch_id_agent_id_pairs) != len(field_updates): raise ValueError("llm_batch_id_agent_id_pairs and field_updates must have the same length") - with self.session_maker() as session: + with db_registry.session() as session: # Lookup primary keys for all requested (batch_id, agent_id) pairs items = ( session.query(LLMBatchItem.id, LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id) @@ -434,7 +430,7 @@ class LLMBatchManager: @enforce_types def delete_llm_batch_item(self, item_id: str, actor: PydanticUser) -> None: """Hard delete a batch item by ID.""" - with self.session_maker() as session: + with db_registry.session() as session: item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor) item.hard_delete(db_session=session, actor=actor) @@ -449,6 +445,6 @@ class LLMBatchManager: Returns: int: The total number of batch items associated with the given llm_batch_id. """ - with self.session_maker() as session: + with db_registry.session() as session: count = session.query(func.count(LLMBatchItem.id)).filter(LLMBatchItem.llm_batch_id == llm_batch_id).scalar() return count or 0 diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index e87f9917..c6ca4579 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -12,6 +12,7 @@ from letta.schemas.letta_message import LettaMessageUpdateUnion from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import MessageUpdate from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry from letta.utils import enforce_types logger = get_logger(__name__) @@ -20,15 +21,10 @@ logger = get_logger(__name__) class MessageManager: """Manager class to handle business logic related to Messages.""" - def __init__(self): - from letta.server.db import db_context - - self.session_maker = db_context - @enforce_types def get_message_by_id(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]: """Fetch a message by ID.""" - with self.session_maker() as session: + with db_registry.session() as session: try: message = MessageModel.read(db_session=session, identifier=message_id, actor=actor) return message.to_pydantic() @@ -38,7 +34,7 @@ class MessageManager: @enforce_types def get_messages_by_ids(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]: """Fetch messages by ID and return them in the requested order.""" - with self.session_maker() as session: + with db_registry.session() as session: results = MessageModel.list(db_session=session, id=message_ids, organization_id=actor.organization_id, limit=len(message_ids)) if len(results) != len(message_ids): @@ -53,7 +49,7 @@ class MessageManager: @enforce_types def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage: """Create a new message.""" - with self.session_maker() as session: + with db_registry.session() as session: # Set the organization id of the Pydantic message pydantic_msg.organization_id = actor.organization_id msg_data = pydantic_msg.model_dump(to_orm=True) @@ -86,7 +82,7 @@ class MessageManager: orm_messages.append(MessageModel(**msg_data)) # Use the batch_create method for efficient creation - with self.session_maker() as session: + with db_registry.session() as session: created_messages = MessageModel.batch_create(orm_messages, session, actor=actor) # Convert back to Pydantic models @@ -173,7 +169,7 @@ class MessageManager: """ Updates an existing record in the database with values from the provided record object. """ - with self.session_maker() as session: + with db_registry.session() as session: # Fetch existing message from database message = MessageModel.read( db_session=session, @@ -181,31 +177,57 @@ class MessageManager: actor=actor, ) - # Some safety checks specific to messages - if message_update.tool_calls and message.role != MessageRole.assistant: - raise ValueError( - f"Tool calls {message_update.tool_calls} can only be added to assistant messages. Message {message_id} has role {message.role}." - ) - if message_update.tool_call_id and message.role != MessageRole.tool: - raise ValueError( - f"Tool call IDs {message_update.tool_call_id} can only be added to tool messages. Message {message_id} has role {message.role}." - ) - - # get update dictionary - update_data = message_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) - # Remove redundant update fields - update_data = {key: value for key, value in update_data.items() if getattr(message, key) != value} - - for key, value in update_data.items(): - setattr(message, key, value) + message = self._update_message_by_id_impl(message_id, message_update, actor, message) message.update(db_session=session, actor=actor) - return message.to_pydantic() + @enforce_types + async def update_message_by_id_async(self, message_id: str, message_update: MessageUpdate, actor: PydanticUser) -> PydanticMessage: + """ + Updates an existing record in the database with values from the provided record object. + Async version of the function above. + """ + async with db_registry.async_session() as session: + # Fetch existing message from database + message = await MessageModel.read_async( + db_session=session, + identifier=message_id, + actor=actor, + ) + + message = self._update_message_by_id_impl(message_id, message_update, actor, message) + await message.update_async(db_session=session, actor=actor) + return message.to_pydantic() + + def _update_message_by_id_impl( + self, message_id: str, message_update: MessageUpdate, actor: PydanticUser, message: MessageModel + ) -> MessageModel: + """ + Modifies the existing message object to update the database in the sync/async functions. + """ + # Some safety checks specific to messages + if message_update.tool_calls and message.role != MessageRole.assistant: + raise ValueError( + f"Tool calls {message_update.tool_calls} can only be added to assistant messages. Message {message_id} has role {message.role}." + ) + if message_update.tool_call_id and message.role != MessageRole.tool: + raise ValueError( + f"Tool call IDs {message_update.tool_call_id} can only be added to tool messages. Message {message_id} has role {message.role}." + ) + + # get update dictionary + update_data = message_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + # Remove redundant update fields + update_data = {key: value for key, value in update_data.items() if getattr(message, key) != value} + + for key, value in update_data.items(): + setattr(message, key, value) + return message + @enforce_types def delete_message_by_id(self, message_id: str, actor: PydanticUser) -> bool: """Delete a message.""" - with self.session_maker() as session: + with db_registry.session() as session: try: msg = MessageModel.read( db_session=session, @@ -229,7 +251,7 @@ class MessageManager: actor: The user requesting the count role: The role of the message """ - with self.session_maker() as session: + with db_registry.session() as session: return MessageModel.size(db_session=session, actor=actor, role=role, agent_id=agent_id) @enforce_types @@ -293,7 +315,7 @@ class MessageManager: NoResultFound: If the provided after/before message IDs do not exist. """ - with self.session_maker() as session: + with db_registry.session() as session: # Permission check: raise if the agent doesn't exist or actor is not allowed. AgentModel.read(db_session=session, identifier=agent_id, actor=actor) @@ -356,7 +378,7 @@ class MessageManager: Efficiently deletes all messages associated with a given agent_id, while enforcing permission checks and avoiding any ORM‑level loads. """ - with self.session_maker() as session: + with db_registry.session() as session: # 1) verify the agent exists and the actor has access AgentModel.read(db_session=session, identifier=agent_id, actor=actor) diff --git a/letta/services/organization_manager.py b/letta/services/organization_manager.py index 1d0b637d..00a52833 100644 --- a/letta/services/organization_manager.py +++ b/letta/services/organization_manager.py @@ -4,6 +4,7 @@ from letta.orm.errors import NoResultFound from letta.orm.organization import Organization as OrganizationModel from letta.schemas.organization import Organization as PydanticOrganization from letta.schemas.organization import OrganizationUpdate +from letta.server.db import db_registry from letta.utils import enforce_types @@ -13,14 +14,6 @@ class OrganizationManager: DEFAULT_ORG_ID = "org-00000000-0000-4000-8000-000000000000" DEFAULT_ORG_NAME = "default_org" - def __init__(self): - # TODO: Please refactor this out - # I am currently working on a ORM refactor and would like to make a more minimal set of changes - # - Matt - from letta.server.db import db_context - - self.session_maker = db_context - @enforce_types def get_default_organization(self) -> PydanticOrganization: """Fetch the default organization.""" @@ -29,7 +22,7 @@ class OrganizationManager: @enforce_types def get_organization_by_id(self, org_id: str) -> Optional[PydanticOrganization]: """Fetch an organization by ID.""" - with self.session_maker() as session: + with db_registry.session() as session: organization = OrganizationModel.read(db_session=session, identifier=org_id) return organization.to_pydantic() @@ -44,7 +37,7 @@ class OrganizationManager: @enforce_types def _create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization: - with self.session_maker() as session: + with db_registry.session() as session: org = OrganizationModel(**pydantic_org.model_dump(to_orm=True)) org.create(session) return org.to_pydantic() @@ -57,7 +50,7 @@ class OrganizationManager: @enforce_types def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization: """Update an organization.""" - with self.session_maker() as session: + with db_registry.session() as session: org = OrganizationModel.read(db_session=session, identifier=org_id) if name: org.name = name @@ -67,7 +60,7 @@ class OrganizationManager: @enforce_types def update_organization(self, org_id: str, org_update: OrganizationUpdate) -> PydanticOrganization: """Update an organization.""" - with self.session_maker() as session: + with db_registry.session() as session: org = OrganizationModel.read(db_session=session, identifier=org_id) if org_update.name: org.name = org_update.name @@ -79,14 +72,14 @@ class OrganizationManager: @enforce_types def delete_organization_by_id(self, org_id: str): """Delete an organization by marking it as deleted.""" - with self.session_maker() as session: + with db_registry.session() as session: organization = OrganizationModel.read(db_session=session, identifier=org_id) organization.hard_delete(session) @enforce_types def list_organizations(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]: """List all organizations with optional pagination.""" - with self.session_maker() as session: + with db_registry.session() as session: organizations = OrganizationModel.list( db_session=session, after=after, diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index b891657e..8d735d9b 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -10,21 +10,17 @@ from letta.orm.passage import AgentPassage, SourcePassage from letta.schemas.agent import AgentState from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry from letta.utils import enforce_types class PassageManager: """Manager class to handle business logic related to Passages.""" - def __init__(self): - from letta.server.db import db_context - - self.session_maker = db_context - @enforce_types def get_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]: """Fetch a passage by ID.""" - with self.session_maker() as session: + with db_registry.session() as session: # Try source passages first try: passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor) @@ -69,7 +65,7 @@ class PassageManager: else: raise ValueError("Passage must have either agent_id or source_id") - with self.session_maker() as session: + with db_registry.session() as session: passage.create(session, actor=actor) return passage.to_pydantic() @@ -145,7 +141,7 @@ class PassageManager: if not passage_id: raise ValueError("Passage ID must be provided.") - with self.session_maker() as session: + with db_registry.session() as session: # Try source passages first try: curr_passage = SourcePassage.read( @@ -179,7 +175,7 @@ class PassageManager: if not passage_id: raise ValueError("Passage ID must be provided.") - with self.session_maker() as session: + with db_registry.session() as session: # Try source passages first try: passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor) @@ -217,7 +213,7 @@ class PassageManager: actor: The user requesting the count agent_id: The agent ID of the messages """ - with self.session_maker() as session: + with db_registry.session() as session: return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id) def estimate_embeddings_size( diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index e77a3f2f..9bb4a817 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -5,20 +5,16 @@ from letta.schemas.enums import ProviderCategory, ProviderType from letta.schemas.providers import Provider as PydanticProvider from letta.schemas.providers import ProviderCheck, ProviderCreate, ProviderUpdate from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry from letta.utils import enforce_types class ProviderManager: - def __init__(self): - from letta.server.db import db_context - - self.session_maker = db_context - @enforce_types def create_provider(self, request: ProviderCreate, actor: PydanticUser) -> PydanticProvider: """Create a new provider if it doesn't already exist.""" - with self.session_maker() as session: + with db_registry.session() as session: provider_create_args = {**request.model_dump(), "provider_category": ProviderCategory.byok} provider = PydanticProvider(**provider_create_args) @@ -38,7 +34,7 @@ class ProviderManager: @enforce_types def update_provider(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider: """Update provider details.""" - with self.session_maker() as session: + with db_registry.session() as session: # Retrieve the existing provider by ID existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor) @@ -54,7 +50,7 @@ class ProviderManager: @enforce_types def delete_provider_by_id(self, provider_id: str, actor: PydanticUser): """Delete a provider.""" - with self.session_maker() as session: + with db_registry.session() as session: # Clear api key field existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor) existing_provider.api_key = None @@ -80,7 +76,7 @@ class ProviderManager: filter_kwargs["name"] = name if provider_type: filter_kwargs["provider_type"] = provider_type - with self.session_maker() as session: + with db_registry.session() as session: providers = ProviderModel.list( db_session=session, after=after, diff --git a/letta/services/sandbox_config_manager.py b/letta/services/sandbox_config_manager.py index 9feaf2a0..5b25b25e 100644 --- a/letta/services/sandbox_config_manager.py +++ b/letta/services/sandbox_config_manager.py @@ -11,6 +11,7 @@ from letta.schemas.sandbox_config import LocalSandboxConfig from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate, SandboxType from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry from letta.utils import enforce_types, printd logger = get_logger(__name__) @@ -19,11 +20,6 @@ logger = get_logger(__name__) class SandboxConfigManager: """Manager class to handle business logic related to SandboxConfig and SandboxEnvironmentVariable.""" - def __init__(self): - from letta.server.db import db_context - - self.session_maker = db_context - @enforce_types def get_or_create_default_sandbox_config(self, sandbox_type: SandboxType, actor: PydanticUser) -> PydanticSandboxConfig: sandbox_config = self.get_sandbox_config_by_type(sandbox_type, actor=actor) @@ -69,7 +65,7 @@ class SandboxConfigManager: return db_sandbox else: # If the sandbox configuration doesn't exist, create a new one - with self.session_maker() as session: + with db_registry.session() as session: db_sandbox = SandboxConfigModel(**sandbox_config.model_dump(exclude_none=True)) db_sandbox.create(session, actor=actor) return db_sandbox.to_pydantic() @@ -79,7 +75,7 @@ class SandboxConfigManager: self, sandbox_config_id: str, sandbox_update: SandboxConfigUpdate, actor: PydanticUser ) -> PydanticSandboxConfig: """Update an existing sandbox configuration.""" - with self.session_maker() as session: + with db_registry.session() as session: sandbox = SandboxConfigModel.read(db_session=session, identifier=sandbox_config_id, actor=actor) # We need to check that the sandbox_update provided is the same type as the original sandbox if sandbox.type != sandbox_update.config.type: @@ -104,7 +100,7 @@ class SandboxConfigManager: @enforce_types def delete_sandbox_config(self, sandbox_config_id: str, actor: PydanticUser) -> PydanticSandboxConfig: """Delete a sandbox configuration by its ID.""" - with self.session_maker() as session: + with db_registry.session() as session: sandbox = SandboxConfigModel.read(db_session=session, identifier=sandbox_config_id, actor=actor) sandbox.hard_delete(db_session=session, actor=actor) return sandbox.to_pydantic() @@ -122,14 +118,14 @@ class SandboxConfigManager: if sandbox_type: kwargs.update({"type": sandbox_type}) - with self.session_maker() as session: + with db_registry.session() as session: sandboxes = SandboxConfigModel.list(db_session=session, after=after, limit=limit, **kwargs) return [sandbox.to_pydantic() for sandbox in sandboxes] @enforce_types def get_sandbox_config_by_id(self, sandbox_config_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSandboxConfig]: """Retrieve a sandbox configuration by its ID.""" - with self.session_maker() as session: + with db_registry.session() as session: try: sandbox = SandboxConfigModel.read(db_session=session, identifier=sandbox_config_id, actor=actor) return sandbox.to_pydantic() @@ -139,7 +135,7 @@ class SandboxConfigManager: @enforce_types def get_sandbox_config_by_type(self, type: SandboxType, actor: Optional[PydanticUser] = None) -> Optional[PydanticSandboxConfig]: """Retrieve a sandbox config by its type.""" - with self.session_maker() as session: + with db_registry.session() as session: try: sandboxes = SandboxConfigModel.list( db_session=session, @@ -175,7 +171,7 @@ class SandboxConfigManager: return db_env_var else: - with self.session_maker() as session: + with db_registry.session() as session: env_var = SandboxEnvVarModel(**env_var.model_dump(to_orm=True, exclude_none=True)) env_var.create(session, actor=actor) return env_var.to_pydantic() @@ -185,7 +181,7 @@ class SandboxConfigManager: self, env_var_id: str, env_var_update: SandboxEnvironmentVariableUpdate, actor: PydanticUser ) -> PydanticEnvVar: """Update an existing sandbox environment variable.""" - with self.session_maker() as session: + with db_registry.session() as session: env_var = SandboxEnvVarModel.read(db_session=session, identifier=env_var_id, actor=actor) update_data = env_var_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) update_data = {key: value for key, value in update_data.items() if getattr(env_var, key) != value} @@ -204,7 +200,7 @@ class SandboxConfigManager: @enforce_types def delete_sandbox_env_var(self, env_var_id: str, actor: PydanticUser) -> PydanticEnvVar: """Delete a sandbox environment variable by its ID.""" - with self.session_maker() as session: + with db_registry.session() as session: env_var = SandboxEnvVarModel.read(db_session=session, identifier=env_var_id, actor=actor) env_var.hard_delete(db_session=session, actor=actor) return env_var.to_pydantic() @@ -218,7 +214,7 @@ class SandboxConfigManager: limit: Optional[int] = 50, ) -> List[PydanticEnvVar]: """List all sandbox environment variables with optional pagination.""" - with self.session_maker() as session: + with db_registry.session() as session: env_vars = SandboxEnvVarModel.list( db_session=session, after=after, @@ -233,7 +229,7 @@ class SandboxConfigManager: self, key: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50 ) -> List[PydanticEnvVar]: """List all sandbox environment variables with optional pagination.""" - with self.session_maker() as session: + with db_registry.session() as session: env_vars = SandboxEnvVarModel.list( db_session=session, after=after, @@ -258,7 +254,7 @@ class SandboxConfigManager: self, key: str, sandbox_config_id: str, actor: Optional[PydanticUser] = None ) -> Optional[PydanticEnvVar]: """Retrieve a sandbox environment variable by its key and sandbox_config_id.""" - with self.session_maker() as session: + with db_registry.session() as session: try: env_var = SandboxEnvVarModel.list( db_session=session, diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index c872f490..6247967c 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -8,17 +8,13 @@ from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.source import Source as PydanticSource from letta.schemas.source import SourceUpdate from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry from letta.utils import enforce_types, printd class SourceManager: """Manager class to handle business logic related to Sources.""" - def __init__(self): - from letta.server.db import db_context - - self.session_maker = db_context - @enforce_types def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource: """Create a new source based on the PydanticSource schema.""" @@ -27,7 +23,7 @@ class SourceManager: if db_source: return db_source else: - with self.session_maker() as session: + with db_registry.session() as session: # Provide default embedding config if not given source.organization_id = actor.organization_id source = SourceModel(**source.model_dump(to_orm=True, exclude_none=True)) @@ -37,7 +33,7 @@ class SourceManager: @enforce_types def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource: """Update a source by its ID with the given SourceUpdate object.""" - with self.session_maker() as session: + with db_registry.session() as session: source = SourceModel.read(db_session=session, identifier=source_id, actor=actor) # get update dictionary @@ -59,7 +55,7 @@ class SourceManager: @enforce_types def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource: """Delete a source by its ID.""" - with self.session_maker() as session: + with db_registry.session() as session: source = SourceModel.read(db_session=session, identifier=source_id) source.hard_delete(db_session=session, actor=actor) return source.to_pydantic() @@ -67,7 +63,7 @@ class SourceManager: @enforce_types def list_sources(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, **kwargs) -> List[PydanticSource]: """List all sources with optional pagination.""" - with self.session_maker() as session: + with db_registry.session() as session: sources = SourceModel.list( db_session=session, after=after, @@ -85,7 +81,7 @@ class SourceManager: """ Get the total count of sources for the given user. """ - with self.session_maker() as session: + with db_registry.session() as session: return SourceModel.size(db_session=session, actor=actor) @enforce_types @@ -100,7 +96,7 @@ class SourceManager: Returns: List[PydanticAgentState]: List of agents that have this source attached """ - with self.session_maker() as session: + with db_registry.session() as session: # Verify source exists and user has permission to access it source = SourceModel.read(db_session=session, identifier=source_id, actor=actor) @@ -112,7 +108,7 @@ class SourceManager: @enforce_types def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]: """Retrieve a source by its ID.""" - with self.session_maker() as session: + with db_registry.session() as session: try: source = SourceModel.read(db_session=session, identifier=source_id, actor=actor) return source.to_pydantic() @@ -122,7 +118,7 @@ class SourceManager: @enforce_types def get_source_by_name(self, source_name: str, actor: PydanticUser) -> Optional[PydanticSource]: """Retrieve a source by its name.""" - with self.session_maker() as session: + with db_registry.session() as session: sources = SourceModel.list( db_session=session, name=source_name, @@ -141,7 +137,7 @@ class SourceManager: if db_file: return db_file else: - with self.session_maker() as session: + with db_registry.session() as session: file_metadata.organization_id = actor.organization_id file_metadata = FileMetadataModel(**file_metadata.model_dump(to_orm=True, exclude_none=True)) file_metadata.create(session, actor=actor) @@ -151,7 +147,7 @@ class SourceManager: @enforce_types def get_file_by_id(self, file_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticFileMetadata]: """Retrieve a file by its ID.""" - with self.session_maker() as session: + with db_registry.session() as session: try: file = FileMetadataModel.read(db_session=session, identifier=file_id, actor=actor) return file.to_pydantic() @@ -163,7 +159,7 @@ class SourceManager: self, source_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50 ) -> List[PydanticFileMetadata]: """List all files with optional pagination.""" - with self.session_maker() as session: + with db_registry.session() as session: files = FileMetadataModel.list( db_session=session, after=after, limit=limit, organization_id=actor.organization_id, source_id=source_id ) @@ -172,7 +168,7 @@ class SourceManager: @enforce_types def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata: """Delete a file by its ID.""" - with self.session_maker() as session: + with db_registry.session() as session: file = FileMetadataModel.read(db_session=session, identifier=file_id) file.hard_delete(db_session=session, actor=actor) return file.to_pydantic() diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index fc5ed3cf..cf34915d 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -11,17 +11,13 @@ from letta.orm.step import Step as StepModel from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.step import Step as PydanticStep from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry from letta.tracing import get_trace_id from letta.utils import enforce_types class StepManager: - def __init__(self): - from letta.server.db import db_context - - self.session_maker = db_context - @enforce_types def list_steps( self, @@ -36,7 +32,7 @@ class StepManager: agent_id: Optional[str] = None, ) -> List[PydanticStep]: """List all jobs with optional pagination and status filter.""" - with self.session_maker() as session: + with db_registry.session() as session: filter_kwargs = {"organization_id": actor.organization_id} if model: filter_kwargs["model"] = model @@ -85,7 +81,7 @@ class StepManager: "tid": None, "trace_id": get_trace_id(), # Get the current trace ID } - with self.session_maker() as session: + with db_registry.session() as session: if job_id: self._verify_job_access(session, job_id, actor, access=["write"]) new_step = StepModel(**step_data) @@ -94,7 +90,7 @@ class StepManager: @enforce_types def get_step(self, step_id: str, actor: PydanticUser) -> PydanticStep: - with self.session_maker() as session: + with db_registry.session() as session: step = StepModel.read(db_session=session, identifier=step_id, actor=actor) return step.to_pydantic() @@ -113,7 +109,7 @@ class StepManager: Raises: NoResultFound: If the step does not exist """ - with self.session_maker() as session: + with db_registry.session() as session: step = session.get(StepModel, step_id) if not step: raise NoResultFound(f"Step with id {step_id} does not exist") diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 6b877db4..5b0cff89 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -23,6 +23,7 @@ from letta.orm.tool import Tool as ToolModel from letta.schemas.tool import Tool as PydanticTool from letta.schemas.tool import ToolCreate, ToolUpdate from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry from letta.utils import enforce_types, printd logger = get_logger(__name__) @@ -31,12 +32,6 @@ logger = get_logger(__name__) class ToolManager: """Manager class to handle business logic related to Tools.""" - def __init__(self): - # Fetching the db_context similarly as in OrganizationManager - from letta.server.db import db_context - - self.session_maker = db_context - # TODO: Refactor this across the codebase to use CreateTool instead of passing in a Tool object @enforce_types def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: @@ -89,7 +84,7 @@ class ToolManager: @enforce_types def create_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: """Create a new tool based on the ToolCreate schema.""" - with self.session_maker() as session: + with db_registry.session() as session: # Set the organization id at the ORM layer pydantic_tool.organization_id = actor.organization_id # Auto-generate description if not provided @@ -104,7 +99,7 @@ class ToolManager: @enforce_types def get_tool_by_id(self, tool_id: str, actor: PydanticUser) -> PydanticTool: """Fetch a tool by its ID.""" - with self.session_maker() as session: + with db_registry.session() as session: # Retrieve tool by id using the Tool model's read method tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor) # Convert the SQLAlchemy Tool object to PydanticTool @@ -114,7 +109,7 @@ class ToolManager: def get_tool_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[PydanticTool]: """Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool.""" try: - with self.session_maker() as session: + with db_registry.session() as session: tool = ToolModel.read(db_session=session, name=tool_name, actor=actor) return tool.to_pydantic() except NoResultFound: @@ -124,7 +119,7 @@ class ToolManager: def get_tool_id_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[str]: """Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool.""" try: - with self.session_maker() as session: + with db_registry.session() as session: tool = ToolModel.read(db_session=session, name=tool_name, actor=actor) return tool.id except NoResultFound: @@ -133,7 +128,7 @@ class ToolManager: @enforce_types def list_tools(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]: """List all tools with optional pagination.""" - with self.session_maker() as session: + with db_registry.session() as session: tools = ToolModel.list( db_session=session, after=after, @@ -166,7 +161,7 @@ class ToolManager: If include_builtin is True, it will also count the built-in tools. """ - with self.session_maker() as session: + with db_registry.session() as session: if include_base_tools: return ToolModel.size(db_session=session, actor=actor) return ToolModel.size(db_session=session, actor=actor, name=LETTA_TOOL_SET) @@ -176,7 +171,7 @@ class ToolManager: self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None ) -> PydanticTool: """Update a tool by its ID with the given ToolUpdate object.""" - with self.session_maker() as session: + with db_registry.session() as session: # Fetch the tool by ID tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor) @@ -202,7 +197,7 @@ class ToolManager: @enforce_types def delete_tool_by_id(self, tool_id: str, actor: PydanticUser) -> None: """Delete a tool by its ID.""" - with self.session_maker() as session: + with db_registry.session() as session: try: tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor) tool.hard_delete(db_session=session, actor=actor) diff --git a/letta/services/user_manager.py b/letta/services/user_manager.py index 50b5939f..9f6a72a5 100644 --- a/letta/services/user_manager.py +++ b/letta/services/user_manager.py @@ -5,6 +5,7 @@ from letta.orm.organization import Organization as OrganizationModel from letta.orm.user import User as UserModel from letta.schemas.user import User as PydanticUser from letta.schemas.user import UserUpdate +from letta.server.db import db_registry from letta.services.organization_manager import OrganizationManager from letta.utils import enforce_types @@ -15,16 +16,10 @@ class UserManager: DEFAULT_USER_NAME = "default_user" DEFAULT_USER_ID = "user-00000000-0000-4000-8000-000000000000" - def __init__(self): - # Fetching the db_context similarly as in OrganizationManager - from letta.server.db import db_context - - self.session_maker = db_context - @enforce_types def create_default_user(self, org_id: str = OrganizationManager.DEFAULT_ORG_ID) -> PydanticUser: """Create the default user.""" - with self.session_maker() as session: + with db_registry.session() as session: # Make sure the org id exists try: OrganizationModel.read(db_session=session, identifier=org_id) @@ -44,7 +39,7 @@ class UserManager: @enforce_types def create_user(self, pydantic_user: PydanticUser) -> PydanticUser: """Create a new user if it doesn't already exist.""" - with self.session_maker() as session: + with db_registry.session() as session: new_user = UserModel(**pydantic_user.model_dump(to_orm=True)) new_user.create(session) return new_user.to_pydantic() @@ -52,7 +47,7 @@ class UserManager: @enforce_types def update_user(self, user_update: UserUpdate) -> PydanticUser: """Update user details.""" - with self.session_maker() as session: + with db_registry.session() as session: # Retrieve the existing user by ID existing_user = UserModel.read(db_session=session, identifier=user_update.id) @@ -68,7 +63,7 @@ class UserManager: @enforce_types def delete_user_by_id(self, user_id: str): """Delete a user and their associated records (agents, sources, mappings).""" - with self.session_maker() as session: + with db_registry.session() as session: # Delete from user table user = UserModel.read(db_session=session, identifier=user_id) user.hard_delete(session) @@ -78,7 +73,7 @@ class UserManager: @enforce_types def get_user_by_id(self, user_id: str) -> PydanticUser: """Fetch a user by ID.""" - with self.session_maker() as session: + with db_registry.session() as session: user = UserModel.read(db_session=session, identifier=user_id) return user.to_pydantic() @@ -104,7 +99,7 @@ class UserManager: @enforce_types def list_users(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]: """List all users with optional pagination.""" - with self.session_maker() as session: + with db_registry.session() as session: users = UserModel.list( db_session=session, after=after, diff --git a/letta/settings.py b/letta/settings.py index 1daa2e88..d634b404 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -203,6 +203,7 @@ class Settings(BaseSettings): use_experimental: bool = False use_vertex_structured_outputs_experimental: bool = False use_vertex_async_loop_experimental: bool = False + experimental_enable_async_db_engine: bool = False # LLM provider client settings httpx_max_retries: int = 5 diff --git a/letta/utils.py b/letta/utils.py index fbb926b8..a23735b2 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -812,7 +812,7 @@ def printd(*args, **kwargs): print(*args, **kwargs) -def united_diff(str1, str2): +def united_diff(str1: str, str2: str) -> str: lines1 = str1.splitlines(True) lines2 = str2.splitlines(True) diff = difflib.unified_diff(lines1, lines2) diff --git a/poetry.lock b/poetry.lock index cdf6c460..2df13969 100644 --- a/poetry.lock +++ b/poetry.lock @@ -326,6 +326,73 @@ files = [ {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, ] +[[package]] +name = "asyncpg" +version = "0.30.0" +description = "An asyncio PostgreSQL driver" +optional = false +python-versions = ">=3.8.0" +groups = ["main"] +files = [ + {file = "asyncpg-0.30.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bfb4dd5ae0699bad2b233672c8fc5ccbd9ad24b89afded02341786887e37927e"}, + {file = "asyncpg-0.30.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dc1f62c792752a49f88b7e6f774c26077091b44caceb1983509edc18a2222ec0"}, + {file = "asyncpg-0.30.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3152fef2e265c9c24eec4ee3d22b4f4d2703d30614b0b6753e9ed4115c8a146f"}, + {file = "asyncpg-0.30.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7255812ac85099a0e1ffb81b10dc477b9973345793776b128a23e60148dd1af"}, + {file = "asyncpg-0.30.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:578445f09f45d1ad7abddbff2a3c7f7c291738fdae0abffbeb737d3fc3ab8b75"}, + {file = "asyncpg-0.30.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c42f6bb65a277ce4d93f3fba46b91a265631c8df7250592dd4f11f8b0152150f"}, + {file = "asyncpg-0.30.0-cp310-cp310-win32.whl", hash = "sha256:aa403147d3e07a267ada2ae34dfc9324e67ccc4cdca35261c8c22792ba2b10cf"}, + {file = "asyncpg-0.30.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb622c94db4e13137c4c7f98834185049cc50ee01d8f657ef898b6407c7b9c50"}, + {file = "asyncpg-0.30.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5e0511ad3dec5f6b4f7a9e063591d407eee66b88c14e2ea636f187da1dcfff6a"}, + {file = "asyncpg-0.30.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:915aeb9f79316b43c3207363af12d0e6fd10776641a7de8a01212afd95bdf0ed"}, + {file = "asyncpg-0.30.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c198a00cce9506fcd0bf219a799f38ac7a237745e1d27f0e1f66d3707c84a5a"}, + {file = "asyncpg-0.30.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3326e6d7381799e9735ca2ec9fd7be4d5fef5dcbc3cb555d8a463d8460607956"}, + {file = "asyncpg-0.30.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:51da377487e249e35bd0859661f6ee2b81db11ad1f4fc036194bc9cb2ead5056"}, + {file = "asyncpg-0.30.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bc6d84136f9c4d24d358f3b02be4b6ba358abd09f80737d1ac7c444f36108454"}, + {file = "asyncpg-0.30.0-cp311-cp311-win32.whl", hash = "sha256:574156480df14f64c2d76450a3f3aaaf26105869cad3865041156b38459e935d"}, + {file = "asyncpg-0.30.0-cp311-cp311-win_amd64.whl", hash = "sha256:3356637f0bd830407b5597317b3cb3571387ae52ddc3bca6233682be88bbbc1f"}, + {file = "asyncpg-0.30.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c902a60b52e506d38d7e80e0dd5399f657220f24635fee368117b8b5fce1142e"}, + {file = "asyncpg-0.30.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aca1548e43bbb9f0f627a04666fedaca23db0a31a84136ad1f868cb15deb6e3a"}, + {file = "asyncpg-0.30.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c2a2ef565400234a633da0eafdce27e843836256d40705d83ab7ec42074efb3"}, + {file = "asyncpg-0.30.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1292b84ee06ac8a2ad8e51c7475aa309245874b61333d97411aab835c4a2f737"}, + {file = "asyncpg-0.30.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0f5712350388d0cd0615caec629ad53c81e506b1abaaf8d14c93f54b35e3595a"}, + {file = "asyncpg-0.30.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:db9891e2d76e6f425746c5d2da01921e9a16b5a71a1c905b13f30e12a257c4af"}, + {file = "asyncpg-0.30.0-cp312-cp312-win32.whl", hash = "sha256:68d71a1be3d83d0570049cd1654a9bdfe506e794ecc98ad0873304a9f35e411e"}, + {file = "asyncpg-0.30.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a0292c6af5c500523949155ec17b7fe01a00ace33b68a476d6b5059f9630305"}, + {file = "asyncpg-0.30.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:05b185ebb8083c8568ea8a40e896d5f7af4b8554b64d7719c0eaa1eb5a5c3a70"}, + {file = "asyncpg-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c47806b1a8cbb0a0db896f4cd34d89942effe353a5035c62734ab13b9f938da3"}, + {file = "asyncpg-0.30.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b6fde867a74e8c76c71e2f64f80c64c0f3163e687f1763cfaf21633ec24ec33"}, + {file = "asyncpg-0.30.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46973045b567972128a27d40001124fbc821c87a6cade040cfcd4fa8a30bcdc4"}, + {file = "asyncpg-0.30.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9110df111cabc2ed81aad2f35394a00cadf4f2e0635603db6ebbd0fc896f46a4"}, + {file = "asyncpg-0.30.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04ff0785ae7eed6cc138e73fc67b8e51d54ee7a3ce9b63666ce55a0bf095f7ba"}, + {file = "asyncpg-0.30.0-cp313-cp313-win32.whl", hash = "sha256:ae374585f51c2b444510cdf3595b97ece4f233fde739aa14b50e0d64e8a7a590"}, + {file = "asyncpg-0.30.0-cp313-cp313-win_amd64.whl", hash = "sha256:f59b430b8e27557c3fb9869222559f7417ced18688375825f8f12302c34e915e"}, + {file = "asyncpg-0.30.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:29ff1fc8b5bf724273782ff8b4f57b0f8220a1b2324184846b39d1ab4122031d"}, + {file = "asyncpg-0.30.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:64e899bce0600871b55368b8483e5e3e7f1860c9482e7f12e0a771e747988168"}, + {file = "asyncpg-0.30.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b290f4726a887f75dcd1b3006f484252db37602313f806e9ffc4e5996cfe5cb"}, + {file = "asyncpg-0.30.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f86b0e2cd3f1249d6fe6fd6cfe0cd4538ba994e2d8249c0491925629b9104d0f"}, + {file = "asyncpg-0.30.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:393af4e3214c8fa4c7b86da6364384c0d1b3298d45803375572f415b6f673f38"}, + {file = "asyncpg-0.30.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:fd4406d09208d5b4a14db9a9dbb311b6d7aeeab57bded7ed2f8ea41aeef39b34"}, + {file = "asyncpg-0.30.0-cp38-cp38-win32.whl", hash = "sha256:0b448f0150e1c3b96cb0438a0d0aa4871f1472e58de14a3ec320dbb2798fb0d4"}, + {file = "asyncpg-0.30.0-cp38-cp38-win_amd64.whl", hash = "sha256:f23b836dd90bea21104f69547923a02b167d999ce053f3d502081acea2fba15b"}, + {file = "asyncpg-0.30.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f4e83f067b35ab5e6371f8a4c93296e0439857b4569850b178a01385e82e9ad"}, + {file = "asyncpg-0.30.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5df69d55add4efcd25ea2a3b02025b669a285b767bfbf06e356d68dbce4234ff"}, + {file = "asyncpg-0.30.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3479a0d9a852c7c84e822c073622baca862d1217b10a02dd57ee4a7a081f708"}, + {file = "asyncpg-0.30.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26683d3b9a62836fad771a18ecf4659a30f348a561279d6227dab96182f46144"}, + {file = "asyncpg-0.30.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:1b982daf2441a0ed314bd10817f1606f1c28b1136abd9e4f11335358c2c631cb"}, + {file = "asyncpg-0.30.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1c06a3a50d014b303e5f6fc1e5f95eb28d2cee89cf58384b700da621e5d5e547"}, + {file = "asyncpg-0.30.0-cp39-cp39-win32.whl", hash = "sha256:1b11a555a198b08f5c4baa8f8231c74a366d190755aa4f99aacec5970afe929a"}, + {file = "asyncpg-0.30.0-cp39-cp39-win_amd64.whl", hash = "sha256:8b684a3c858a83cd876f05958823b68e8d14ec01bb0c0d14a6704c5bf9711773"}, + {file = "asyncpg-0.30.0.tar.gz", hash = "sha256:c551e9928ab6707602f44811817f82ba3c446e018bfe1d3abecc8ba5f3eac851"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.3", markers = "python_version < \"3.11.0\""} + +[package.extras] +docs = ["Sphinx (>=8.1.3,<8.2.0)", "sphinx-rtd-theme (>=1.2.2)"] +gssauth = ["gssapi ; platform_system != \"Windows\"", "sspilib ; platform_system == \"Windows\""] +test = ["distro (>=1.9.0,<1.10.0)", "flake8 (>=6.1,<7.0)", "flake8-pyi (>=24.1.0,<24.2.0)", "gssapi ; platform_system == \"Linux\"", "k5test ; platform_system == \"Linux\"", "mypy (>=1.8.0,<1.9.0)", "sspilib ; platform_system == \"Windows\"", "uvloop (>=0.15.3) ; platform_system != \"Windows\" and python_version < \"3.14.0\""] + [[package]] name = "attrs" version = "25.3.0" @@ -7503,4 +7570,4 @@ tests = ["wikipedia"] [metadata] lock-version = "2.1" python-versions = "<3.14,>=3.10" -content-hash = "f82fec7b3f35d4222c43b692db8cd005eaf8bcf6761bb202d0dbf64121c6b2ab" +content-hash = "862dc5a31d4385e89dc9a751cd171a611da3102c6832447a5f61926b25f03e06" diff --git a/pyproject.toml b/pyproject.toml index 66a987ca..e77dc3cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,6 +90,7 @@ firecrawl-py = "^1.15.0" apscheduler = "^3.11.0" aiomultiprocess = "^0.9.1" matplotlib = "^3.10.1" +asyncpg = "^0.30.0" [tool.poetry.extras] diff --git a/tests/integration_test_sleeptime_agent.py b/tests/integration_test_sleeptime_agent.py index 0749b399..17dc8430 100644 --- a/tests/integration_test_sleeptime_agent.py +++ b/tests/integration_test_sleeptime_agent.py @@ -15,6 +15,7 @@ from letta.schemas.enums import JobStatus, ToolRuleType from letta.schemas.group import GroupUpdate, ManagerType, SleeptimeManagerUpdate from letta.schemas.message import MessageCreate from letta.schemas.run import Run +from letta.server.db import db_registry from letta.server.server import SyncServer from letta.utils import get_human_text, get_persona_text @@ -37,7 +38,7 @@ def org_id(server): yield org.id # cleanup - with server.organization_manager.session_maker() as session: + with db_registry.session() as session: session.execute(delete(Step)) session.execute(delete(Provider)) session.commit() diff --git a/tests/test_multi_agent.py b/tests/test_multi_agent.py index cbaa54dd..150922c4 100644 --- a/tests/test_multi_agent.py +++ b/tests/test_multi_agent.py @@ -15,6 +15,7 @@ from letta.schemas.group import ( SupervisorManager, ) from letta.schemas.message import MessageCreate +from letta.server.db import db_registry from letta.server.server import SyncServer @@ -36,7 +37,7 @@ def org_id(server): yield org.id # cleanup - with server.organization_manager.session_maker() as session: + with db_registry.session() as session: session.execute(delete(Step)) session.execute(delete(Provider)) session.commit() diff --git a/tests/test_server.py b/tests/test_server.py index b6440c42..a3932d81 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -19,6 +19,7 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.providers import ProviderCreate from letta.schemas.sandbox_config import SandboxType from letta.schemas.user import User +from letta.server.db import db_registry utils.DEBUG = True from letta.config import LettaConfig @@ -284,7 +285,7 @@ def org_id(server): yield org.id # cleanup - with server.organization_manager.session_maker() as session: + with db_registry.session() as session: session.execute(delete(Step)) session.execute(delete(Provider)) session.commit()