feat: Add ORM for organization model (#1914)
This commit is contained in:
@@ -4,8 +4,8 @@ from logging.config import fileConfig
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
from alembic import context
|
||||
from letta.base import Base
|
||||
from letta.config import LettaConfig
|
||||
from letta.orm.base import Base
|
||||
from letta.settings import settings
|
||||
|
||||
letta_config = LettaConfig.load()
|
||||
|
||||
@@ -25,10 +25,10 @@ from sqlalchemy_json import MutableJson
|
||||
from tqdm import tqdm
|
||||
|
||||
from letta.agent_store.storage import StorageConnector, TableType
|
||||
from letta.base import Base
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.metadata import EmbeddingConfigColumn, FileMetadataModel, ToolCallColumn
|
||||
from letta.orm.base import Base
|
||||
|
||||
# from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
|
||||
from letta.schemas.message import Message
|
||||
@@ -509,8 +509,10 @@ class SQLLiteStorageConnector(SQLStorageConnector):
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
# Need this in order to allow UUIDs to be stored successfully in the sqlite database
|
||||
# import sqlite3
|
||||
|
||||
# import uuid
|
||||
#
|
||||
# sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le)
|
||||
# sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))
|
||||
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
@@ -5,7 +5,10 @@ LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta")
|
||||
|
||||
# Defaults
|
||||
DEFAULT_USER_ID = "user-00000000"
|
||||
DEFAULT_ORG_ID = "org-00000000"
|
||||
# This UUID follows the UUID4 rules:
|
||||
# The 13th character (4) indicates it's version 4.
|
||||
# The first character of the third segment (8) ensures the variant is correctly set.
|
||||
DEFAULT_ORG_ID = "organization-00000000-0000-4000-8000-000000000000"
|
||||
DEFAULT_USER_NAME = "default"
|
||||
DEFAULT_ORG_NAME = "default"
|
||||
|
||||
|
||||
@@ -20,8 +20,8 @@ from sqlalchemy import (
|
||||
)
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from letta.base import Base
|
||||
from letta.config import LettaConfig
|
||||
from letta.orm.base import Base
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.api_key import APIKey
|
||||
from letta.schemas.block import Block, Human, Persona
|
||||
@@ -34,7 +34,6 @@ from letta.schemas.memory import Memory
|
||||
|
||||
# from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
|
||||
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.user import User
|
||||
@@ -174,21 +173,6 @@ class UserModel(Base):
|
||||
return User(id=self.id, name=self.name, created_at=self.created_at, org_id=self.org_id)
|
||||
|
||||
|
||||
class OrganizationModel(Base):
|
||||
__tablename__ = "organizations"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
name = Column(String, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Organization(id='{self.id}' name='{self.name}')>"
|
||||
|
||||
def to_record(self) -> Organization:
|
||||
return Organization(id=self.id, name=self.name, created_at=self.created_at)
|
||||
|
||||
|
||||
# TODO: eventually store providers?
|
||||
# class Provider(Base):
|
||||
# __tablename__ = "providers"
|
||||
@@ -551,14 +535,6 @@ class MetadataStore:
|
||||
session.add(UserModel(**vars(user)))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def create_organization(self, organization: Organization):
|
||||
with self.session_maker() as session:
|
||||
if session.query(OrganizationModel).filter(OrganizationModel.id == organization.id).count() > 0:
|
||||
raise ValueError(f"Organization with id {organization.id} already exists")
|
||||
session.add(OrganizationModel(**vars(organization)))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def create_block(self, block: Block):
|
||||
with self.session_maker() as session:
|
||||
@@ -698,16 +674,6 @@ class MetadataStore:
|
||||
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def delete_organization(self, org_id: str):
|
||||
with self.session_maker() as session:
|
||||
# delete from organizations table
|
||||
session.query(OrganizationModel).filter(OrganizationModel.id == org_id).delete()
|
||||
|
||||
# TODO: delete associated data
|
||||
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[ToolModel]:
|
||||
with self.session_maker() as session:
|
||||
@@ -762,30 +728,6 @@ class MetadataStore:
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0].to_record()
|
||||
|
||||
@enforce_types
|
||||
def get_organization(self, org_id: str) -> Optional[Organization]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(OrganizationModel).filter(OrganizationModel.id == org_id).all()
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0].to_record()
|
||||
|
||||
@enforce_types
|
||||
def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
|
||||
with self.session_maker() as session:
|
||||
query = session.query(OrganizationModel).order_by(desc(OrganizationModel.id))
|
||||
if cursor:
|
||||
query = query.filter(OrganizationModel.id < cursor)
|
||||
results = query.limit(limit).all()
|
||||
if not results:
|
||||
return None, []
|
||||
organization_records = [r.to_record() for r in results]
|
||||
next_cursor = organization_records[-1].id
|
||||
assert isinstance(next_cursor, str)
|
||||
|
||||
return next_cursor, organization_records
|
||||
|
||||
@enforce_types
|
||||
def get_all_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
|
||||
with self.session_maker() as session:
|
||||
|
||||
0
letta/orm/__all__.py
Normal file
0
letta/orm/__all__.py
Normal file
0
letta/orm/__init__.py
Normal file
0
letta/orm/__init__.py
Normal file
75
letta/orm/base.py
Normal file
75
letta/orm/base.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import UUID as SQLUUID
|
||||
from sqlalchemy import Boolean, DateTime, func, text
|
||||
from sqlalchemy.orm import (
|
||||
DeclarativeBase,
|
||||
Mapped,
|
||||
declarative_mixin,
|
||||
declared_attr,
|
||||
mapped_column,
|
||||
)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""absolute base for sqlalchemy classes"""
|
||||
|
||||
|
||||
@declarative_mixin
|
||||
class CommonSqlalchemyMetaMixins(Base):
|
||||
__abstract__ = True
|
||||
|
||||
created_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), server_default=func.now(), server_onupdate=func.now())
|
||||
is_deleted: Mapped[bool] = mapped_column(Boolean, server_default=text("FALSE"))
|
||||
|
||||
@declared_attr
|
||||
def _created_by_id(cls):
|
||||
return cls._user_by_id()
|
||||
|
||||
@declared_attr
|
||||
def _last_updated_by_id(cls):
|
||||
return cls._user_by_id()
|
||||
|
||||
@classmethod
|
||||
def _user_by_id(cls):
|
||||
"""a flexible non-constrained record of a user.
|
||||
This way users can get added, deleted etc without history freaking out
|
||||
"""
|
||||
return mapped_column(SQLUUID(), nullable=True)
|
||||
|
||||
@property
|
||||
def last_updated_by_id(self) -> Optional[str]:
|
||||
return self._user_id_getter("last_updated")
|
||||
|
||||
@last_updated_by_id.setter
|
||||
def last_updated_by_id(self, value: str) -> None:
|
||||
self._user_id_setter("last_updated", value)
|
||||
|
||||
@property
|
||||
def created_by_id(self) -> Optional[str]:
|
||||
return self._user_id_getter("created")
|
||||
|
||||
@created_by_id.setter
|
||||
def created_by_id(self, value: str) -> None:
|
||||
self._user_id_setter("created", value)
|
||||
|
||||
def _user_id_getter(self, prop: str) -> Optional[str]:
|
||||
"""returns the user id for the specified property"""
|
||||
full_prop = f"_{prop}_by_id"
|
||||
prop_value = getattr(self, full_prop, None)
|
||||
if not prop_value:
|
||||
return
|
||||
return f"user-{prop_value}"
|
||||
|
||||
def _user_id_setter(self, prop: str, value: str) -> None:
|
||||
"""returns the user id for the specified property"""
|
||||
full_prop = f"_{prop}_by_id"
|
||||
if not value:
|
||||
setattr(self, full_prop, None)
|
||||
return
|
||||
prefix, id_ = value.split("-", 1)
|
||||
assert prefix == "user", f"{prefix} is not a valid id prefix for a user id"
|
||||
setattr(self, full_prop, UUID(id_))
|
||||
8
letta/orm/enums.py
Normal file
8
letta/orm/enums.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ToolSourceType(str, Enum):
|
||||
"""Defines what a tool was derived from"""
|
||||
|
||||
python = "python"
|
||||
json = "json"
|
||||
2
letta/orm/errors.py
Normal file
2
letta/orm/errors.py
Normal file
@@ -0,0 +1,2 @@
|
||||
class NoResultFound(Exception):
|
||||
"""A record or records cannot be found given the provided search params"""
|
||||
40
letta/orm/mixins.py
Normal file
40
letta/orm/mixins.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Optional, Type
|
||||
from uuid import UUID
|
||||
|
||||
from letta.orm.base import Base
|
||||
|
||||
|
||||
class MalformedIdError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _relation_getter(instance: "Base", prop: str) -> Optional[str]:
|
||||
prefix = prop.replace("_", "")
|
||||
formatted_prop = f"_{prop}_id"
|
||||
try:
|
||||
uuid_ = getattr(instance, formatted_prop)
|
||||
return f"{prefix}-{uuid_}"
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
def _relation_setter(instance: Type["Base"], prop: str, value: str) -> None:
|
||||
formatted_prop = f"_{prop}_id"
|
||||
prefix = prop.replace("_", "")
|
||||
if not value:
|
||||
setattr(instance, formatted_prop, None)
|
||||
return
|
||||
try:
|
||||
found_prefix, id_ = value.split("-", 1)
|
||||
except ValueError as e:
|
||||
raise MalformedIdError(f"{value} is not a valid ID.") from e
|
||||
assert (
|
||||
# TODO: should be able to get this from the Mapped typing, not sure how though
|
||||
# prefix = getattr(?, "prefix")
|
||||
found_prefix
|
||||
== prefix
|
||||
), f"{found_prefix} is not a valid id prefix, expecting {prefix}"
|
||||
try:
|
||||
setattr(instance, formatted_prop, UUID(id_))
|
||||
except ValueError as e:
|
||||
raise MalformedIdError("Hash segment of {value} is not a valid UUID") from e
|
||||
35
letta/orm/organization.py
Normal file
35
letta/orm/organization.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.organization import Organization as PydanticOrganization
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class Organization(SqlalchemyBase):
|
||||
"""The highest level of the object tree. All Entities belong to one and only one Organization."""
|
||||
|
||||
__tablename__ = "organizations"
|
||||
__pydantic_model__ = PydanticOrganization
|
||||
|
||||
name: Mapped[str] = mapped_column(doc="The display name of the organization.")
|
||||
|
||||
# TODO: Map these relationships later when we actually make these models
|
||||
# below is just a suggestion
|
||||
# users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")
|
||||
# agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
|
||||
# sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan")
|
||||
# tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
|
||||
# documents: Mapped[List["Document"]] = relationship("Document", back_populates="organization", cascade="all, delete-orphan")
|
||||
|
||||
@classmethod
|
||||
def default(cls, db_session: "Session") -> "Organization":
|
||||
"""Get the default org, or create it if it doesn't exist."""
|
||||
try:
|
||||
return db_session.query(cls).filter(cls.name == "Default Organization").one()
|
||||
except NoResultFound:
|
||||
return cls(name="Default Organization").create(db_session)
|
||||
214
letta/orm/sqlalchemy_base.py
Normal file
214
letta/orm/sqlalchemy_base.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional, Type, Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from humps import depascalize
|
||||
from sqlalchemy import Boolean, String, select
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
||||
from letta.orm.errors import NoResultFound
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# from letta.orm.user import User
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
__abstract__ = True
|
||||
|
||||
__order_by_default__ = "created_at"
|
||||
|
||||
_id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"{uuid4()}")
|
||||
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False, doc="Is this record deleted? Used for universal soft deletes.")
|
||||
|
||||
@classmethod
|
||||
def __prefix__(cls) -> str:
|
||||
return depascalize(cls.__name__)
|
||||
|
||||
@property
|
||||
def id(self) -> Optional[str]:
|
||||
if self._id:
|
||||
return f"{self.__prefix__()}-{self._id}"
|
||||
|
||||
@id.setter
|
||||
def id(self, value: str) -> None:
|
||||
if not value:
|
||||
return
|
||||
prefix, id_ = value.split("-", 1)
|
||||
assert prefix == self.__prefix__(), f"{prefix} is not a valid id prefix for {self.__class__.__name__}"
|
||||
assert SqlalchemyBase.is_valid_uuid4(id_), f"{id_} is not a valid uuid4"
|
||||
self._id = id_
|
||||
|
||||
@classmethod
|
||||
def list(
|
||||
cls, *, db_session: "Session", cursor: Optional[str] = None, limit: Optional[int] = 50, **kwargs
|
||||
) -> List[Type["SqlalchemyBase"]]:
|
||||
"""List records with optional cursor (for pagination) and limit."""
|
||||
with db_session as session:
|
||||
# Start with the base query filtered by kwargs
|
||||
query = select(cls).filter_by(**kwargs)
|
||||
|
||||
# Add a cursor condition if provided
|
||||
if cursor:
|
||||
cursor_uuid = cls.get_uid_from_identifier(cursor) # Assuming the cursor is an _id value
|
||||
query = query.where(cls._id > cursor_uuid)
|
||||
|
||||
# Add a limit to the query if provided
|
||||
query = query.order_by(cls._id).limit(limit)
|
||||
|
||||
# Handle soft deletes if the class has the 'is_deleted' attribute
|
||||
if hasattr(cls, "is_deleted"):
|
||||
query = query.where(cls.is_deleted == False)
|
||||
|
||||
# Execute the query and return the results as a list of model instances
|
||||
return list(session.execute(query).scalars())
|
||||
|
||||
@classmethod
|
||||
def get_uid_from_identifier(cls, identifier: str, indifferent: Optional[bool] = False) -> str:
|
||||
"""converts the id into a uuid object
|
||||
Args:
|
||||
identifier: the string identifier, such as `organization-xxxx-xx...`
|
||||
indifferent: if True, will not enforce the prefix check
|
||||
"""
|
||||
try:
|
||||
uuid_string = identifier.split("-", 1)[1] if indifferent else identifier.replace(f"{cls.__prefix__()}-", "")
|
||||
assert SqlalchemyBase.is_valid_uuid4(uuid_string)
|
||||
return uuid_string
|
||||
except ValueError as e:
|
||||
raise ValueError(f"{identifier} is not a valid identifier for class {cls.__name__}") from e
|
||||
|
||||
@classmethod
|
||||
def is_valid_uuid4(cls, uuid_string: str) -> bool:
|
||||
try:
|
||||
# Try to create a UUID object from the string
|
||||
uuid_obj = UUID(uuid_string)
|
||||
# Check if the UUID is version 4
|
||||
return uuid_obj.version == 4
|
||||
except ValueError:
|
||||
# Raised if the string is not a valid UUID
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def read(
|
||||
cls,
|
||||
db_session: "Session",
|
||||
identifier: Union[str, UUID],
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
**kwargs,
|
||||
) -> Type["SqlalchemyBase"]:
|
||||
"""The primary accessor for an ORM record.
|
||||
Args:
|
||||
db_session: the database session to use when retrieving the record
|
||||
identifier: the identifier of the record to read, can be the id string or the UUID object for backwards compatibility
|
||||
actor: if specified, results will be scoped only to records the user is able to access
|
||||
access: if actor is specified, records will be filtered to the minimum permission level for the actor
|
||||
kwargs: additional arguments to pass to the read, used for more complex objects
|
||||
Returns:
|
||||
The matching object
|
||||
Raises:
|
||||
NoResultFound: if the object is not found
|
||||
"""
|
||||
del kwargs # arity for more complex reads
|
||||
identifier = cls.get_uid_from_identifier(identifier)
|
||||
query = select(cls).where(cls._id == identifier)
|
||||
# if actor:
|
||||
# query = cls.apply_access_predicate(query, actor, access)
|
||||
if hasattr(cls, "is_deleted"):
|
||||
query = query.where(cls.is_deleted == False)
|
||||
if found := db_session.execute(query).scalar():
|
||||
return found
|
||||
raise NoResultFound(f"{cls.__name__} with id {identifier} not found")
|
||||
|
||||
def create(self, db_session: "Session") -> Type["SqlalchemyBase"]:
|
||||
# self._infer_organization(db_session)
|
||||
|
||||
with db_session as session:
|
||||
session.add(self)
|
||||
session.commit()
|
||||
session.refresh(self)
|
||||
return self
|
||||
|
||||
def delete(self, db_session: "Session") -> Type["SqlalchemyBase"]:
|
||||
self.is_deleted = True
|
||||
return self.update(db_session)
|
||||
|
||||
def update(self, db_session: "Session") -> Type["SqlalchemyBase"]:
|
||||
with db_session as session:
|
||||
session.add(self)
|
||||
session.commit()
|
||||
session.refresh(self)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def read_or_create(cls, *, db_session: "Session", **kwargs) -> Type["SqlalchemyBase"]:
|
||||
"""get an instance by search criteria or create it if it doesn't exist"""
|
||||
try:
|
||||
return cls.read(db_session=db_session, identifier=kwargs.get("id", None))
|
||||
except NoResultFound:
|
||||
clean_kwargs = {k: v for k, v in kwargs.items() if k in cls.__table__.columns}
|
||||
return cls(**clean_kwargs).create(db_session=db_session)
|
||||
|
||||
# TODO: Add back later when access predicates are actually important
|
||||
# The idea behind this is that you can add a WHERE clause restricting the actions you can take, e.g. R/W
|
||||
# @classmethod
|
||||
# def apply_access_predicate(
|
||||
# cls,
|
||||
# query: "Select",
|
||||
# actor: "User",
|
||||
# access: List[Literal["read", "write", "admin"]],
|
||||
# ) -> "Select":
|
||||
# """applies a WHERE clause restricting results to the given actor and access level
|
||||
# Args:
|
||||
# query: The initial sqlalchemy select statement
|
||||
# actor: The user acting on the query. **Note**: this is called 'actor' to identify the
|
||||
# person or system acting. Users can act on users, making naming very sticky otherwise.
|
||||
# access:
|
||||
# what mode of access should the query restrict to? This will be used with granular permissions,
|
||||
# but because of how it will impact every query we want to be explicitly calling access ahead of time.
|
||||
# Returns:
|
||||
# the sqlalchemy select statement restricted to the given access.
|
||||
# """
|
||||
# del access # entrypoint for row-level permissions. Defaults to "same org as the actor, all permissions" at the moment
|
||||
# org_uid = getattr(actor, "_organization_id", getattr(actor.organization, "_id", None))
|
||||
# if not org_uid:
|
||||
# raise ValueError("object %s has no organization accessor", actor)
|
||||
# return query.where(cls._organization_id == org_uid, cls.is_deleted == False)
|
||||
|
||||
@property
|
||||
def __pydantic_model__(self) -> Type["BaseModel"]:
|
||||
raise NotImplementedError("Sqlalchemy models must declare a __pydantic_model__ property to be convertable.")
|
||||
|
||||
def to_pydantic(self) -> Type["BaseModel"]:
|
||||
"""converts to the basic pydantic model counterpart"""
|
||||
return self.__pydantic_model__.model_validate(self)
|
||||
|
||||
def to_record(self) -> Type["BaseModel"]:
|
||||
"""Deprecated accessor for to_pydantic"""
|
||||
logger.warning("to_record is deprecated, use to_pydantic instead.")
|
||||
return self.to_pydantic()
|
||||
|
||||
# TODO: Look into this later and maybe add back?
|
||||
# def _infer_organization(self, db_session: "Session") -> None:
|
||||
# """🪄 MAGIC ALERT! 🪄
|
||||
# Because so much of the original API is centered around user scopes,
|
||||
# this allows us to continue with that scope and then infer the org from the creating user.
|
||||
#
|
||||
# IF a created_by_id is set, we will use that to infer the organization and magic set it at create time!
|
||||
# If not do nothing to the object. Mutates in place.
|
||||
# """
|
||||
# if self.created_by_id and hasattr(self, "_organization_id"):
|
||||
# try:
|
||||
# from letta.orm.user import User # to avoid circular import
|
||||
#
|
||||
# created_by = User.read(db_session, self.created_by_id)
|
||||
# except NoResultFound:
|
||||
# logger.warning(f"User {self.created_by_id} not found, unable to infer organization.")
|
||||
# return
|
||||
# self._organization_id = created_by._organization_id
|
||||
@@ -7,13 +7,13 @@ from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
class OrganizationBase(LettaBase):
|
||||
__id_prefix__ = "org"
|
||||
__id_prefix__ = "organization"
|
||||
|
||||
|
||||
class Organization(OrganizationBase):
|
||||
id: str = OrganizationBase.generate_id_field()
|
||||
id: str = Field(..., description="The id of the organization.")
|
||||
name: str = Field(..., description="The name of the organization.")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the organization.")
|
||||
|
||||
|
||||
class OrganizationCreate(OrganizationBase):
|
||||
|
||||
@@ -22,7 +22,7 @@ def get_all_orgs(
|
||||
Get a list of all orgs in the database
|
||||
"""
|
||||
try:
|
||||
next_cursor, orgs = server.ms.list_organizations(cursor=cursor, limit=limit)
|
||||
next_cursor, orgs = server.organization_manager.list_organizations(cursor=cursor, limit=limit)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -38,8 +38,7 @@ def create_org(
|
||||
"""
|
||||
Create a new org in the database
|
||||
"""
|
||||
|
||||
org = server.create_organization(request)
|
||||
org = server.organization_manager.create_organization(request)
|
||||
return org
|
||||
|
||||
|
||||
@@ -50,10 +49,10 @@ def delete_org(
|
||||
):
|
||||
# TODO make a soft deletion, instead of a hard deletion
|
||||
try:
|
||||
org = server.ms.get_organization(org_id=org_id)
|
||||
org = server.organization_manager.get_organization_by_id(org_id=org_id)
|
||||
if org is None:
|
||||
raise HTTPException(status_code=404, detail=f"Organization does not exist")
|
||||
server.ms.delete_organization(org_id=org_id)
|
||||
server.organization_manager.delete_organization(org_id=org_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -44,6 +44,7 @@ from letta.log import get_logger
|
||||
from letta.memory import get_memory_functions
|
||||
from letta.metadata import Base, MetadataStore
|
||||
from letta.o1_agent import O1Agent
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.prompts import gpt_system
|
||||
from letta.providers import (
|
||||
AnthropicProvider,
|
||||
@@ -80,12 +81,12 @@ from letta.schemas.memory import (
|
||||
RecallMemorySummary,
|
||||
)
|
||||
from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
|
||||
from letta.schemas.organization import Organization, OrganizationCreate
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
||||
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User, UserCreate
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.utils import create_random_username, json_dumps, json_loads
|
||||
|
||||
# from letta.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
|
||||
@@ -245,6 +246,9 @@ class SyncServer(Server):
|
||||
self.config = config
|
||||
self.ms = MetadataStore(self.config)
|
||||
|
||||
# Managers that interface with data models
|
||||
self.organization_manager = OrganizationManager()
|
||||
|
||||
# TODO: this should be removed
|
||||
# add global default tools (for admin)
|
||||
self.add_default_tools(module_name="base")
|
||||
@@ -773,20 +777,6 @@ class SyncServer(Server):
|
||||
|
||||
return user
|
||||
|
||||
def create_organization(self, request: OrganizationCreate) -> Organization:
|
||||
"""Create a new org using a config"""
|
||||
if not request.name:
|
||||
# auto-generate a name
|
||||
request.name = create_random_username()
|
||||
org = Organization(name=request.name)
|
||||
self.ms.create_organization(org)
|
||||
logger.info(f"Created new org from config: {org}")
|
||||
|
||||
# add default for the org
|
||||
# TODO: add default data
|
||||
|
||||
return org
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
request: CreateAgent,
|
||||
@@ -2125,18 +2115,13 @@ class SyncServer(Server):
|
||||
|
||||
def get_default_user(self) -> User:
|
||||
|
||||
from letta.constants import (
|
||||
DEFAULT_ORG_ID,
|
||||
DEFAULT_ORG_NAME,
|
||||
DEFAULT_USER_ID,
|
||||
DEFAULT_USER_NAME,
|
||||
)
|
||||
from letta.constants import DEFAULT_ORG_ID, DEFAULT_USER_ID, DEFAULT_USER_NAME
|
||||
|
||||
# check if default org exists
|
||||
default_org = self.ms.get_organization(DEFAULT_ORG_ID)
|
||||
if not default_org:
|
||||
org = Organization(name=DEFAULT_ORG_NAME, id=DEFAULT_ORG_ID)
|
||||
self.ms.create_organization(org)
|
||||
try:
|
||||
self.organization_manager.get_organization_by_id(DEFAULT_ORG_ID)
|
||||
except NoResultFound:
|
||||
self.organization_manager.create_default_organization()
|
||||
|
||||
# check if default user exists
|
||||
try:
|
||||
|
||||
0
letta/services/__init__.py
Normal file
0
letta/services/__init__.py
Normal file
66
letta/services/organization_manager.py
Normal file
66
letta/services/organization_manager.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
|
||||
from letta.constants import DEFAULT_ORG_ID, DEFAULT_ORG_NAME
|
||||
from letta.orm.organization import Organization
|
||||
from letta.schemas.organization import Organization as PydanticOrganization
|
||||
from letta.utils import create_random_username
|
||||
|
||||
|
||||
class OrganizationManager:
|
||||
"""Manager class to handle business logic related to Organizations."""
|
||||
|
||||
def __init__(self):
|
||||
# This is probably horrible but we reuse this technique from metadata.py
|
||||
# TODO: Please refactor this out
|
||||
# I am currently working on a ORM refactor and would like to make a more minimal set of changes
|
||||
# - Matt
|
||||
from letta.server.server import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
def get_organization_by_id(self, org_id: str) -> PydanticOrganization:
|
||||
"""Fetch an organization by ID."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
organization = Organization.read(db_session=session, identifier=org_id)
|
||||
return organization.to_pydantic()
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Organization with id {org_id} not found.")
|
||||
|
||||
def create_organization(self, name: Optional[str] = None) -> PydanticOrganization:
|
||||
"""Create a new organization. If a name is provided, it is used, otherwise, a random one is generated."""
|
||||
with self.session_maker() as session:
|
||||
org = Organization(name=name if name else create_random_username())
|
||||
org.create(session)
|
||||
return org.to_pydantic()
|
||||
|
||||
def create_default_organization(self) -> PydanticOrganization:
|
||||
"""Create the default organization."""
|
||||
with self.session_maker() as session:
|
||||
org = Organization(name=DEFAULT_ORG_NAME)
|
||||
org.id = DEFAULT_ORG_ID
|
||||
org.create(session)
|
||||
return org.to_pydantic()
|
||||
|
||||
def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization:
|
||||
"""Update an organization."""
|
||||
with self.session_maker() as session:
|
||||
organization = Organization.read(db_session=session, identifier=org_id)
|
||||
if name:
|
||||
organization.name = name
|
||||
organization.update(session)
|
||||
return organization.to_pydantic()
|
||||
|
||||
def delete_organization(self, org_id: str):
|
||||
"""Delete an organization by marking it as deleted."""
|
||||
with self.session_maker() as session:
|
||||
organization = Organization.read(db_session=session, identifier=org_id)
|
||||
organization.delete(session)
|
||||
|
||||
def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]:
|
||||
"""List organizations with pagination based on cursor (org_id) and limit."""
|
||||
with self.session_maker() as session:
|
||||
results = Organization.list(db_session=session, cursor=cursor, limit=limit)
|
||||
return [org.to_pydantic() for org in results]
|
||||
13
poetry.lock
generated
13
poetry.lock
generated
@@ -5742,6 +5742,17 @@ files = [
|
||||
[package.extras]
|
||||
windows-terminal = ["colorama (>=0.4.6)"]
|
||||
|
||||
[[package]]
|
||||
name = "pyhumps"
|
||||
version = "3.8.0"
|
||||
description = "🐫 Convert strings (and dictionary keys) between snake case, camel case and pascal case in Python. Inspired by Humps for Node"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "pyhumps-3.8.0-py3-none-any.whl", hash = "sha256:060e1954d9069f428232a1adda165db0b9d8dfdce1d265d36df7fbff540acfd6"},
|
||||
{file = "pyhumps-3.8.0.tar.gz", hash = "sha256:498026258f7ee1a8e447c2e28526c0bea9407f9a59c03260aee4bd6c04d681a3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pylance"
|
||||
version = "0.9.18"
|
||||
@@ -8423,4 +8434,4 @@ tests = ["wikipedia"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.13,>=3.10"
|
||||
content-hash = "357ad0382673050758dd4f98ba71d574cdebea385eefc9481b9c8bab743eafd3"
|
||||
content-hash = "5c05bb8ee0f17e149be1482f6295fb2dcac41d8a23a27b890a81d2e9fa30b4e8"
|
||||
|
||||
@@ -77,6 +77,7 @@ langchain-community = {version = "^0.2.17", optional = true}
|
||||
composio-langchain = "^0.5.28"
|
||||
composio-core = "^0.5.34"
|
||||
alembic = "^1.13.3"
|
||||
pyhumps = "^3.8.0"
|
||||
|
||||
[tool.poetry.extras]
|
||||
#local = ["llama-index-embeddings-huggingface"]
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from letta import Admin
|
||||
|
||||
test_base_url = "http://localhost:8283"
|
||||
|
||||
# admin credentials
|
||||
test_server_token = "test_server_token"
|
||||
|
||||
|
||||
def run_server():
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
print("Starting server...")
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def start_uvicorn_server():
|
||||
"""Starts Uvicorn server in a background thread."""
|
||||
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
print("Starting server...")
|
||||
time.sleep(5)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def admin_client():
|
||||
# Setup: Create a user via the client before the tests
|
||||
|
||||
admin = Admin(test_base_url, test_server_token)
|
||||
admin._reset_server()
|
||||
yield admin
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def organization(admin_client):
|
||||
# create an organization
|
||||
org_name = "test_org"
|
||||
org = admin_client.create_organization(org_name)
|
||||
assert org_name == org.name, f"Expected {org_name}, got {org.name}"
|
||||
|
||||
# test listing
|
||||
orgs = admin_client.get_organizations()
|
||||
assert len(orgs) > 0, f"Expected 1 org, got {orgs}"
|
||||
|
||||
yield org
|
||||
admin_client.delete_organization(org.id)
|
||||
|
||||
|
||||
def test_admin_client(admin_client, organization):
|
||||
|
||||
# create a user
|
||||
user_name = "test_user"
|
||||
user1 = admin_client.create_user(user_name, organization.id)
|
||||
assert user_name == user1.name, f"Expected {user_name}, got {user1.name}"
|
||||
|
||||
# create another user
|
||||
user2 = admin_client.create_user()
|
||||
|
||||
# create keys
|
||||
key1_name = "test_key1"
|
||||
key2_name = "test_key2"
|
||||
api_key1 = admin_client.create_key(user1.id, key1_name)
|
||||
admin_client.create_key(user2.id, key2_name)
|
||||
|
||||
# list users
|
||||
users = admin_client.get_users()
|
||||
assert len(users) == 2
|
||||
assert user1.id in [user.id for user in users]
|
||||
assert user2.id in [user.id for user in users]
|
||||
|
||||
# list keys
|
||||
user1_keys = admin_client.get_keys(user1.id)
|
||||
assert len(user1_keys) == 1, f"Expected 1 keys, got {user1_keys}"
|
||||
assert api_key1.key == user1_keys[0].key
|
||||
|
||||
# delete key
|
||||
deleted_key1 = admin_client.delete_key(api_key1.key)
|
||||
assert deleted_key1.key == api_key1.key
|
||||
assert len(admin_client.get_keys(user1.id)) == 0
|
||||
|
||||
# delete users
|
||||
deleted_user1 = admin_client.delete_user(user1.id)
|
||||
assert deleted_user1.id == user1.id
|
||||
deleted_user2 = admin_client.delete_user(user2.id)
|
||||
assert deleted_user2.id == user2.id
|
||||
|
||||
# list users
|
||||
users = admin_client.get_users()
|
||||
assert len(users) == 0, f"Expected 0 users, got {users}"
|
||||
|
||||
|
||||
# def test_get_users_pagination(admin_client):
|
||||
#
|
||||
# page_size = 5
|
||||
# num_users = 7
|
||||
# expected_users_remainder = num_users - page_size
|
||||
#
|
||||
# # create users
|
||||
# all_user_ids = []
|
||||
# for i in range(num_users):
|
||||
#
|
||||
# user_id = uuid.uuid4()
|
||||
# all_user_ids.append(user_id)
|
||||
# key_name = "test_key" + f"{i}"
|
||||
#
|
||||
# create_user_response = admin_client.create_user(user_id)
|
||||
# admin_client.create_key(create_user_response.user_id, key_name)
|
||||
#
|
||||
# # list users in page 1
|
||||
# get_all_users_response1 = admin_client.get_users(limit=page_size)
|
||||
# cursor1 = get_all_users_response1.cursor
|
||||
# user_list1 = get_all_users_response1.user_list
|
||||
# assert len(user_list1) == page_size
|
||||
#
|
||||
# # list users in page 2 using cursor
|
||||
# get_all_users_response2 = admin_client.get_users(cursor1, limit=page_size)
|
||||
# cursor2 = get_all_users_response2.cursor
|
||||
# user_list2 = get_all_users_response2.user_list
|
||||
#
|
||||
# assert len(user_list2) == expected_users_remainder
|
||||
# assert cursor1 != cursor2
|
||||
#
|
||||
# # delete users
|
||||
# clean_up_users_and_keys(all_user_ids)
|
||||
#
|
||||
# # list users to check pagination with no users
|
||||
# users = admin_client.get_users()
|
||||
# assert len(users.user_list) == 0, f"Expected 0 users, got {users}"
|
||||
|
||||
|
||||
def clean_up_users_and_keys(user_id_list):
|
||||
admin_client = Admin(test_base_url, test_server_token)
|
||||
|
||||
# clean up all keys and users
|
||||
for user_id in user_id_list:
|
||||
keys_list = admin_client.get_keys(user_id)
|
||||
for key in keys_list:
|
||||
admin_client.delete_key(key)
|
||||
admin_client.delete_user(user_id)
|
||||
@@ -523,7 +523,6 @@ def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentStat
|
||||
def test_organization(client: RESTClient):
|
||||
if isinstance(client, LocalClient):
|
||||
pytest.skip("Skipping test_organization because LocalClient does not support organizations")
|
||||
client.base_url
|
||||
|
||||
|
||||
def test_model_configs(client: Union[LocalClient, RESTClient]):
|
||||
|
||||
@@ -3,9 +3,17 @@ import uuid
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
import letta.utils as utils
|
||||
from letta.constants import BASE_TOOLS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.constants import (
|
||||
BASE_TOOLS,
|
||||
DEFAULT_MESSAGE_TOOL,
|
||||
DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
DEFAULT_ORG_ID,
|
||||
DEFAULT_ORG_NAME,
|
||||
)
|
||||
from letta.orm.organization import Organization
|
||||
from letta.schemas.enums import MessageRole
|
||||
|
||||
utils.DEBUG = True
|
||||
@@ -31,6 +39,14 @@ from letta.server.server import SyncServer
|
||||
from .utils import DummyDataConnector
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_organization_table(server: SyncServer):
|
||||
"""Fixture to clear the organization table before each test."""
|
||||
with server.organization_manager.session_maker() as session:
|
||||
session.execute(delete(Organization)) # Clear all records from the organization table
|
||||
session.commit() # Commit the deletion
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
# if os.getenv("OPENAI_API_KEY"):
|
||||
@@ -547,3 +563,47 @@ def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id:
|
||||
+ overview.num_tokens_functions_definitions
|
||||
+ overview.num_tokens_external_memory_summary
|
||||
)
|
||||
|
||||
|
||||
def test_list_organizations(server: SyncServer):
|
||||
# Create a new org and confirm that it is created correctly
|
||||
org_name = "test"
|
||||
org = server.organization_manager.create_organization(name=org_name)
|
||||
|
||||
orgs = server.organization_manager.list_organizations()
|
||||
assert len(orgs) == 1
|
||||
assert orgs[0].name == org_name
|
||||
|
||||
# Delete it after
|
||||
server.organization_manager.delete_organization(org.id)
|
||||
assert len(server.organization_manager.list_organizations()) == 0
|
||||
|
||||
|
||||
def test_create_default_organization(server: SyncServer):
|
||||
server.organization_manager.create_default_organization()
|
||||
retrieved = server.organization_manager.get_organization_by_id(DEFAULT_ORG_ID)
|
||||
assert retrieved.name == DEFAULT_ORG_NAME
|
||||
|
||||
|
||||
def test_update_organization_name(server: SyncServer):
|
||||
org_name_a = "a"
|
||||
org_name_b = "b"
|
||||
org = server.organization_manager.create_organization(name=org_name_a)
|
||||
assert org.name == org_name_a
|
||||
org = server.organization_manager.update_organization_name_using_id(org_id=org.id, name=org_name_b)
|
||||
assert org.name == org_name_b
|
||||
|
||||
|
||||
def test_list_organizations_pagination(server: SyncServer):
|
||||
server.organization_manager.create_organization(name="a")
|
||||
server.organization_manager.create_organization(name="b")
|
||||
|
||||
orgs_x = server.organization_manager.list_organizations(limit=1)
|
||||
assert len(orgs_x) == 1
|
||||
|
||||
orgs_y = server.organization_manager.list_organizations(cursor=orgs_x[0].id, limit=1)
|
||||
assert len(orgs_y) == 1
|
||||
assert orgs_y[0].name != orgs_x[0].name
|
||||
|
||||
orgs = server.organization_manager.list_organizations(cursor=orgs_y[0].id, limit=1)
|
||||
assert len(orgs) == 0
|
||||
|
||||
Reference in New Issue
Block a user