feat: Add ORM for organization model (#1914)

This commit is contained in:
Matthew Zhou
2024-10-22 14:47:09 -07:00
committed by GitHub
parent a2e1cfd9e5
commit 1be576a28e
23 changed files with 541 additions and 248 deletions

View File

@@ -4,8 +4,8 @@ from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool from sqlalchemy import engine_from_config, pool
from alembic import context from alembic import context
from letta.base import Base
from letta.config import LettaConfig from letta.config import LettaConfig
from letta.orm.base import Base
from letta.settings import settings from letta.settings import settings
letta_config = LettaConfig.load() letta_config = LettaConfig.load()

View File

@@ -25,10 +25,10 @@ from sqlalchemy_json import MutableJson
from tqdm import tqdm from tqdm import tqdm
from letta.agent_store.storage import StorageConnector, TableType from letta.agent_store.storage import StorageConnector, TableType
from letta.base import Base
from letta.config import LettaConfig from letta.config import LettaConfig
from letta.constants import MAX_EMBEDDING_DIM from letta.constants import MAX_EMBEDDING_DIM
from letta.metadata import EmbeddingConfigColumn, FileMetadataModel, ToolCallColumn from letta.metadata import EmbeddingConfigColumn, FileMetadataModel, ToolCallColumn
from letta.orm.base import Base
# from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall # from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
from letta.schemas.message import Message from letta.schemas.message import Message
@@ -509,8 +509,10 @@ class SQLLiteStorageConnector(SQLStorageConnector):
self.session_maker = db_context self.session_maker = db_context
# Need this in order to allow UUIDs to be stored successfully in the sqlite database
# import sqlite3 # import sqlite3
# import uuid
#
# sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le) # sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le)
# sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b)) # sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))

View File

@@ -1,3 +0,0 @@
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()

View File

@@ -5,7 +5,10 @@ LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta")
# Defaults # Defaults
DEFAULT_USER_ID = "user-00000000" DEFAULT_USER_ID = "user-00000000"
DEFAULT_ORG_ID = "org-00000000" # This UUID follows the UUID4 rules:
# The 13th character (4) indicates it's version 4.
# The first character of the third segment (8) ensures the variant is correctly set.
DEFAULT_ORG_ID = "organization-00000000-0000-4000-8000-000000000000"
DEFAULT_USER_NAME = "default" DEFAULT_USER_NAME = "default"
DEFAULT_ORG_NAME = "default" DEFAULT_ORG_NAME = "default"

View File

@@ -20,8 +20,8 @@ from sqlalchemy import (
) )
from sqlalchemy.sql import func from sqlalchemy.sql import func
from letta.base import Base
from letta.config import LettaConfig from letta.config import LettaConfig
from letta.orm.base import Base
from letta.schemas.agent import AgentState from letta.schemas.agent import AgentState
from letta.schemas.api_key import APIKey from letta.schemas.api_key import APIKey
from letta.schemas.block import Block, Human, Persona from letta.schemas.block import Block, Human, Persona
@@ -34,7 +34,6 @@ from letta.schemas.memory import Memory
# from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall # from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from letta.schemas.organization import Organization
from letta.schemas.source import Source from letta.schemas.source import Source
from letta.schemas.tool import Tool from letta.schemas.tool import Tool
from letta.schemas.user import User from letta.schemas.user import User
@@ -174,21 +173,6 @@ class UserModel(Base):
return User(id=self.id, name=self.name, created_at=self.created_at, org_id=self.org_id) return User(id=self.id, name=self.name, created_at=self.created_at, org_id=self.org_id)
class OrganizationModel(Base):
__tablename__ = "organizations"
__table_args__ = {"extend_existing": True}
id = Column(String, primary_key=True)
name = Column(String, nullable=False)
created_at = Column(DateTime(timezone=True))
def __repr__(self) -> str:
return f"<Organization(id='{self.id}' name='{self.name}')>"
def to_record(self) -> Organization:
return Organization(id=self.id, name=self.name, created_at=self.created_at)
# TODO: eventually store providers? # TODO: eventually store providers?
# class Provider(Base): # class Provider(Base):
# __tablename__ = "providers" # __tablename__ = "providers"
@@ -551,14 +535,6 @@ class MetadataStore:
session.add(UserModel(**vars(user))) session.add(UserModel(**vars(user)))
session.commit() session.commit()
@enforce_types
def create_organization(self, organization: Organization):
with self.session_maker() as session:
if session.query(OrganizationModel).filter(OrganizationModel.id == organization.id).count() > 0:
raise ValueError(f"Organization with id {organization.id} already exists")
session.add(OrganizationModel(**vars(organization)))
session.commit()
@enforce_types @enforce_types
def create_block(self, block: Block): def create_block(self, block: Block):
with self.session_maker() as session: with self.session_maker() as session:
@@ -698,16 +674,6 @@ class MetadataStore:
session.commit() session.commit()
@enforce_types
def delete_organization(self, org_id: str):
with self.session_maker() as session:
# delete from organizations table
session.query(OrganizationModel).filter(OrganizationModel.id == org_id).delete()
# TODO: delete associated data
session.commit()
@enforce_types @enforce_types
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[ToolModel]: def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[ToolModel]:
with self.session_maker() as session: with self.session_maker() as session:
@@ -762,30 +728,6 @@ class MetadataStore:
assert len(results) == 1, f"Expected 1 result, got {len(results)}" assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record() return results[0].to_record()
@enforce_types
def get_organization(self, org_id: str) -> Optional[Organization]:
with self.session_maker() as session:
results = session.query(OrganizationModel).filter(OrganizationModel.id == org_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()
@enforce_types
def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
with self.session_maker() as session:
query = session.query(OrganizationModel).order_by(desc(OrganizationModel.id))
if cursor:
query = query.filter(OrganizationModel.id < cursor)
results = query.limit(limit).all()
if not results:
return None, []
organization_records = [r.to_record() for r in results]
next_cursor = organization_records[-1].id
assert isinstance(next_cursor, str)
return next_cursor, organization_records
@enforce_types @enforce_types
def get_all_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50): def get_all_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
with self.session_maker() as session: with self.session_maker() as session:

