diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 500fdd19..1ae32249 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union from asyncpg.exceptions import DeadlockDetectedError, QueryCanceledError from sqlalchemy import Sequence, String, and_, delete, func, or_, select +from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Mapped, Session, mapped_column @@ -544,12 +545,31 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): actor: Optional["User"] = None, no_commit: bool = False, no_refresh: bool = False, - ) -> "SqlalchemyBase": - """Async version of create function""" + ignore_conflicts: bool = False, + ) -> Optional["SqlalchemyBase"]: + """Async version of create function + + Args: + ignore_conflicts: If True, uses INSERT ... ON CONFLICT DO NOTHING and returns + None if a conflict occurred (no exception raised). + """ logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}") if actor: self._set_created_and_updated_by_fields(actor.id) + + if ignore_conflicts: + values = { + col.name: getattr(self, col.key) + for col in self.__table__.columns + if not (getattr(self, col.key) is None and col.server_default is not None) + } + stmt = pg_insert(self.__table__).values(**values).on_conflict_do_nothing() + result = await db_session.execute(stmt) + if not no_commit: + await db_session.commit() + return self if result.rowcount > 0 else None + for attempt in range(_DEADLOCK_MAX_RETRIES): try: db_session.add(self) diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 52b46495..2f66d140 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -723,19 +723,12 @@ class ProviderManager: f"org_id={pydantic_model.organization_id}" ) - # Convert to ORM model = ProviderModelORM(**pydantic_model.model_dump(to_orm=True)) - try: - await model.create_async(session) - logger.info(f" ✓ Successfully created LLM model {llm_config.handle} with ID {model.id}") - except Exception as e: - logger.info(f" ✗ Failed to create LLM model {llm_config.handle}: {e}") - # Log the full error details - import traceback - - logger.info(f" Full traceback: {traceback.format_exc()}") - # Roll back the session to clear the failed transaction - await session.rollback() + result = await model.create_async(session, ignore_conflicts=True) + if result: + logger.info(f" ✓ Successfully created LLM model {llm_config.handle}") + else: + logger.info(f" LLM model {llm_config.handle} already exists (concurrent insert), skipping") else: # Check if max_context_window or model_endpoint_type needs to be updated existing_model = existing[0] @@ -813,19 +806,12 @@ class ProviderManager: f"org_id={pydantic_model.organization_id}" ) - # Convert to ORM model = ProviderModelORM(**pydantic_model.model_dump(to_orm=True)) - try: - await model.create_async(session) - logger.info(f" ✓ Successfully created embedding model {embedding_config.handle} with ID {model.id}") - except Exception as e: - logger.error(f" ✗ Failed to create embedding model {embedding_config.handle}: {e}") - # Log the full error details - import traceback - - logger.error(f" Full traceback: {traceback.format_exc()}") - # Roll back the session to clear the failed transaction - await session.rollback() + result = await model.create_async(session, ignore_conflicts=True) + if result: + logger.info(f" ✓ Successfully created embedding model {embedding_config.handle}") + else: + logger.info(f" Embedding model {embedding_config.handle} already exists (concurrent insert), skipping") else: # Check if model_endpoint_type needs to be updated existing_model = existing[0]