diff --git a/.github/workflows/migration-test.yml b/.github/workflows/migration-test.yml index 9f04b6d2..142c4068 100644 --- a/.github/workflows/migration-test.yml +++ b/.github/workflows/migration-test.yml @@ -8,13 +8,24 @@ jobs: test: runs-on: ubuntu-latest timeout-minutes: 15 + services: + postgres: + image: pgvector/pgvector:pg17 + ports: + - 5432:5432 + env: + POSTGRES_HOST_AUTH_METHOD: trust + POSTGRES_DB: postgres + POSTGRES_USER: postgres + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 steps: - name: Checkout uses: actions/checkout@v4 - - - name: Build and run container - run: bash db/run_postgres.sh - + - run: psql -h localhost -U postgres -d postgres -c 'CREATE EXTENSION vector' - name: "Setup Python, Poetry and Dependencies" uses: packetcoders/action-setup-cache-python-poetry@main with: @@ -23,12 +34,11 @@ jobs: install-args: "--all-extras" - name: Test alembic migration env: - LETTA_PG_PORT: 8888 - LETTA_PG_USER: letta - LETTA_PG_PASSWORD: letta - LETTA_PG_DB: letta + LETTA_PG_PORT: 5432 + LETTA_PG_USER: postgres + LETTA_PG_PASSWORD: postgres + LETTA_PG_DB: postgres LETTA_PG_HOST: localhost - LETTA_SERVER_PASS: test_server_token run: | poetry run alembic upgrade head poetry run alembic check diff --git a/.github/workflows/test_cli.yml b/.github/workflows/test_cli.yml index 6c3a658b..c7cd5240 100644 --- a/.github/workflows/test_cli.yml +++ b/.github/workflows/test_cli.yml @@ -19,14 +19,24 @@ jobs: image: qdrant/qdrant ports: - 6333:6333 + postgres: + image: pgvector/pgvector:pg17 + ports: + - 5432:5432 + env: + POSTGRES_HOST_AUTH_METHOD: trust + POSTGRES_DB: postgres + POSTGRES_USER: postgres + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 steps: - name: Checkout uses: actions/checkout@v4 - - name: Build and run container - run: bash db/run_postgres.sh - - name: "Setup Python, Poetry and Dependencies" uses: packetcoders/action-setup-cache-python-poetry@main with: @@ -34,12 +44,23 @@ jobs: poetry-version: "1.8.2" install-args: "-E dev -E postgres -E tests" + - name: Migrate database + env: + LETTA_PG_PORT: 5432 + LETTA_PG_USER: postgres + LETTA_PG_PASSWORD: postgres + LETTA_PG_DB: postgres + LETTA_PG_HOST: localhost + run: | + psql -h localhost -U postgres -d postgres -c 'CREATE EXTENSION vector' + poetry run alembic upgrade head + - name: Test `letta run` up until first message env: - LETTA_PG_PORT: 8888 - LETTA_PG_USER: letta - LETTA_PG_PASSWORD: letta - LETTA_PG_DB: letta + LETTA_PG_PORT: 5432 + LETTA_PG_USER: postgres + LETTA_PG_PASSWORD: postgres + LETTA_PG_DB: postgres LETTA_PG_HOST: localhost LETTA_SERVER_PASS: test_server_token run: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 13f83686..639d0b4d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -34,27 +34,46 @@ jobs: image: qdrant/qdrant ports: - 6333:6333 + postgres: + image: pgvector/pgvector:pg17 + ports: + - 5432:5432 + env: + POSTGRES_HOST_AUTH_METHOD: trust + POSTGRES_DB: postgres + POSTGRES_USER: postgres + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 steps: - name: Checkout uses: actions/checkout@v4 - - name: Build and run container - run: bash db/run_postgres.sh - - name: Setup Python, Poetry, and Dependencies uses: packetcoders/action-setup-cache-python-poetry@main with: python-version: "3.12" poetry-version: "1.8.2" install-args: "-E dev -E postgres -E milvus -E external-tools -E tests" - + - name: Migrate database + env: + LETTA_PG_PORT: 5432 + LETTA_PG_USER: postgres + LETTA_PG_PASSWORD: postgres + LETTA_PG_DB: postgres + LETTA_PG_HOST: localhost + run: | + psql -h localhost -U postgres -d postgres -c 'CREATE EXTENSION vector' + poetry run alembic upgrade head - name: Run core unit tests env: - LETTA_PG_PORT: 8888 - LETTA_PG_USER: letta - LETTA_PG_PASSWORD: letta - LETTA_PG_DB: letta + LETTA_PG_PORT: 5432 + LETTA_PG_USER: postgres + LETTA_PG_PASSWORD: postgres + LETTA_PG_DB: postgres LETTA_PG_HOST: localhost LETTA_SERVER_PASS: test_server_token run: | @@ -68,27 +87,46 @@ jobs: image: qdrant/qdrant ports: - 6333:6333 + postgres: + image: pgvector/pgvector:pg17 + ports: + - 5432:5432 + env: + POSTGRES_HOST_AUTH_METHOD: trust + POSTGRES_DB: postgres + POSTGRES_USER: postgres + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 steps: - name: Checkout uses: actions/checkout@v4 - - name: Build and run container - run: bash db/run_postgres.sh - - name: Setup Python, Poetry, and Dependencies uses: packetcoders/action-setup-cache-python-poetry@main with: python-version: "3.12" poetry-version: "1.8.2" install-args: "-E dev -E postgres -E milvus -E external-tools -E tests" - + - name: Migrate database + env: + LETTA_PG_PORT: 5432 + LETTA_PG_USER: postgres + LETTA_PG_PASSWORD: postgres + LETTA_PG_DB: postgres + LETTA_PG_HOST: localhost + run: | + psql -h localhost -U postgres -d postgres -c 'CREATE EXTENSION vector' + poetry run alembic upgrade head - name: Run misc unit tests env: - LETTA_PG_PORT: 8888 - LETTA_PG_USER: letta - LETTA_PG_PASSWORD: letta - LETTA_PG_DB: letta + LETTA_PG_PORT: 5432 + LETTA_PG_USER: postgres + LETTA_PG_PASSWORD: postgres + LETTA_PG_DB: postgres LETTA_PG_HOST: localhost LETTA_SERVER_PASS: test_server_token PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }} diff --git a/Dockerfile b/Dockerfile index 1c97e993..36cf9d15 100644 --- a/Dockerfile +++ b/Dockerfile @@ -31,6 +31,8 @@ ENV VIRTUAL_ENV=/app/.venv \ COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV} COPY ./letta /letta +COPY ./alembic.ini /alembic.ini +COPY ./alembic /alembic EXPOSE 8283 @@ -46,6 +48,8 @@ ENV PYTHONPATH=/ WORKDIR / COPY ./tests /tests COPY ./letta /letta +COPY ./alembic.ini /alembic.ini +COPY ./alembic /alembic #COPY ./configs/server_config.yaml /root/.letta/config EXPOSE 8083 diff --git a/alembic/env.py b/alembic/env.py index 3c084a82..0f70d0ef 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -4,8 +4,9 @@ from logging.config import fileConfig from sqlalchemy import engine_from_config, pool from alembic import context +from letta.agent_store.db import attach_base from letta.config import LettaConfig -from letta.orm.base import Base +from letta.orm import Base from letta.settings import settings letta_config = LettaConfig.load() @@ -14,7 +15,6 @@ letta_config = LettaConfig.load() # access to the values within the .ini file in use. config = context.config -print(settings.letta_pg_uri_no_default) if settings.letta_pg_uri_no_default: config.set_main_option("sqlalchemy.url", settings.letta_pg_uri) else: @@ -29,6 +29,8 @@ if config.config_file_name is not None: # for 'autogenerate' support # from myapp import mymodel # target_metadata = mymodel.Base.metadata +attach_base() + target_metadata = Base.metadata # other values from the config, defined by the needs of env.py, diff --git a/alembic/versions/0c315956709d_.py b/alembic/versions/0c315956709d_.py new file mode 100644 index 00000000..98c66125 --- /dev/null +++ b/alembic/versions/0c315956709d_.py @@ -0,0 +1,125 @@ +"""empty message + +Revision ID: 0c315956709d +Revises: 9a505cc7eca9 +Create Date: 2024-10-30 17:08:09.638235 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +import letta.metadata +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "0c315956709d" +down_revision: Union[str, None] = "9a505cc7eca9" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("users") + op.drop_table("tools") + op.drop_table("organizations") + op.create_table( + "organization", + sa.Column("name", sa.String(), nullable=False), + sa.Column("_id", sa.String(), nullable=False), + sa.Column("deleted", sa.Boolean(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("_id"), + ) + op.create_table( + "tool", + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("tags", sa.JSON(), nullable=False), + sa.Column("source_type", sa.String(), nullable=False), + sa.Column("source_code", sa.String(), nullable=True), + sa.Column("json_schema", sa.JSON(), nullable=False), + sa.Column("module", sa.String(), nullable=True), + sa.Column("_id", sa.String(), nullable=False), + sa.Column("deleted", sa.Boolean(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("_organization_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["_organization_id"], + ["organization._id"], + ), + sa.PrimaryKeyConstraint("_id"), + sa.UniqueConstraint("name", "_organization_id", name="uix_name_organization"), + ) + op.create_table( + "user", + sa.Column("name", sa.String(), nullable=False), + sa.Column("_id", sa.String(), nullable=False), + sa.Column("deleted", sa.Boolean(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("_organization_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["_organization_id"], + ["organization._id"], + ), + sa.PrimaryKeyConstraint("_id"), + ) + op.add_column("agents", sa.Column("tool_rules", letta.metadata.ToolRulesColumn(), nullable=True)) + op.alter_column("block", "name", existing_type=sa.VARCHAR(), nullable=True) + op.alter_column("block", "label", existing_type=sa.VARCHAR(), nullable=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column("block", "label", existing_type=sa.VARCHAR(), nullable=True) + op.alter_column("block", "name", existing_type=sa.VARCHAR(), nullable=False) + op.drop_column("agents", "tool_rules") + op.drop_table("organization") + op.drop_table("tool") + op.drop_table("user") + op.create_table( + "organizations", + sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("name", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint("id", name="organizations_pkey"), + ) + op.create_table( + "tools", + sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("name", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("description", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("source_type", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("source_code", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("json_schema", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), + sa.Column("module", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("tags", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint("id", name="tools_pkey"), + ) + op.create_table( + "users", + sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("org_id", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("name", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column("policies_accepted", sa.BOOLEAN(), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint("id", name="users_pkey"), + ) + # ### end Alembic commands ### diff --git a/alembic/versions/9a505cc7eca9_create_a_baseline_migrations.py b/alembic/versions/9a505cc7eca9_create_a_baseline_migrations.py index d1ee25e1..765e6d73 100644 --- a/alembic/versions/9a505cc7eca9_create_a_baseline_migrations.py +++ b/alembic/versions/9a505cc7eca9_create_a_baseline_migrations.py @@ -8,6 +8,13 @@ Create Date: 2024-10-11 14:19:19.875656 from typing import Sequence, Union +import pgvector +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +import letta.metadata +from alembic import op + # revision identifiers, used by Alembic. revision: str = "9a505cc7eca9" down_revision: Union[str, None] = None @@ -16,12 +23,173 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - pass - # ### end Alembic commands ### + op.create_table( + "agent_source_mapping", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("agent_id", sa.String(), nullable=False), + sa.Column("source_id", sa.String(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("agent_source_mapping_idx_user", "agent_source_mapping", ["user_id", "agent_id", "source_id"], unique=False) + op.create_table( + "agents", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("description", sa.String(), nullable=True), + sa.Column("message_ids", sa.JSON(), nullable=True), + sa.Column("memory", sa.JSON(), nullable=True), + sa.Column("system", sa.String(), nullable=True), + sa.Column("agent_type", sa.String(), nullable=True), + sa.Column("llm_config", letta.metadata.LLMConfigColumn(), nullable=True), + sa.Column("embedding_config", letta.metadata.EmbeddingConfigColumn(), nullable=True), + sa.Column("metadata_", sa.JSON(), nullable=True), + sa.Column("tools", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("agents_idx_user", "agents", ["user_id"], unique=False) + op.create_table( + "block", + sa.Column("id", sa.String(), nullable=False), + sa.Column("value", sa.String(), nullable=False), + sa.Column("limit", sa.BIGINT(), nullable=True), + sa.Column("name", sa.String(), nullable=True), + sa.Column("template", sa.Boolean(), nullable=True), + sa.Column("label", sa.String(), nullable=False), + sa.Column("metadata_", sa.JSON(), nullable=True), + sa.Column("description", sa.String(), nullable=True), + sa.Column("user_id", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("block_idx_user", "block", ["user_id"], unique=False) + op.create_table( + "files", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("source_id", sa.String(), nullable=False), + sa.Column("file_name", sa.String(), nullable=True), + sa.Column("file_path", sa.String(), nullable=True), + sa.Column("file_type", sa.String(), nullable=True), + sa.Column("file_size", sa.Integer(), nullable=True), + sa.Column("file_creation_date", sa.String(), nullable=True), + sa.Column("file_last_modified_date", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "jobs", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("status", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("metadata_", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "messages", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("agent_id", sa.String(), nullable=False), + sa.Column("role", sa.String(), nullable=False), + sa.Column("text", sa.String(), nullable=True), + sa.Column("model", sa.String(), nullable=True), + sa.Column("name", sa.String(), nullable=True), + sa.Column("tool_calls", letta.metadata.ToolCallColumn(), nullable=True), + sa.Column("tool_call_id", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("message_idx_user", "messages", ["user_id", "agent_id"], unique=False) + op.create_table( + "organizations", + sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("name", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint("id", name="organizations_pkey"), + ) + op.create_table( + "passages", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("text", sa.String(), nullable=True), + sa.Column("file_id", sa.String(), nullable=True), + sa.Column("agent_id", sa.String(), nullable=True), + sa.Column("source_id", sa.String(), nullable=True), + sa.Column("embedding", pgvector.sqlalchemy.Vector(dim=4096), nullable=True), + sa.Column("embedding_config", letta.metadata.EmbeddingConfigColumn(), nullable=True), + sa.Column("metadata_", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("passage_idx_user", "passages", ["user_id", "agent_id", "file_id"], unique=False) + op.create_table( + "sources", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("embedding_config", letta.metadata.EmbeddingConfigColumn(), nullable=True), + sa.Column("description", sa.String(), nullable=True), + sa.Column("metadata_", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("sources_idx_user", "sources", ["user_id"], unique=False) + op.create_table( + "tokens", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("key", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("tokens_idx_key", "tokens", ["key"], unique=False) + op.create_index("tokens_idx_user", "tokens", ["user_id"], unique=False) + + op.create_table( + "users", + sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("org_id", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("name", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column("policies_accepted", sa.BOOLEAN(), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint("id", name="users_pkey"), + ) + op.create_table( + "tools", + sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("name", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("description", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("source_type", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("source_code", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("json_schema", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), + sa.Column("module", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("tags", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint("id", name="tools_pkey"), + ) def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - pass - # ### end Alembic commands ### + op.drop_table("users") + op.drop_table("tools") + op.drop_index("tokens_idx_user", table_name="tokens") + op.drop_index("tokens_idx_key", table_name="tokens") + op.drop_table("tokens") + op.drop_index("sources_idx_user", table_name="sources") + op.drop_table("sources") + op.drop_index("passage_idx_user", table_name="passages") + op.drop_table("passages") + op.drop_table("organizations") + op.drop_index("message_idx_user", table_name="messages") + op.drop_table("messages") + op.drop_table("jobs") + op.drop_table("files") + op.drop_index("block_idx_user", table_name="block") + op.drop_table("block") + op.drop_index("agents_idx_user", table_name="agents") + op.drop_table("agents") + op.drop_index("agent_source_mapping_idx_user", table_name="agent_source_mapping") + op.drop_table("agent_source_mapping") diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index e69de29b..c95da85b 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -0,0 +1,4 @@ +from letta.orm.base import Base +from letta.orm.organization import Organization +from letta.orm.tool import Tool +from letta.orm.user import User diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 2fa93939..f0099dc1 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -63,7 +63,7 @@ class AgentState(BaseAgent, validate_assignment=True): tools: List[str] = Field(..., description="The tools used by the agent.") # tool rules - tool_rules: List[BaseToolRule] = Field(..., description="The list of tool rules.") + tool_rules: Optional[List[BaseToolRule]] = Field(default=None, description="The list of tool rules.") # system prompt system: str = Field(..., description="The system prompt used by the agent.") diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index dfcd567d..dd9a51a0 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -54,6 +54,12 @@ password = None # password = secrets.token_urlsafe(16) # #typer.secho(f"Generated admin server password for this session: {password}", fg=typer.colors.GREEN) +import logging + +from fastapi import FastAPI + +log = logging.getLogger("uvicorn") + def create_application() -> "FastAPI": """the application start routine""" diff --git a/letta/server/server.py b/letta/server/server.py index 52660064..a05c87cb 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -35,8 +35,9 @@ from letta.interface import AgentInterface # abstract from letta.interface import CLIInterface # for printing to terminal from letta.log import get_logger from letta.memory import get_memory_functions -from letta.metadata import Base, MetadataStore +from letta.metadata import MetadataStore from letta.o1_agent import O1Agent +from letta.orm import Base from letta.orm.errors import NoResultFound from letta.prompts import gpt_system from letta.providers import ( @@ -169,6 +170,8 @@ from letta.settings import model_settings, settings, tool_settings config = LettaConfig.load() +attach_base() + if settings.letta_pg_uri_no_default: config.recall_storage_type = "postgres" config.recall_storage_uri = settings.letta_pg_uri_no_default @@ -181,13 +184,10 @@ else: # TODO: don't rely on config storage engine = create_engine("sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db")) + Base.metadata.create_all(bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -attach_base() - -Base.metadata.create_all(bind=engine) - # Dependency def get_db(): diff --git a/letta/server/startup.sh b/letta/server/startup.sh index 975be72b..870bc0e4 100755 --- a/letta/server/startup.sh +++ b/letta/server/startup.sh @@ -1,5 +1,8 @@ #!/bin/sh echo "Starting MEMGPT server..." + +alembic upgrade head + if [ "$MEMGPT_ENVIRONMENT" = "DEVELOPMENT" ] ; then echo "Starting in development mode!" uvicorn letta.server.rest_api.app:app --reload --reload-dir /letta --host 0.0.0.0 --port 8283