0
letta/orm/__all__.py Normal file
View File

0
letta/orm/__init__.py Normal file
View File

75
letta/orm/base.py Normal file
View File

@@ -0,0 +1,75 @@
from datetime import datetime
from typing import Optional
from uuid import UUID
from sqlalchemy import UUID as SQLUUID
from sqlalchemy import Boolean, DateTime, func, text
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
declarative_mixin,
declared_attr,
mapped_column,
)
class Base(DeclarativeBase):
"""absolute base for sqlalchemy classes"""
@declarative_mixin
class CommonSqlalchemyMetaMixins(Base):
__abstract__ = True
created_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), server_default=func.now())
updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), server_default=func.now(), server_onupdate=func.now())
is_deleted: Mapped[bool] = mapped_column(Boolean, server_default=text("FALSE"))
@declared_attr
def _created_by_id(cls):
return cls._user_by_id()
@declared_attr
def _last_updated_by_id(cls):
return cls._user_by_id()
@classmethod
def _user_by_id(cls):
"""a flexible non-constrained record of a user.
This way users can get added, deleted etc without history freaking out
"""
return mapped_column(SQLUUID(), nullable=True)
@property
def last_updated_by_id(self) -> Optional[str]:
return self._user_id_getter("last_updated")
@last_updated_by_id.setter
def last_updated_by_id(self, value: str) -> None:
self._user_id_setter("last_updated", value)
@property
def created_by_id(self) -> Optional[str]:
return self._user_id_getter("created")
@created_by_id.setter
def created_by_id(self, value: str) -> None:
self._user_id_setter("created", value)
def _user_id_getter(self, prop: str) -> Optional[str]:
"""returns the user id for the specified property"""
full_prop = f"_{prop}_by_id"
prop_value = getattr(self, full_prop, None)
if not prop_value:
return
return f"user-{prop_value}"
def _user_id_setter(self, prop: str, value: str) -> None:
"""returns the user id for the specified property"""
full_prop = f"_{prop}_by_id"
if not value:
setattr(self, full_prop, None)
return
prefix, id_ = value.split("-", 1)
assert prefix == "user", f"{prefix} is not a valid id prefix for a user id"
setattr(self, full_prop, UUID(id_))

8
letta/orm/enums.py Normal file
View File

@@ -0,0 +1,8 @@
from enum import Enum
class ToolSourceType(str, Enum):
"""Defines what a tool was derived from"""
python = "python"
json = "json"

2
letta/orm/errors.py Normal file
View File

