chore: Move ID generation logic out of the ORM layer and into the Pydantic model layer (#1981)

This commit is contained in:
Matthew Zhou
2024-11-05 17:05:10 -08:00
committed by GitHub
parent b9f772f196
commit b3f86fe4cd
29 changed files with 270 additions and 380 deletions

View File

@@ -131,4 +131,4 @@ jobs:
LETTA_SERVER_PASS: test_server_token
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
run: |
poetry run pytest -s -vv -k "not test_single_path_agent_tool_call_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests
poetry run pytest -s -vv -k "not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests

View File

@@ -1,58 +0,0 @@
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# Revision identifiers, used by Alembic.
revision: str = "0c315956709d"
down_revision: Union[str, None] = "9a505cc7eca9"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Step 1: Rename tables first
op.rename_table("organizations", "organization")
op.rename_table("tools", "tool")
op.rename_table("users", "user")
# Step 2: Rename `id` to `_id` in each table, keeping it as the primary key
op.alter_column("organization", "id", new_column_name="_id", existing_type=sa.String, nullable=False)
op.alter_column("tool", "id", new_column_name="_id", existing_type=sa.String, nullable=False)
op.alter_column("user", "id", new_column_name="_id", existing_type=sa.String, nullable=False)
# Step 3: Add `_organization_id` to `tool` table
# This is required for the unique constraint below
op.add_column("tool", sa.Column("_organization_id", sa.String, nullable=True))
# Step 4: Modify nullable constraints on `tool` columns
op.alter_column("tool", "tags", existing_type=sa.JSON, nullable=True)
op.alter_column("tool", "source_type", existing_type=sa.String, nullable=True)
op.alter_column("tool", "json_schema", existing_type=sa.JSON, nullable=True)
# Step 5: Add unique constraint on `name` and `_organization_id` in `tool` table
op.create_unique_constraint("uq_tool_name_organization", "tool", ["name", "_organization_id"])
def downgrade() -> None:
# Reverse unique constraint first
op.drop_constraint("uq_tool_name_organization", "tool", type_="unique")
# Reverse nullable constraints on `tool` columns
op.alter_column("tool", "tags", existing_type=sa.JSON, nullable=False)
op.alter_column("tool", "source_type", existing_type=sa.String, nullable=False)
op.alter_column("tool", "json_schema", existing_type=sa.JSON, nullable=False)
# Remove `_organization_id` column from `tool` table
op.drop_column("tool", "_organization_id")
# Reverse the column renaming from `_id` back to `id`
op.alter_column("organization", "_id", new_column_name="id", existing_type=sa.String, nullable=False)
op.alter_column("tool", "_id", new_column_name="id", existing_type=sa.String, nullable=False)
op.alter_column("user", "_id", new_column_name="id", existing_type=sa.String, nullable=False)
# Reverse table renaming last
op.rename_table("organization", "organizations")
op.rename_table("tool", "tools")
op.rename_table("user", "users")

View File

@@ -0,0 +1,95 @@
"""Move organizations users tools to orm
Revision ID: d14ae606614c
Revises: 9a505cc7eca9
Create Date: 2024-11-05 15:03:12.350096
"""
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
import letta
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "d14ae606614c"
down_revision: Union[str, None] = "9a505cc7eca9"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def deprecated_tool():
return "this is a deprecated tool, please remove it from your tools list"
def upgrade() -> None:
# Delete all tools
op.execute("DELETE FROM tools")
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("agents", sa.Column("tool_rules", letta.metadata.ToolRulesColumn(), nullable=True))
op.alter_column("block", "name", new_column_name="template_name", nullable=True)
op.add_column("organizations", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
op.add_column("organizations", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
op.add_column("organizations", sa.Column("_created_by_id", sa.String(), nullable=True))
op.add_column("organizations", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
op.add_column("tools", sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
op.add_column("tools", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
op.add_column("tools", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
op.add_column("tools", sa.Column("_created_by_id", sa.String(), nullable=True))
op.add_column("tools", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
op.add_column("tools", sa.Column("organization_id", sa.String(), nullable=False))
op.alter_column("tools", "tags", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
op.alter_column("tools", "source_type", existing_type=sa.VARCHAR(), nullable=False)
op.alter_column("tools", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
op.create_unique_constraint("uix_name_organization", "tools", ["name", "organization_id"])
op.create_foreign_key(None, "tools", "organizations", ["organization_id"], ["id"])
op.drop_column("tools", "user_id")
op.add_column("users", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
op.add_column("users", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
op.add_column("users", sa.Column("_created_by_id", sa.String(), nullable=True))
op.add_column("users", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
op.add_column("users", sa.Column("organization_id", sa.String(), nullable=True))
# loop through all rows in the user table and set the _organization_id column from organization_id
op.execute('UPDATE "users" SET organization_id = org_id')
# set the _organization_id column to not nullable
op.alter_column("users", "organization_id", existing_type=sa.String(), nullable=False)
op.create_foreign_key(None, "users", "organizations", ["organization_id"], ["id"])
op.drop_column("users", "org_id")
op.drop_column("users", "policies_accepted")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("users", sa.Column("policies_accepted", sa.BOOLEAN(), autoincrement=False, nullable=False))
op.add_column("users", sa.Column("org_id", sa.VARCHAR(), autoincrement=False, nullable=True))
op.drop_constraint(None, "users", type_="foreignkey")
op.drop_column("users", "organization_id")
op.drop_column("users", "_last_updated_by_id")
op.drop_column("users", "_created_by_id")
op.drop_column("users", "is_deleted")
op.drop_column("users", "updated_at")
op.add_column("tools", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=True))
op.drop_constraint(None, "tools", type_="foreignkey")
op.drop_constraint("uix_name_organization", "tools", type_="unique")
op.alter_column("tools", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
op.alter_column("tools", "source_type", existing_type=sa.VARCHAR(), nullable=True)
op.alter_column("tools", "tags", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
op.drop_column("tools", "organization_id")
op.drop_column("tools", "_last_updated_by_id")
op.drop_column("tools", "_created_by_id")
op.drop_column("tools", "is_deleted")
op.drop_column("tools", "updated_at")
op.drop_column("tools", "created_at")
op.drop_column("organizations", "_last_updated_by_id")
op.drop_column("organizations", "_created_by_id")
op.drop_column("organizations", "is_deleted")
op.drop_column("organizations", "updated_at")
op.add_column("block", sa.Column("name", sa.VARCHAR(), autoincrement=False, nullable=True))
op.drop_column("block", "template_name")
op.drop_column("agents", "tool_rules")
# ### end Alembic commands ###

View File

@@ -1,89 +0,0 @@
"""Sync migration with model changes
Revision ID: ee50a967e090
Revises: 0c315956709d
Create Date: 2024-11-01 19:18:39.399950
"""
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
import letta
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "ee50a967e090"
down_revision: Union[str, None] = "0c315956709d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("agents", sa.Column("tool_rules", letta.metadata.ToolRulesColumn(), nullable=True))
op.add_column("organization", sa.Column("deleted", sa.Boolean(), nullable=False))
op.add_column("organization", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
op.add_column("organization", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
op.add_column("organization", sa.Column("_created_by_id", sa.String(), nullable=True))
op.add_column("organization", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
op.add_column("tool", sa.Column("deleted", sa.Boolean(), nullable=False))
op.add_column("tool", sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
op.add_column("tool", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
op.add_column("tool", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
op.add_column("tool", sa.Column("_created_by_id", sa.String(), nullable=True))
op.add_column("tool", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
op.alter_column("tool", "tags", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
op.alter_column("tool", "source_type", existing_type=sa.VARCHAR(), nullable=False)
op.alter_column("tool", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
op.alter_column("tool", "_organization_id", existing_type=sa.VARCHAR(), nullable=False)
op.drop_constraint("uq_tool_name_organization", "tool", type_="unique")
op.create_unique_constraint("uix_name_organization", "tool", ["name", "_organization_id"])
op.create_foreign_key(None, "tool", "organization", ["_organization_id"], ["_id"])
op.drop_column("tool", "user_id")
op.add_column("user", sa.Column("deleted", sa.Boolean(), nullable=False))
op.add_column("user", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
op.add_column("user", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
op.add_column("user", sa.Column("_created_by_id", sa.String(), nullable=True))
op.add_column("user", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
op.add_column("user", sa.Column("_organization_id", sa.String(), nullable=False))
op.create_foreign_key(None, "user", "organization", ["_organization_id"], ["_id"])
op.drop_column("user", "policies_accepted")
op.drop_column("user", "org_id")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("user", sa.Column("org_id", sa.VARCHAR(), autoincrement=False, nullable=True))
op.add_column("user", sa.Column("policies_accepted", sa.BOOLEAN(), autoincrement=False, nullable=False))
op.drop_constraint(None, "user", type_="foreignkey")
op.drop_column("user", "_organization_id")
op.drop_column("user", "_last_updated_by_id")
op.drop_column("user", "_created_by_id")
op.drop_column("user", "is_deleted")
op.drop_column("user", "updated_at")
op.drop_column("user", "deleted")
op.add_column("tool", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=True))
op.drop_constraint(None, "tool", type_="foreignkey")
op.drop_constraint("uix_name_organization", "tool", type_="unique")
op.create_unique_constraint("uq_tool_name_organization", "tool", ["name", "_organization_id"])
op.alter_column("tool", "_organization_id", existing_type=sa.VARCHAR(), nullable=True)
op.alter_column("tool", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
op.alter_column("tool", "source_type", existing_type=sa.VARCHAR(), nullable=True)
op.alter_column("tool", "tags", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
op.drop_column("tool", "_last_updated_by_id")
op.drop_column("tool", "_created_by_id")
op.drop_column("tool", "is_deleted")
op.drop_column("tool", "updated_at")
op.drop_column("tool", "created_at")
op.drop_column("tool", "deleted")
op.drop_column("organization", "_last_updated_by_id")
op.drop_column("organization", "_created_by_id")
op.drop_column("organization", "is_deleted")
op.drop_column("organization", "updated_at")
op.drop_column("organization", "deleted")
op.drop_column("agents", "tool_rules")
# ### end Alembic commands ###

View File

@@ -1,36 +0,0 @@
"""Rename block.name to block.template_name
Revision ID: eff245f340f9
Revises: 0c315956709d
Create Date: 2024-10-31 18:09:08.819371
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "eff245f340f9"
down_revision: Union[str, None] = "ee50a967e090"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column("block", "name", new_column_name="template_name", existing_type=sa.String(), nullable=True)
# op.add_column('block', sa.Column('template_name', sa.String(), nullable=True))
# op.drop_column('block', 'name')
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column("block", "template_name", new_column_name="name", existing_type=sa.String(), nullable=True)
# op.add_column('block', sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=True))
# op.drop_column('block', 'template_name')
# ### end Alembic commands ###

View File

@@ -3,6 +3,7 @@ import inspect
import traceback
import warnings
from abc import ABC, abstractmethod
from lib2to3.fixer_util import is_list
from typing import List, Literal, Optional, Tuple, Union
from tqdm import tqdm
@@ -250,6 +251,9 @@ class Agent(BaseAgent):
# if there are tool rules, print out a warning
warnings.warn("Tool rules only work reliably for the latest OpenAI models that support structured outputs.")
# add default rule for having send_message be a terminal tool
if not is_list(agent_state.tool_rules):
agent_state.tool_rules = []
agent_state.tool_rules.append(TerminalToolRule(tool_name="send_message"))
self.tool_rules_solver = ToolRulesSolver(tool_rules=agent_state.tool_rules)

View File

@@ -358,26 +358,26 @@ class PostgresStorageConnector(SQLStorageConnector):
# construct URI from enviornment variables
if settings.pg_uri:
self.uri = settings.pg_uri
# use config URI
# TODO: remove this eventually (config should NOT contain URI)
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
self.uri = self.config.archival_storage_uri
self.db_model = PassageModel
if self.config.archival_storage_uri is None:
raise ValueError(f"Must specify archival_storage_uri in config {self.config.config_path}")
elif table_type == TableType.RECALL_MEMORY:
self.uri = self.config.recall_storage_uri
self.db_model = MessageModel
if self.config.recall_storage_uri is None:
raise ValueError(f"Must specify recall_storage_uri in config {self.config.config_path}")
elif table_type == TableType.FILES:
self.uri = self.config.metadata_storage_uri
self.db_model = FileMetadataModel
if self.config.metadata_storage_uri is None:
raise ValueError(f"Must specify metadata_storage_uri in config {self.config.config_path}")
else:
# use config URI
# TODO: remove this eventually (config should NOT contain URI)
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
self.uri = self.config.archival_storage_uri
self.db_model = PassageModel
if self.config.archival_storage_uri is None:
raise ValueError(f"Must specify archival_storage_uri in config {self.config.config_path}")
elif table_type == TableType.RECALL_MEMORY:
self.uri = self.config.recall_storage_uri
self.db_model = MessageModel
if self.config.recall_storage_uri is None:
raise ValueError(f"Must specify recall_storage_uri in config {self.config.config_path}")
elif table_type == TableType.FILES:
self.uri = self.config.metadata_storage_uri
self.db_model = FileMetadataModel
if self.config.metadata_storage_uri is None:
raise ValueError(f"Must specify metadata_storage_uri in config {self.config.config_path}")
else:
raise ValueError(f"Table type {table_type} not implemented")
raise ValueError(f"Table type {table_type} not implemented")
for c in self.db_model.__table__.columns:
if c.name == "embedding":

View File

@@ -2267,18 +2267,18 @@ class LocalClient(AbstractClient):
langchain_tool=langchain_tool,
additional_imports_module_attr_map=additional_imports_module_attr_map,
)
return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)
return self.server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=self.user)
def load_crewai_tool(self, crewai_tool: "CrewAIBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool:
tool_create = ToolCreate.from_crewai(
crewai_tool=crewai_tool,
additional_imports_module_attr_map=additional_imports_module_attr_map,
)
return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)
return self.server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=self.user)
def load_composio_tool(self, action: "ActionType") -> Tool:
tool_create = ToolCreate.from_composio(action=action)
return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)
return self.server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=self.user)
# TODO: Use the above function `add_tool` here as there is duplicate logic
def create_tool(
@@ -2310,7 +2310,7 @@ class LocalClient(AbstractClient):
# call server function
return self.server.tool_manager.create_or_update_tool(
ToolCreate(
Tool(
source_type=source_type,
source_code=source_code,
name=name,
@@ -2738,7 +2738,7 @@ class LocalClient(AbstractClient):
return self.server.list_embedding_models()
def create_org(self, name: Optional[str] = None) -> Organization:
return self.server.organization_manager.create_organization(name=name)
return self.server.organization_manager.create_organization(pydantic_org=Organization(name=name))
def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
return self.server.organization_manager.list_organizations(cursor=cursor, limit=limit)

View File

@@ -67,7 +67,7 @@ class CommonSqlalchemyMetaMixins(Base):
prop_value = getattr(self, full_prop, None)
if not prop_value:
return
return f"user-{prop_value}"
return prop_value
def _user_id_setter(self, prop: str, value: str) -> None:
"""returns the user id for the specified property"""
@@ -75,6 +75,9 @@ class CommonSqlalchemyMetaMixins(Base):
if not value:
setattr(self, full_prop, None)
return
# Safety check
prefix, id_ = value.split("-", 1)
assert prefix == "user", f"{prefix} is not a valid id prefix for a user id"
setattr(self, full_prop, id_)
# Set the full value
setattr(self, full_prop, value)

View File

@@ -1,11 +1,9 @@
from typing import Optional
from uuid import UUID
from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm.base import Base
from letta.orm.errors import MalformedIdError
def is_valid_uuid4(uuid_string: str) -> bool:
@@ -17,53 +15,12 @@ def is_valid_uuid4(uuid_string: str) -> bool:
return False
def _relation_getter(instance: "Base", prop: str) -> Optional[str]:
"""Get relation and return id with prefix as a string."""
prefix = prop.replace("_", "")
formatted_prop = f"_{prop}_id"
try:
id_ = getattr(instance, formatted_prop) # Get the string id directly
return f"{prefix}-{id_}"
except AttributeError:
return None
def _relation_setter(instance: "Base", prop: str, value: str) -> None:
"""Set relation using the id with prefix, ensuring the id is a valid UUIDv4."""
formatted_prop = f"_{prop}_id"
prefix = prop.replace("_", "")
if not value:
setattr(instance, formatted_prop, None)
return
try:
found_prefix, id_ = value.split("-", 1)
except ValueError as e:
raise MalformedIdError(f"{value} is not a valid ID.") from e
# Ensure prefix matches
assert found_prefix == prefix, f"{found_prefix} is not a valid id prefix, expecting {prefix}"
# Validate that the id is a valid UUID4 string
if not is_valid_uuid4(id_):
raise MalformedIdError(f"Hash segment of {value} is not a valid UUID4")
setattr(instance, formatted_prop, id_) # Store id as a string
class OrganizationMixin(Base):
"""Mixin for models that belong to an organization."""
__abstract__ = True
_organization_id: Mapped[str] = mapped_column(String, ForeignKey("organization._id"))
@property
def organization_id(self) -> str:
return _relation_getter(self, "organization")
@organization_id.setter
def organization_id(self, value: str) -> None:
_relation_setter(self, "organization", value)
organization_id: Mapped[str] = mapped_column(String, ForeignKey("organizations.id"))
class UserMixin(Base):
@@ -71,12 +28,4 @@ class UserMixin(Base):
__abstract__ = True
_user_id: Mapped[str] = mapped_column(String, ForeignKey("user._id"))
@property
def user_id(self) -> str:
return _relation_getter(self, "user")
@user_id.setter
def user_id(self, value: str) -> None:
_relation_setter(self, "user", value)
user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"))

View File

@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, List
from sqlalchemy import String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.sqlalchemy_base import SqlalchemyBase
@@ -14,9 +15,10 @@ if TYPE_CHECKING:
class Organization(SqlalchemyBase):
"""The highest level of the object tree. All Entities belong to one and only one Organization."""
__tablename__ = "organization"
__tablename__ = "organizations"
__pydantic_model__ = PydanticOrganization
id: Mapped[str] = mapped_column(String, primary_key=True)
name: Mapped[str] = mapped_column(doc="The display name of the organization.")
users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")

View File

@@ -1,14 +1,11 @@
from typing import TYPE_CHECKING, List, Literal, Optional, Type
from uuid import uuid4
from humps import depascalize
from sqlalchemy import Boolean, String, select
from sqlalchemy import String, select
from sqlalchemy.orm import Mapped, mapped_column
from letta.log import get_logger
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
from letta.orm.errors import NoResultFound
from letta.orm.mixins import is_valid_uuid4
if TYPE_CHECKING:
from pydantic import BaseModel
@@ -24,27 +21,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
__order_by_default__ = "created_at"
_id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"{uuid4()}")
deleted: Mapped[bool] = mapped_column(Boolean, default=False, doc="Is this record deleted? Used for universal soft deletes.")
@classmethod
def __prefix__(cls) -> str:
return depascalize(cls.__name__)
@property
def id(self) -> Optional[str]:
if self._id:
return f"{self.__prefix__()}-{self._id}"
@id.setter
def id(self, value: str) -> None:
if not value:
return
prefix, id_ = value.split("-", 1)
assert prefix == self.__prefix__(), f"{prefix} is not a valid id prefix for {self.__class__.__name__}"
assert is_valid_uuid4(id_), f"{id_} is not a valid uuid4"
self._id = id_
id: Mapped[str] = mapped_column(String, primary_key=True)
@classmethod
def list(
@@ -57,11 +34,10 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
# Add a cursor condition if provided
if cursor:
cursor_uuid = cls.get_uid_from_identifier(cursor) # Assuming the cursor is an _id value
query = query.where(cls._id > cursor_uuid)
query = query.where(cls.id > cursor)
# Add a limit to the query if provided
query = query.order_by(cls._id).limit(limit)
query = query.order_by(cls.id).limit(limit)
# Handle soft deletes if the class has the 'is_deleted' attribute
if hasattr(cls, "is_deleted"):
@@ -70,20 +46,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
# Execute the query and return the results as a list of model instances
return list(session.execute(query).scalars())
@classmethod
def get_uid_from_identifier(cls, identifier: str, indifferent: Optional[bool] = False) -> str:
"""converts the id into a uuid object
Args:
identifier: the string identifier, such as `organization-xxxx-xx...`
indifferent: if True, will not enforce the prefix check
"""
try:
uuid_string = identifier.split("-", 1)[1] if indifferent else identifier.replace(f"{cls.__prefix__()}-", "")
assert is_valid_uuid4(uuid_string)
return uuid_string
except ValueError as e:
raise ValueError(f"{identifier} is not a valid identifier for class {cls.__name__}") from e
@classmethod
def read(
cls,
@@ -112,8 +74,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
# If an identifier is provided, add it to the query conditions
if identifier is not None:
identifier = cls.get_uid_from_identifier(identifier)
query = query.where(cls._id == identifier)
query = query.where(cls.id == identifier)
query_conditions.append(f"id='{identifier}'")
if kwargs:
@@ -183,7 +144,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
org_id = getattr(actor, "organization_id", None)
if not org_id:
raise ValueError(f"object {actor} has no organization accessor")
return query.where(cls._organization_id == cls.get_uid_from_identifier(org_id, indifferent=True), cls.is_deleted == False)
return query.where(cls.organization_id == org_id, cls.is_deleted == False)
@property
def __pydantic_model__(self) -> Type["BaseModel"]:

View File

@@ -21,13 +21,14 @@ class Tool(SqlalchemyBase, OrganizationMixin):
more granular permissions.
"""
__tablename__ = "tool"
__tablename__ = "tools"
__pydantic_model__ = PydanticTool
# Add unique constraint on (name, _organization_id)
# An organization should not have multiple tools with the same name
__table_args__ = (UniqueConstraint("name", "_organization_id", name="uix_name_organization"),)
__table_args__ = (UniqueConstraint("name", "organization_id", name="uix_name_organization"),)
id: Mapped[str] = mapped_column(String, primary_key=True)
name: Mapped[str] = mapped_column(doc="The display name of the tool.")
description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.")
tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.")

View File

@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING
from sqlalchemy import String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
@@ -13,9 +14,10 @@ if TYPE_CHECKING:
class User(SqlalchemyBase, OrganizationMixin):
"""User ORM class"""
__tablename__ = "user"
__tablename__ = "users"
__pydantic_model__ = PydanticUser
id: Mapped[str] = mapped_column(String, primary_key=True)
name: Mapped[str] = mapped_column(nullable=False, doc="The display name of the user.")
# relationships

View File

@@ -106,6 +106,10 @@ class Memory(BaseModel, validate_assignment=True):
# New format
obj.prompt_template = state["prompt_template"]
for key, value in state["memory"].items():
# TODO: This is migration code, please take a look at a later time to get rid of this
if "name" in value:
value["template_name"] = value["name"]
value.pop("name")
obj.memory[key] = Block(**value)
else:
# Old format (pre-template)

View File

@@ -4,16 +4,16 @@ from typing import Optional
from pydantic import Field
from letta.schemas.letta_base import LettaBase
from letta.utils import get_utc_time
from letta.utils import create_random_username, get_utc_time
class OrganizationBase(LettaBase):
__id_prefix__ = "organization"
__id_prefix__ = "org"
class Organization(OrganizationBase):
id: str = Field(..., description="The id of the organization.")
name: str = Field(..., description="The name of the organization.")
id: str = OrganizationBase.generate_id_field()
name: str = Field(create_random_username(), description="The name of the organization.")
created_at: Optional[datetime] = Field(default_factory=get_utc_time, description="The creation date of the organization.")

View File

@@ -33,21 +33,21 @@ class Tool(BaseTool):
"""
id: str = Field(..., description="The id of the tool.")
id: str = BaseTool.generate_id_field()
description: Optional[str] = Field(None, description="The description of the tool.")
source_type: Optional[str] = Field(None, description="The type of the source code.")
module: Optional[str] = Field(None, description="The module of the function.")
organization_id: str = Field(..., description="The unique identifier of the organization associated with the tool.")
name: str = Field(..., description="The name of the function.")
tags: List[str] = Field(..., description="Metadata tags.")
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the tool.")
name: Optional[str] = Field(None, description="The name of the function.")
tags: List[str] = Field([], description="Metadata tags.")
# code
source_code: str = Field(..., description="The source code of the function.")
json_schema: Dict = Field(default_factory=dict, description="The JSON schema of the function.")
json_schema: Optional[Dict] = Field(None, description="The JSON schema of the function.")
# metadata fields
created_by_id: str = Field(..., description="The id of the user that made this Tool.")
last_updated_by_id: str = Field(..., description="The id of the user that made this Tool.")
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
def to_dict(self):
"""

View File

@@ -21,7 +21,7 @@ class User(UserBase):
created_at (datetime): The creation date of the user.
"""
id: str = Field(..., description="The id of the user.")
id: str = UserBase.generate_id_field()
organization_id: Optional[str] = Field(OrganizationManager.DEFAULT_ORG_ID, description="The organization id of the user")
name: str = Field(..., description="The name of the user.")
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The creation date of the user.")

View File

@@ -38,7 +38,8 @@ def create_org(
"""
Create a new org in the database
"""
org = server.organization_manager.create_organization(name=request.name)
org = Organization(**request.model_dump())
org = server.organization_manager.create_organization(pydantic_org=org)
return org

View File

@@ -89,7 +89,8 @@ def create_tool(
actor = server.get_user_or_default(user_id=user_id)
# Send request to create the tool
return server.tool_manager.create_or_update_tool(tool_create=request, actor=actor)
tool = Tool(**request.model_dump())
return server.tool_manager.create_or_update_tool(pydantic_tool=tool, actor=actor)
@router.patch("/{tool_id}", response_model=Tool, operation_id="update_tool")

View File

@@ -51,8 +51,8 @@ def create_user(
"""
Create a new user in the database
"""
user = server.user_manager.create_user(request)
user = User(**request.model_dump())
user = server.user_manager.create_user(user)
return user

View File

@@ -824,7 +824,7 @@ class SyncServer(Server):
source_type = "python"
tags = ["memory", "memgpt-base"]
tool = self.tool_manager.create_or_update_tool(
ToolCreate(
Tool(
source_code=source_code,
source_type=source_type,
tags=tags,
@@ -1766,7 +1766,7 @@ class SyncServer(Server):
tool_creates += ToolCreate.load_default_composio_tools()
for tool_create in tool_creates:
try:
self.tool_manager.create_or_update_tool(tool_create, actor=actor)
self.tool_manager.create_or_update_tool(Tool(**tool_create.model_dump()), actor=actor)
except Exception as e:
warnings.warn(f"An error occurred while creating tool {tool_create}: {e}")
warnings.warn(traceback.format_exc())

View File

@@ -3,13 +3,13 @@ from typing import List, Optional
from letta.orm.errors import NoResultFound
from letta.orm.organization import Organization as OrganizationModel
from letta.schemas.organization import Organization as PydanticOrganization
from letta.utils import create_random_username, enforce_types
from letta.utils import enforce_types
class OrganizationManager:
"""Manager class to handle business logic related to Organizations."""
DEFAULT_ORG_ID = "organization-00000000-0000-4000-8000-000000000000"
DEFAULT_ORG_ID = "org-00000000-0000-4000-8000-000000000000"
DEFAULT_ORG_NAME = "default_org"
def __init__(self):
@@ -37,10 +37,10 @@ class OrganizationManager:
raise ValueError(f"Organization with id {org_id} not found.")
@enforce_types
def create_organization(self, name: Optional[str] = None) -> PydanticOrganization:
def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
"""Create a new organization. If a name is provided, it is used, otherwise, a random one is generated."""
with self.session_maker() as session:
org = OrganizationModel(name=name if name else create_random_username())
org = OrganizationModel(**pydantic_org.model_dump())
org.create(session)
return org.to_pydantic()

View File

@@ -7,10 +7,9 @@ from letta.functions.functions import derive_openai_json_schema, load_function_s
# TODO: Remove this once we translate all of these to the ORM
from letta.orm.errors import NoResultFound
from letta.orm.organization import Organization as OrganizationModel
from letta.orm.tool import Tool as ToolModel
from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.tool import ToolCreate, ToolUpdate
from letta.schemas.tool import ToolUpdate
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types, printd
@@ -33,20 +32,20 @@ class ToolManager:
self.session_maker = db_context
@enforce_types
def create_or_update_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool:
def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
# Derive json_schema
derived_json_schema = tool_create.json_schema or derive_openai_json_schema(
source_code=tool_create.source_code, name=tool_create.name
derived_json_schema = pydantic_tool.json_schema or derive_openai_json_schema(
source_code=pydantic_tool.source_code, name=pydantic_tool.name
)
derived_name = tool_create.name or derived_json_schema["name"]
derived_name = pydantic_tool.name or derived_json_schema["name"]
try:
# NOTE: We use the organization id here
# This is important, because even if it's a different user, adding the same tool to the org should not happen
tool = self.get_tool_by_name(tool_name=derived_name, actor=actor)
# Put to dict and remove fields that should not be reset
update_data = tool_create.model_dump(exclude={"module"}, exclude_unset=True)
update_data = pydantic_tool.model_dump(exclude={"module"}, exclude_unset=True, exclude_none=True)
# Remove redundant update fields
update_data = {key: value for key, value in update_data.items() if getattr(tool, key) != value}
@@ -55,22 +54,24 @@ class ToolManager:
self.update_tool_by_id(tool.id, ToolUpdate(**update_data), actor)
else:
printd(
f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={tool_create.name}, but found existing tool with nothing to update."
f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update."
)
except NoResultFound:
tool_create.json_schema = derived_json_schema
tool_create.name = derived_name
tool = self.create_tool(tool_create, actor=actor)
pydantic_tool.json_schema = derived_json_schema
pydantic_tool.name = derived_name
tool = self.create_tool(pydantic_tool, actor=actor)
return tool
@enforce_types
def create_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool:
def create_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
# Create the tool
with self.session_maker() as session:
create_data = tool_create.model_dump()
tool = ToolModel(**create_data, organization_id=actor.organization_id) # Unpack everything directly into ToolModel
# Set the organization id at the ORM layer
pydantic_tool.organization_id = actor.organization_id
tool_data = pydantic_tool.model_dump()
tool = ToolModel(**tool_data)
tool.create(session, actor=actor)
return tool.to_pydantic()
@@ -99,7 +100,7 @@ class ToolManager:
db_session=session,
cursor=cursor,
limit=limit,
_organization_id=OrganizationModel.get_uid_from_identifier(actor.organization_id),
organization_id=actor.organization_id,
)
return [tool.to_pydantic() for tool in tools]
@@ -176,7 +177,7 @@ class ToolManager:
# create to tool
tools.append(
self.create_or_update_tool(
ToolCreate(
PydanticTool(
name=name,
tags=tags,
source_type="python",

View File

@@ -4,7 +4,7 @@ from letta.orm.errors import NoResultFound
from letta.orm.organization import Organization as OrganizationModel
from letta.orm.user import User as UserModel
from letta.schemas.user import User as PydanticUser
from letta.schemas.user import UserCreate, UserUpdate
from letta.schemas.user import UserUpdate
from letta.services.organization_manager import OrganizationManager
from letta.utils import enforce_types
@@ -42,10 +42,10 @@ class UserManager:
return user.to_pydantic()
@enforce_types
def create_user(self, user_create: UserCreate) -> PydanticUser:
def create_user(self, pydantic_user: PydanticUser) -> PydanticUser:
"""Create a new user if it doesn't already exist."""
with self.session_maker() as session:
new_user = UserModel(**user_create.model_dump())
new_user = UserModel(**pydantic_user.model_dump())
new_user.create(session)
return new_user.to_pydantic()

52
scripts/migrate_tools.py Normal file
View File

@@ -0,0 +1,52 @@
from letta.functions.functions import parse_source_code
from letta.schemas.tool import Tool
from letta.schemas.user import User
from letta.services.organization_manager import OrganizationManager
from letta.services.tool_manager import ToolManager
def deprecated_tool():
return "this is a deprecated tool, please remove it from your tools list"
orgs = OrganizationManager().list_organizations(cursor=None, limit=5000)
for org in orgs:
if org.name != "default":
fake_user = User(id="user-00000000-0000-4000-8000-000000000000", name="fake", organization_id=org.id)
ToolManager().add_base_tools(actor=fake_user)
source_code = parse_source_code(deprecated_tool)
source_type = "python"
description = "deprecated"
tags = ["deprecated"]
ToolManager().create_or_update_tool(
Tool(
name="core_memory_append",
source_code=source_code,
source_type=source_type,
description=description,
),
actor=fake_user,
)
ToolManager().create_or_update_tool(
Tool(
name="core_memory_replace",
source_code=source_code,
source_type=source_type,
description=description,
),
actor=fake_user,
)
ToolManager().create_or_update_tool(
Tool(
name="pause_heartbeats",
source_code=source_code,
source_type=source_type,
description=description,
),
actor=fake_user,
)

View File

@@ -420,7 +420,7 @@ def test_tools_from_langchain(client: LocalClient):
exec(source_code, {}, local_scope)
func = local_scope[tool.name]
expected_content = "Albert Einstein ( EYEN-styne; German:"
expected_content = "Albert Einstein"
assert expected_content in func(query="Albert Einstein")

View File

@@ -6,12 +6,15 @@ from letta.functions.functions import derive_openai_json_schema, parse_source_co
from letta.orm.organization import Organization
from letta.orm.tool import Tool
from letta.orm.user import User
from letta.schemas.tool import ToolCreate, ToolUpdate
from letta.schemas.organization import Organization as PydanticOrganization
from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.tool import ToolUpdate
from letta.services.organization_manager import OrganizationManager
utils.DEBUG = True
from letta.config import LettaConfig
from letta.schemas.user import UserCreate, UserUpdate
from letta.schemas.user import User as PydanticUser
from letta.schemas.user import UserUpdate
from letta.server.server import SyncServer
@@ -47,17 +50,17 @@ def tool_fixture(server: SyncServer):
org = server.organization_manager.create_default_organization()
user = server.user_manager.create_default_user()
other_user = server.user_manager.create_user(UserCreate(name="other", organization_id=org.id))
tool_create = ToolCreate(description=description, tags=tags, source_code=source_code, source_type=source_type)
derived_json_schema = derive_openai_json_schema(source_code=tool_create.source_code, name=tool_create.name)
other_user = server.user_manager.create_user(PydanticUser(name="other", organization_id=org.id))
tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type)
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
derived_name = derived_json_schema["name"]
tool_create.json_schema = derived_json_schema
tool_create.name = derived_name
tool.json_schema = derived_json_schema
tool.name = derived_name
tool = server.tool_manager.create_tool(tool_create, actor=user)
tool = server.tool_manager.create_tool(tool, actor=user)
# Yield the created tool, organization, and user for use in tests
yield {"tool": tool, "organization": org, "user": user, "other_user": other_user, "tool_create": tool_create}
yield {"tool": tool, "organization": org, "user": user, "other_user": other_user, "tool_create": tool}
@pytest.fixture(scope="module")
@@ -76,7 +79,7 @@ def server():
def test_list_organizations(server: SyncServer):
# Create a new org and confirm that it is created correctly
org_name = "test"
org = server.organization_manager.create_organization(name=org_name)
org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name))
orgs = server.organization_manager.list_organizations()
assert len(orgs) == 1
@@ -96,15 +99,15 @@ def test_create_default_organization(server: SyncServer):
def test_update_organization_name(server: SyncServer):
org_name_a = "a"
org_name_b = "b"
org = server.organization_manager.create_organization(name=org_name_a)
org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name_a))
assert org.name == org_name_a
org = server.organization_manager.update_organization_name_using_id(org_id=org.id, name=org_name_b)
assert org.name == org_name_b
def test_list_organizations_pagination(server: SyncServer):
server.organization_manager.create_organization(name="a")
server.organization_manager.create_organization(name="b")
server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="a"))
server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="b"))
orgs_x = server.organization_manager.list_organizations(limit=1)
assert len(orgs_x) == 1
@@ -125,7 +128,7 @@ def test_list_users(server: SyncServer):
org = server.organization_manager.create_default_organization()
user_name = "user"
user = server.user_manager.create_user(UserCreate(name=user_name, organization_id=org.id))
user = server.user_manager.create_user(PydanticUser(name=user_name, organization_id=org.id))
users = server.user_manager.list_users()
assert len(users) == 1
@@ -146,13 +149,13 @@ def test_create_default_user(server: SyncServer):
def test_update_user(server: SyncServer):
# Create default organization
default_org = server.organization_manager.create_default_organization()
test_org = server.organization_manager.create_organization(name="test_org")
test_org = server.organization_manager.create_organization(PydanticOrganization(name="test_org"))
user_name_a = "a"
user_name_b = "b"
# Assert it's been created
user = server.user_manager.create_user(UserCreate(name=user_name_a, organization_id=default_org.id))
user = server.user_manager.create_user(PydanticUser(name=user_name_a, organization_id=default_org.id))
assert user.name == user_name_a
# Adjust name
@@ -340,7 +343,6 @@ def test_update_tool_multi_user(server: SyncServer, tool_fixture):
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=other_user)
# Check that the created_by and last_updated_by fields are correct
# Fetch the updated tool to verify the changes
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user)

View File

@@ -14,6 +14,7 @@ from letta.constants import DEFAULT_PRESET
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from letta.services.tool_manager import ToolManager
test_agent_name = f"test_client_{str(uuid.uuid4())}"
# test_preset_name = "test_preset"
@@ -93,15 +94,9 @@ def test_create_tool(client: Union[LocalClient, RESTClient]):
return message
tools = client.list_tools()
assert sorted([t.name for t in tools]) == sorted(
[
"archival_memory_search",
"send_message",
"conversation_search",
"conversation_search_date",
"archival_memory_insert",
]
)
tool_names = [t.name for t in tools]
for tool in ToolManager.BASE_TOOL_NAMES:
assert tool in tool_names
tool = client.create_tool(print_tool, name="my_name", tags=["extras"])