chore: Move ID generation logic out of the ORM layer and into the Pydantic model layer (#1981)
This commit is contained in:
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -131,4 +131,4 @@ jobs:
|
||||
LETTA_SERVER_PASS: test_server_token
|
||||
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
|
||||
run: |
|
||||
poetry run pytest -s -vv -k "not test_single_path_agent_tool_call_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests
|
||||
poetry run pytest -s -vv -k "not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# Revision identifiers, used by Alembic.
|
||||
revision: str = "0c315956709d"
|
||||
down_revision: Union[str, None] = "9a505cc7eca9"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Step 1: Rename tables first
|
||||
op.rename_table("organizations", "organization")
|
||||
op.rename_table("tools", "tool")
|
||||
op.rename_table("users", "user")
|
||||
|
||||
# Step 2: Rename `id` to `_id` in each table, keeping it as the primary key
|
||||
op.alter_column("organization", "id", new_column_name="_id", existing_type=sa.String, nullable=False)
|
||||
op.alter_column("tool", "id", new_column_name="_id", existing_type=sa.String, nullable=False)
|
||||
op.alter_column("user", "id", new_column_name="_id", existing_type=sa.String, nullable=False)
|
||||
|
||||
# Step 3: Add `_organization_id` to `tool` table
|
||||
# This is required for the unique constraint below
|
||||
op.add_column("tool", sa.Column("_organization_id", sa.String, nullable=True))
|
||||
|
||||
# Step 4: Modify nullable constraints on `tool` columns
|
||||
op.alter_column("tool", "tags", existing_type=sa.JSON, nullable=True)
|
||||
op.alter_column("tool", "source_type", existing_type=sa.String, nullable=True)
|
||||
op.alter_column("tool", "json_schema", existing_type=sa.JSON, nullable=True)
|
||||
|
||||
# Step 5: Add unique constraint on `name` and `_organization_id` in `tool` table
|
||||
op.create_unique_constraint("uq_tool_name_organization", "tool", ["name", "_organization_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Reverse unique constraint first
|
||||
op.drop_constraint("uq_tool_name_organization", "tool", type_="unique")
|
||||
|
||||
# Reverse nullable constraints on `tool` columns
|
||||
op.alter_column("tool", "tags", existing_type=sa.JSON, nullable=False)
|
||||
op.alter_column("tool", "source_type", existing_type=sa.String, nullable=False)
|
||||
op.alter_column("tool", "json_schema", existing_type=sa.JSON, nullable=False)
|
||||
|
||||
# Remove `_organization_id` column from `tool` table
|
||||
op.drop_column("tool", "_organization_id")
|
||||
|
||||
# Reverse the column renaming from `_id` back to `id`
|
||||
op.alter_column("organization", "_id", new_column_name="id", existing_type=sa.String, nullable=False)
|
||||
op.alter_column("tool", "_id", new_column_name="id", existing_type=sa.String, nullable=False)
|
||||
op.alter_column("user", "_id", new_column_name="id", existing_type=sa.String, nullable=False)
|
||||
|
||||
# Reverse table renaming last
|
||||
op.rename_table("organization", "organizations")
|
||||
op.rename_table("tool", "tools")
|
||||
op.rename_table("user", "users")
|
||||
@@ -0,0 +1,95 @@
|
||||
"""Move organizations users tools to orm
|
||||
|
||||
Revision ID: d14ae606614c
|
||||
Revises: 9a505cc7eca9
|
||||
Create Date: 2024-11-05 15:03:12.350096
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import letta
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d14ae606614c"
|
||||
down_revision: Union[str, None] = "9a505cc7eca9"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def deprecated_tool():
|
||||
return "this is a deprecated tool, please remove it from your tools list"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Delete all tools
|
||||
op.execute("DELETE FROM tools")
|
||||
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("agents", sa.Column("tool_rules", letta.metadata.ToolRulesColumn(), nullable=True))
|
||||
op.alter_column("block", "name", new_column_name="template_name", nullable=True)
|
||||
op.add_column("organizations", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
|
||||
op.add_column("organizations", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
|
||||
op.add_column("organizations", sa.Column("_created_by_id", sa.String(), nullable=True))
|
||||
op.add_column("organizations", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
|
||||
op.add_column("tools", sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
|
||||
op.add_column("tools", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
|
||||
op.add_column("tools", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
|
||||
op.add_column("tools", sa.Column("_created_by_id", sa.String(), nullable=True))
|
||||
op.add_column("tools", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
|
||||
op.add_column("tools", sa.Column("organization_id", sa.String(), nullable=False))
|
||||
op.alter_column("tools", "tags", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
|
||||
op.alter_column("tools", "source_type", existing_type=sa.VARCHAR(), nullable=False)
|
||||
op.alter_column("tools", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
|
||||
op.create_unique_constraint("uix_name_organization", "tools", ["name", "organization_id"])
|
||||
op.create_foreign_key(None, "tools", "organizations", ["organization_id"], ["id"])
|
||||
op.drop_column("tools", "user_id")
|
||||
op.add_column("users", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
|
||||
op.add_column("users", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
|
||||
op.add_column("users", sa.Column("_created_by_id", sa.String(), nullable=True))
|
||||
op.add_column("users", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
|
||||
op.add_column("users", sa.Column("organization_id", sa.String(), nullable=True))
|
||||
# loop through all rows in the user table and set the _organization_id column from organization_id
|
||||
op.execute('UPDATE "users" SET organization_id = org_id')
|
||||
# set the _organization_id column to not nullable
|
||||
op.alter_column("users", "organization_id", existing_type=sa.String(), nullable=False)
|
||||
op.create_foreign_key(None, "users", "organizations", ["organization_id"], ["id"])
|
||||
op.drop_column("users", "org_id")
|
||||
op.drop_column("users", "policies_accepted")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("users", sa.Column("policies_accepted", sa.BOOLEAN(), autoincrement=False, nullable=False))
|
||||
op.add_column("users", sa.Column("org_id", sa.VARCHAR(), autoincrement=False, nullable=True))
|
||||
op.drop_constraint(None, "users", type_="foreignkey")
|
||||
op.drop_column("users", "organization_id")
|
||||
op.drop_column("users", "_last_updated_by_id")
|
||||
op.drop_column("users", "_created_by_id")
|
||||
op.drop_column("users", "is_deleted")
|
||||
op.drop_column("users", "updated_at")
|
||||
op.add_column("tools", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=True))
|
||||
op.drop_constraint(None, "tools", type_="foreignkey")
|
||||
op.drop_constraint("uix_name_organization", "tools", type_="unique")
|
||||
op.alter_column("tools", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
|
||||
op.alter_column("tools", "source_type", existing_type=sa.VARCHAR(), nullable=True)
|
||||
op.alter_column("tools", "tags", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
|
||||
op.drop_column("tools", "organization_id")
|
||||
op.drop_column("tools", "_last_updated_by_id")
|
||||
op.drop_column("tools", "_created_by_id")
|
||||
op.drop_column("tools", "is_deleted")
|
||||
op.drop_column("tools", "updated_at")
|
||||
op.drop_column("tools", "created_at")
|
||||
op.drop_column("organizations", "_last_updated_by_id")
|
||||
op.drop_column("organizations", "_created_by_id")
|
||||
op.drop_column("organizations", "is_deleted")
|
||||
op.drop_column("organizations", "updated_at")
|
||||
op.add_column("block", sa.Column("name", sa.VARCHAR(), autoincrement=False, nullable=True))
|
||||
op.drop_column("block", "template_name")
|
||||
op.drop_column("agents", "tool_rules")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,89 +0,0 @@
|
||||
"""Sync migration with model changes
|
||||
|
||||
Revision ID: ee50a967e090
|
||||
Revises: 0c315956709d
|
||||
Create Date: 2024-11-01 19:18:39.399950
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import letta
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "ee50a967e090"
|
||||
down_revision: Union[str, None] = "0c315956709d"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("agents", sa.Column("tool_rules", letta.metadata.ToolRulesColumn(), nullable=True))
|
||||
op.add_column("organization", sa.Column("deleted", sa.Boolean(), nullable=False))
|
||||
op.add_column("organization", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
|
||||
op.add_column("organization", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
|
||||
op.add_column("organization", sa.Column("_created_by_id", sa.String(), nullable=True))
|
||||
op.add_column("organization", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
|
||||
op.add_column("tool", sa.Column("deleted", sa.Boolean(), nullable=False))
|
||||
op.add_column("tool", sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
|
||||
op.add_column("tool", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
|
||||
op.add_column("tool", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
|
||||
op.add_column("tool", sa.Column("_created_by_id", sa.String(), nullable=True))
|
||||
op.add_column("tool", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
|
||||
op.alter_column("tool", "tags", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
|
||||
op.alter_column("tool", "source_type", existing_type=sa.VARCHAR(), nullable=False)
|
||||
op.alter_column("tool", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
|
||||
op.alter_column("tool", "_organization_id", existing_type=sa.VARCHAR(), nullable=False)
|
||||
op.drop_constraint("uq_tool_name_organization", "tool", type_="unique")
|
||||
op.create_unique_constraint("uix_name_organization", "tool", ["name", "_organization_id"])
|
||||
op.create_foreign_key(None, "tool", "organization", ["_organization_id"], ["_id"])
|
||||
op.drop_column("tool", "user_id")
|
||||
op.add_column("user", sa.Column("deleted", sa.Boolean(), nullable=False))
|
||||
op.add_column("user", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
|
||||
op.add_column("user", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
|
||||
op.add_column("user", sa.Column("_created_by_id", sa.String(), nullable=True))
|
||||
op.add_column("user", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
|
||||
op.add_column("user", sa.Column("_organization_id", sa.String(), nullable=False))
|
||||
op.create_foreign_key(None, "user", "organization", ["_organization_id"], ["_id"])
|
||||
op.drop_column("user", "policies_accepted")
|
||||
op.drop_column("user", "org_id")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("user", sa.Column("org_id", sa.VARCHAR(), autoincrement=False, nullable=True))
|
||||
op.add_column("user", sa.Column("policies_accepted", sa.BOOLEAN(), autoincrement=False, nullable=False))
|
||||
op.drop_constraint(None, "user", type_="foreignkey")
|
||||
op.drop_column("user", "_organization_id")
|
||||
op.drop_column("user", "_last_updated_by_id")
|
||||
op.drop_column("user", "_created_by_id")
|
||||
op.drop_column("user", "is_deleted")
|
||||
op.drop_column("user", "updated_at")
|
||||
op.drop_column("user", "deleted")
|
||||
op.add_column("tool", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=True))
|
||||
op.drop_constraint(None, "tool", type_="foreignkey")
|
||||
op.drop_constraint("uix_name_organization", "tool", type_="unique")
|
||||
op.create_unique_constraint("uq_tool_name_organization", "tool", ["name", "_organization_id"])
|
||||
op.alter_column("tool", "_organization_id", existing_type=sa.VARCHAR(), nullable=True)
|
||||
op.alter_column("tool", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
|
||||
op.alter_column("tool", "source_type", existing_type=sa.VARCHAR(), nullable=True)
|
||||
op.alter_column("tool", "tags", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
|
||||
op.drop_column("tool", "_last_updated_by_id")
|
||||
op.drop_column("tool", "_created_by_id")
|
||||
op.drop_column("tool", "is_deleted")
|
||||
op.drop_column("tool", "updated_at")
|
||||
op.drop_column("tool", "created_at")
|
||||
op.drop_column("tool", "deleted")
|
||||
op.drop_column("organization", "_last_updated_by_id")
|
||||
op.drop_column("organization", "_created_by_id")
|
||||
op.drop_column("organization", "is_deleted")
|
||||
op.drop_column("organization", "updated_at")
|
||||
op.drop_column("organization", "deleted")
|
||||
op.drop_column("agents", "tool_rules")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,36 +0,0 @@
|
||||
"""Rename block.name to block.template_name
|
||||
|
||||
Revision ID: eff245f340f9
|
||||
Revises: 0c315956709d
|
||||
Create Date: 2024-10-31 18:09:08.819371
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "eff245f340f9"
|
||||
down_revision: Union[str, None] = "ee50a967e090"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column("block", "name", new_column_name="template_name", existing_type=sa.String(), nullable=True)
|
||||
|
||||
# op.add_column('block', sa.Column('template_name', sa.String(), nullable=True))
|
||||
# op.drop_column('block', 'name')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column("block", "template_name", new_column_name="name", existing_type=sa.String(), nullable=True)
|
||||
# op.add_column('block', sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=True))
|
||||
# op.drop_column('block', 'template_name')
|
||||
# ### end Alembic commands ###
|
||||
@@ -3,6 +3,7 @@ import inspect
|
||||
import traceback
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from lib2to3.fixer_util import is_list
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
from tqdm import tqdm
|
||||
@@ -250,6 +251,9 @@ class Agent(BaseAgent):
|
||||
# if there are tool rules, print out a warning
|
||||
warnings.warn("Tool rules only work reliably for the latest OpenAI models that support structured outputs.")
|
||||
# add default rule for having send_message be a terminal tool
|
||||
|
||||
if not is_list(agent_state.tool_rules):
|
||||
agent_state.tool_rules = []
|
||||
agent_state.tool_rules.append(TerminalToolRule(tool_name="send_message"))
|
||||
self.tool_rules_solver = ToolRulesSolver(tool_rules=agent_state.tool_rules)
|
||||
|
||||
|
||||
@@ -358,26 +358,26 @@ class PostgresStorageConnector(SQLStorageConnector):
|
||||
# construct URI from enviornment variables
|
||||
if settings.pg_uri:
|
||||
self.uri = settings.pg_uri
|
||||
|
||||
# use config URI
|
||||
# TODO: remove this eventually (config should NOT contain URI)
|
||||
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
|
||||
self.uri = self.config.archival_storage_uri
|
||||
self.db_model = PassageModel
|
||||
if self.config.archival_storage_uri is None:
|
||||
raise ValueError(f"Must specify archival_storage_uri in config {self.config.config_path}")
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
self.uri = self.config.recall_storage_uri
|
||||
self.db_model = MessageModel
|
||||
if self.config.recall_storage_uri is None:
|
||||
raise ValueError(f"Must specify recall_storage_uri in config {self.config.config_path}")
|
||||
elif table_type == TableType.FILES:
|
||||
self.uri = self.config.metadata_storage_uri
|
||||
self.db_model = FileMetadataModel
|
||||
if self.config.metadata_storage_uri is None:
|
||||
raise ValueError(f"Must specify metadata_storage_uri in config {self.config.config_path}")
|
||||
else:
|
||||
# use config URI
|
||||
# TODO: remove this eventually (config should NOT contain URI)
|
||||
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
|
||||
self.uri = self.config.archival_storage_uri
|
||||
self.db_model = PassageModel
|
||||
if self.config.archival_storage_uri is None:
|
||||
raise ValueError(f"Must specify archival_storage_uri in config {self.config.config_path}")
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
self.uri = self.config.recall_storage_uri
|
||||
self.db_model = MessageModel
|
||||
if self.config.recall_storage_uri is None:
|
||||
raise ValueError(f"Must specify recall_storage_uri in config {self.config.config_path}")
|
||||
elif table_type == TableType.FILES:
|
||||
self.uri = self.config.metadata_storage_uri
|
||||
self.db_model = FileMetadataModel
|
||||
if self.config.metadata_storage_uri is None:
|
||||
raise ValueError(f"Must specify metadata_storage_uri in config {self.config.config_path}")
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
|
||||
for c in self.db_model.__table__.columns:
|
||||
if c.name == "embedding":
|
||||
|
||||
@@ -2267,18 +2267,18 @@ class LocalClient(AbstractClient):
|
||||
langchain_tool=langchain_tool,
|
||||
additional_imports_module_attr_map=additional_imports_module_attr_map,
|
||||
)
|
||||
return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)
|
||||
return self.server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=self.user)
|
||||
|
||||
def load_crewai_tool(self, crewai_tool: "CrewAIBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool:
|
||||
tool_create = ToolCreate.from_crewai(
|
||||
crewai_tool=crewai_tool,
|
||||
additional_imports_module_attr_map=additional_imports_module_attr_map,
|
||||
)
|
||||
return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)
|
||||
return self.server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=self.user)
|
||||
|
||||
def load_composio_tool(self, action: "ActionType") -> Tool:
|
||||
tool_create = ToolCreate.from_composio(action=action)
|
||||
return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)
|
||||
return self.server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=self.user)
|
||||
|
||||
# TODO: Use the above function `add_tool` here as there is duplicate logic
|
||||
def create_tool(
|
||||
@@ -2310,7 +2310,7 @@ class LocalClient(AbstractClient):
|
||||
|
||||
# call server function
|
||||
return self.server.tool_manager.create_or_update_tool(
|
||||
ToolCreate(
|
||||
Tool(
|
||||
source_type=source_type,
|
||||
source_code=source_code,
|
||||
name=name,
|
||||
@@ -2738,7 +2738,7 @@ class LocalClient(AbstractClient):
|
||||
return self.server.list_embedding_models()
|
||||
|
||||
def create_org(self, name: Optional[str] = None) -> Organization:
|
||||
return self.server.organization_manager.create_organization(name=name)
|
||||
return self.server.organization_manager.create_organization(pydantic_org=Organization(name=name))
|
||||
|
||||
def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
|
||||
return self.server.organization_manager.list_organizations(cursor=cursor, limit=limit)
|
||||
|
||||
@@ -67,7 +67,7 @@ class CommonSqlalchemyMetaMixins(Base):
|
||||
prop_value = getattr(self, full_prop, None)
|
||||
if not prop_value:
|
||||
return
|
||||
return f"user-{prop_value}"
|
||||
return prop_value
|
||||
|
||||
def _user_id_setter(self, prop: str, value: str) -> None:
|
||||
"""returns the user id for the specified property"""
|
||||
@@ -75,6 +75,9 @@ class CommonSqlalchemyMetaMixins(Base):
|
||||
if not value:
|
||||
setattr(self, full_prop, None)
|
||||
return
|
||||
# Safety check
|
||||
prefix, id_ = value.split("-", 1)
|
||||
assert prefix == "user", f"{prefix} is not a valid id prefix for a user id"
|
||||
setattr(self, full_prop, id_)
|
||||
|
||||
# Set the full value
|
||||
setattr(self, full_prop, value)
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.orm.base import Base
|
||||
from letta.orm.errors import MalformedIdError
|
||||
|
||||
|
||||
def is_valid_uuid4(uuid_string: str) -> bool:
|
||||
@@ -17,53 +15,12 @@ def is_valid_uuid4(uuid_string: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _relation_getter(instance: "Base", prop: str) -> Optional[str]:
|
||||
"""Get relation and return id with prefix as a string."""
|
||||
prefix = prop.replace("_", "")
|
||||
formatted_prop = f"_{prop}_id"
|
||||
try:
|
||||
id_ = getattr(instance, formatted_prop) # Get the string id directly
|
||||
return f"{prefix}-{id_}"
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
def _relation_setter(instance: "Base", prop: str, value: str) -> None:
|
||||
"""Set relation using the id with prefix, ensuring the id is a valid UUIDv4."""
|
||||
formatted_prop = f"_{prop}_id"
|
||||
prefix = prop.replace("_", "")
|
||||
if not value:
|
||||
setattr(instance, formatted_prop, None)
|
||||
return
|
||||
try:
|
||||
found_prefix, id_ = value.split("-", 1)
|
||||
except ValueError as e:
|
||||
raise MalformedIdError(f"{value} is not a valid ID.") from e
|
||||
|
||||
# Ensure prefix matches
|
||||
assert found_prefix == prefix, f"{found_prefix} is not a valid id prefix, expecting {prefix}"
|
||||
|
||||
# Validate that the id is a valid UUID4 string
|
||||
if not is_valid_uuid4(id_):
|
||||
raise MalformedIdError(f"Hash segment of {value} is not a valid UUID4")
|
||||
|
||||
setattr(instance, formatted_prop, id_) # Store id as a string
|
||||
|
||||
|
||||
class OrganizationMixin(Base):
|
||||
"""Mixin for models that belong to an organization."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
_organization_id: Mapped[str] = mapped_column(String, ForeignKey("organization._id"))
|
||||
|
||||
@property
|
||||
def organization_id(self) -> str:
|
||||
return _relation_getter(self, "organization")
|
||||
|
||||
@organization_id.setter
|
||||
def organization_id(self, value: str) -> None:
|
||||
_relation_setter(self, "organization", value)
|
||||
organization_id: Mapped[str] = mapped_column(String, ForeignKey("organizations.id"))
|
||||
|
||||
|
||||
class UserMixin(Base):
|
||||
@@ -71,12 +28,4 @@ class UserMixin(Base):
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
_user_id: Mapped[str] = mapped_column(String, ForeignKey("user._id"))
|
||||
|
||||
@property
|
||||
def user_id(self) -> str:
|
||||
return _relation_getter(self, "user")
|
||||
|
||||
@user_id.setter
|
||||
def user_id(self, value: str) -> None:
|
||||
_relation_setter(self, "user", value)
|
||||
user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"))
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
@@ -14,9 +15,10 @@ if TYPE_CHECKING:
|
||||
class Organization(SqlalchemyBase):
|
||||
"""The highest level of the object tree. All Entities belong to one and only one Organization."""
|
||||
|
||||
__tablename__ = "organization"
|
||||
__tablename__ = "organizations"
|
||||
__pydantic_model__ = PydanticOrganization
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(doc="The display name of the organization.")
|
||||
|
||||
users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional, Type
|
||||
from uuid import uuid4
|
||||
|
||||
from humps import depascalize
|
||||
from sqlalchemy import Boolean, String, select
|
||||
from sqlalchemy import String, select
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.mixins import is_valid_uuid4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
@@ -24,27 +21,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
|
||||
__order_by_default__ = "created_at"
|
||||
|
||||
_id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"{uuid4()}")
|
||||
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False, doc="Is this record deleted? Used for universal soft deletes.")
|
||||
|
||||
@classmethod
|
||||
def __prefix__(cls) -> str:
|
||||
return depascalize(cls.__name__)
|
||||
|
||||
@property
|
||||
def id(self) -> Optional[str]:
|
||||
if self._id:
|
||||
return f"{self.__prefix__()}-{self._id}"
|
||||
|
||||
@id.setter
|
||||
def id(self, value: str) -> None:
|
||||
if not value:
|
||||
return
|
||||
prefix, id_ = value.split("-", 1)
|
||||
assert prefix == self.__prefix__(), f"{prefix} is not a valid id prefix for {self.__class__.__name__}"
|
||||
assert is_valid_uuid4(id_), f"{id_} is not a valid uuid4"
|
||||
self._id = id_
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
|
||||
@classmethod
|
||||
def list(
|
||||
@@ -57,11 +34,10 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
|
||||
# Add a cursor condition if provided
|
||||
if cursor:
|
||||
cursor_uuid = cls.get_uid_from_identifier(cursor) # Assuming the cursor is an _id value
|
||||
query = query.where(cls._id > cursor_uuid)
|
||||
query = query.where(cls.id > cursor)
|
||||
|
||||
# Add a limit to the query if provided
|
||||
query = query.order_by(cls._id).limit(limit)
|
||||
query = query.order_by(cls.id).limit(limit)
|
||||
|
||||
# Handle soft deletes if the class has the 'is_deleted' attribute
|
||||
if hasattr(cls, "is_deleted"):
|
||||
@@ -70,20 +46,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
# Execute the query and return the results as a list of model instances
|
||||
return list(session.execute(query).scalars())
|
||||
|
||||
@classmethod
|
||||
def get_uid_from_identifier(cls, identifier: str, indifferent: Optional[bool] = False) -> str:
|
||||
"""converts the id into a uuid object
|
||||
Args:
|
||||
identifier: the string identifier, such as `organization-xxxx-xx...`
|
||||
indifferent: if True, will not enforce the prefix check
|
||||
"""
|
||||
try:
|
||||
uuid_string = identifier.split("-", 1)[1] if indifferent else identifier.replace(f"{cls.__prefix__()}-", "")
|
||||
assert is_valid_uuid4(uuid_string)
|
||||
return uuid_string
|
||||
except ValueError as e:
|
||||
raise ValueError(f"{identifier} is not a valid identifier for class {cls.__name__}") from e
|
||||
|
||||
@classmethod
|
||||
def read(
|
||||
cls,
|
||||
@@ -112,8 +74,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
|
||||
# If an identifier is provided, add it to the query conditions
|
||||
if identifier is not None:
|
||||
identifier = cls.get_uid_from_identifier(identifier)
|
||||
query = query.where(cls._id == identifier)
|
||||
query = query.where(cls.id == identifier)
|
||||
query_conditions.append(f"id='{identifier}'")
|
||||
|
||||
if kwargs:
|
||||
@@ -183,7 +144,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
org_id = getattr(actor, "organization_id", None)
|
||||
if not org_id:
|
||||
raise ValueError(f"object {actor} has no organization accessor")
|
||||
return query.where(cls._organization_id == cls.get_uid_from_identifier(org_id, indifferent=True), cls.is_deleted == False)
|
||||
return query.where(cls.organization_id == org_id, cls.is_deleted == False)
|
||||
|
||||
@property
|
||||
def __pydantic_model__(self) -> Type["BaseModel"]:
|
||||
|
||||
@@ -21,13 +21,14 @@ class Tool(SqlalchemyBase, OrganizationMixin):
|
||||
more granular permissions.
|
||||
"""
|
||||
|
||||
__tablename__ = "tool"
|
||||
__tablename__ = "tools"
|
||||
__pydantic_model__ = PydanticTool
|
||||
|
||||
# Add unique constraint on (name, _organization_id)
|
||||
# An organization should not have multiple tools with the same name
|
||||
__table_args__ = (UniqueConstraint("name", "_organization_id", name="uix_name_organization"),)
|
||||
__table_args__ = (UniqueConstraint("name", "organization_id", name="uix_name_organization"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(doc="The display name of the tool.")
|
||||
description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.")
|
||||
tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
@@ -13,9 +14,10 @@ if TYPE_CHECKING:
|
||||
class User(SqlalchemyBase, OrganizationMixin):
|
||||
"""User ORM class"""
|
||||
|
||||
__tablename__ = "user"
|
||||
__tablename__ = "users"
|
||||
__pydantic_model__ = PydanticUser
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(nullable=False, doc="The display name of the user.")
|
||||
|
||||
# relationships
|
||||
|
||||
@@ -106,6 +106,10 @@ class Memory(BaseModel, validate_assignment=True):
|
||||
# New format
|
||||
obj.prompt_template = state["prompt_template"]
|
||||
for key, value in state["memory"].items():
|
||||
# TODO: This is migration code, please take a look at a later time to get rid of this
|
||||
if "name" in value:
|
||||
value["template_name"] = value["name"]
|
||||
value.pop("name")
|
||||
obj.memory[key] = Block(**value)
|
||||
else:
|
||||
# Old format (pre-template)
|
||||
|
||||
@@ -4,16 +4,16 @@ from typing import Optional
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.utils import get_utc_time
|
||||
from letta.utils import create_random_username, get_utc_time
|
||||
|
||||
|
||||
class OrganizationBase(LettaBase):
|
||||
__id_prefix__ = "organization"
|
||||
__id_prefix__ = "org"
|
||||
|
||||
|
||||
class Organization(OrganizationBase):
|
||||
id: str = Field(..., description="The id of the organization.")
|
||||
name: str = Field(..., description="The name of the organization.")
|
||||
id: str = OrganizationBase.generate_id_field()
|
||||
name: str = Field(create_random_username(), description="The name of the organization.")
|
||||
created_at: Optional[datetime] = Field(default_factory=get_utc_time, description="The creation date of the organization.")
|
||||
|
||||
|
||||
|
||||
@@ -33,21 +33,21 @@ class Tool(BaseTool):
|
||||
|
||||
"""
|
||||
|
||||
id: str = Field(..., description="The id of the tool.")
|
||||
id: str = BaseTool.generate_id_field()
|
||||
description: Optional[str] = Field(None, description="The description of the tool.")
|
||||
source_type: Optional[str] = Field(None, description="The type of the source code.")
|
||||
module: Optional[str] = Field(None, description="The module of the function.")
|
||||
organization_id: str = Field(..., description="The unique identifier of the organization associated with the tool.")
|
||||
name: str = Field(..., description="The name of the function.")
|
||||
tags: List[str] = Field(..., description="Metadata tags.")
|
||||
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the tool.")
|
||||
name: Optional[str] = Field(None, description="The name of the function.")
|
||||
tags: List[str] = Field([], description="Metadata tags.")
|
||||
|
||||
# code
|
||||
source_code: str = Field(..., description="The source code of the function.")
|
||||
json_schema: Dict = Field(default_factory=dict, description="The JSON schema of the function.")
|
||||
json_schema: Optional[Dict] = Field(None, description="The JSON schema of the function.")
|
||||
|
||||
# metadata fields
|
||||
created_by_id: str = Field(..., description="The id of the user that made this Tool.")
|
||||
last_updated_by_id: str = Field(..., description="The id of the user that made this Tool.")
|
||||
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
|
||||
@@ -21,7 +21,7 @@ class User(UserBase):
|
||||
created_at (datetime): The creation date of the user.
|
||||
"""
|
||||
|
||||
id: str = Field(..., description="The id of the user.")
|
||||
id: str = UserBase.generate_id_field()
|
||||
organization_id: Optional[str] = Field(OrganizationManager.DEFAULT_ORG_ID, description="The organization id of the user")
|
||||
name: str = Field(..., description="The name of the user.")
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The creation date of the user.")
|
||||
|
||||
@@ -38,7 +38,8 @@ def create_org(
|
||||
"""
|
||||
Create a new org in the database
|
||||
"""
|
||||
org = server.organization_manager.create_organization(name=request.name)
|
||||
org = Organization(**request.model_dump())
|
||||
org = server.organization_manager.create_organization(pydantic_org=org)
|
||||
return org
|
||||
|
||||
|
||||
|
||||
@@ -89,7 +89,8 @@ def create_tool(
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
# Send request to create the tool
|
||||
return server.tool_manager.create_or_update_tool(tool_create=request, actor=actor)
|
||||
tool = Tool(**request.model_dump())
|
||||
return server.tool_manager.create_or_update_tool(pydantic_tool=tool, actor=actor)
|
||||
|
||||
|
||||
@router.patch("/{tool_id}", response_model=Tool, operation_id="update_tool")
|
||||
|
||||
@@ -51,8 +51,8 @@ def create_user(
|
||||
"""
|
||||
Create a new user in the database
|
||||
"""
|
||||
|
||||
user = server.user_manager.create_user(request)
|
||||
user = User(**request.model_dump())
|
||||
user = server.user_manager.create_user(user)
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -824,7 +824,7 @@ class SyncServer(Server):
|
||||
source_type = "python"
|
||||
tags = ["memory", "memgpt-base"]
|
||||
tool = self.tool_manager.create_or_update_tool(
|
||||
ToolCreate(
|
||||
Tool(
|
||||
source_code=source_code,
|
||||
source_type=source_type,
|
||||
tags=tags,
|
||||
@@ -1766,7 +1766,7 @@ class SyncServer(Server):
|
||||
tool_creates += ToolCreate.load_default_composio_tools()
|
||||
for tool_create in tool_creates:
|
||||
try:
|
||||
self.tool_manager.create_or_update_tool(tool_create, actor=actor)
|
||||
self.tool_manager.create_or_update_tool(Tool(**tool_create.model_dump()), actor=actor)
|
||||
except Exception as e:
|
||||
warnings.warn(f"An error occurred while creating tool {tool_create}: {e}")
|
||||
warnings.warn(traceback.format_exc())
|
||||
|
||||
@@ -3,13 +3,13 @@ from typing import List, Optional
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.organization import Organization as OrganizationModel
|
||||
from letta.schemas.organization import Organization as PydanticOrganization
|
||||
from letta.utils import create_random_username, enforce_types
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class OrganizationManager:
|
||||
"""Manager class to handle business logic related to Organizations."""
|
||||
|
||||
DEFAULT_ORG_ID = "organization-00000000-0000-4000-8000-000000000000"
|
||||
DEFAULT_ORG_ID = "org-00000000-0000-4000-8000-000000000000"
|
||||
DEFAULT_ORG_NAME = "default_org"
|
||||
|
||||
def __init__(self):
|
||||
@@ -37,10 +37,10 @@ class OrganizationManager:
|
||||
raise ValueError(f"Organization with id {org_id} not found.")
|
||||
|
||||
@enforce_types
|
||||
def create_organization(self, name: Optional[str] = None) -> PydanticOrganization:
|
||||
def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
|
||||
"""Create a new organization. If a name is provided, it is used, otherwise, a random one is generated."""
|
||||
with self.session_maker() as session:
|
||||
org = OrganizationModel(name=name if name else create_random_username())
|
||||
org = OrganizationModel(**pydantic_org.model_dump())
|
||||
org.create(session)
|
||||
return org.to_pydantic()
|
||||
|
||||
|
||||
@@ -7,10 +7,9 @@ from letta.functions.functions import derive_openai_json_schema, load_function_s
|
||||
|
||||
# TODO: Remove this once we translate all of these to the ORM
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.organization import Organization as OrganizationModel
|
||||
from letta.orm.tool import Tool as ToolModel
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool import ToolCreate, ToolUpdate
|
||||
from letta.schemas.tool import ToolUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types, printd
|
||||
|
||||
@@ -33,20 +32,20 @@ class ToolManager:
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def create_or_update_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool:
|
||||
def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
# Derive json_schema
|
||||
derived_json_schema = tool_create.json_schema or derive_openai_json_schema(
|
||||
source_code=tool_create.source_code, name=tool_create.name
|
||||
derived_json_schema = pydantic_tool.json_schema or derive_openai_json_schema(
|
||||
source_code=pydantic_tool.source_code, name=pydantic_tool.name
|
||||
)
|
||||
derived_name = tool_create.name or derived_json_schema["name"]
|
||||
derived_name = pydantic_tool.name or derived_json_schema["name"]
|
||||
|
||||
try:
|
||||
# NOTE: We use the organization id here
|
||||
# This is important, because even if it's a different user, adding the same tool to the org should not happen
|
||||
tool = self.get_tool_by_name(tool_name=derived_name, actor=actor)
|
||||
# Put to dict and remove fields that should not be reset
|
||||
update_data = tool_create.model_dump(exclude={"module"}, exclude_unset=True)
|
||||
update_data = pydantic_tool.model_dump(exclude={"module"}, exclude_unset=True, exclude_none=True)
|
||||
# Remove redundant update fields
|
||||
update_data = {key: value for key, value in update_data.items() if getattr(tool, key) != value}
|
||||
|
||||
@@ -55,22 +54,24 @@ class ToolManager:
|
||||
self.update_tool_by_id(tool.id, ToolUpdate(**update_data), actor)
|
||||
else:
|
||||
printd(
|
||||
f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={tool_create.name}, but found existing tool with nothing to update."
|
||||
f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update."
|
||||
)
|
||||
except NoResultFound:
|
||||
tool_create.json_schema = derived_json_schema
|
||||
tool_create.name = derived_name
|
||||
tool = self.create_tool(tool_create, actor=actor)
|
||||
pydantic_tool.json_schema = derived_json_schema
|
||||
pydantic_tool.name = derived_name
|
||||
tool = self.create_tool(pydantic_tool, actor=actor)
|
||||
|
||||
return tool
|
||||
|
||||
@enforce_types
|
||||
def create_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool:
|
||||
def create_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
# Create the tool
|
||||
with self.session_maker() as session:
|
||||
create_data = tool_create.model_dump()
|
||||
tool = ToolModel(**create_data, organization_id=actor.organization_id) # Unpack everything directly into ToolModel
|
||||
# Set the organization id at the ORM layer
|
||||
pydantic_tool.organization_id = actor.organization_id
|
||||
tool_data = pydantic_tool.model_dump()
|
||||
tool = ToolModel(**tool_data)
|
||||
tool.create(session, actor=actor)
|
||||
|
||||
return tool.to_pydantic()
|
||||
@@ -99,7 +100,7 @@ class ToolManager:
|
||||
db_session=session,
|
||||
cursor=cursor,
|
||||
limit=limit,
|
||||
_organization_id=OrganizationModel.get_uid_from_identifier(actor.organization_id),
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
return [tool.to_pydantic() for tool in tools]
|
||||
|
||||
@@ -176,7 +177,7 @@ class ToolManager:
|
||||
# create to tool
|
||||
tools.append(
|
||||
self.create_or_update_tool(
|
||||
ToolCreate(
|
||||
PydanticTool(
|
||||
name=name,
|
||||
tags=tags,
|
||||
source_type="python",
|
||||
|
||||
@@ -4,7 +4,7 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.orm.organization import Organization as OrganizationModel
|
||||
from letta.orm.user import User as UserModel
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.schemas.user import UserCreate, UserUpdate
|
||||
from letta.schemas.user import UserUpdate
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.utils import enforce_types
|
||||
|
||||
@@ -42,10 +42,10 @@ class UserManager:
|
||||
return user.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def create_user(self, user_create: UserCreate) -> PydanticUser:
|
||||
def create_user(self, pydantic_user: PydanticUser) -> PydanticUser:
|
||||
"""Create a new user if it doesn't already exist."""
|
||||
with self.session_maker() as session:
|
||||
new_user = UserModel(**user_create.model_dump())
|
||||
new_user = UserModel(**pydantic_user.model_dump())
|
||||
new_user.create(session)
|
||||
return new_user.to_pydantic()
|
||||
|
||||
|
||||
52
scripts/migrate_tools.py
Normal file
52
scripts/migrate_tools.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from letta.functions.functions import parse_source_code
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.user import User
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
|
||||
|
||||
def deprecated_tool():
|
||||
return "this is a deprecated tool, please remove it from your tools list"
|
||||
|
||||
|
||||
orgs = OrganizationManager().list_organizations(cursor=None, limit=5000)
|
||||
for org in orgs:
|
||||
if org.name != "default":
|
||||
fake_user = User(id="user-00000000-0000-4000-8000-000000000000", name="fake", organization_id=org.id)
|
||||
|
||||
ToolManager().add_base_tools(actor=fake_user)
|
||||
|
||||
source_code = parse_source_code(deprecated_tool)
|
||||
source_type = "python"
|
||||
description = "deprecated"
|
||||
tags = ["deprecated"]
|
||||
|
||||
ToolManager().create_or_update_tool(
|
||||
Tool(
|
||||
name="core_memory_append",
|
||||
source_code=source_code,
|
||||
source_type=source_type,
|
||||
description=description,
|
||||
),
|
||||
actor=fake_user,
|
||||
)
|
||||
|
||||
ToolManager().create_or_update_tool(
|
||||
Tool(
|
||||
name="core_memory_replace",
|
||||
source_code=source_code,
|
||||
source_type=source_type,
|
||||
description=description,
|
||||
),
|
||||
actor=fake_user,
|
||||
)
|
||||
|
||||
ToolManager().create_or_update_tool(
|
||||
Tool(
|
||||
name="pause_heartbeats",
|
||||
source_code=source_code,
|
||||
source_type=source_type,
|
||||
description=description,
|
||||
),
|
||||
actor=fake_user,
|
||||
)
|
||||
@@ -420,7 +420,7 @@ def test_tools_from_langchain(client: LocalClient):
|
||||
exec(source_code, {}, local_scope)
|
||||
func = local_scope[tool.name]
|
||||
|
||||
expected_content = "Albert Einstein ( EYEN-styne; German:"
|
||||
expected_content = "Albert Einstein"
|
||||
assert expected_content in func(query="Albert Einstein")
|
||||
|
||||
|
||||
|
||||
@@ -6,12 +6,15 @@ from letta.functions.functions import derive_openai_json_schema, parse_source_co
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.tool import Tool
|
||||
from letta.orm.user import User
|
||||
from letta.schemas.tool import ToolCreate, ToolUpdate
|
||||
from letta.schemas.organization import Organization as PydanticOrganization
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool import ToolUpdate
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
|
||||
utils.DEBUG = True
|
||||
from letta.config import LettaConfig
|
||||
from letta.schemas.user import UserCreate, UserUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.schemas.user import UserUpdate
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
|
||||
@@ -47,17 +50,17 @@ def tool_fixture(server: SyncServer):
|
||||
|
||||
org = server.organization_manager.create_default_organization()
|
||||
user = server.user_manager.create_default_user()
|
||||
other_user = server.user_manager.create_user(UserCreate(name="other", organization_id=org.id))
|
||||
tool_create = ToolCreate(description=description, tags=tags, source_code=source_code, source_type=source_type)
|
||||
derived_json_schema = derive_openai_json_schema(source_code=tool_create.source_code, name=tool_create.name)
|
||||
other_user = server.user_manager.create_user(PydanticUser(name="other", organization_id=org.id))
|
||||
tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type)
|
||||
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
|
||||
derived_name = derived_json_schema["name"]
|
||||
tool_create.json_schema = derived_json_schema
|
||||
tool_create.name = derived_name
|
||||
tool.json_schema = derived_json_schema
|
||||
tool.name = derived_name
|
||||
|
||||
tool = server.tool_manager.create_tool(tool_create, actor=user)
|
||||
tool = server.tool_manager.create_tool(tool, actor=user)
|
||||
|
||||
# Yield the created tool, organization, and user for use in tests
|
||||
yield {"tool": tool, "organization": org, "user": user, "other_user": other_user, "tool_create": tool_create}
|
||||
yield {"tool": tool, "organization": org, "user": user, "other_user": other_user, "tool_create": tool}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -76,7 +79,7 @@ def server():
|
||||
def test_list_organizations(server: SyncServer):
|
||||
# Create a new org and confirm that it is created correctly
|
||||
org_name = "test"
|
||||
org = server.organization_manager.create_organization(name=org_name)
|
||||
org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name))
|
||||
|
||||
orgs = server.organization_manager.list_organizations()
|
||||
assert len(orgs) == 1
|
||||
@@ -96,15 +99,15 @@ def test_create_default_organization(server: SyncServer):
|
||||
def test_update_organization_name(server: SyncServer):
|
||||
org_name_a = "a"
|
||||
org_name_b = "b"
|
||||
org = server.organization_manager.create_organization(name=org_name_a)
|
||||
org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name_a))
|
||||
assert org.name == org_name_a
|
||||
org = server.organization_manager.update_organization_name_using_id(org_id=org.id, name=org_name_b)
|
||||
assert org.name == org_name_b
|
||||
|
||||
|
||||
def test_list_organizations_pagination(server: SyncServer):
|
||||
server.organization_manager.create_organization(name="a")
|
||||
server.organization_manager.create_organization(name="b")
|
||||
server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="a"))
|
||||
server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="b"))
|
||||
|
||||
orgs_x = server.organization_manager.list_organizations(limit=1)
|
||||
assert len(orgs_x) == 1
|
||||
@@ -125,7 +128,7 @@ def test_list_users(server: SyncServer):
|
||||
org = server.organization_manager.create_default_organization()
|
||||
|
||||
user_name = "user"
|
||||
user = server.user_manager.create_user(UserCreate(name=user_name, organization_id=org.id))
|
||||
user = server.user_manager.create_user(PydanticUser(name=user_name, organization_id=org.id))
|
||||
|
||||
users = server.user_manager.list_users()
|
||||
assert len(users) == 1
|
||||
@@ -146,13 +149,13 @@ def test_create_default_user(server: SyncServer):
|
||||
def test_update_user(server: SyncServer):
|
||||
# Create default organization
|
||||
default_org = server.organization_manager.create_default_organization()
|
||||
test_org = server.organization_manager.create_organization(name="test_org")
|
||||
test_org = server.organization_manager.create_organization(PydanticOrganization(name="test_org"))
|
||||
|
||||
user_name_a = "a"
|
||||
user_name_b = "b"
|
||||
|
||||
# Assert it's been created
|
||||
user = server.user_manager.create_user(UserCreate(name=user_name_a, organization_id=default_org.id))
|
||||
user = server.user_manager.create_user(PydanticUser(name=user_name_a, organization_id=default_org.id))
|
||||
assert user.name == user_name_a
|
||||
|
||||
# Adjust name
|
||||
@@ -340,7 +343,6 @@ def test_update_tool_multi_user(server: SyncServer, tool_fixture):
|
||||
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=other_user)
|
||||
|
||||
# Check that the created_by and last_updated_by fields are correct
|
||||
|
||||
# Fetch the updated tool to verify the changes
|
||||
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from letta.constants import DEFAULT_PRESET
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory
|
||||
from letta.services.tool_manager import ToolManager
|
||||
|
||||
test_agent_name = f"test_client_{str(uuid.uuid4())}"
|
||||
# test_preset_name = "test_preset"
|
||||
@@ -93,15 +94,9 @@ def test_create_tool(client: Union[LocalClient, RESTClient]):
|
||||
return message
|
||||
|
||||
tools = client.list_tools()
|
||||
assert sorted([t.name for t in tools]) == sorted(
|
||||
[
|
||||
"archival_memory_search",
|
||||
"send_message",
|
||||
"conversation_search",
|
||||
"conversation_search_date",
|
||||
"archival_memory_insert",
|
||||
]
|
||||
)
|
||||
tool_names = [t.name for t in tools]
|
||||
for tool in ToolManager.BASE_TOOL_NAMES:
|
||||
assert tool in tool_names
|
||||
|
||||
tool = client.create_tool(print_tool, name="my_name", tags=["extras"])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user