@@ -0,0 +1,2 @@
class NoResultFound(Exception):
"""A record or records cannot be found given the provided search params"""

40
letta/orm/mixins.py Normal file
View File

@@ -0,0 +1,40 @@
from typing import Optional, Type
from uuid import UUID
from letta.orm.base import Base
class MalformedIdError(Exception):
pass
def _relation_getter(instance: "Base", prop: str) -> Optional[str]:
prefix = prop.replace("_", "")
formatted_prop = f"_{prop}_id"
try:
uuid_ = getattr(instance, formatted_prop)
return f"{prefix}-{uuid_}"
except AttributeError:
return None
def _relation_setter(instance: Type["Base"], prop: str, value: str) -> None:
formatted_prop = f"_{prop}_id"
prefix = prop.replace("_", "")
if not value:
setattr(instance, formatted_prop, None)
return
try:
found_prefix, id_ = value.split("-", 1)
except ValueError as e:
raise MalformedIdError(f"{value} is not a valid ID.") from e
assert (
# TODO: should be able to get this from the Mapped typing, not sure how though
# prefix = getattr(?, "prefix")
found_prefix
== prefix
), f"{found_prefix} is not a valid id prefix, expecting {prefix}"
try:
setattr(instance, formatted_prop, UUID(id_))
except ValueError as e:
raise MalformedIdError("Hash segment of {value} is not a valid UUID") from e

35
letta/orm/organization.py Normal file
View File

@@ -0,0 +1,35 @@
from typing import TYPE_CHECKING
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.organization import Organization as PydanticOrganization
if TYPE_CHECKING:
from sqlalchemy.orm import Session
class Organization(SqlalchemyBase):
"""The highest level of the object tree. All Entities belong to one and only one Organization."""
__tablename__ = "organizations"
__pydantic_model__ = PydanticOrganization
name: Mapped[str] = mapped_column(doc="The display name of the organization.")
# TODO: Map these relationships later when we actually make these models
# below is just a suggestion
# users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")
# agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
# sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan")
# tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
# documents: Mapped[List["Document"]] = relationship("Document", back_populates="organization", cascade="all, delete-orphan")
@classmethod
def default(cls, db_session: "Session") -> "Organization":
"""Get the default org, or create it if it doesn't exist."""
try:
return db_session.query(cls).filter(cls.name == "Default Organization").one()
except NoResultFound:
return cls(name="Default Organization").create(db_session)

View File

