From ff4be4576b45a993037285a1824a61285cd005d2 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 23 Oct 2024 10:28:00 -0700 Subject: [PATCH] feat: Add ORM for user model (#1924) --- .github/workflows/tests.yml | 11 ++ letta/constants.py | 6 +- letta/metadata.py | 86 ------------ letta/orm/errors.py | 4 + letta/orm/mixins.py | 59 +++++--- letta/orm/organization.py | 21 +-- letta/orm/sqlalchemy_base.py | 16 +-- letta/orm/user.py | 25 ++++ letta/schemas/user.py | 19 ++- .../rest_api/routers/v1/organizations.py | 4 +- letta/server/rest_api/routers/v1/users.py | 13 +- letta/server/server.py | 103 ++++++-------- letta/services/organization_manager.py | 24 +++- letta/services/user_manager.py | 99 +++++++++++++ tests/test_managers.py | 132 ++++++++++++++++++ tests/test_server.py | 82 +++-------- 16 files changed, 422 insertions(+), 282 deletions(-) create mode 100644 letta/orm/user.py create mode 100644 letta/services/user_manager.py create mode 100644 tests/test_managers.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 33ab5ff5..cdf3edaa 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -48,6 +48,17 @@ jobs: run: | poetry run pytest -s -vv tests/test_server.py + - name: Run server manager tests + env: + LETTA_PG_PORT: 8888 + LETTA_PG_USER: letta + LETTA_PG_PASSWORD: letta + LETTA_PG_DB: letta + LETTA_PG_HOST: localhost + LETTA_SERVER_PASS: test_server_token + run: | + poetry run pytest -s -vv tests/test_managers.py + - name: Run tools tests env: LETTA_PG_PORT: 8888 diff --git a/letta/constants.py b/letta/constants.py index f317a0e1..a581f7b3 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -4,13 +4,13 @@ from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta") # Defaults -DEFAULT_USER_ID = "user-00000000" +DEFAULT_USER_ID = "user-00000000-0000-4000-8000-000000000000" # This UUID follows the UUID4 rules: # The 13th character (4) indicates it's version 4. # The first character of the third segment (8) ensures the variant is correctly set. DEFAULT_ORG_ID = "organization-00000000-0000-4000-8000-000000000000" -DEFAULT_USER_NAME = "default" -DEFAULT_ORG_NAME = "default" +DEFAULT_USER_NAME = "default_user" +DEFAULT_ORG_NAME = "default_org" # String in the error message for when the context window is too large diff --git a/letta/metadata.py b/letta/metadata.py index 328a829d..e36fbad6 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -15,7 +15,6 @@ from sqlalchemy import ( String, TypeDecorator, asc, - desc, or_, ) from sqlalchemy.sql import func @@ -31,8 +30,6 @@ from letta.schemas.file import FileMetadata from letta.schemas.job import Job from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory - -# from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction from letta.schemas.source import Source from letta.schemas.tool import Tool @@ -154,25 +151,6 @@ class ToolCallColumn(TypeDecorator): return value -class UserModel(Base): - __tablename__ = "users" - __table_args__ = {"extend_existing": True} - - id = Column(String, primary_key=True) - org_id = Column(String) - name = Column(String, nullable=False) - created_at = Column(DateTime(timezone=True)) - - # TODO: what is this? - policies_accepted = Column(Boolean, nullable=False, default=False) - - def __repr__(self) -> str: - return f"" - - def to_record(self) -> User: - return User(id=self.id, name=self.name, created_at=self.created_at, org_id=self.org_id) - - # TODO: eventually store providers? # class Provider(Base): # __tablename__ = "providers" @@ -497,15 +475,6 @@ class MetadataStore: tokens = [r.to_record() for r in results] return tokens - @enforce_types - def get_user_from_api_key(self, api_key: str) -> Optional[User]: - """Get the user associated with a given API key""" - token = self.get_api_key(api_key=api_key) - if token is None: - raise ValueError(f"Provided token does not exist") - else: - return self.get_user(user_id=token.user_id) - @enforce_types def create_agent(self, agent: AgentState): # insert into agent table @@ -527,14 +496,6 @@ class MetadataStore: session.add(SourceModel(**vars(source))) session.commit() - @enforce_types - def create_user(self, user: User): - with self.session_maker() as session: - if session.query(UserModel).filter(UserModel.id == user.id).count() > 0: - raise ValueError(f"User with id {user.id} already exists") - session.add(UserModel(**vars(user))) - session.commit() - @enforce_types def create_block(self, block: Block): with self.session_maker() as session: @@ -573,12 +534,6 @@ class MetadataStore: session.query(AgentModel).filter(AgentModel.id == agent.id).update(fields) session.commit() - @enforce_types - def update_user(self, user: User): - with self.session_maker() as session: - session.query(UserModel).filter(UserModel.id == user.id).update(vars(user)) - session.commit() - @enforce_types def update_source(self, source: Source): with self.session_maker() as session: @@ -657,23 +612,6 @@ class MetadataStore: session.commit() - @enforce_types - def delete_user(self, user_id: str): - with self.session_maker() as session: - # delete from users table - session.query(UserModel).filter(UserModel.id == user_id).delete() - - # delete associated agents - session.query(AgentModel).filter(AgentModel.user_id == user_id).delete() - - # delete associated sources - session.query(SourceModel).filter(SourceModel.user_id == user_id).delete() - - # delete associated mappings - session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete() - - session.commit() - @enforce_types def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[ToolModel]: with self.session_maker() as session: @@ -719,30 +657,6 @@ class MetadataStore: assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result return results[0].to_record() - @enforce_types - def get_user(self, user_id: str) -> Optional[User]: - with self.session_maker() as session: - results = session.query(UserModel).filter(UserModel.id == user_id).all() - if len(results) == 0: - return None - assert len(results) == 1, f"Expected 1 result, got {len(results)}" - return results[0].to_record() - - @enforce_types - def get_all_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50): - with self.session_maker() as session: - query = session.query(UserModel).order_by(desc(UserModel.id)) - if cursor: - query = query.filter(UserModel.id < cursor) - results = query.limit(limit).all() - if not results: - return None, [] - user_records = [r.to_record() for r in results] - next_cursor = user_records[-1].id - assert isinstance(next_cursor, str) - - return next_cursor, user_records - @enforce_types def get_source( self, source_id: Optional[str] = None, user_id: Optional[str] = None, source_name: Optional[str] = None diff --git a/letta/orm/errors.py b/letta/orm/errors.py index d1bcf4ab..12e5b16b 100644 --- a/letta/orm/errors.py +++ b/letta/orm/errors.py @@ -1,2 +1,6 @@ class NoResultFound(Exception): """A record or records cannot be found given the provided search params""" + + +class MalformedIdError(Exception): + """An id not in the right format, most likely violating uuid4 format.""" diff --git a/letta/orm/mixins.py b/letta/orm/mixins.py index c6348957..71845b6e 100644 --- a/letta/orm/mixins.py +++ b/letta/orm/mixins.py @@ -1,24 +1,35 @@ -from typing import Optional, Type +from typing import Optional from uuid import UUID +from sqlalchemy import ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column + from letta.orm.base import Base +from letta.orm.errors import MalformedIdError -class MalformedIdError(Exception): - pass +def is_valid_uuid4(uuid_string: str) -> bool: + """Check if a string is a valid UUID4.""" + try: + uuid_obj = UUID(uuid_string) + return uuid_obj.version == 4 + except ValueError: + return False def _relation_getter(instance: "Base", prop: str) -> Optional[str]: + """Get relation and return id with prefix as a string.""" prefix = prop.replace("_", "") formatted_prop = f"_{prop}_id" try: - uuid_ = getattr(instance, formatted_prop) - return f"{prefix}-{uuid_}" + id_ = getattr(instance, formatted_prop) # Get the string id directly + return f"{prefix}-{id_}" except AttributeError: return None -def _relation_setter(instance: Type["Base"], prop: str, value: str) -> None: +def _relation_setter(instance: "Base", prop: str, value: str) -> None: + """Set relation using the id with prefix, ensuring the id is a valid UUIDv4.""" formatted_prop = f"_{prop}_id" prefix = prop.replace("_", "") if not value: @@ -28,13 +39,29 @@ def _relation_setter(instance: Type["Base"], prop: str, value: str) -> None: found_prefix, id_ = value.split("-", 1) except ValueError as e: raise MalformedIdError(f"{value} is not a valid ID.") from e - assert ( - # TODO: should be able to get this from the Mapped typing, not sure how though - # prefix = getattr(?, "prefix") - found_prefix - == prefix - ), f"{found_prefix} is not a valid id prefix, expecting {prefix}" - try: - setattr(instance, formatted_prop, UUID(id_)) - except ValueError as e: - raise MalformedIdError("Hash segment of {value} is not a valid UUID") from e + + # Ensure prefix matches + assert found_prefix == prefix, f"{found_prefix} is not a valid id prefix, expecting {prefix}" + + # Validate that the id is a valid UUID4 string + if not is_valid_uuid4(id_): + raise MalformedIdError(f"Hash segment of {value} is not a valid UUID4") + + setattr(instance, formatted_prop, id_) # Store id as a string + + +class OrganizationMixin(Base): + """Mixin for models that belong to an organization.""" + + __abstract__ = True + + # Changed _organization_id to store string (still a valid UUID4 string) + _organization_id: Mapped[str] = mapped_column(String, ForeignKey("organization._id")) + + @property + def organization_id(self) -> str: + return _relation_getter(self, "organization") + + @organization_id.setter + def organization_id(self, value: str) -> None: + _relation_setter(self, "organization", value) diff --git a/letta/orm/organization.py b/letta/orm/organization.py index 394cb436..244c49e0 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -1,35 +1,28 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List -from sqlalchemy.exc import NoResultFound -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.organization import Organization as PydanticOrganization if TYPE_CHECKING: - from sqlalchemy.orm import Session + + from letta.orm.user import User class Organization(SqlalchemyBase): """The highest level of the object tree. All Entities belong to one and only one Organization.""" - __tablename__ = "organizations" + __tablename__ = "organization" __pydantic_model__ = PydanticOrganization name: Mapped[str] = mapped_column(doc="The display name of the organization.") + users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan") + # TODO: Map these relationships later when we actually make these models # below is just a suggestion - # users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan") # agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan") # sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan") # tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan") # documents: Mapped[List["Document"]] = relationship("Document", back_populates="organization", cascade="all, delete-orphan") - - @classmethod - def default(cls, db_session: "Session") -> "Organization": - """Get the default org, or create it if it doesn't exist.""" - try: - return db_session.query(cls).filter(cls.name == "Default Organization").one() - except NoResultFound: - return cls(name="Default Organization").create(db_session) diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 557593e1..0e0a3821 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -8,6 +8,7 @@ from sqlalchemy.orm import Mapped, mapped_column from letta.log import get_logger from letta.orm.base import Base, CommonSqlalchemyMetaMixins from letta.orm.errors import NoResultFound +from letta.orm.mixins import is_valid_uuid4 if TYPE_CHECKING: from pydantic import BaseModel @@ -42,7 +43,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): return prefix, id_ = value.split("-", 1) assert prefix == self.__prefix__(), f"{prefix} is not a valid id prefix for {self.__class__.__name__}" - assert SqlalchemyBase.is_valid_uuid4(id_), f"{id_} is not a valid uuid4" + assert is_valid_uuid4(id_), f"{id_} is not a valid uuid4" self._id = id_ @classmethod @@ -78,22 +79,11 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): """ try: uuid_string = identifier.split("-", 1)[1] if indifferent else identifier.replace(f"{cls.__prefix__()}-", "") - assert SqlalchemyBase.is_valid_uuid4(uuid_string) + assert is_valid_uuid4(uuid_string) return uuid_string except ValueError as e: raise ValueError(f"{identifier} is not a valid identifier for class {cls.__name__}") from e - @classmethod - def is_valid_uuid4(cls, uuid_string: str) -> bool: - try: - # Try to create a UUID object from the string - uuid_obj = UUID(uuid_string) - # Check if the UUID is version 4 - return uuid_obj.version == 4 - except ValueError: - # Raised if the string is not a valid UUID - return False - @classmethod def read( cls, diff --git a/letta/orm/user.py b/letta/orm/user.py new file mode 100644 index 00000000..bb555721 --- /dev/null +++ b/letta/orm/user.py @@ -0,0 +1,25 @@ +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.mixins import OrganizationMixin +from letta.orm.organization import Organization +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.user import User as PydanticUser + + +class User(SqlalchemyBase, OrganizationMixin): + """User ORM class""" + + __tablename__ = "user" + __pydantic_model__ = PydanticUser + + name: Mapped[str] = mapped_column(nullable=False, doc="The display name of the user.") + + # relationships + organization: Mapped["Organization"] = relationship("Organization", back_populates="users") + + # TODO: Add this back later potentially + # agents: Mapped[List["Agent"]] = relationship( + # "Agent", secondary="users_agents", back_populates="users", doc="the agents associated with this user." + # ) + # tokens: Mapped[List["Token"]] = relationship("Token", back_populates="user", doc="the tokens associated with this user.") + # jobs: Mapped[List["Job"]] = relationship("Job", back_populates="user", doc="the jobs associated with this user.") diff --git a/letta/schemas/user.py b/letta/schemas/user.py index d8cfd48d..b499e126 100644 --- a/letta/schemas/user.py +++ b/letta/schemas/user.py @@ -3,6 +3,7 @@ from typing import Optional from pydantic import Field +from letta.constants import DEFAULT_ORG_ID from letta.schemas.letta_base import LettaBase @@ -20,14 +21,20 @@ class User(UserBase): created_at (datetime): The creation date of the user. """ - id: str = UserBase.generate_id_field() - org_id: Optional[str] = Field( - ..., description="The organization id of the user" - ) # TODO: dont make optional, and pass in default org ID + id: str = Field(..., description="The id of the user.") + organization_id: Optional[str] = Field(DEFAULT_ORG_ID, description="The organization id of the user") name: str = Field(..., description="The name of the user.") created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.") + updated_at: datetime = Field(default_factory=datetime.utcnow, description="The update date of the user.") + is_deleted: bool = Field(False, description="Whether this user is deleted or not.") class UserCreate(UserBase): - name: Optional[str] = Field(None, description="The name of the user.") - org_id: Optional[str] = Field(None, description="The organization id of the user.") + name: str = Field(..., description="The name of the user.") + organization_id: str = Field(..., description="The organization id of the user.") + + +class UserUpdate(UserBase): + id: str = Field(..., description="The id of the user to update.") + name: Optional[str] = Field(None, description="The new name of the user.") + organization_id: Optional[str] = Field(None, description="The new organization id of the user.") diff --git a/letta/server/rest_api/routers/v1/organizations.py b/letta/server/rest_api/routers/v1/organizations.py index efe9882a..c4ac9f2c 100644 --- a/letta/server/rest_api/routers/v1/organizations.py +++ b/letta/server/rest_api/routers/v1/organizations.py @@ -42,7 +42,7 @@ def create_org( return org -@router.delete("/", tags=["admin"], response_model=Organization, operation_id="delete_organization") +@router.delete("/", tags=["admin"], response_model=Organization, operation_id="delete_organization_by_id") def delete_org( org_id: str = Query(..., description="The org_id key to be deleted."), server: "SyncServer" = Depends(get_letta_server), @@ -52,7 +52,7 @@ def delete_org( org = server.organization_manager.get_organization_by_id(org_id=org_id) if org is None: raise HTTPException(status_code=404, detail=f"Organization does not exist") - server.organization_manager.delete_organization(org_id=org_id) + server.organization_manager.delete_organization_by_id(org_id=org_id) except HTTPException: raise except Exception as e: diff --git a/letta/server/rest_api/routers/v1/users.py b/letta/server/rest_api/routers/v1/users.py index 8e40c7b3..80e9f24f 100644 --- a/letta/server/rest_api/routers/v1/users.py +++ b/letta/server/rest_api/routers/v1/users.py @@ -26,7 +26,7 @@ router = APIRouter(prefix="/users", tags=["users", "admin"]) @router.get("/", tags=["admin"], response_model=List[User], operation_id="list_users") -def get_all_users( +def list_users( cursor: Optional[str] = Query(None), limit: Optional[int] = Query(50), server: "SyncServer" = Depends(get_letta_server), @@ -35,8 +35,7 @@ def get_all_users( Get a list of all users in the database """ try: - next_cursor, users = server.ms.get_all_users(cursor=cursor, limit=limit) - # processed_users = [{"user_id": user.id} for user in users] + next_cursor, users = server.user_manager.list_users(cursor=cursor, limit=limit) except HTTPException: raise except Exception as e: @@ -53,7 +52,7 @@ def create_user( Create a new user in the database """ - user = server.create_user(request) + user = server.user_manager.create_user(request) return user @@ -64,10 +63,10 @@ def delete_user( ): # TODO make a soft deletion, instead of a hard deletion try: - user = server.ms.get_user(user_id=user_id) + user = server.user_manager.get_user_by_id(user_id=user_id) if user is None: raise HTTPException(status_code=404, detail=f"User does not exist") - server.ms.delete_user(user_id=user_id) + server.user_manager.delete_user_by_id(user_id=user_id) except HTTPException: raise except Exception as e: @@ -95,7 +94,7 @@ def get_api_keys( """ Get a list of all API keys for a user """ - if server.ms.get_user(user_id=user_id) is None: + if server.user_manager.get_user_by_id(user_id=user_id) is None: raise HTTPException(status_code=404, detail=f"User does not exist") api_keys = server.ms.get_all_api_keys_for_user(user_id=user_id) return api_keys diff --git a/letta/server/server.py b/letta/server/server.py index 7bddb188..082fb474 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -44,7 +44,6 @@ from letta.log import get_logger from letta.memory import get_memory_functions from letta.metadata import Base, MetadataStore from letta.o1_agent import O1Agent -from letta.orm.errors import NoResultFound from letta.prompts import gpt_system from letta.providers import ( AnthropicProvider, @@ -87,6 +86,7 @@ from letta.schemas.tool import Tool, ToolCreate, ToolUpdate from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User, UserCreate from letta.services.organization_manager import OrganizationManager +from letta.services.user_manager import UserManager from letta.utils import create_random_username, json_dumps, json_loads # from letta.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin @@ -248,6 +248,7 @@ class SyncServer(Server): # Managers that interface with data models self.organization_manager = OrganizationManager() + self.user_manager = UserManager() # TODO: this should be removed # add global default tools (for admin) @@ -576,7 +577,7 @@ class SyncServer(Server): timestamp: Optional[datetime] = None, ) -> LettaUsageStatistics: """Process an incoming user message and feed it through the Letta agent""" - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -625,7 +626,7 @@ class SyncServer(Server): timestamp: Optional[datetime] = None, ) -> LettaUsageStatistics: """Process an incoming system message and feed it through the Letta agent""" - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -694,7 +695,7 @@ class SyncServer(Server): Otherwise, we can pass them in directly. """ - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -743,7 +744,7 @@ class SyncServer(Server): # @LockingServer.agent_lock_decorator def run_command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics: """Run a command on the agent""" - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -754,19 +755,12 @@ class SyncServer(Server): command = command[1:] # strip the prefix return self._command(user_id=user_id, agent_id=agent_id, command=command) - def list_users_paginated(self, cursor: str, limit: int) -> List[User]: - """List all users""" - # TODO: make this paginated - next_cursor, users = self.ms.get_all_users(cursor, limit) - return next_cursor, users - def create_user(self, request: UserCreate) -> User: """Create a new user using a config""" if not request.name: # auto-generate a name request.name = create_random_username() - user = User(name=request.name, org_id=request.org_id) - self.ms.create_user(user) + user = self.user_manager.create_user(request) logger.debug(f"Created new user from config: {user}") # add default for the user @@ -785,7 +779,7 @@ class SyncServer(Server): interface: Union[AgentInterface, None] = None, ) -> AgentState: """Create a new agent using a config""" - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if interface is None: @@ -809,7 +803,7 @@ class SyncServer(Server): raise ValueError(f"Invalid agent type: {request.agent_type}") logger.debug(f"Attempting to find user: {user_id}") - user = self.ms.get_user(user_id=user_id) + user = self.user_manager.get_user_by_id(user_id=user_id) if not user: raise ValueError(f"cannot find user with associated client id: {user_id}") @@ -912,7 +906,7 @@ class SyncServer(Server): user_id: str, ): """Update the agents core memory block, return the new state""" - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=request.id) is None: raise ValueError(f"Agent agent_id={request.id} does not exist") @@ -974,7 +968,7 @@ class SyncServer(Server): def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[Tool]: """Get tools from an existing agent""" - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -990,7 +984,7 @@ class SyncServer(Server): user_id: str, ): """Add tools from an existing agent""" - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -1029,7 +1023,7 @@ class SyncServer(Server): user_id: str, ): """Remove tools from an existing agent""" - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -1075,7 +1069,7 @@ class SyncServer(Server): user_id: str, ) -> List[AgentState]: """List all available agents to a user""" - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") agents_states = self.ms.list_agents(user_id=user_id) @@ -1091,7 +1085,7 @@ class SyncServer(Server): if user_id is None: agents_states = self.ms.list_all_agents() else: - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") agents_states = self.ms.list_agents(user_id=user_id) @@ -1231,13 +1225,13 @@ class SyncServer(Server): """Get the agent state""" return self.ms.get_agent(agent_id=agent_id, user_id=user_id) - def get_user(self, user_id: str) -> User: - """Get the user""" - user = self.ms.get_user(user_id=user_id) - if user is None: - raise ValueError(f"User with user_id {user_id} does not exist") - else: - return user + # def get_user(self, user_id: str) -> User: + # """Get the user""" + # user = self.user_manager.get_user_by_id(user_id=user_id) + # if user is None: + # raise ValueError(f"User with user_id {user_id} does not exist") + # else: + # return user def get_agent_memory(self, agent_id: str) -> Memory: """Return the memory of an agent (core memory)""" @@ -1328,7 +1322,7 @@ class SyncServer(Server): def get_agent_archival(self, user_id: str, agent_id: str, start: int, count: int) -> List[Passage]: """Paginated query of all messages in agent archival memory""" - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -1353,7 +1347,7 @@ class SyncServer(Server): order_by: Optional[str] = "created_at", reverse: Optional[bool] = False, ) -> List[Passage]: - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -1368,7 +1362,7 @@ class SyncServer(Server): return records def insert_archival_memory(self, user_id: str, agent_id: str, memory_contents: str) -> List[Passage]: - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -1383,7 +1377,7 @@ class SyncServer(Server): return [letta_agent.persistence_manager.archival_memory.storage.get(id=passage_id) for passage_id in passage_ids] def delete_archival_memory(self, user_id: str, agent_id: str, memory_id: str): - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -1414,7 +1408,7 @@ class SyncServer(Server): assistant_message_function_name: str = constants.DEFAULT_MESSAGE_TOOL, assistant_message_function_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG, ) -> Union[List[Message], List[LettaMessage]]: - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -1456,7 +1450,7 @@ class SyncServer(Server): def get_agent_state(self, user_id: str, agent_id: Optional[str], agent_name: Optional[str] = None) -> Optional[AgentState]: """Return the config of an agent""" - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if agent_id: if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: @@ -1497,7 +1491,7 @@ class SyncServer(Server): def update_agent_core_memory(self, user_id: str, agent_id: str, new_memory_contents: dict) -> Memory: """Update the agents core memory block, return the new state""" - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -1528,7 +1522,7 @@ class SyncServer(Server): def rename_agent(self, user_id: str, agent_id: str, new_agent_name: str) -> AgentState: """Update the name of the agent in the database""" - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -1550,13 +1544,9 @@ class SyncServer(Server): assert isinstance(letta_agent.agent_state.id, str) return letta_agent.agent_state - def delete_user(self, user_id: str): - # TODO: delete user - pass - def delete_agent(self, user_id: str, agent_id: str): """Delete an agent in the database""" - if self.ms.get_user(user_id=user_id) is None: + if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -1585,7 +1575,8 @@ class SyncServer(Server): def api_key_to_user(self, api_key: str) -> str: """Decode an API key to a user""" - user = self.ms.get_user_from_api_key(api_key=api_key) + token = self.ms.get_api_key(api_key=api_key) + user = self.user_manager.get_user_by_id(token.user_id) if user is None: raise HTTPException(status_code=403, detail="Invalid credentials") else: @@ -2113,29 +2104,15 @@ class SyncServer(Server): letta_agent = self._get_or_load_agent(agent_id=agent_id) return letta_agent.retry_message() + # TODO: Move a lot of this default logic to the ORM def get_default_user(self) -> User: + self.organization_manager.create_default_organization() + user = self.user_manager.create_default_user() - from letta.constants import DEFAULT_ORG_ID, DEFAULT_USER_ID, DEFAULT_USER_NAME + self.add_default_blocks(user.id) + self.add_default_tools(module_name="base", user_id=user.id) - # check if default org exists - try: - self.organization_manager.get_organization_by_id(DEFAULT_ORG_ID) - except NoResultFound: - self.organization_manager.create_default_organization() - - # check if default user exists - try: - self.get_user(DEFAULT_USER_ID) - except ValueError: - user = User(name=DEFAULT_USER_NAME, org_id=DEFAULT_ORG_ID, id=DEFAULT_USER_ID) - self.ms.create_user(user) - - # add default data (TODO: move to org) - self.add_default_blocks(user.id) - self.add_default_tools(module_name="base", user_id=user.id) - - # check if default org exists - return self.get_user(DEFAULT_USER_ID) + return user def get_user_or_default(self, user_id: Optional[str]) -> User: """Get the user object for user_id if it exists, otherwise return the default user object""" @@ -2143,7 +2120,7 @@ class SyncServer(Server): return self.get_default_user() else: try: - return self.get_user(user_id=user_id) + return self.user_manager.get_user_by_id(user_id=user_id) except ValueError: raise HTTPException(status_code=404, detail=f"User with id {user_id} not found") diff --git a/letta/services/organization_manager.py b/letta/services/organization_manager.py index e6976849..7e90602e 100644 --- a/letta/services/organization_manager.py +++ b/letta/services/organization_manager.py @@ -1,11 +1,10 @@ from typing import List, Optional -from sqlalchemy.exc import NoResultFound - from letta.constants import DEFAULT_ORG_ID, DEFAULT_ORG_NAME +from letta.orm.errors import NoResultFound from letta.orm.organization import Organization from letta.schemas.organization import Organization as PydanticOrganization -from letta.utils import create_random_username +from letta.utils import create_random_username, enforce_types class OrganizationManager: @@ -20,6 +19,7 @@ class OrganizationManager: self.session_maker = db_context + @enforce_types def get_organization_by_id(self, org_id: str) -> PydanticOrganization: """Fetch an organization by ID.""" with self.session_maker() as session: @@ -29,6 +29,7 @@ class OrganizationManager: except NoResultFound: raise ValueError(f"Organization with id {org_id} not found.") + @enforce_types def create_organization(self, name: Optional[str] = None) -> PydanticOrganization: """Create a new organization. If a name is provided, it is used, otherwise, a random one is generated.""" with self.session_maker() as session: @@ -36,14 +37,21 @@ class OrganizationManager: org.create(session) return org.to_pydantic() + @enforce_types def create_default_organization(self) -> PydanticOrganization: """Create the default organization.""" with self.session_maker() as session: - org = Organization(name=DEFAULT_ORG_NAME) - org.id = DEFAULT_ORG_ID - org.create(session) + # Try to get it first + try: + org = Organization.read(db_session=session, identifier=DEFAULT_ORG_ID) + # If it doesn't exist, make it + except NoResultFound: + org = Organization(name=DEFAULT_ORG_NAME, id=DEFAULT_ORG_ID) + org.create(session) + return org.to_pydantic() + @enforce_types def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization: """Update an organization.""" with self.session_maker() as session: @@ -53,12 +61,14 @@ class OrganizationManager: organization.update(session) return organization.to_pydantic() - def delete_organization(self, org_id: str): + @enforce_types + def delete_organization_by_id(self, org_id: str): """Delete an organization by marking it as deleted.""" with self.session_maker() as session: organization = Organization.read(db_session=session, identifier=org_id) organization.delete(session) + @enforce_types def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]: """List organizations with pagination based on cursor (org_id) and limit.""" with self.session_maker() as session: diff --git a/letta/services/user_manager.py b/letta/services/user_manager.py new file mode 100644 index 00000000..ddd1fc5b --- /dev/null +++ b/letta/services/user_manager.py @@ -0,0 +1,99 @@ +from typing import List, Optional, Tuple + +from letta.constants import DEFAULT_ORG_ID, DEFAULT_USER_ID, DEFAULT_USER_NAME + +# TODO: Remove this once we translate all of these to the ORM +from letta.metadata import AgentModel, AgentSourceMappingModel, SourceModel +from letta.orm.errors import NoResultFound +from letta.orm.organization import Organization as OrganizationModel +from letta.orm.user import User as UserModel +from letta.schemas.user import User as PydanticUser +from letta.schemas.user import UserCreate, UserUpdate +from letta.utils import enforce_types + + +class UserManager: + """Manager class to handle business logic related to Users.""" + + def __init__(self): + # Fetching the db_context similarly as in OrganizationManager + from letta.server.server import db_context + + self.session_maker = db_context + + @enforce_types + def create_default_user(self, org_id: str = DEFAULT_ORG_ID) -> PydanticUser: + """Create the default user.""" + with self.session_maker() as session: + # Make sure the org id exists + try: + OrganizationModel.read(db_session=session, identifier=org_id) + except NoResultFound: + raise ValueError(f"No organization with {org_id} exists in the organization table.") + + # Try to retrieve the user + try: + user = UserModel.read(db_session=session, identifier=DEFAULT_USER_ID) + except NoResultFound: + # If it doesn't exist, make it + user = UserModel(id=DEFAULT_USER_ID, name=DEFAULT_USER_NAME, organization_id=org_id) + user.create(session) + + return user.to_pydantic() + + @enforce_types + def create_user(self, user_create: UserCreate) -> PydanticUser: + """Create a new user if it doesn't already exist.""" + with self.session_maker() as session: + new_user = UserModel(**user_create.model_dump()) + new_user.create(session) + return new_user.to_pydantic() + + @enforce_types + def update_user(self, user_update: UserUpdate) -> PydanticUser: + """Update user details.""" + with self.session_maker() as session: + # Retrieve the existing user by ID + existing_user = UserModel.read(db_session=session, identifier=user_update.id) + + # Update only the fields that are provided in UserUpdate + update_data = user_update.model_dump(exclude_unset=True, exclude_none=True) + for key, value in update_data.items(): + setattr(existing_user, key, value) + + # Commit the updated user + existing_user.update(session) + return existing_user.to_pydantic() + + @enforce_types + def delete_user_by_id(self, user_id: str): + """Delete a user and their associated records (agents, sources, mappings).""" + with self.session_maker() as session: + # Delete from user table + user = UserModel.read(db_session=session, identifier=user_id) + user.delete(session) + + # TODO: Remove this once we have ORM models for the Agent, Source, and AgentSourceMapping + # Cascade delete for related models: Agent, Source, AgentSourceMapping + session.query(AgentModel).filter(AgentModel.user_id == user_id).delete() + session.query(SourceModel).filter(SourceModel.user_id == user_id).delete() + session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete() + + session.commit() + + @enforce_types + def get_user_by_id(self, user_id: str) -> PydanticUser: + """Fetch a user by ID.""" + with self.session_maker() as session: + try: + user = UserModel.read(db_session=session, identifier=user_id) + return user.to_pydantic() + except NoResultFound: + raise ValueError(f"User with id {user_id} not found.") + + @enforce_types + def list_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> Tuple[Optional[str], List[PydanticUser]]: + """List users with pagination using cursor (id) and limit.""" + with self.session_maker() as session: + results = UserModel.list(db_session=session, cursor=cursor, limit=limit) + return [user.to_pydantic() for user in results] diff --git a/tests/test_managers.py b/tests/test_managers.py new file mode 100644 index 00000000..1f9e5616 --- /dev/null +++ b/tests/test_managers.py @@ -0,0 +1,132 @@ +import pytest +from sqlalchemy import delete + +import letta.utils as utils +from letta.constants import ( + DEFAULT_ORG_ID, + DEFAULT_ORG_NAME, + DEFAULT_USER_ID, + DEFAULT_USER_NAME, +) +from letta.orm.organization import Organization +from letta.orm.user import User + +utils.DEBUG = True +from letta.config import LettaConfig +from letta.schemas.user import UserCreate, UserUpdate +from letta.server.server import SyncServer + + +@pytest.fixture(autouse=True) +def clear_organization_and_user_table(server: SyncServer): + """Fixture to clear the organization table before each test.""" + with server.organization_manager.session_maker() as session: + session.execute(delete(User)) # Clear all records from the user table + session.execute(delete(Organization)) # Clear all records from the organization table + session.commit() # Commit the deletion + + +@pytest.fixture(scope="module") +def server(): + config = LettaConfig.load() + + config.save() + + server = SyncServer() + return server + + +# ====================================================================================================================== +# Organization Manager Tests +# ====================================================================================================================== +def test_list_organizations(server: SyncServer): + # Create a new org and confirm that it is created correctly + org_name = "test" + org = server.organization_manager.create_organization(name=org_name) + + orgs = server.organization_manager.list_organizations() + assert len(orgs) == 1 + assert orgs[0].name == org_name + + # Delete it after + server.organization_manager.delete_organization_by_id(org.id) + assert len(server.organization_manager.list_organizations()) == 0 + + +def test_create_default_organization(server: SyncServer): + server.organization_manager.create_default_organization() + retrieved = server.organization_manager.get_organization_by_id(DEFAULT_ORG_ID) + assert retrieved.name == DEFAULT_ORG_NAME + + +def test_update_organization_name(server: SyncServer): + org_name_a = "a" + org_name_b = "b" + org = server.organization_manager.create_organization(name=org_name_a) + assert org.name == org_name_a + org = server.organization_manager.update_organization_name_using_id(org_id=org.id, name=org_name_b) + assert org.name == org_name_b + + +def test_list_organizations_pagination(server: SyncServer): + server.organization_manager.create_organization(name="a") + server.organization_manager.create_organization(name="b") + + orgs_x = server.organization_manager.list_organizations(limit=1) + assert len(orgs_x) == 1 + + orgs_y = server.organization_manager.list_organizations(cursor=orgs_x[0].id, limit=1) + assert len(orgs_y) == 1 + assert orgs_y[0].name != orgs_x[0].name + + orgs = server.organization_manager.list_organizations(cursor=orgs_y[0].id, limit=1) + assert len(orgs) == 0 + + +# ====================================================================================================================== +# User Manager Tests +# ====================================================================================================================== +def test_list_users(server: SyncServer): + # Create default organization + org = server.organization_manager.create_default_organization() + + user_name = "user" + user = server.user_manager.create_user(UserCreate(name=user_name, organization_id=org.id)) + + users = server.user_manager.list_users() + assert len(users) == 1 + assert users[0].name == user_name + + # Delete it after + server.user_manager.delete_user_by_id(user.id) + assert len(server.user_manager.list_users()) == 0 + + +def test_create_default_user(server: SyncServer): + org = server.organization_manager.create_default_organization() + server.user_manager.create_default_user(org_id=org.id) + retrieved = server.user_manager.get_user_by_id(DEFAULT_USER_ID) + assert retrieved.name == DEFAULT_USER_NAME + + +def test_update_user(server: SyncServer): + # Create default organization + default_org = server.organization_manager.create_default_organization() + test_org = server.organization_manager.create_organization(name="test_org") + + user_name_a = "a" + user_name_b = "b" + + # Assert it's been created + user = server.user_manager.create_user(UserCreate(name=user_name_a, organization_id=default_org.id)) + assert user.name == user_name_a + + # Adjust name + user = server.user_manager.update_user(UserUpdate(id=user.id, name=user_name_b)) + assert user.name == user_name_b + assert user.organization_id == DEFAULT_ORG_ID + + # Adjust org id + user = server.user_manager.update_user(UserUpdate(id=user.id, organization_id=test_org.id)) + assert user.name == user_name_b + assert user.organization_id == test_org.id diff --git a/tests/test_server.py b/tests/test_server.py index 69f55bac..0622d50e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -3,17 +3,9 @@ import uuid import warnings import pytest -from sqlalchemy import delete import letta.utils as utils -from letta.constants import ( - BASE_TOOLS, - DEFAULT_MESSAGE_TOOL, - DEFAULT_MESSAGE_TOOL_KWARG, - DEFAULT_ORG_ID, - DEFAULT_ORG_NAME, -) -from letta.orm.organization import Organization +from letta.constants import BASE_TOOLS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.enums import MessageRole utils.DEBUG = True @@ -39,14 +31,6 @@ from letta.server.server import SyncServer from .utils import DummyDataConnector -@pytest.fixture(autouse=True) -def clear_organization_table(server: SyncServer): - """Fixture to clear the organization table before each test.""" - with server.organization_manager.session_maker() as session: - session.execute(delete(Organization)) # Clear all records from the organization table - session.commit() # Commit the deletion - - @pytest.fixture(scope="module") def server(): # if os.getenv("OPENAI_API_KEY"): @@ -76,15 +60,27 @@ def server(): @pytest.fixture(scope="module") -def user_id(server): +def org_id(server): + # create org + org = server.organization_manager.create_organization(name="test_org") + print(f"Created org\n{org.id}") + + yield org.id + + # cleanup + server.organization_manager.delete_organization_by_id(org.id) + + +@pytest.fixture(scope="module") +def user_id(server, org_id): # create user - user = server.create_user(UserCreate(name="test_user")) + user = server.create_user(UserCreate(name="test_user", organization_id=org_id)) print(f"Created user\n{user.id}") yield user.id # cleanup - server.delete_user(user.id) + server.user_manager.delete_user_by_id(user.id) @pytest.fixture(scope="module") @@ -183,7 +179,7 @@ def test_user_message(server, user_id, agent_id): @pytest.mark.order(5) -def test_get_recall_memory(server, user_id, agent_id): +def test_get_recall_memory(server, org_id, user_id, agent_id): # test recall memory cursor pagination messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2) cursor1 = messages_1[-1].id @@ -563,47 +559,3 @@ def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id: + overview.num_tokens_functions_definitions + overview.num_tokens_external_memory_summary ) - - -def test_list_organizations(server: SyncServer): - # Create a new org and confirm that it is created correctly - org_name = "test" - org = server.organization_manager.create_organization(name=org_name) - - orgs = server.organization_manager.list_organizations() - assert len(orgs) == 1 - assert orgs[0].name == org_name - - # Delete it after - server.organization_manager.delete_organization(org.id) - assert len(server.organization_manager.list_organizations()) == 0 - - -def test_create_default_organization(server: SyncServer): - server.organization_manager.create_default_organization() - retrieved = server.organization_manager.get_organization_by_id(DEFAULT_ORG_ID) - assert retrieved.name == DEFAULT_ORG_NAME - - -def test_update_organization_name(server: SyncServer): - org_name_a = "a" - org_name_b = "b" - org = server.organization_manager.create_organization(name=org_name_a) - assert org.name == org_name_a - org = server.organization_manager.update_organization_name_using_id(org_id=org.id, name=org_name_b) - assert org.name == org_name_b - - -def test_list_organizations_pagination(server: SyncServer): - server.organization_manager.create_organization(name="a") - server.organization_manager.create_organization(name="b") - - orgs_x = server.organization_manager.list_organizations(limit=1) - assert len(orgs_x) == 1 - - orgs_y = server.organization_manager.list_organizations(cursor=orgs_x[0].id, limit=1) - assert len(orgs_y) == 1 - assert orgs_y[0].name != orgs_x[0].name - - orgs = server.organization_manager.list_organizations(cursor=orgs_y[0].id, limit=1) - assert len(orgs) == 0