diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 639d0b4d..a80bc4eb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -131,4 +131,4 @@ jobs: LETTA_SERVER_PASS: test_server_token PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }} run: | - poetry run pytest -s -vv -k "not test_single_path_agent_tool_call_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests + poetry run pytest -s -vv -k "not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests diff --git a/alembic/versions/0c315956709d_.py b/alembic/versions/0c315956709d_.py deleted file mode 100644 index 704ab235..00000000 --- a/alembic/versions/0c315956709d_.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Sequence, Union - -import sqlalchemy as sa - -from alembic import op - -# Revision identifiers, used by Alembic. -revision: str = "0c315956709d" -down_revision: Union[str, None] = "9a505cc7eca9" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # Step 1: Rename tables first - op.rename_table("organizations", "organization") - op.rename_table("tools", "tool") - op.rename_table("users", "user") - - # Step 2: Rename `id` to `_id` in each table, keeping it as the primary key - op.alter_column("organization", "id", new_column_name="_id", existing_type=sa.String, nullable=False) - op.alter_column("tool", "id", new_column_name="_id", existing_type=sa.String, nullable=False) - op.alter_column("user", "id", new_column_name="_id", existing_type=sa.String, nullable=False) - - # Step 3: Add `_organization_id` to `tool` table - # This is required for the unique constraint below - op.add_column("tool", sa.Column("_organization_id", sa.String, nullable=True)) - - # Step 4: Modify nullable constraints on `tool` columns - op.alter_column("tool", "tags", existing_type=sa.JSON, nullable=True) - op.alter_column("tool", "source_type", existing_type=sa.String, nullable=True) - op.alter_column("tool", "json_schema", existing_type=sa.JSON, nullable=True) - - # Step 5: Add unique constraint on `name` and `_organization_id` in `tool` table - op.create_unique_constraint("uq_tool_name_organization", "tool", ["name", "_organization_id"]) - - -def downgrade() -> None: - # Reverse unique constraint first - op.drop_constraint("uq_tool_name_organization", "tool", type_="unique") - - # Reverse nullable constraints on `tool` columns - op.alter_column("tool", "tags", existing_type=sa.JSON, nullable=False) - op.alter_column("tool", "source_type", existing_type=sa.String, nullable=False) - op.alter_column("tool", "json_schema", existing_type=sa.JSON, nullable=False) - - # Remove `_organization_id` column from `tool` table - op.drop_column("tool", "_organization_id") - - # Reverse the column renaming from `_id` back to `id` - op.alter_column("organization", "_id", new_column_name="id", existing_type=sa.String, nullable=False) - op.alter_column("tool", "_id", new_column_name="id", existing_type=sa.String, nullable=False) - op.alter_column("user", "_id", new_column_name="id", existing_type=sa.String, nullable=False) - - # Reverse table renaming last - op.rename_table("organization", "organizations") - op.rename_table("tool", "tools") - op.rename_table("user", "users") diff --git a/alembic/versions/d14ae606614c_move_organizations_users_tools_to_orm.py b/alembic/versions/d14ae606614c_move_organizations_users_tools_to_orm.py new file mode 100644 index 00000000..c05775eb --- /dev/null +++ b/alembic/versions/d14ae606614c_move_organizations_users_tools_to_orm.py @@ -0,0 +1,95 @@ +"""Move organizations users tools to orm + +Revision ID: d14ae606614c +Revises: 9a505cc7eca9 +Create Date: 2024-11-05 15:03:12.350096 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +import letta +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "d14ae606614c" +down_revision: Union[str, None] = "9a505cc7eca9" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def deprecated_tool(): + return "this is a deprecated tool, please remove it from your tools list" + + +def upgrade() -> None: + # Delete all tools + op.execute("DELETE FROM tools") + + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("agents", sa.Column("tool_rules", letta.metadata.ToolRulesColumn(), nullable=True)) + op.alter_column("block", "name", new_column_name="template_name", nullable=True) + op.add_column("organizations", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True)) + op.add_column("organizations", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False)) + op.add_column("organizations", sa.Column("_created_by_id", sa.String(), nullable=True)) + op.add_column("organizations", sa.Column("_last_updated_by_id", sa.String(), nullable=True)) + op.add_column("tools", sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True)) + op.add_column("tools", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True)) + op.add_column("tools", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False)) + op.add_column("tools", sa.Column("_created_by_id", sa.String(), nullable=True)) + op.add_column("tools", sa.Column("_last_updated_by_id", sa.String(), nullable=True)) + op.add_column("tools", sa.Column("organization_id", sa.String(), nullable=False)) + op.alter_column("tools", "tags", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False) + op.alter_column("tools", "source_type", existing_type=sa.VARCHAR(), nullable=False) + op.alter_column("tools", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False) + op.create_unique_constraint("uix_name_organization", "tools", ["name", "organization_id"]) + op.create_foreign_key(None, "tools", "organizations", ["organization_id"], ["id"]) + op.drop_column("tools", "user_id") + op.add_column("users", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True)) + op.add_column("users", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False)) + op.add_column("users", sa.Column("_created_by_id", sa.String(), nullable=True)) + op.add_column("users", sa.Column("_last_updated_by_id", sa.String(), nullable=True)) + op.add_column("users", sa.Column("organization_id", sa.String(), nullable=True)) + # loop through all rows in the user table and set the _organization_id column from organization_id + op.execute('UPDATE "users" SET organization_id = org_id') + # set the _organization_id column to not nullable + op.alter_column("users", "organization_id", existing_type=sa.String(), nullable=False) + op.create_foreign_key(None, "users", "organizations", ["organization_id"], ["id"]) + op.drop_column("users", "org_id") + op.drop_column("users", "policies_accepted") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("users", sa.Column("policies_accepted", sa.BOOLEAN(), autoincrement=False, nullable=False)) + op.add_column("users", sa.Column("org_id", sa.VARCHAR(), autoincrement=False, nullable=True)) + op.drop_constraint(None, "users", type_="foreignkey") + op.drop_column("users", "organization_id") + op.drop_column("users", "_last_updated_by_id") + op.drop_column("users", "_created_by_id") + op.drop_column("users", "is_deleted") + op.drop_column("users", "updated_at") + op.add_column("tools", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=True)) + op.drop_constraint(None, "tools", type_="foreignkey") + op.drop_constraint("uix_name_organization", "tools", type_="unique") + op.alter_column("tools", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True) + op.alter_column("tools", "source_type", existing_type=sa.VARCHAR(), nullable=True) + op.alter_column("tools", "tags", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True) + op.drop_column("tools", "organization_id") + op.drop_column("tools", "_last_updated_by_id") + op.drop_column("tools", "_created_by_id") + op.drop_column("tools", "is_deleted") + op.drop_column("tools", "updated_at") + op.drop_column("tools", "created_at") + op.drop_column("organizations", "_last_updated_by_id") + op.drop_column("organizations", "_created_by_id") + op.drop_column("organizations", "is_deleted") + op.drop_column("organizations", "updated_at") + op.add_column("block", sa.Column("name", sa.VARCHAR(), autoincrement=False, nullable=True)) + op.drop_column("block", "template_name") + op.drop_column("agents", "tool_rules") + # ### end Alembic commands ### diff --git a/alembic/versions/ee50a967e090_sync_migration_with_model_changes.py b/alembic/versions/ee50a967e090_sync_migration_with_model_changes.py deleted file mode 100644 index 7f053ea6..00000000 --- a/alembic/versions/ee50a967e090_sync_migration_with_model_changes.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Sync migration with model changes - -Revision ID: ee50a967e090 -Revises: 0c315956709d -Create Date: 2024-11-01 19:18:39.399950 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - -import letta -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "ee50a967e090" -down_revision: Union[str, None] = "0c315956709d" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.add_column("agents", sa.Column("tool_rules", letta.metadata.ToolRulesColumn(), nullable=True)) - op.add_column("organization", sa.Column("deleted", sa.Boolean(), nullable=False)) - op.add_column("organization", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True)) - op.add_column("organization", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False)) - op.add_column("organization", sa.Column("_created_by_id", sa.String(), nullable=True)) - op.add_column("organization", sa.Column("_last_updated_by_id", sa.String(), nullable=True)) - op.add_column("tool", sa.Column("deleted", sa.Boolean(), nullable=False)) - op.add_column("tool", sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True)) - op.add_column("tool", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True)) - op.add_column("tool", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False)) - op.add_column("tool", sa.Column("_created_by_id", sa.String(), nullable=True)) - op.add_column("tool", sa.Column("_last_updated_by_id", sa.String(), nullable=True)) - op.alter_column("tool", "tags", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False) - op.alter_column("tool", "source_type", existing_type=sa.VARCHAR(), nullable=False) - op.alter_column("tool", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False) - op.alter_column("tool", "_organization_id", existing_type=sa.VARCHAR(), nullable=False) - op.drop_constraint("uq_tool_name_organization", "tool", type_="unique") - op.create_unique_constraint("uix_name_organization", "tool", ["name", "_organization_id"]) - op.create_foreign_key(None, "tool", "organization", ["_organization_id"], ["_id"]) - op.drop_column("tool", "user_id") - op.add_column("user", sa.Column("deleted", sa.Boolean(), nullable=False)) - op.add_column("user", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True)) - op.add_column("user", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False)) - op.add_column("user", sa.Column("_created_by_id", sa.String(), nullable=True)) - op.add_column("user", sa.Column("_last_updated_by_id", sa.String(), nullable=True)) - op.add_column("user", sa.Column("_organization_id", sa.String(), nullable=False)) - op.create_foreign_key(None, "user", "organization", ["_organization_id"], ["_id"]) - op.drop_column("user", "policies_accepted") - op.drop_column("user", "org_id") - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.add_column("user", sa.Column("org_id", sa.VARCHAR(), autoincrement=False, nullable=True)) - op.add_column("user", sa.Column("policies_accepted", sa.BOOLEAN(), autoincrement=False, nullable=False)) - op.drop_constraint(None, "user", type_="foreignkey") - op.drop_column("user", "_organization_id") - op.drop_column("user", "_last_updated_by_id") - op.drop_column("user", "_created_by_id") - op.drop_column("user", "is_deleted") - op.drop_column("user", "updated_at") - op.drop_column("user", "deleted") - op.add_column("tool", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=True)) - op.drop_constraint(None, "tool", type_="foreignkey") - op.drop_constraint("uix_name_organization", "tool", type_="unique") - op.create_unique_constraint("uq_tool_name_organization", "tool", ["name", "_organization_id"]) - op.alter_column("tool", "_organization_id", existing_type=sa.VARCHAR(), nullable=True) - op.alter_column("tool", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True) - op.alter_column("tool", "source_type", existing_type=sa.VARCHAR(), nullable=True) - op.alter_column("tool", "tags", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True) - op.drop_column("tool", "_last_updated_by_id") - op.drop_column("tool", "_created_by_id") - op.drop_column("tool", "is_deleted") - op.drop_column("tool", "updated_at") - op.drop_column("tool", "created_at") - op.drop_column("tool", "deleted") - op.drop_column("organization", "_last_updated_by_id") - op.drop_column("organization", "_created_by_id") - op.drop_column("organization", "is_deleted") - op.drop_column("organization", "updated_at") - op.drop_column("organization", "deleted") - op.drop_column("agents", "tool_rules") - # ### end Alembic commands ### diff --git a/alembic/versions/eff245f340f9_rename_block_name_to_block_template_name.py b/alembic/versions/eff245f340f9_rename_block_name_to_block_template_name.py deleted file mode 100644 index a723b60a..00000000 --- a/alembic/versions/eff245f340f9_rename_block_name_to_block_template_name.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Rename block.name to block.template_name - -Revision ID: eff245f340f9 -Revises: 0c315956709d -Create Date: 2024-10-31 18:09:08.819371 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa - -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "eff245f340f9" -down_revision: Union[str, None] = "ee50a967e090" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.alter_column("block", "name", new_column_name="template_name", existing_type=sa.String(), nullable=True) - - # op.add_column('block', sa.Column('template_name', sa.String(), nullable=True)) - # op.drop_column('block', 'name') - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.alter_column("block", "template_name", new_column_name="name", existing_type=sa.String(), nullable=True) - # op.add_column('block', sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=True)) - # op.drop_column('block', 'template_name') - # ### end Alembic commands ### diff --git a/letta/agent.py b/letta/agent.py index 724ca5b5..03056ce2 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -3,6 +3,7 @@ import inspect import traceback import warnings from abc import ABC, abstractmethod +from lib2to3.fixer_util import is_list from typing import List, Literal, Optional, Tuple, Union from tqdm import tqdm @@ -250,6 +251,9 @@ class Agent(BaseAgent): # if there are tool rules, print out a warning warnings.warn("Tool rules only work reliably for the latest OpenAI models that support structured outputs.") # add default rule for having send_message be a terminal tool + + if not is_list(agent_state.tool_rules): + agent_state.tool_rules = [] agent_state.tool_rules.append(TerminalToolRule(tool_name="send_message")) self.tool_rules_solver = ToolRulesSolver(tool_rules=agent_state.tool_rules) diff --git a/letta/agent_store/db.py b/letta/agent_store/db.py index 840c03ce..0b00e2ad 100644 --- a/letta/agent_store/db.py +++ b/letta/agent_store/db.py @@ -358,26 +358,26 @@ class PostgresStorageConnector(SQLStorageConnector): # construct URI from enviornment variables if settings.pg_uri: self.uri = settings.pg_uri + + # use config URI + # TODO: remove this eventually (config should NOT contain URI) + if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES: + self.uri = self.config.archival_storage_uri + self.db_model = PassageModel + if self.config.archival_storage_uri is None: + raise ValueError(f"Must specify archival_storage_uri in config {self.config.config_path}") + elif table_type == TableType.RECALL_MEMORY: + self.uri = self.config.recall_storage_uri + self.db_model = MessageModel + if self.config.recall_storage_uri is None: + raise ValueError(f"Must specify recall_storage_uri in config {self.config.config_path}") + elif table_type == TableType.FILES: + self.uri = self.config.metadata_storage_uri + self.db_model = FileMetadataModel + if self.config.metadata_storage_uri is None: + raise ValueError(f"Must specify metadata_storage_uri in config {self.config.config_path}") else: - # use config URI - # TODO: remove this eventually (config should NOT contain URI) - if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES: - self.uri = self.config.archival_storage_uri - self.db_model = PassageModel - if self.config.archival_storage_uri is None: - raise ValueError(f"Must specify archival_storage_uri in config {self.config.config_path}") - elif table_type == TableType.RECALL_MEMORY: - self.uri = self.config.recall_storage_uri - self.db_model = MessageModel - if self.config.recall_storage_uri is None: - raise ValueError(f"Must specify recall_storage_uri in config {self.config.config_path}") - elif table_type == TableType.FILES: - self.uri = self.config.metadata_storage_uri - self.db_model = FileMetadataModel - if self.config.metadata_storage_uri is None: - raise ValueError(f"Must specify metadata_storage_uri in config {self.config.config_path}") - else: - raise ValueError(f"Table type {table_type} not implemented") + raise ValueError(f"Table type {table_type} not implemented") for c in self.db_model.__table__.columns: if c.name == "embedding": diff --git a/letta/client/client.py b/letta/client/client.py index 3dd1a814..b8b0c70b 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -2267,18 +2267,18 @@ class LocalClient(AbstractClient): langchain_tool=langchain_tool, additional_imports_module_attr_map=additional_imports_module_attr_map, ) - return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user) + return self.server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=self.user) def load_crewai_tool(self, crewai_tool: "CrewAIBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool: tool_create = ToolCreate.from_crewai( crewai_tool=crewai_tool, additional_imports_module_attr_map=additional_imports_module_attr_map, ) - return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user) + return self.server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=self.user) def load_composio_tool(self, action: "ActionType") -> Tool: tool_create = ToolCreate.from_composio(action=action) - return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user) + return self.server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=self.user) # TODO: Use the above function `add_tool` here as there is duplicate logic def create_tool( @@ -2310,7 +2310,7 @@ class LocalClient(AbstractClient): # call server function return self.server.tool_manager.create_or_update_tool( - ToolCreate( + Tool( source_type=source_type, source_code=source_code, name=name, @@ -2738,7 +2738,7 @@ class LocalClient(AbstractClient): return self.server.list_embedding_models() def create_org(self, name: Optional[str] = None) -> Organization: - return self.server.organization_manager.create_organization(name=name) + return self.server.organization_manager.create_organization(pydantic_org=Organization(name=name)) def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]: return self.server.organization_manager.list_organizations(cursor=cursor, limit=limit) diff --git a/letta/orm/base.py b/letta/orm/base.py index d8a84751..e9491c41 100644 --- a/letta/orm/base.py +++ b/letta/orm/base.py @@ -67,7 +67,7 @@ class CommonSqlalchemyMetaMixins(Base): prop_value = getattr(self, full_prop, None) if not prop_value: return - return f"user-{prop_value}" + return prop_value def _user_id_setter(self, prop: str, value: str) -> None: """returns the user id for the specified property""" @@ -75,6 +75,9 @@ class CommonSqlalchemyMetaMixins(Base): if not value: setattr(self, full_prop, None) return + # Safety check prefix, id_ = value.split("-", 1) assert prefix == "user", f"{prefix} is not a valid id prefix for a user id" - setattr(self, full_prop, id_) + + # Set the full value + setattr(self, full_prop, value) diff --git a/letta/orm/mixins.py b/letta/orm/mixins.py index 6ff3ec19..57145475 100644 --- a/letta/orm/mixins.py +++ b/letta/orm/mixins.py @@ -1,11 +1,9 @@ -from typing import Optional from uuid import UUID from sqlalchemy import ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column from letta.orm.base import Base -from letta.orm.errors import MalformedIdError def is_valid_uuid4(uuid_string: str) -> bool: @@ -17,53 +15,12 @@ def is_valid_uuid4(uuid_string: str) -> bool: return False -def _relation_getter(instance: "Base", prop: str) -> Optional[str]: - """Get relation and return id with prefix as a string.""" - prefix = prop.replace("_", "") - formatted_prop = f"_{prop}_id" - try: - id_ = getattr(instance, formatted_prop) # Get the string id directly - return f"{prefix}-{id_}" - except AttributeError: - return None - - -def _relation_setter(instance: "Base", prop: str, value: str) -> None: - """Set relation using the id with prefix, ensuring the id is a valid UUIDv4.""" - formatted_prop = f"_{prop}_id" - prefix = prop.replace("_", "") - if not value: - setattr(instance, formatted_prop, None) - return - try: - found_prefix, id_ = value.split("-", 1) - except ValueError as e: - raise MalformedIdError(f"{value} is not a valid ID.") from e - - # Ensure prefix matches - assert found_prefix == prefix, f"{found_prefix} is not a valid id prefix, expecting {prefix}" - - # Validate that the id is a valid UUID4 string - if not is_valid_uuid4(id_): - raise MalformedIdError(f"Hash segment of {value} is not a valid UUID4") - - setattr(instance, formatted_prop, id_) # Store id as a string - - class OrganizationMixin(Base): """Mixin for models that belong to an organization.""" __abstract__ = True - _organization_id: Mapped[str] = mapped_column(String, ForeignKey("organization._id")) - - @property - def organization_id(self) -> str: - return _relation_getter(self, "organization") - - @organization_id.setter - def organization_id(self, value: str) -> None: - _relation_setter(self, "organization", value) + organization_id: Mapped[str] = mapped_column(String, ForeignKey("organizations.id")) class UserMixin(Base): @@ -71,12 +28,4 @@ class UserMixin(Base): __abstract__ = True - _user_id: Mapped[str] = mapped_column(String, ForeignKey("user._id")) - - @property - def user_id(self) -> str: - return _relation_getter(self, "user") - - @user_id.setter - def user_id(self, value: str) -> None: - _relation_setter(self, "user", value) + user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id")) diff --git a/letta/orm/organization.py b/letta/orm/organization.py index 51e87e8a..88f8ea5d 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, List +from sqlalchemy import String from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.sqlalchemy_base import SqlalchemyBase @@ -14,9 +15,10 @@ if TYPE_CHECKING: class Organization(SqlalchemyBase): """The highest level of the object tree. All Entities belong to one and only one Organization.""" - __tablename__ = "organization" + __tablename__ = "organizations" __pydantic_model__ = PydanticOrganization + id: Mapped[str] = mapped_column(String, primary_key=True) name: Mapped[str] = mapped_column(doc="The display name of the organization.") users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan") diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 2e5954e4..20728d7b 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -1,14 +1,11 @@ from typing import TYPE_CHECKING, List, Literal, Optional, Type -from uuid import uuid4 -from humps import depascalize -from sqlalchemy import Boolean, String, select +from sqlalchemy import String, select from sqlalchemy.orm import Mapped, mapped_column from letta.log import get_logger from letta.orm.base import Base, CommonSqlalchemyMetaMixins from letta.orm.errors import NoResultFound -from letta.orm.mixins import is_valid_uuid4 if TYPE_CHECKING: from pydantic import BaseModel @@ -24,27 +21,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): __order_by_default__ = "created_at" - _id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"{uuid4()}") - - deleted: Mapped[bool] = mapped_column(Boolean, default=False, doc="Is this record deleted? Used for universal soft deletes.") - - @classmethod - def __prefix__(cls) -> str: - return depascalize(cls.__name__) - - @property - def id(self) -> Optional[str]: - if self._id: - return f"{self.__prefix__()}-{self._id}" - - @id.setter - def id(self, value: str) -> None: - if not value: - return - prefix, id_ = value.split("-", 1) - assert prefix == self.__prefix__(), f"{prefix} is not a valid id prefix for {self.__class__.__name__}" - assert is_valid_uuid4(id_), f"{id_} is not a valid uuid4" - self._id = id_ + id: Mapped[str] = mapped_column(String, primary_key=True) @classmethod def list( @@ -57,11 +34,10 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): # Add a cursor condition if provided if cursor: - cursor_uuid = cls.get_uid_from_identifier(cursor) # Assuming the cursor is an _id value - query = query.where(cls._id > cursor_uuid) + query = query.where(cls.id > cursor) # Add a limit to the query if provided - query = query.order_by(cls._id).limit(limit) + query = query.order_by(cls.id).limit(limit) # Handle soft deletes if the class has the 'is_deleted' attribute if hasattr(cls, "is_deleted"): @@ -70,20 +46,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): # Execute the query and return the results as a list of model instances return list(session.execute(query).scalars()) - @classmethod - def get_uid_from_identifier(cls, identifier: str, indifferent: Optional[bool] = False) -> str: - """converts the id into a uuid object - Args: - identifier: the string identifier, such as `organization-xxxx-xx...` - indifferent: if True, will not enforce the prefix check - """ - try: - uuid_string = identifier.split("-", 1)[1] if indifferent else identifier.replace(f"{cls.__prefix__()}-", "") - assert is_valid_uuid4(uuid_string) - return uuid_string - except ValueError as e: - raise ValueError(f"{identifier} is not a valid identifier for class {cls.__name__}") from e - @classmethod def read( cls, @@ -112,8 +74,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): # If an identifier is provided, add it to the query conditions if identifier is not None: - identifier = cls.get_uid_from_identifier(identifier) - query = query.where(cls._id == identifier) + query = query.where(cls.id == identifier) query_conditions.append(f"id='{identifier}'") if kwargs: @@ -183,7 +144,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): org_id = getattr(actor, "organization_id", None) if not org_id: raise ValueError(f"object {actor} has no organization accessor") - return query.where(cls._organization_id == cls.get_uid_from_identifier(org_id, indifferent=True), cls.is_deleted == False) + return query.where(cls.organization_id == org_id, cls.is_deleted == False) @property def __pydantic_model__(self) -> Type["BaseModel"]: diff --git a/letta/orm/tool.py b/letta/orm/tool.py index acf2761b..5e0ec0d0 100644 --- a/letta/orm/tool.py +++ b/letta/orm/tool.py @@ -21,13 +21,14 @@ class Tool(SqlalchemyBase, OrganizationMixin): more granular permissions. """ - __tablename__ = "tool" + __tablename__ = "tools" __pydantic_model__ = PydanticTool # Add unique constraint on (name, _organization_id) # An organization should not have multiple tools with the same name - __table_args__ = (UniqueConstraint("name", "_organization_id", name="uix_name_organization"),) + __table_args__ = (UniqueConstraint("name", "organization_id", name="uix_name_organization"),) + id: Mapped[str] = mapped_column(String, primary_key=True) name: Mapped[str] = mapped_column(doc="The display name of the tool.") description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.") tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.") diff --git a/letta/orm/user.py b/letta/orm/user.py index dfa0acc9..05f69102 100644 --- a/letta/orm/user.py +++ b/letta/orm/user.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING +from sqlalchemy import String from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.mixins import OrganizationMixin @@ -13,9 +14,10 @@ if TYPE_CHECKING: class User(SqlalchemyBase, OrganizationMixin): """User ORM class""" - __tablename__ = "user" + __tablename__ = "users" __pydantic_model__ = PydanticUser + id: Mapped[str] = mapped_column(String, primary_key=True) name: Mapped[str] = mapped_column(nullable=False, doc="The display name of the user.") # relationships diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index ae3b34dd..1ce7b4c7 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -106,6 +106,10 @@ class Memory(BaseModel, validate_assignment=True): # New format obj.prompt_template = state["prompt_template"] for key, value in state["memory"].items(): + # TODO: This is migration code, please take a look at a later time to get rid of this + if "name" in value: + value["template_name"] = value["name"] + value.pop("name") obj.memory[key] = Block(**value) else: # Old format (pre-template) diff --git a/letta/schemas/organization.py b/letta/schemas/organization.py index 38fb9cef..35784ad0 100644 --- a/letta/schemas/organization.py +++ b/letta/schemas/organization.py @@ -4,16 +4,16 @@ from typing import Optional from pydantic import Field from letta.schemas.letta_base import LettaBase -from letta.utils import get_utc_time +from letta.utils import create_random_username, get_utc_time class OrganizationBase(LettaBase): - __id_prefix__ = "organization" + __id_prefix__ = "org" class Organization(OrganizationBase): - id: str = Field(..., description="The id of the organization.") - name: str = Field(..., description="The name of the organization.") + id: str = OrganizationBase.generate_id_field() + name: str = Field(create_random_username(), description="The name of the organization.") created_at: Optional[datetime] = Field(default_factory=get_utc_time, description="The creation date of the organization.") diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index e538f3fa..34733a18 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -33,21 +33,21 @@ class Tool(BaseTool): """ - id: str = Field(..., description="The id of the tool.") + id: str = BaseTool.generate_id_field() description: Optional[str] = Field(None, description="The description of the tool.") source_type: Optional[str] = Field(None, description="The type of the source code.") module: Optional[str] = Field(None, description="The module of the function.") - organization_id: str = Field(..., description="The unique identifier of the organization associated with the tool.") - name: str = Field(..., description="The name of the function.") - tags: List[str] = Field(..., description="Metadata tags.") + organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the tool.") + name: Optional[str] = Field(None, description="The name of the function.") + tags: List[str] = Field([], description="Metadata tags.") # code source_code: str = Field(..., description="The source code of the function.") - json_schema: Dict = Field(default_factory=dict, description="The JSON schema of the function.") + json_schema: Optional[Dict] = Field(None, description="The JSON schema of the function.") # metadata fields - created_by_id: str = Field(..., description="The id of the user that made this Tool.") - last_updated_by_id: str = Field(..., description="The id of the user that made this Tool.") + created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.") + last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.") def to_dict(self): """ diff --git a/letta/schemas/user.py b/letta/schemas/user.py index 674eb599..59a4594e 100644 --- a/letta/schemas/user.py +++ b/letta/schemas/user.py @@ -21,7 +21,7 @@ class User(UserBase): created_at (datetime): The creation date of the user. """ - id: str = Field(..., description="The id of the user.") + id: str = UserBase.generate_id_field() organization_id: Optional[str] = Field(OrganizationManager.DEFAULT_ORG_ID, description="The organization id of the user") name: str = Field(..., description="The name of the user.") created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The creation date of the user.") diff --git a/letta/server/rest_api/routers/v1/organizations.py b/letta/server/rest_api/routers/v1/organizations.py index a52d81c7..2f4cdb1b 100644 --- a/letta/server/rest_api/routers/v1/organizations.py +++ b/letta/server/rest_api/routers/v1/organizations.py @@ -38,7 +38,8 @@ def create_org( """ Create a new org in the database """ - org = server.organization_manager.create_organization(name=request.name) + org = Organization(**request.model_dump()) + org = server.organization_manager.create_organization(pydantic_org=org) return org diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index d1b442f7..22f3dc03 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -89,7 +89,8 @@ def create_tool( actor = server.get_user_or_default(user_id=user_id) # Send request to create the tool - return server.tool_manager.create_or_update_tool(tool_create=request, actor=actor) + tool = Tool(**request.model_dump()) + return server.tool_manager.create_or_update_tool(pydantic_tool=tool, actor=actor) @router.patch("/{tool_id}", response_model=Tool, operation_id="update_tool") diff --git a/letta/server/rest_api/routers/v1/users.py b/letta/server/rest_api/routers/v1/users.py index 80e9f24f..d0e0f787 100644 --- a/letta/server/rest_api/routers/v1/users.py +++ b/letta/server/rest_api/routers/v1/users.py @@ -51,8 +51,8 @@ def create_user( """ Create a new user in the database """ - - user = server.user_manager.create_user(request) + user = User(**request.model_dump()) + user = server.user_manager.create_user(user) return user diff --git a/letta/server/server.py b/letta/server/server.py index 2fbd7806..55b46135 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -824,7 +824,7 @@ class SyncServer(Server): source_type = "python" tags = ["memory", "memgpt-base"] tool = self.tool_manager.create_or_update_tool( - ToolCreate( + Tool( source_code=source_code, source_type=source_type, tags=tags, @@ -1766,7 +1766,7 @@ class SyncServer(Server): tool_creates += ToolCreate.load_default_composio_tools() for tool_create in tool_creates: try: - self.tool_manager.create_or_update_tool(tool_create, actor=actor) + self.tool_manager.create_or_update_tool(Tool(**tool_create.model_dump()), actor=actor) except Exception as e: warnings.warn(f"An error occurred while creating tool {tool_create}: {e}") warnings.warn(traceback.format_exc()) diff --git a/letta/services/organization_manager.py b/letta/services/organization_manager.py index 8c13c037..1832c580 100644 --- a/letta/services/organization_manager.py +++ b/letta/services/organization_manager.py @@ -3,13 +3,13 @@ from typing import List, Optional from letta.orm.errors import NoResultFound from letta.orm.organization import Organization as OrganizationModel from letta.schemas.organization import Organization as PydanticOrganization -from letta.utils import create_random_username, enforce_types +from letta.utils import enforce_types class OrganizationManager: """Manager class to handle business logic related to Organizations.""" - DEFAULT_ORG_ID = "organization-00000000-0000-4000-8000-000000000000" + DEFAULT_ORG_ID = "org-00000000-0000-4000-8000-000000000000" DEFAULT_ORG_NAME = "default_org" def __init__(self): @@ -37,10 +37,10 @@ class OrganizationManager: raise ValueError(f"Organization with id {org_id} not found.") @enforce_types - def create_organization(self, name: Optional[str] = None) -> PydanticOrganization: + def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization: """Create a new organization. If a name is provided, it is used, otherwise, a random one is generated.""" with self.session_maker() as session: - org = OrganizationModel(name=name if name else create_random_username()) + org = OrganizationModel(**pydantic_org.model_dump()) org.create(session) return org.to_pydantic() diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index cee451df..1b85e316 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -7,10 +7,9 @@ from letta.functions.functions import derive_openai_json_schema, load_function_s # TODO: Remove this once we translate all of these to the ORM from letta.orm.errors import NoResultFound -from letta.orm.organization import Organization as OrganizationModel from letta.orm.tool import Tool as ToolModel from letta.schemas.tool import Tool as PydanticTool -from letta.schemas.tool import ToolCreate, ToolUpdate +from letta.schemas.tool import ToolUpdate from letta.schemas.user import User as PydanticUser from letta.utils import enforce_types, printd @@ -33,20 +32,20 @@ class ToolManager: self.session_maker = db_context @enforce_types - def create_or_update_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool: + def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: """Create a new tool based on the ToolCreate schema.""" # Derive json_schema - derived_json_schema = tool_create.json_schema or derive_openai_json_schema( - source_code=tool_create.source_code, name=tool_create.name + derived_json_schema = pydantic_tool.json_schema or derive_openai_json_schema( + source_code=pydantic_tool.source_code, name=pydantic_tool.name ) - derived_name = tool_create.name or derived_json_schema["name"] + derived_name = pydantic_tool.name or derived_json_schema["name"] try: # NOTE: We use the organization id here # This is important, because even if it's a different user, adding the same tool to the org should not happen tool = self.get_tool_by_name(tool_name=derived_name, actor=actor) # Put to dict and remove fields that should not be reset - update_data = tool_create.model_dump(exclude={"module"}, exclude_unset=True) + update_data = pydantic_tool.model_dump(exclude={"module"}, exclude_unset=True, exclude_none=True) # Remove redundant update fields update_data = {key: value for key, value in update_data.items() if getattr(tool, key) != value} @@ -55,22 +54,24 @@ class ToolManager: self.update_tool_by_id(tool.id, ToolUpdate(**update_data), actor) else: printd( - f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={tool_create.name}, but found existing tool with nothing to update." + f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update." ) except NoResultFound: - tool_create.json_schema = derived_json_schema - tool_create.name = derived_name - tool = self.create_tool(tool_create, actor=actor) + pydantic_tool.json_schema = derived_json_schema + pydantic_tool.name = derived_name + tool = self.create_tool(pydantic_tool, actor=actor) return tool @enforce_types - def create_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool: + def create_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: """Create a new tool based on the ToolCreate schema.""" # Create the tool with self.session_maker() as session: - create_data = tool_create.model_dump() - tool = ToolModel(**create_data, organization_id=actor.organization_id) # Unpack everything directly into ToolModel + # Set the organization id at the ORM layer + pydantic_tool.organization_id = actor.organization_id + tool_data = pydantic_tool.model_dump() + tool = ToolModel(**tool_data) tool.create(session, actor=actor) return tool.to_pydantic() @@ -99,7 +100,7 @@ class ToolManager: db_session=session, cursor=cursor, limit=limit, - _organization_id=OrganizationModel.get_uid_from_identifier(actor.organization_id), + organization_id=actor.organization_id, ) return [tool.to_pydantic() for tool in tools] @@ -176,7 +177,7 @@ class ToolManager: # create to tool tools.append( self.create_or_update_tool( - ToolCreate( + PydanticTool( name=name, tags=tags, source_type="python", diff --git a/letta/services/user_manager.py b/letta/services/user_manager.py index c9f6b166..42df72fa 100644 --- a/letta/services/user_manager.py +++ b/letta/services/user_manager.py @@ -4,7 +4,7 @@ from letta.orm.errors import NoResultFound from letta.orm.organization import Organization as OrganizationModel from letta.orm.user import User as UserModel from letta.schemas.user import User as PydanticUser -from letta.schemas.user import UserCreate, UserUpdate +from letta.schemas.user import UserUpdate from letta.services.organization_manager import OrganizationManager from letta.utils import enforce_types @@ -42,10 +42,10 @@ class UserManager: return user.to_pydantic() @enforce_types - def create_user(self, user_create: UserCreate) -> PydanticUser: + def create_user(self, pydantic_user: PydanticUser) -> PydanticUser: """Create a new user if it doesn't already exist.""" with self.session_maker() as session: - new_user = UserModel(**user_create.model_dump()) + new_user = UserModel(**pydantic_user.model_dump()) new_user.create(session) return new_user.to_pydantic() diff --git a/scripts/migrate_tools.py b/scripts/migrate_tools.py new file mode 100644 index 00000000..6ab9ed9e --- /dev/null +++ b/scripts/migrate_tools.py @@ -0,0 +1,52 @@ +from letta.functions.functions import parse_source_code +from letta.schemas.tool import Tool +from letta.schemas.user import User +from letta.services.organization_manager import OrganizationManager +from letta.services.tool_manager import ToolManager + + +def deprecated_tool(): + return "this is a deprecated tool, please remove it from your tools list" + + +orgs = OrganizationManager().list_organizations(cursor=None, limit=5000) +for org in orgs: + if org.name != "default": + fake_user = User(id="user-00000000-0000-4000-8000-000000000000", name="fake", organization_id=org.id) + + ToolManager().add_base_tools(actor=fake_user) + + source_code = parse_source_code(deprecated_tool) + source_type = "python" + description = "deprecated" + tags = ["deprecated"] + + ToolManager().create_or_update_tool( + Tool( + name="core_memory_append", + source_code=source_code, + source_type=source_type, + description=description, + ), + actor=fake_user, + ) + + ToolManager().create_or_update_tool( + Tool( + name="core_memory_replace", + source_code=source_code, + source_type=source_type, + description=description, + ), + actor=fake_user, + ) + + ToolManager().create_or_update_tool( + Tool( + name="pause_heartbeats", + source_code=source_code, + source_type=source_type, + description=description, + ), + actor=fake_user, + ) diff --git a/tests/test_local_client.py b/tests/test_local_client.py index 246abd71..2ffd26f7 100644 --- a/tests/test_local_client.py +++ b/tests/test_local_client.py @@ -420,7 +420,7 @@ def test_tools_from_langchain(client: LocalClient): exec(source_code, {}, local_scope) func = local_scope[tool.name] - expected_content = "Albert Einstein ( EYEN-styne; German:" + expected_content = "Albert Einstein" assert expected_content in func(query="Albert Einstein") diff --git a/tests/test_managers.py b/tests/test_managers.py index c8232963..557c00c9 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -6,12 +6,15 @@ from letta.functions.functions import derive_openai_json_schema, parse_source_co from letta.orm.organization import Organization from letta.orm.tool import Tool from letta.orm.user import User -from letta.schemas.tool import ToolCreate, ToolUpdate +from letta.schemas.organization import Organization as PydanticOrganization +from letta.schemas.tool import Tool as PydanticTool +from letta.schemas.tool import ToolUpdate from letta.services.organization_manager import OrganizationManager utils.DEBUG = True from letta.config import LettaConfig -from letta.schemas.user import UserCreate, UserUpdate +from letta.schemas.user import User as PydanticUser +from letta.schemas.user import UserUpdate from letta.server.server import SyncServer @@ -47,17 +50,17 @@ def tool_fixture(server: SyncServer): org = server.organization_manager.create_default_organization() user = server.user_manager.create_default_user() - other_user = server.user_manager.create_user(UserCreate(name="other", organization_id=org.id)) - tool_create = ToolCreate(description=description, tags=tags, source_code=source_code, source_type=source_type) - derived_json_schema = derive_openai_json_schema(source_code=tool_create.source_code, name=tool_create.name) + other_user = server.user_manager.create_user(PydanticUser(name="other", organization_id=org.id)) + tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type) + derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name) derived_name = derived_json_schema["name"] - tool_create.json_schema = derived_json_schema - tool_create.name = derived_name + tool.json_schema = derived_json_schema + tool.name = derived_name - tool = server.tool_manager.create_tool(tool_create, actor=user) + tool = server.tool_manager.create_tool(tool, actor=user) # Yield the created tool, organization, and user for use in tests - yield {"tool": tool, "organization": org, "user": user, "other_user": other_user, "tool_create": tool_create} + yield {"tool": tool, "organization": org, "user": user, "other_user": other_user, "tool_create": tool} @pytest.fixture(scope="module") @@ -76,7 +79,7 @@ def server(): def test_list_organizations(server: SyncServer): # Create a new org and confirm that it is created correctly org_name = "test" - org = server.organization_manager.create_organization(name=org_name) + org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name)) orgs = server.organization_manager.list_organizations() assert len(orgs) == 1 @@ -96,15 +99,15 @@ def test_create_default_organization(server: SyncServer): def test_update_organization_name(server: SyncServer): org_name_a = "a" org_name_b = "b" - org = server.organization_manager.create_organization(name=org_name_a) + org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name_a)) assert org.name == org_name_a org = server.organization_manager.update_organization_name_using_id(org_id=org.id, name=org_name_b) assert org.name == org_name_b def test_list_organizations_pagination(server: SyncServer): - server.organization_manager.create_organization(name="a") - server.organization_manager.create_organization(name="b") + server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="a")) + server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="b")) orgs_x = server.organization_manager.list_organizations(limit=1) assert len(orgs_x) == 1 @@ -125,7 +128,7 @@ def test_list_users(server: SyncServer): org = server.organization_manager.create_default_organization() user_name = "user" - user = server.user_manager.create_user(UserCreate(name=user_name, organization_id=org.id)) + user = server.user_manager.create_user(PydanticUser(name=user_name, organization_id=org.id)) users = server.user_manager.list_users() assert len(users) == 1 @@ -146,13 +149,13 @@ def test_create_default_user(server: SyncServer): def test_update_user(server: SyncServer): # Create default organization default_org = server.organization_manager.create_default_organization() - test_org = server.organization_manager.create_organization(name="test_org") + test_org = server.organization_manager.create_organization(PydanticOrganization(name="test_org")) user_name_a = "a" user_name_b = "b" # Assert it's been created - user = server.user_manager.create_user(UserCreate(name=user_name_a, organization_id=default_org.id)) + user = server.user_manager.create_user(PydanticUser(name=user_name_a, organization_id=default_org.id)) assert user.name == user_name_a # Adjust name @@ -340,7 +343,6 @@ def test_update_tool_multi_user(server: SyncServer, tool_fixture): server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=other_user) # Check that the created_by and last_updated_by fields are correct - # Fetch the updated tool to verify the changes updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user) diff --git a/tests/test_tools.py b/tests/test_tools.py index 4195b220..f7e9464c 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -14,6 +14,7 @@ from letta.constants import DEFAULT_PRESET from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ChatMemory +from letta.services.tool_manager import ToolManager test_agent_name = f"test_client_{str(uuid.uuid4())}" # test_preset_name = "test_preset" @@ -93,15 +94,9 @@ def test_create_tool(client: Union[LocalClient, RESTClient]): return message tools = client.list_tools() - assert sorted([t.name for t in tools]) == sorted( - [ - "archival_memory_search", - "send_message", - "conversation_search", - "conversation_search_date", - "archival_memory_insert", - ] - ) + tool_names = [t.name for t in tools] + for tool in ToolManager.BASE_TOOL_NAMES: + assert tool in tool_names tool = client.create_tool(print_tool, name="my_name", tags=["extras"])