@@ -0,0 +1,214 @@
from typing import TYPE_CHECKING, List, Literal, Optional, Type, Union
from uuid import UUID, uuid4
from humps import depascalize
from sqlalchemy import Boolean, String, select
from sqlalchemy.orm import Mapped, mapped_column
from letta.log import get_logger
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
from letta.orm.errors import NoResultFound
if TYPE_CHECKING:
from pydantic import BaseModel
from sqlalchemy.orm import Session
# from letta.orm.user import User
logger = get_logger(__name__)
class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
__abstract__ = True
__order_by_default__ = "created_at"
_id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"{uuid4()}")
deleted: Mapped[bool] = mapped_column(Boolean, default=False, doc="Is this record deleted? Used for universal soft deletes.")
@classmethod
def __prefix__(cls) -> str:
return depascalize(cls.__name__)
@property
def id(self) -> Optional[str]:
if self._id:
return f"{self.__prefix__()}-{self._id}"
@id.setter
def id(self, value: str) -> None:
if not value:
return
prefix, id_ = value.split("-", 1)
assert prefix == self.__prefix__(), f"{prefix} is not a valid id prefix for {self.__class__.__name__}"
assert SqlalchemyBase.is_valid_uuid4(id_), f"{id_} is not a valid uuid4"
self._id = id_
@classmethod
def list(
cls, *, db_session: "Session", cursor: Optional[str] = None, limit: Optional[int] = 50, **kwargs
) -> List[Type["SqlalchemyBase"]]:
"""List records with optional cursor (for pagination) and limit."""
with db_session as session:
# Start with the base query filtered by kwargs
query = select(cls).filter_by(**kwargs)
# Add a cursor condition if provided
if cursor:
cursor_uuid = cls.get_uid_from_identifier(cursor) # Assuming the cursor is an _id value
query = query.where(cls._id > cursor_uuid)
# Add a limit to the query if provided
query = query.order_by(cls._id).limit(limit)
# Handle soft deletes if the class has the 'is_deleted' attribute
if hasattr(cls, "is_deleted"):
query = query.where(cls.is_deleted == False)
# Execute the query and return the results as a list of model instances
return list(session.execute(query).scalars())
@classmethod
def get_uid_from_identifier(cls, identifier: str, indifferent: Optional[bool] = False) -> str:
"""converts the id into a uuid object
Args:
identifier: the string identifier, such as `organization-xxxx-xx...`
indifferent: if True, will not enforce the prefix check
"""
try:
uuid_string = identifier.split("-", 1)[1] if indifferent else identifier.replace(f"{cls.__prefix__()}-", "")
assert SqlalchemyBase.is_valid_uuid4(uuid_string)
return uuid_string
except ValueError as e:
raise ValueError(f"{identifier} is not a valid identifier for class {cls.__name__}") from e
@classmethod
def is_valid_uuid4(cls, uuid_string: str) -> bool:
try:
# Try to create a UUID object from the string
uuid_obj = UUID(uuid_string)
# Check if the UUID is version 4
return uuid_obj.version == 4
except ValueError:
# Raised if the string is not a valid UUID
return False
@classmethod
def read(
cls,
db_session: "Session",
identifier: Union[str, UUID],
actor: Optional["User"] = None,
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
**kwargs,
) -> Type["SqlalchemyBase"]:
"""The primary accessor for an ORM record.
Args:
db_session: the database session to use when retrieving the record
identifier: the identifier of the record to read, can be the id string or the UUID object for backwards compatibility
actor: if specified, results will be scoped only to records the user is able to access
access: if actor is specified, records will be filtered to the minimum permission level for the actor
kwargs: additional arguments to pass to the read, used for more complex objects
Returns:
The matching object
Raises:
NoResultFound: if the object is not found
"""
del kwargs # arity for more complex reads
identifier = cls.get_uid_from_identifier(identifier)
query = select(cls).where(cls._id == identifier)
# if actor:
# query = cls.apply_access_predicate(query, actor, access)
if hasattr(cls, "is_deleted"):
query = query.where(cls.is_deleted == False)
if found := db_session.execute(query).scalar():
return found
raise NoResultFound(f"{cls.__name__} with id {identifier} not found")
def create(self, db_session: "Session") -> Type["SqlalchemyBase"]:
# self._infer_organization(db_session)
with db_session as session:
session.add(self)
session.commit()
session.refresh(self)
return self
def delete(self, db_session: "Session") -> Type["SqlalchemyBase"]:
self.is_deleted = True
return self.update(db_session)
def update(self, db_session: "Session") -> Type["SqlalchemyBase"]:
with db_session as session:
session.add(self)
session.commit()
session.refresh(self)
return self
@classmethod
def read_or_create(cls, *, db_session: "Session", **kwargs) -> Type["SqlalchemyBase"]:
"""get an instance by search criteria or create it if it doesn't exist"""
try:
return cls.read(db_session=db_session, identifier=kwargs.get("id", None))
except NoResultFound:
clean_kwargs = {k: v for k, v in kwargs.items() if k in cls.__table__.columns}
return cls(**clean_kwargs).create(db_session=db_session)
# TODO: Add back later when access predicates are actually important
# The idea behind this is that you can add a WHERE clause restricting the actions you can take, e.g. R/W
# @classmethod
# def apply_access_predicate(
# cls,
# query: "Select",
# actor: "User",
# access: List[Literal["read", "write", "admin"]],
# ) -> "Select":
# """applies a WHERE clause restricting results to the given actor and access level
# Args:
# query: The initial sqlalchemy select statement
# actor: The user acting on the query. **Note**: this is called 'actor' to identify the
# person or system acting. Users can act on users, making naming very sticky otherwise.
# access:
# what mode of access should the query restrict to? This will be used with granular permissions,
# but because of how it will impact every query we want to be explicitly calling access ahead of time.
# Returns:
# the sqlalchemy select statement restricted to the given access.
# """
# del access # entrypoint for row-level permissions. Defaults to "same org as the actor, all permissions" at the moment
# org_uid = getattr(actor, "_organization_id", getattr(actor.organization, "_id", None))
# if not org_uid:
# raise ValueError("object %s has no organization accessor", actor)
# return query.where(cls._organization_id == org_uid, cls.is_deleted == False)
@property
def __pydantic_model__(self) -> Type["BaseModel"]:
raise NotImplementedError("Sqlalchemy models must declare a __pydantic_model__ property to be convertable.")
def to_pydantic(self) -> Type["BaseModel"]:
"""converts to the basic pydantic model counterpart"""
return self.__pydantic_model__.model_validate(self)
def to_record(self) -> Type["BaseModel"]:
"""Deprecated accessor for to_pydantic"""
logger.warning("to_record is deprecated, use to_pydantic instead.")
return self.to_pydantic()
# TODO: Look into this later and maybe add back?
# def _infer_organization(self, db_session: "Session") -> None:
# """🪄 MAGIC ALERT! 🪄
# Because so much of the original API is centered around user scopes,
# this allows us to continue with that scope and then infer the org from the creating user.
#
# IF a created_by_id is set, we will use that to infer the organization and magic set it at create time!
# If not do nothing to the object. Mutates in place.
# """
# if self.created_by_id and hasattr(self, "_organization_id"):
# try:
# from letta.orm.user import User # to avoid circular import
#
# created_by = User.read(db_session, self.created_by_id)
# except NoResultFound:
# logger.warning(f"User {self.created_by_id} not found, unable to infer organization.")
# return
# self._organization_id = created_by._organization_id

