From 1be576a28e0d362d10a1354545f4e05fadd19563 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 22 Oct 2024 14:47:09 -0700 Subject: [PATCH] feat: Add ORM for organization model (#1914) --- alembic/env.py | 2 +- letta/agent_store/db.py | 6 +- letta/base.py | 3 - letta/constants.py | 5 +- letta/metadata.py | 60 +---- letta/orm/__all__.py | 0 letta/orm/__init__.py | 0 letta/orm/base.py | 75 ++++++ letta/orm/enums.py | 8 + letta/orm/errors.py | 2 + letta/orm/mixins.py | 40 ++++ letta/orm/organization.py | 35 +++ letta/orm/sqlalchemy_base.py | 214 ++++++++++++++++++ letta/schemas/organization.py | 6 +- .../rest_api/routers/v1/organizations.py | 9 +- letta/server/server.py | 35 +-- letta/services/__init__.py | 0 letta/services/organization_manager.py | 66 ++++++ poetry.lock | 13 +- pyproject.toml | 1 + tests/test_admin_client.py | 146 ------------ tests/test_client.py | 1 - tests/test_server.py | 62 ++++- 23 files changed, 541 insertions(+), 248 deletions(-) delete mode 100644 letta/base.py create mode 100644 letta/orm/__all__.py create mode 100644 letta/orm/__init__.py create mode 100644 letta/orm/base.py create mode 100644 letta/orm/enums.py create mode 100644 letta/orm/errors.py create mode 100644 letta/orm/mixins.py create mode 100644 letta/orm/organization.py create mode 100644 letta/orm/sqlalchemy_base.py create mode 100644 letta/services/__init__.py create mode 100644 letta/services/organization_manager.py delete mode 100644 tests/test_admin_client.py diff --git a/alembic/env.py b/alembic/env.py index f19996b1..3c084a82 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -4,8 +4,8 @@ from logging.config import fileConfig from sqlalchemy import engine_from_config, pool from alembic import context -from letta.base import Base from letta.config import LettaConfig +from letta.orm.base import Base from letta.settings import settings letta_config = LettaConfig.load() diff --git a/letta/agent_store/db.py b/letta/agent_store/db.py index 5e4fc5ae..840c03ce 100644 --- a/letta/agent_store/db.py +++ b/letta/agent_store/db.py @@ -25,10 +25,10 @@ from sqlalchemy_json import MutableJson from tqdm import tqdm from letta.agent_store.storage import StorageConnector, TableType -from letta.base import Base from letta.config import LettaConfig from letta.constants import MAX_EMBEDDING_DIM from letta.metadata import EmbeddingConfigColumn, FileMetadataModel, ToolCallColumn +from letta.orm.base import Base # from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall from letta.schemas.message import Message @@ -509,8 +509,10 @@ class SQLLiteStorageConnector(SQLStorageConnector): self.session_maker = db_context + # Need this in order to allow UUIDs to be stored successfully in the sqlite database # import sqlite3 - + # import uuid + # # sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le) # sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b)) diff --git a/letta/base.py b/letta/base.py deleted file mode 100644 index 860e5425..00000000 --- a/letta/base.py +++ /dev/null @@ -1,3 +0,0 @@ -from sqlalchemy.ext.declarative import declarative_base - -Base = declarative_base() diff --git a/letta/constants.py b/letta/constants.py index 9db6b7bb..f317a0e1 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -5,7 +5,10 @@ LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta") # Defaults DEFAULT_USER_ID = "user-00000000" -DEFAULT_ORG_ID = "org-00000000" +# This UUID follows the UUID4 rules: +# The 13th character (4) indicates it's version 4. +# The first character of the third segment (8) ensures the variant is correctly set. +DEFAULT_ORG_ID = "organization-00000000-0000-4000-8000-000000000000" DEFAULT_USER_NAME = "default" DEFAULT_ORG_NAME = "default" diff --git a/letta/metadata.py b/letta/metadata.py index 1d36d216..328a829d 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -20,8 +20,8 @@ from sqlalchemy import ( ) from sqlalchemy.sql import func -from letta.base import Base from letta.config import LettaConfig +from letta.orm.base import Base from letta.schemas.agent import AgentState from letta.schemas.api_key import APIKey from letta.schemas.block import Block, Human, Persona @@ -34,7 +34,6 @@ from letta.schemas.memory import Memory # from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction -from letta.schemas.organization import Organization from letta.schemas.source import Source from letta.schemas.tool import Tool from letta.schemas.user import User @@ -174,21 +173,6 @@ class UserModel(Base): return User(id=self.id, name=self.name, created_at=self.created_at, org_id=self.org_id) -class OrganizationModel(Base): - __tablename__ = "organizations" - __table_args__ = {"extend_existing": True} - - id = Column(String, primary_key=True) - name = Column(String, nullable=False) - created_at = Column(DateTime(timezone=True)) - - def __repr__(self) -> str: - return f"" - - def to_record(self) -> Organization: - return Organization(id=self.id, name=self.name, created_at=self.created_at) - - # TODO: eventually store providers? # class Provider(Base): # __tablename__ = "providers" @@ -551,14 +535,6 @@ class MetadataStore: session.add(UserModel(**vars(user))) session.commit() - @enforce_types - def create_organization(self, organization: Organization): - with self.session_maker() as session: - if session.query(OrganizationModel).filter(OrganizationModel.id == organization.id).count() > 0: - raise ValueError(f"Organization with id {organization.id} already exists") - session.add(OrganizationModel(**vars(organization))) - session.commit() - @enforce_types def create_block(self, block: Block): with self.session_maker() as session: @@ -698,16 +674,6 @@ class MetadataStore: session.commit() - @enforce_types - def delete_organization(self, org_id: str): - with self.session_maker() as session: - # delete from organizations table - session.query(OrganizationModel).filter(OrganizationModel.id == org_id).delete() - - # TODO: delete associated data - - session.commit() - @enforce_types def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[ToolModel]: with self.session_maker() as session: @@ -762,30 +728,6 @@ class MetadataStore: assert len(results) == 1, f"Expected 1 result, got {len(results)}" return results[0].to_record() - @enforce_types - def get_organization(self, org_id: str) -> Optional[Organization]: - with self.session_maker() as session: - results = session.query(OrganizationModel).filter(OrganizationModel.id == org_id).all() - if len(results) == 0: - return None - assert len(results) == 1, f"Expected 1 result, got {len(results)}" - return results[0].to_record() - - @enforce_types - def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50): - with self.session_maker() as session: - query = session.query(OrganizationModel).order_by(desc(OrganizationModel.id)) - if cursor: - query = query.filter(OrganizationModel.id < cursor) - results = query.limit(limit).all() - if not results: - return None, [] - organization_records = [r.to_record() for r in results] - next_cursor = organization_records[-1].id - assert isinstance(next_cursor, str) - - return next_cursor, organization_records - @enforce_types def get_all_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50): with self.session_maker() as session: diff --git a/letta/orm/__all__.py b/letta/orm/__all__.py new file mode 100644 index 00000000..e69de29b diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/letta/orm/base.py b/letta/orm/base.py new file mode 100644 index 00000000..61f7575d --- /dev/null +++ b/letta/orm/base.py @@ -0,0 +1,75 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID + +from sqlalchemy import UUID as SQLUUID +from sqlalchemy import Boolean, DateTime, func, text +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + declarative_mixin, + declared_attr, + mapped_column, +) + + +class Base(DeclarativeBase): + """absolute base for sqlalchemy classes""" + + +@declarative_mixin +class CommonSqlalchemyMetaMixins(Base): + __abstract__ = True + + created_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), server_default=func.now()) + updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), server_default=func.now(), server_onupdate=func.now()) + is_deleted: Mapped[bool] = mapped_column(Boolean, server_default=text("FALSE")) + + @declared_attr + def _created_by_id(cls): + return cls._user_by_id() + + @declared_attr + def _last_updated_by_id(cls): + return cls._user_by_id() + + @classmethod + def _user_by_id(cls): + """a flexible non-constrained record of a user. + This way users can get added, deleted etc without history freaking out + """ + return mapped_column(SQLUUID(), nullable=True) + + @property + def last_updated_by_id(self) -> Optional[str]: + return self._user_id_getter("last_updated") + + @last_updated_by_id.setter + def last_updated_by_id(self, value: str) -> None: + self._user_id_setter("last_updated", value) + + @property + def created_by_id(self) -> Optional[str]: + return self._user_id_getter("created") + + @created_by_id.setter + def created_by_id(self, value: str) -> None: + self._user_id_setter("created", value) + + def _user_id_getter(self, prop: str) -> Optional[str]: + """returns the user id for the specified property""" + full_prop = f"_{prop}_by_id" + prop_value = getattr(self, full_prop, None) + if not prop_value: + return + return f"user-{prop_value}" + + def _user_id_setter(self, prop: str, value: str) -> None: + """returns the user id for the specified property""" + full_prop = f"_{prop}_by_id" + if not value: + setattr(self, full_prop, None) + return + prefix, id_ = value.split("-", 1) + assert prefix == "user", f"{prefix} is not a valid id prefix for a user id" + setattr(self, full_prop, UUID(id_)) diff --git a/letta/orm/enums.py b/letta/orm/enums.py new file mode 100644 index 00000000..c9a7b060 --- /dev/null +++ b/letta/orm/enums.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class ToolSourceType(str, Enum): + """Defines what a tool was derived from""" + + python = "python" + json = "json" diff --git a/letta/orm/errors.py b/letta/orm/errors.py new file mode 100644 index 00000000..d1bcf4ab --- /dev/null +++ b/letta/orm/errors.py @@ -0,0 +1,2 @@ +class NoResultFound(Exception): + """A record or records cannot be found given the provided search params""" diff --git a/letta/orm/mixins.py b/letta/orm/mixins.py new file mode 100644 index 00000000..c6348957 --- /dev/null +++ b/letta/orm/mixins.py @@ -0,0 +1,40 @@ +from typing import Optional, Type +from uuid import UUID + +from letta.orm.base import Base + + +class MalformedIdError(Exception): + pass + + +def _relation_getter(instance: "Base", prop: str) -> Optional[str]: + prefix = prop.replace("_", "") + formatted_prop = f"_{prop}_id" + try: + uuid_ = getattr(instance, formatted_prop) + return f"{prefix}-{uuid_}" + except AttributeError: + return None + + +def _relation_setter(instance: Type["Base"], prop: str, value: str) -> None: + formatted_prop = f"_{prop}_id" + prefix = prop.replace("_", "") + if not value: + setattr(instance, formatted_prop, None) + return + try: + found_prefix, id_ = value.split("-", 1) + except ValueError as e: + raise MalformedIdError(f"{value} is not a valid ID.") from e + assert ( + # TODO: should be able to get this from the Mapped typing, not sure how though + # prefix = getattr(?, "prefix") + found_prefix + == prefix + ), f"{found_prefix} is not a valid id prefix, expecting {prefix}" + try: + setattr(instance, formatted_prop, UUID(id_)) + except ValueError as e: + raise MalformedIdError("Hash segment of {value} is not a valid UUID") from e diff --git a/letta/orm/organization.py b/letta/orm/organization.py new file mode 100644 index 00000000..394cb436 --- /dev/null +++ b/letta/orm/organization.py @@ -0,0 +1,35 @@ +from typing import TYPE_CHECKING + +from sqlalchemy.exc import NoResultFound +from sqlalchemy.orm import Mapped, mapped_column + +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.organization import Organization as PydanticOrganization + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + +class Organization(SqlalchemyBase): + """The highest level of the object tree. All Entities belong to one and only one Organization.""" + + __tablename__ = "organizations" + __pydantic_model__ = PydanticOrganization + + name: Mapped[str] = mapped_column(doc="The display name of the organization.") + + # TODO: Map these relationships later when we actually make these models + # below is just a suggestion + # users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan") + # agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan") + # sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan") + # tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan") + # documents: Mapped[List["Document"]] = relationship("Document", back_populates="organization", cascade="all, delete-orphan") + + @classmethod + def default(cls, db_session: "Session") -> "Organization": + """Get the default org, or create it if it doesn't exist.""" + try: + return db_session.query(cls).filter(cls.name == "Default Organization").one() + except NoResultFound: + return cls(name="Default Organization").create(db_session) diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py new file mode 100644 index 00000000..557593e1 --- /dev/null +++ b/letta/orm/sqlalchemy_base.py @@ -0,0 +1,214 @@ +from typing import TYPE_CHECKING, List, Literal, Optional, Type, Union +from uuid import UUID, uuid4 + +from humps import depascalize +from sqlalchemy import Boolean, String, select +from sqlalchemy.orm import Mapped, mapped_column + +from letta.log import get_logger +from letta.orm.base import Base, CommonSqlalchemyMetaMixins +from letta.orm.errors import NoResultFound + +if TYPE_CHECKING: + from pydantic import BaseModel + from sqlalchemy.orm import Session + + # from letta.orm.user import User + +logger = get_logger(__name__) + + +class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): + __abstract__ = True + + __order_by_default__ = "created_at" + + _id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"{uuid4()}") + + deleted: Mapped[bool] = mapped_column(Boolean, default=False, doc="Is this record deleted? Used for universal soft deletes.") + + @classmethod + def __prefix__(cls) -> str: + return depascalize(cls.__name__) + + @property + def id(self) -> Optional[str]: + if self._id: + return f"{self.__prefix__()}-{self._id}" + + @id.setter + def id(self, value: str) -> None: + if not value: + return + prefix, id_ = value.split("-", 1) + assert prefix == self.__prefix__(), f"{prefix} is not a valid id prefix for {self.__class__.__name__}" + assert SqlalchemyBase.is_valid_uuid4(id_), f"{id_} is not a valid uuid4" + self._id = id_ + + @classmethod + def list( + cls, *, db_session: "Session", cursor: Optional[str] = None, limit: Optional[int] = 50, **kwargs + ) -> List[Type["SqlalchemyBase"]]: + """List records with optional cursor (for pagination) and limit.""" + with db_session as session: + # Start with the base query filtered by kwargs + query = select(cls).filter_by(**kwargs) + + # Add a cursor condition if provided + if cursor: + cursor_uuid = cls.get_uid_from_identifier(cursor) # Assuming the cursor is an _id value + query = query.where(cls._id > cursor_uuid) + + # Add a limit to the query if provided + query = query.order_by(cls._id).limit(limit) + + # Handle soft deletes if the class has the 'is_deleted' attribute + if hasattr(cls, "is_deleted"): + query = query.where(cls.is_deleted == False) + + # Execute the query and return the results as a list of model instances + return list(session.execute(query).scalars()) + + @classmethod + def get_uid_from_identifier(cls, identifier: str, indifferent: Optional[bool] = False) -> str: + """converts the id into a uuid object + Args: + identifier: the string identifier, such as `organization-xxxx-xx...` + indifferent: if True, will not enforce the prefix check + """ + try: + uuid_string = identifier.split("-", 1)[1] if indifferent else identifier.replace(f"{cls.__prefix__()}-", "") + assert SqlalchemyBase.is_valid_uuid4(uuid_string) + return uuid_string + except ValueError as e: + raise ValueError(f"{identifier} is not a valid identifier for class {cls.__name__}") from e + + @classmethod + def is_valid_uuid4(cls, uuid_string: str) -> bool: + try: + # Try to create a UUID object from the string + uuid_obj = UUID(uuid_string) + # Check if the UUID is version 4 + return uuid_obj.version == 4 + except ValueError: + # Raised if the string is not a valid UUID + return False + + @classmethod + def read( + cls, + db_session: "Session", + identifier: Union[str, UUID], + actor: Optional["User"] = None, + access: Optional[List[Literal["read", "write", "admin"]]] = ["read"], + **kwargs, + ) -> Type["SqlalchemyBase"]: + """The primary accessor for an ORM record. + Args: + db_session: the database session to use when retrieving the record + identifier: the identifier of the record to read, can be the id string or the UUID object for backwards compatibility + actor: if specified, results will be scoped only to records the user is able to access + access: if actor is specified, records will be filtered to the minimum permission level for the actor + kwargs: additional arguments to pass to the read, used for more complex objects + Returns: + The matching object + Raises: + NoResultFound: if the object is not found + """ + del kwargs # arity for more complex reads + identifier = cls.get_uid_from_identifier(identifier) + query = select(cls).where(cls._id == identifier) + # if actor: + # query = cls.apply_access_predicate(query, actor, access) + if hasattr(cls, "is_deleted"): + query = query.where(cls.is_deleted == False) + if found := db_session.execute(query).scalar(): + return found + raise NoResultFound(f"{cls.__name__} with id {identifier} not found") + + def create(self, db_session: "Session") -> Type["SqlalchemyBase"]: + # self._infer_organization(db_session) + + with db_session as session: + session.add(self) + session.commit() + session.refresh(self) + return self + + def delete(self, db_session: "Session") -> Type["SqlalchemyBase"]: + self.is_deleted = True + return self.update(db_session) + + def update(self, db_session: "Session") -> Type["SqlalchemyBase"]: + with db_session as session: + session.add(self) + session.commit() + session.refresh(self) + return self + + @classmethod + def read_or_create(cls, *, db_session: "Session", **kwargs) -> Type["SqlalchemyBase"]: + """get an instance by search criteria or create it if it doesn't exist""" + try: + return cls.read(db_session=db_session, identifier=kwargs.get("id", None)) + except NoResultFound: + clean_kwargs = {k: v for k, v in kwargs.items() if k in cls.__table__.columns} + return cls(**clean_kwargs).create(db_session=db_session) + + # TODO: Add back later when access predicates are actually important + # The idea behind this is that you can add a WHERE clause restricting the actions you can take, e.g. R/W + # @classmethod + # def apply_access_predicate( + # cls, + # query: "Select", + # actor: "User", + # access: List[Literal["read", "write", "admin"]], + # ) -> "Select": + # """applies a WHERE clause restricting results to the given actor and access level + # Args: + # query: The initial sqlalchemy select statement + # actor: The user acting on the query. **Note**: this is called 'actor' to identify the + # person or system acting. Users can act on users, making naming very sticky otherwise. + # access: + # what mode of access should the query restrict to? This will be used with granular permissions, + # but because of how it will impact every query we want to be explicitly calling access ahead of time. + # Returns: + # the sqlalchemy select statement restricted to the given access. + # """ + # del access # entrypoint for row-level permissions. Defaults to "same org as the actor, all permissions" at the moment + # org_uid = getattr(actor, "_organization_id", getattr(actor.organization, "_id", None)) + # if not org_uid: + # raise ValueError("object %s has no organization accessor", actor) + # return query.where(cls._organization_id == org_uid, cls.is_deleted == False) + + @property + def __pydantic_model__(self) -> Type["BaseModel"]: + raise NotImplementedError("Sqlalchemy models must declare a __pydantic_model__ property to be convertable.") + + def to_pydantic(self) -> Type["BaseModel"]: + """converts to the basic pydantic model counterpart""" + return self.__pydantic_model__.model_validate(self) + + def to_record(self) -> Type["BaseModel"]: + """Deprecated accessor for to_pydantic""" + logger.warning("to_record is deprecated, use to_pydantic instead.") + return self.to_pydantic() + + # TODO: Look into this later and maybe add back? + # def _infer_organization(self, db_session: "Session") -> None: + # """🪄 MAGIC ALERT! 🪄 + # Because so much of the original API is centered around user scopes, + # this allows us to continue with that scope and then infer the org from the creating user. + # + # IF a created_by_id is set, we will use that to infer the organization and magic set it at create time! + # If not do nothing to the object. Mutates in place. + # """ + # if self.created_by_id and hasattr(self, "_organization_id"): + # try: + # from letta.orm.user import User # to avoid circular import + # + # created_by = User.read(db_session, self.created_by_id) + # except NoResultFound: + # logger.warning(f"User {self.created_by_id} not found, unable to infer organization.") + # return + # self._organization_id = created_by._organization_id diff --git a/letta/schemas/organization.py b/letta/schemas/organization.py index 8d9b7da5..cc969e15 100644 --- a/letta/schemas/organization.py +++ b/letta/schemas/organization.py @@ -7,13 +7,13 @@ from letta.schemas.letta_base import LettaBase class OrganizationBase(LettaBase): - __id_prefix__ = "org" + __id_prefix__ = "organization" class Organization(OrganizationBase): - id: str = OrganizationBase.generate_id_field() + id: str = Field(..., description="The id of the organization.") name: str = Field(..., description="The name of the organization.") - created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.") + created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the organization.") class OrganizationCreate(OrganizationBase): diff --git a/letta/server/rest_api/routers/v1/organizations.py b/letta/server/rest_api/routers/v1/organizations.py index 29dddbd3..efe9882a 100644 --- a/letta/server/rest_api/routers/v1/organizations.py +++ b/letta/server/rest_api/routers/v1/organizations.py @@ -22,7 +22,7 @@ def get_all_orgs( Get a list of all orgs in the database """ try: - next_cursor, orgs = server.ms.list_organizations(cursor=cursor, limit=limit) + next_cursor, orgs = server.organization_manager.list_organizations(cursor=cursor, limit=limit) except HTTPException: raise except Exception as e: @@ -38,8 +38,7 @@ def create_org( """ Create a new org in the database """ - - org = server.create_organization(request) + org = server.organization_manager.create_organization(request) return org @@ -50,10 +49,10 @@ def delete_org( ): # TODO make a soft deletion, instead of a hard deletion try: - org = server.ms.get_organization(org_id=org_id) + org = server.organization_manager.get_organization_by_id(org_id=org_id) if org is None: raise HTTPException(status_code=404, detail=f"Organization does not exist") - server.ms.delete_organization(org_id=org_id) + server.organization_manager.delete_organization(org_id=org_id) except HTTPException: raise except Exception as e: diff --git a/letta/server/server.py b/letta/server/server.py index 283f55db..7bddb188 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -44,6 +44,7 @@ from letta.log import get_logger from letta.memory import get_memory_functions from letta.metadata import Base, MetadataStore from letta.o1_agent import O1Agent +from letta.orm.errors import NoResultFound from letta.prompts import gpt_system from letta.providers import ( AnthropicProvider, @@ -80,12 +81,12 @@ from letta.schemas.memory import ( RecallMemorySummary, ) from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage -from letta.schemas.organization import Organization, OrganizationCreate from letta.schemas.passage import Passage from letta.schemas.source import Source, SourceCreate, SourceUpdate from letta.schemas.tool import Tool, ToolCreate, ToolUpdate from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User, UserCreate +from letta.services.organization_manager import OrganizationManager from letta.utils import create_random_username, json_dumps, json_loads # from letta.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin @@ -245,6 +246,9 @@ class SyncServer(Server): self.config = config self.ms = MetadataStore(self.config) + # Managers that interface with data models + self.organization_manager = OrganizationManager() + # TODO: this should be removed # add global default tools (for admin) self.add_default_tools(module_name="base") @@ -773,20 +777,6 @@ class SyncServer(Server): return user - def create_organization(self, request: OrganizationCreate) -> Organization: - """Create a new org using a config""" - if not request.name: - # auto-generate a name - request.name = create_random_username() - org = Organization(name=request.name) - self.ms.create_organization(org) - logger.info(f"Created new org from config: {org}") - - # add default for the org - # TODO: add default data - - return org - def create_agent( self, request: CreateAgent, @@ -2125,18 +2115,13 @@ class SyncServer(Server): def get_default_user(self) -> User: - from letta.constants import ( - DEFAULT_ORG_ID, - DEFAULT_ORG_NAME, - DEFAULT_USER_ID, - DEFAULT_USER_NAME, - ) + from letta.constants import DEFAULT_ORG_ID, DEFAULT_USER_ID, DEFAULT_USER_NAME # check if default org exists - default_org = self.ms.get_organization(DEFAULT_ORG_ID) - if not default_org: - org = Organization(name=DEFAULT_ORG_NAME, id=DEFAULT_ORG_ID) - self.ms.create_organization(org) + try: + self.organization_manager.get_organization_by_id(DEFAULT_ORG_ID) + except NoResultFound: + self.organization_manager.create_default_organization() # check if default user exists try: diff --git a/letta/services/__init__.py b/letta/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/letta/services/organization_manager.py b/letta/services/organization_manager.py new file mode 100644 index 00000000..e6976849 --- /dev/null +++ b/letta/services/organization_manager.py @@ -0,0 +1,66 @@ +from typing import List, Optional + +from sqlalchemy.exc import NoResultFound + +from letta.constants import DEFAULT_ORG_ID, DEFAULT_ORG_NAME +from letta.orm.organization import Organization +from letta.schemas.organization import Organization as PydanticOrganization +from letta.utils import create_random_username + + +class OrganizationManager: + """Manager class to handle business logic related to Organizations.""" + + def __init__(self): + # This is probably horrible but we reuse this technique from metadata.py + # TODO: Please refactor this out + # I am currently working on a ORM refactor and would like to make a more minimal set of changes + # - Matt + from letta.server.server import db_context + + self.session_maker = db_context + + def get_organization_by_id(self, org_id: str) -> PydanticOrganization: + """Fetch an organization by ID.""" + with self.session_maker() as session: + try: + organization = Organization.read(db_session=session, identifier=org_id) + return organization.to_pydantic() + except NoResultFound: + raise ValueError(f"Organization with id {org_id} not found.") + + def create_organization(self, name: Optional[str] = None) -> PydanticOrganization: + """Create a new organization. If a name is provided, it is used, otherwise, a random one is generated.""" + with self.session_maker() as session: + org = Organization(name=name if name else create_random_username()) + org.create(session) + return org.to_pydantic() + + def create_default_organization(self) -> PydanticOrganization: + """Create the default organization.""" + with self.session_maker() as session: + org = Organization(name=DEFAULT_ORG_NAME) + org.id = DEFAULT_ORG_ID + org.create(session) + return org.to_pydantic() + + def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization: + """Update an organization.""" + with self.session_maker() as session: + organization = Organization.read(db_session=session, identifier=org_id) + if name: + organization.name = name + organization.update(session) + return organization.to_pydantic() + + def delete_organization(self, org_id: str): + """Delete an organization by marking it as deleted.""" + with self.session_maker() as session: + organization = Organization.read(db_session=session, identifier=org_id) + organization.delete(session) + + def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]: + """List organizations with pagination based on cursor (org_id) and limit.""" + with self.session_maker() as session: + results = Organization.list(db_session=session, cursor=cursor, limit=limit) + return [org.to_pydantic() for org in results] diff --git a/poetry.lock b/poetry.lock index 0038c033..8827a702 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5742,6 +5742,17 @@ files = [ [package.extras] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pyhumps" +version = "3.8.0" +description = "🐫 Convert strings (and dictionary keys) between snake case, camel case and pascal case in Python. Inspired by Humps for Node" +optional = false +python-versions = "*" +files = [ + {file = "pyhumps-3.8.0-py3-none-any.whl", hash = "sha256:060e1954d9069f428232a1adda165db0b9d8dfdce1d265d36df7fbff540acfd6"}, + {file = "pyhumps-3.8.0.tar.gz", hash = "sha256:498026258f7ee1a8e447c2e28526c0bea9407f9a59c03260aee4bd6c04d681a3"}, +] + [[package]] name = "pylance" version = "0.9.18" @@ -8423,4 +8434,4 @@ tests = ["wikipedia"] [metadata] lock-version = "2.0" python-versions = "<3.13,>=3.10" -content-hash = "357ad0382673050758dd4f98ba71d574cdebea385eefc9481b9c8bab743eafd3" +content-hash = "5c05bb8ee0f17e149be1482f6295fb2dcac41d8a23a27b890a81d2e9fa30b4e8" diff --git a/pyproject.toml b/pyproject.toml index f3d69bf9..b77de021 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,7 @@ langchain-community = {version = "^0.2.17", optional = true} composio-langchain = "^0.5.28" composio-core = "^0.5.34" alembic = "^1.13.3" +pyhumps = "^3.8.0" [tool.poetry.extras] #local = ["llama-index-embeddings-huggingface"] diff --git a/tests/test_admin_client.py b/tests/test_admin_client.py deleted file mode 100644 index 11aef20a..00000000 --- a/tests/test_admin_client.py +++ /dev/null @@ -1,146 +0,0 @@ -import threading -import time - -import pytest - -from letta import Admin - -test_base_url = "http://localhost:8283" - -# admin credentials -test_server_token = "test_server_token" - - -def run_server(): - from letta.server.rest_api.app import start_server - - print("Starting server...") - start_server(debug=True) - - -@pytest.fixture(scope="session", autouse=True) -def start_uvicorn_server(): - """Starts Uvicorn server in a background thread.""" - - thread = threading.Thread(target=run_server, daemon=True) - thread.start() - print("Starting server...") - time.sleep(5) - yield - - -@pytest.fixture(scope="module") -def admin_client(): - # Setup: Create a user via the client before the tests - - admin = Admin(test_base_url, test_server_token) - admin._reset_server() - yield admin - - -@pytest.fixture(scope="module") -def organization(admin_client): - # create an organization - org_name = "test_org" - org = admin_client.create_organization(org_name) - assert org_name == org.name, f"Expected {org_name}, got {org.name}" - - # test listing - orgs = admin_client.get_organizations() - assert len(orgs) > 0, f"Expected 1 org, got {orgs}" - - yield org - admin_client.delete_organization(org.id) - - -def test_admin_client(admin_client, organization): - - # create a user - user_name = "test_user" - user1 = admin_client.create_user(user_name, organization.id) - assert user_name == user1.name, f"Expected {user_name}, got {user1.name}" - - # create another user - user2 = admin_client.create_user() - - # create keys - key1_name = "test_key1" - key2_name = "test_key2" - api_key1 = admin_client.create_key(user1.id, key1_name) - admin_client.create_key(user2.id, key2_name) - - # list users - users = admin_client.get_users() - assert len(users) == 2 - assert user1.id in [user.id for user in users] - assert user2.id in [user.id for user in users] - - # list keys - user1_keys = admin_client.get_keys(user1.id) - assert len(user1_keys) == 1, f"Expected 1 keys, got {user1_keys}" - assert api_key1.key == user1_keys[0].key - - # delete key - deleted_key1 = admin_client.delete_key(api_key1.key) - assert deleted_key1.key == api_key1.key - assert len(admin_client.get_keys(user1.id)) == 0 - - # delete users - deleted_user1 = admin_client.delete_user(user1.id) - assert deleted_user1.id == user1.id - deleted_user2 = admin_client.delete_user(user2.id) - assert deleted_user2.id == user2.id - - # list users - users = admin_client.get_users() - assert len(users) == 0, f"Expected 0 users, got {users}" - - -# def test_get_users_pagination(admin_client): -# -# page_size = 5 -# num_users = 7 -# expected_users_remainder = num_users - page_size -# -# # create users -# all_user_ids = [] -# for i in range(num_users): -# -# user_id = uuid.uuid4() -# all_user_ids.append(user_id) -# key_name = "test_key" + f"{i}" -# -# create_user_response = admin_client.create_user(user_id) -# admin_client.create_key(create_user_response.user_id, key_name) -# -# # list users in page 1 -# get_all_users_response1 = admin_client.get_users(limit=page_size) -# cursor1 = get_all_users_response1.cursor -# user_list1 = get_all_users_response1.user_list -# assert len(user_list1) == page_size -# -# # list users in page 2 using cursor -# get_all_users_response2 = admin_client.get_users(cursor1, limit=page_size) -# cursor2 = get_all_users_response2.cursor -# user_list2 = get_all_users_response2.user_list -# -# assert len(user_list2) == expected_users_remainder -# assert cursor1 != cursor2 -# -# # delete users -# clean_up_users_and_keys(all_user_ids) -# -# # list users to check pagination with no users -# users = admin_client.get_users() -# assert len(users.user_list) == 0, f"Expected 0 users, got {users}" - - -def clean_up_users_and_keys(user_id_list): - admin_client = Admin(test_base_url, test_server_token) - - # clean up all keys and users - for user_id in user_id_list: - keys_list = admin_client.get_keys(user_id) - for key in keys_list: - admin_client.delete_key(key) - admin_client.delete_user(user_id) diff --git a/tests/test_client.py b/tests/test_client.py index 0a5e5620..8ede7b7f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -523,7 +523,6 @@ def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentStat def test_organization(client: RESTClient): if isinstance(client, LocalClient): pytest.skip("Skipping test_organization because LocalClient does not support organizations") - client.base_url def test_model_configs(client: Union[LocalClient, RESTClient]): diff --git a/tests/test_server.py b/tests/test_server.py index 9285b25e..69f55bac 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -3,9 +3,17 @@ import uuid import warnings import pytest +from sqlalchemy import delete import letta.utils as utils -from letta.constants import BASE_TOOLS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG +from letta.constants import ( + BASE_TOOLS, + DEFAULT_MESSAGE_TOOL, + DEFAULT_MESSAGE_TOOL_KWARG, + DEFAULT_ORG_ID, + DEFAULT_ORG_NAME, +) +from letta.orm.organization import Organization from letta.schemas.enums import MessageRole utils.DEBUG = True @@ -31,6 +39,14 @@ from letta.server.server import SyncServer from .utils import DummyDataConnector +@pytest.fixture(autouse=True) +def clear_organization_table(server: SyncServer): + """Fixture to clear the organization table before each test.""" + with server.organization_manager.session_maker() as session: + session.execute(delete(Organization)) # Clear all records from the organization table + session.commit() # Commit the deletion + + @pytest.fixture(scope="module") def server(): # if os.getenv("OPENAI_API_KEY"): @@ -547,3 +563,47 @@ def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id: + overview.num_tokens_functions_definitions + overview.num_tokens_external_memory_summary ) + + +def test_list_organizations(server: SyncServer): + # Create a new org and confirm that it is created correctly + org_name = "test" + org = server.organization_manager.create_organization(name=org_name) + + orgs = server.organization_manager.list_organizations() + assert len(orgs) == 1 + assert orgs[0].name == org_name + + # Delete it after + server.organization_manager.delete_organization(org.id) + assert len(server.organization_manager.list_organizations()) == 0 + + +def test_create_default_organization(server: SyncServer): + server.organization_manager.create_default_organization() + retrieved = server.organization_manager.get_organization_by_id(DEFAULT_ORG_ID) + assert retrieved.name == DEFAULT_ORG_NAME + + +def test_update_organization_name(server: SyncServer): + org_name_a = "a" + org_name_b = "b" + org = server.organization_manager.create_organization(name=org_name_a) + assert org.name == org_name_a + org = server.organization_manager.update_organization_name_using_id(org_id=org.id, name=org_name_b) + assert org.name == org_name_b + + +def test_list_organizations_pagination(server: SyncServer): + server.organization_manager.create_organization(name="a") + server.organization_manager.create_organization(name="b") + + orgs_x = server.organization_manager.list_organizations(limit=1) + assert len(orgs_x) == 1 + + orgs_y = server.organization_manager.list_organizations(cursor=orgs_x[0].id, limit=1) + assert len(orgs_y) == 1 + assert orgs_y[0].name != orgs_x[0].name + + orgs = server.organization_manager.list_organizations(cursor=orgs_y[0].id, limit=1) + assert len(orgs) == 0