View File

@@ -7,13 +7,13 @@ from letta.schemas.letta_base import LettaBase
class OrganizationBase(LettaBase): class OrganizationBase(LettaBase):
__id_prefix__ = "org" __id_prefix__ = "organization"
class Organization(OrganizationBase): class Organization(OrganizationBase):
id: str = OrganizationBase.generate_id_field() id: str = Field(..., description="The id of the organization.")
name: str = Field(..., description="The name of the organization.") name: str = Field(..., description="The name of the organization.")
created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.") created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the organization.")
class OrganizationCreate(OrganizationBase): class OrganizationCreate(OrganizationBase):

View File

@@ -22,7 +22,7 @@ def get_all_orgs(
Get a list of all orgs in the database Get a list of all orgs in the database
""" """
try: try:
next_cursor, orgs = server.ms.list_organizations(cursor=cursor, limit=limit) next_cursor, orgs = server.organization_manager.list_organizations(cursor=cursor, limit=limit)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@@ -38,8 +38,7 @@ def create_org(
""" """
Create a new org in the database Create a new org in the database
""" """
org = server.organization_manager.create_organization(request)
org = server.create_organization(request)
return org return org
@@ -50,10 +49,10 @@ def delete_org(
): ):
# TODO make a soft deletion, instead of a hard deletion # TODO make a soft deletion, instead of a hard deletion
try: try:
org = server.ms.get_organization(org_id=org_id) org = server.organization_manager.get_organization_by_id(org_id=org_id)
if org is None: if org is None:
raise HTTPException(status_code=404, detail=f"Organization does not exist") raise HTTPException(status_code=404, detail=f"Organization does not exist")
server.ms.delete_organization(org_id=org_id) server.organization_manager.delete_organization(org_id=org_id)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:

View File

@@ -44,6 +44,7 @@ from letta.log import get_logger
from letta.memory import get_memory_functions from letta.memory import get_memory_functions
from letta.metadata import Base, MetadataStore from letta.metadata import Base, MetadataStore
from letta.o1_agent import O1Agent from letta.o1_agent import O1Agent
from letta.orm.errors import NoResultFound
from letta.prompts import gpt_system from letta.prompts import gpt_system
from letta.providers import ( from letta.providers import (
AnthropicProvider, AnthropicProvider,
@@ -80,12 +81,12 @@ from letta.schemas.memory import (
RecallMemorySummary, RecallMemorySummary,
) )
from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
from letta.schemas.organization import Organization, OrganizationCreate
from letta.schemas.passage import Passage from letta.schemas.passage import Passage
from letta.schemas.source import Source, SourceCreate, SourceUpdate from letta.schemas.source import Source, SourceCreate, SourceUpdate
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
from letta.schemas.usage import LettaUsageStatistics from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User, UserCreate from letta.schemas.user import User, UserCreate
from letta.services.organization_manager import OrganizationManager
from letta.utils import create_random_username, json_dumps, json_loads from letta.utils import create_random_username, json_dumps, json_loads
# from letta.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin # from letta.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
@@ -245,6 +246,9 @@ class SyncServer(Server):
self.config = config self.config = config
self.ms = MetadataStore(self.config) self.ms = MetadataStore(self.config)
# Managers that interface with data models
self.organization_manager = OrganizationManager()
# TODO: this should be removed # TODO: this should be removed
# add global default tools (for admin) # add global default tools (for admin)
self.add_default_tools(module_name="base") self.add_default_tools(module_name="base")
@@ -773,20 +777,6 @@ class SyncServer(Server):
return user return user
def create_organization(self, request: OrganizationCreate) -> Organization:
"""Create a new org using a config"""
if not request.name:
# auto-generate a name
request.name = create_random_username()
org = Organization(name=request.name)
self.ms.create_organization(org)
logger.info(f"Created new org from config: {org}")
# add default for the org
# TODO: add default data
return org
def create_agent( def create_agent(
self, self,
request: CreateAgent, request: CreateAgent,
@@ -2125,18 +2115,13 @@ class SyncServer(Server):
def get_default_user(self) -> User: def get_default_user(self) -> User:
from letta.constants import ( from letta.constants import DEFAULT_ORG_ID, DEFAULT_USER_ID, DEFAULT_USER_NAME
DEFAULT_ORG_ID,
DEFAULT_ORG_NAME,
DEFAULT_USER_ID,
DEFAULT_USER_NAME,
)
# check if default org exists # check if default org exists
default_org = self.ms.get_organization(DEFAULT_ORG_ID) try:
if not default_org: self.organization_manager.get_organization_by_id(DEFAULT_ORG_ID)
org = Organization(name=DEFAULT_ORG_NAME, id=DEFAULT_ORG_ID) except NoResultFound:
self.ms.create_organization(org) self.organization_manager.create_default_organization()
# check if default user exists # check if default user exists
try: try:

View File

View File

@@ -0,0 +1,66 @@
from typing import List, Optional
from sqlalchemy.exc import NoResultFound
from letta.constants import DEFAULT_ORG_ID, DEFAULT_ORG_NAME
from letta.orm.organization import Organization
from letta.schemas.organization import Organization as PydanticOrganization
from letta.utils import create_random_username
class OrganizationManager:
"""Manager class to handle business logic related to Organizations."""
def __init__(self):
# This is probably horrible but we reuse this technique from metadata.py
# TODO: Please refactor this out
# I am currently working on a ORM refactor and would like to make a more minimal set of changes
# - Matt
from letta.server.server import db_context
self.session_maker = db_context
def get_organization_by_id(self, org_id: str) -> PydanticOrganization:
"""Fetch an organization by ID."""
with self.session_maker() as session:
try:
organization = Organization.read(db_session=session, identifier=org_id)
return organization.to_pydantic()
except NoResultFound:
raise ValueError(f"Organization with id {org_id} not found.")
def create_organization(self, name: Optional[str] = None) -> PydanticOrganization:
"""Create a new organization. If a name is provided, it is used, otherwise, a random one is generated."""
with self.session_maker() as session:
org = Organization(name=name if name else create_random_username())
org.create(session)
return org.to_pydantic()
def create_default_organization(self) -> PydanticOrganization:
"""Create the default organization."""
with self.session_maker() as session:
org = Organization(name=DEFAULT_ORG_NAME)
org.id = DEFAULT_ORG_ID
org.create(session)
return org.to_pydantic()
def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization:
"""Update an organization."""
with self.session_maker() as session:
organization = Organization.read(db_session=session, identifier=org_id)
if name:
organization.name = name
organization.update(session)
return organization.to_pydantic()
def delete_organization(self, org_id: str):
"""Delete an organization by marking it as deleted."""
with self.session_maker() as session:
organization = Organization.read(db_session=session, identifier=org_id)
organization.delete(session)
def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]:
"""List organizations with pagination based on cursor (org_id) and limit."""
with self.session_maker() as session:
results = Organization.list(db_session=session, cursor=cursor, limit=limit)
return [org.to_pydantic() for org in results]

13
poetry.lock generated
View File

@@ -5742,6 +5742,17 @@ files = [
[package.extras] [package.extras]
windows-terminal = ["colorama (>=0.4.6)"] windows-terminal = ["colorama (>=0.4.6)"]
[[package]]
name = "pyhumps"
version = "3.8.0"
description = "🐫 Convert strings (and dictionary keys) between snake case, camel case and pascal case in Python. Inspired by Humps for Node"
optional = false
python-versions = "*"
files = [
{file = "pyhumps-3.8.0-py3-none-any.whl", hash = "sha256:060e1954d9069f428232a1adda165db0b9d8dfdce1d265d36df7fbff540acfd6"},
{file = "pyhumps-3.8.0.tar.gz", hash = "sha256:498026258f7ee1a8e447c2e28526c0bea9407f9a59c03260aee4bd6c04d681a3"},
]
[[package]] [[package]]
name = "pylance" name = "pylance"
version = "0.9.18" version = "0.9.18"
@@ -8423,4 +8434,4 @@ tests = ["wikipedia"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "<3.13,>=3.10" python-versions = "<3.13,>=3.10"
content-hash = "357ad0382673050758dd4f98ba71d574cdebea385eefc9481b9c8bab743eafd3" content-hash = "5c05bb8ee0f17e149be1482f6295fb2dcac41d8a23a27b890a81d2e9fa30b4e8"

View File

@@ -77,6 +77,7 @@ langchain-community = {version = "^0.2.17", optional = true}
composio-langchain = "^0.5.28" composio-langchain = "^0.5.28"
composio-core = "^0.5.34" composio-core = "^0.5.34"
alembic = "^1.13.3" alembic = "^1.13.3"
pyhumps = "^3.8.0"
[tool.poetry.extras] [tool.poetry.extras]
#local = ["llama-index-embeddings-huggingface"] #local = ["llama-index-embeddings-huggingface"]

View File

@@ -1,146 +0,0 @@
import threading
import time
import pytest
from letta import Admin
test_base_url = "http://localhost:8283"
# admin credentials
test_server_token = "test_server_token"
def run_server():
from letta.server.rest_api.app import start_server
print("Starting server...")
start_server(debug=True)
@pytest.fixture(scope="session", autouse=True)
def start_uvicorn_server():
"""Starts Uvicorn server in a background thread."""
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
print("Starting server...")
time.sleep(5)
yield
@pytest.fixture(scope="module")
def admin_client():
# Setup: Create a user via the client before the tests
admin = Admin(test_base_url, test_server_token)
admin._reset_server()
yield admin
@pytest.fixture(scope="module")
def organization(admin_client):
# create an organization
org_name = "test_org"
org = admin_client.create_organization(org_name)
assert org_name == org.name, f"Expected {org_name}, got {org.name}"
# test listing
orgs = admin_client.get_organizations()
assert len(orgs) > 0, f"Expected 1 org, got {orgs}"
yield org
admin_client.delete_organization(org.id)
def test_admin_client(admin_client, organization):
# create a user
user_name = "test_user"
user1 = admin_client.create_user(user_name, organization.id)
assert user_name == user1.name, f"Expected {user_name}, got {user1.name}"
# create another user
user2 = admin_client.create_user()
# create keys
key1_name = "test_key1"
key2_name = "test_key2"
api_key1 = admin_client.create_key(user1.id, key1_name)
admin_client.create_key(user2.id, key2_name)
# list users
users = admin_client.get_users()
assert len(users) == 2
assert user1.id in [user.id for user in users]
assert user2.id in [user.id for user in users]
# list keys
user1_keys = admin_client.get_keys(user1.id)
assert len(user1_keys) == 1, f"Expected 1 keys, got {user1_keys}"
assert api_key1.key == user1_keys[0].key
# delete key
deleted_key1 = admin_client.delete_key(api_key1.key)
assert deleted_key1.key == api_key1.key
assert len(admin_client.get_keys(user1.id)) == 0
# delete users
deleted_user1 = admin_client.delete_user(user1.id)
assert deleted_user1.id == user1.id
deleted_user2 = admin_client.delete_user(user2.id)
assert deleted_user2.id == user2.id
# list users
users = admin_client.get_users()
assert len(users) == 0, f"Expected 0 users, got {users}"
# def test_get_users_pagination(admin_client):
#
# page_size = 5
# num_users = 7
# expected_users_remainder = num_users - page_size
#
# # create users
# all_user_ids = []
# for i in range(num_users):
#
# user_id = uuid.uuid4()
# all_user_ids.append(user_id)
# key_name = "test_key" + f"{i}"
#
# create_user_response = admin_client.create_user(user_id)
# admin_client.create_key(create_user_response.user_id, key_name)
#
# # list users in page 1
# get_all_users_response1 = admin_client.get_users(limit=page_size)
# cursor1 = get_all_users_response1.cursor
# user_list1 = get_all_users_response1.user_list
# assert len(user_list1) == page_size
#
# # list users in page 2 using cursor
# get_all_users_response2 = admin_client.get_users(cursor1, limit=page_size)
# cursor2 = get_all_users_response2.cursor
# user_list2 = get_all_users_response2.user_list
#
# assert len(user_list2) == expected_users_remainder
# assert cursor1 != cursor2
#
# # delete users
# clean_up_users_and_keys(all_user_ids)
#
# # list users to check pagination with no users
# users = admin_client.get_users()
# assert len(users.user_list) == 0, f"Expected 0 users, got {users}"
def clean_up_users_and_keys(user_id_list):
admin_client = Admin(test_base_url, test_server_token)
# clean up all keys and users
for user_id in user_id_list:
keys_list = admin_client.get_keys(user_id)
for key in keys_list:
admin_client.delete_key(key)
admin_client.delete_user(user_id)

View File

@@ -523,7 +523,6 @@ def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentStat
def test_organization(client: RESTClient): def test_organization(client: RESTClient):
if isinstance(client, LocalClient): if isinstance(client, LocalClient):
pytest.skip("Skipping test_organization because LocalClient does not support organizations") pytest.skip("Skipping test_organization because LocalClient does not support organizations")
client.base_url
def test_model_configs(client: Union[LocalClient, RESTClient]): def test_model_configs(client: Union[LocalClient, RESTClient]):

View File

@@ -3,9 +3,17 @@ import uuid
import warnings import warnings
import pytest import pytest
from sqlalchemy import delete
import letta.utils as utils import letta.utils as utils
from letta.constants import BASE_TOOLS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.constants import (
BASE_TOOLS,
DEFAULT_MESSAGE_TOOL,
DEFAULT_MESSAGE_TOOL_KWARG,
DEFAULT_ORG_ID,
DEFAULT_ORG_NAME,
)
from letta.orm.organization import Organization
from letta.schemas.enums import MessageRole from letta.schemas.enums import MessageRole
utils.DEBUG = True utils.DEBUG = True
@@ -31,6 +39,14 @@ from letta.server.server import SyncServer
from .utils import DummyDataConnector from .utils import DummyDataConnector
@pytest.fixture(autouse=True)
def clear_organization_table(server: SyncServer):
"""Fixture to clear the organization table before each test."""
with server.organization_manager.session_maker() as session:
session.execute(delete(Organization)) # Clear all records from the organization table
session.commit() # Commit the deletion
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
# if os.getenv("OPENAI_API_KEY"): # if os.getenv("OPENAI_API_KEY"):
@@ -547,3 +563,47 @@ def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id:
+ overview.num_tokens_functions_definitions + overview.num_tokens_functions_definitions
+ overview.num_tokens_external_memory_summary + overview.num_tokens_external_memory_summary
) )
def test_list_organizations(server: SyncServer):
# Create a new org and confirm that it is created correctly
org_name = "test"
org = server.organization_manager.create_organization(name=org_name)
orgs = server.organization_manager.list_organizations()
assert len(orgs) == 1
assert orgs[0].name == org_name
# Delete it after
server.organization_manager.delete_organization(org.id)
assert len(server.organization_manager.list_organizations()) == 0
def test_create_default_organization(server: SyncServer):
server.organization_manager.create_default_organization()
retrieved = server.organization_manager.get_organization_by_id(DEFAULT_ORG_ID)
assert retrieved.name == DEFAULT_ORG_NAME
def test_update_organization_name(server: SyncServer):
org_name_a = "a"
org_name_b = "b"
org = server.organization_manager.create_organization(name=org_name_a)
assert org.name == org_name_a
org = server.organization_manager.update_organization_name_using_id(org_id=org.id, name=org_name_b)
assert org.name == org_name_b
def test_list_organizations_pagination(server: SyncServer):
server.organization_manager.create_organization(name="a")
server.organization_manager.create_organization(name="b")
orgs_x = server.organization_manager.list_organizations(limit=1)
assert len(orgs_x) == 1
orgs_y = server.organization_manager.list_organizations(cursor=orgs_x[0].id, limit=1)
assert len(orgs_y) == 1
assert orgs_y[0].name != orgs_x[0].name
orgs = server.organization_manager.list_organizations(cursor=orgs_y[0].id, limit=1)
assert len(orgs) == 0