diff --git a/alembic/env.py b/alembic/env.py index 767b7bbd..e7dfe71c 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -1,12 +1,11 @@ import os from logging.config import fileConfig -from sqlalchemy import engine_from_config, pool - from alembic import context from letta.config import LettaConfig from letta.orm import Base from letta.settings import settings +from sqlalchemy import engine_from_config, pool letta_config = LettaConfig.load() diff --git a/alembic/versions/08b2f8225812_adding_toolsagents_orm.py b/alembic/versions/08b2f8225812_adding_toolsagents_orm.py index 902225ab..d0e2cac8 100644 --- a/alembic/versions/08b2f8225812_adding_toolsagents_orm.py +++ b/alembic/versions/08b2f8225812_adding_toolsagents_orm.py @@ -5,40 +5,44 @@ Revises: 3c683a662c82 Create Date: 2024-12-05 16:46:51.258831 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. -revision: str = '08b2f8225812' -down_revision: Union[str, None] = '3c683a662c82' +revision: str = "08b2f8225812" +down_revision: Union[str, None] = "3c683a662c82" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_table('tools_agents', - sa.Column('agent_id', sa.String(), nullable=False), - sa.Column('tool_id', sa.String(), nullable=False), - sa.Column('tool_name', sa.String(), nullable=False), - sa.Column('id', sa.String(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), - sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), - sa.Column('_created_by_id', sa.String(), nullable=True), - sa.Column('_last_updated_by_id', sa.String(), nullable=True), - sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ), - sa.ForeignKeyConstraint(['tool_id'], ['tools.id'], name='fk_tool_id'), - sa.PrimaryKeyConstraint('agent_id', 'tool_id', 'tool_name', 'id'), - sa.UniqueConstraint('agent_id', 'tool_name', name='unique_tool_per_agent') + op.create_table( + "tools_agents", + sa.Column("agent_id", sa.String(), nullable=False), + sa.Column("tool_id", sa.String(), nullable=False), + sa.Column("tool_name", sa.String(), nullable=False), + sa.Column("id", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.ForeignKeyConstraint( + ["agent_id"], + ["agents.id"], + ), + sa.ForeignKeyConstraint(["tool_id"], ["tools.id"], name="fk_tool_id"), + sa.PrimaryKeyConstraint("agent_id", "tool_id", "tool_name", "id"), + sa.UniqueConstraint("agent_id", "tool_name", name="unique_tool_per_agent"), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('tools_agents') + op.drop_table("tools_agents") # ### end Alembic commands ### diff --git a/alembic/versions/1c8880d671ee_make_an_blocks_agents_mapping_table.py b/alembic/versions/1c8880d671ee_make_an_blocks_agents_mapping_table.py index ffcb0b67..6890e0ed 100644 --- a/alembic/versions/1c8880d671ee_make_an_blocks_agents_mapping_table.py +++ b/alembic/versions/1c8880d671ee_make_an_blocks_agents_mapping_table.py @@ -9,7 +9,6 @@ Create Date: 2024-11-22 15:42:47.209229 from typing import Sequence, Union import sqlalchemy as sa - from alembic import op # revision identifiers, used by Alembic. diff --git a/alembic/versions/3c683a662c82_migrate_jobs_to_the_orm.py b/alembic/versions/3c683a662c82_migrate_jobs_to_the_orm.py index 4f9b746d..62b97e9d 100644 --- a/alembic/versions/3c683a662c82_migrate_jobs_to_the_orm.py +++ b/alembic/versions/3c683a662c82_migrate_jobs_to_the_orm.py @@ -9,9 +9,8 @@ Create Date: 2024-12-04 15:59:41.708396 from typing import Sequence, Union import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - from alembic import op +from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision: str = "3c683a662c82" diff --git a/alembic/versions/4e88e702f85e_drop_api_tokens_table_in_oss.py b/alembic/versions/4e88e702f85e_drop_api_tokens_table_in_oss.py index 75a90445..b692e855 100644 --- a/alembic/versions/4e88e702f85e_drop_api_tokens_table_in_oss.py +++ b/alembic/versions/4e88e702f85e_drop_api_tokens_table_in_oss.py @@ -9,7 +9,6 @@ Create Date: 2024-12-13 17:19:55.796210 from typing import Sequence, Union import sqlalchemy as sa - from alembic import op # revision identifiers, used by Alembic. diff --git a/alembic/versions/54dec07619c4_divide_passage_table_into_.py b/alembic/versions/54dec07619c4_divide_passage_table_into_.py index afe9d418..ead70ec6 100644 --- a/alembic/versions/54dec07619c4_divide_passage_table_into_.py +++ b/alembic/versions/54dec07619c4_divide_passage_table_into_.py @@ -5,18 +5,18 @@ Revises: 4e88e702f85e Create Date: 2024-12-14 17:23:08.772554 """ + from typing import Sequence, Union -from alembic import op -from pgvector.sqlalchemy import Vector import sqlalchemy as sa +from alembic import op +from letta.orm.custom_columns import EmbeddingConfigColumn +from pgvector.sqlalchemy import Vector from sqlalchemy.dialects import postgresql -from letta.orm.custom_columns import EmbeddingConfigColumn - # revision identifiers, used by Alembic. -revision: str = '54dec07619c4' -down_revision: Union[str, None] = '4e88e702f85e' +revision: str = "54dec07619c4" +down_revision: Union[str, None] = "4e88e702f85e" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -24,82 +24,88 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.create_table( - 'agent_passages', - sa.Column('id', sa.String(), nullable=False), - sa.Column('text', sa.String(), nullable=False), - sa.Column('embedding_config', EmbeddingConfigColumn(), nullable=False), - sa.Column('metadata_', sa.JSON(), nullable=False), - sa.Column('embedding', Vector(dim=4096), nullable=True), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), - sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), - sa.Column('_created_by_id', sa.String(), nullable=True), - sa.Column('_last_updated_by_id', sa.String(), nullable=True), - sa.Column('organization_id', sa.String(), nullable=False), - sa.Column('agent_id', sa.String(), nullable=False), - sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ), - sa.PrimaryKeyConstraint('id') + "agent_passages", + sa.Column("id", sa.String(), nullable=False), + sa.Column("text", sa.String(), nullable=False), + sa.Column("embedding_config", EmbeddingConfigColumn(), nullable=False), + sa.Column("metadata_", sa.JSON(), nullable=False), + sa.Column("embedding", Vector(dim=4096), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.Column("agent_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("id"), ) - op.create_index('agent_passages_org_idx', 'agent_passages', ['organization_id'], unique=False) + op.create_index("agent_passages_org_idx", "agent_passages", ["organization_id"], unique=False) op.create_table( - 'source_passages', - sa.Column('id', sa.String(), nullable=False), - sa.Column('text', sa.String(), nullable=False), - sa.Column('embedding_config', EmbeddingConfigColumn(), nullable=False), - sa.Column('metadata_', sa.JSON(), nullable=False), - sa.Column('embedding', Vector(dim=4096), nullable=True), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), - sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), - sa.Column('_created_by_id', sa.String(), nullable=True), - sa.Column('_last_updated_by_id', sa.String(), nullable=True), - sa.Column('organization_id', sa.String(), nullable=False), - sa.Column('file_id', sa.String(), nullable=True), - sa.Column('source_id', sa.String(), nullable=False), - sa.ForeignKeyConstraint(['file_id'], ['files.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ), - sa.ForeignKeyConstraint(['source_id'], ['sources.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') + "source_passages", + sa.Column("id", sa.String(), nullable=False), + sa.Column("text", sa.String(), nullable=False), + sa.Column("embedding_config", EmbeddingConfigColumn(), nullable=False), + sa.Column("metadata_", sa.JSON(), nullable=False), + sa.Column("embedding", Vector(dim=4096), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.Column("file_id", sa.String(), nullable=True), + sa.Column("source_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["file_id"], ["files.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.ForeignKeyConstraint(["source_id"], ["sources.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), ) - op.create_index('source_passages_org_idx', 'source_passages', ['organization_id'], unique=False) - op.drop_table('passages') - op.drop_constraint('files_source_id_fkey', 'files', type_='foreignkey') - op.create_foreign_key(None, 'files', 'sources', ['source_id'], ['id'], ondelete='CASCADE') - op.drop_constraint('messages_agent_id_fkey', 'messages', type_='foreignkey') - op.create_foreign_key(None, 'messages', 'agents', ['agent_id'], ['id'], ondelete='CASCADE') + op.create_index("source_passages_org_idx", "source_passages", ["organization_id"], unique=False) + op.drop_table("passages") + op.drop_constraint("files_source_id_fkey", "files", type_="foreignkey") + op.create_foreign_key(None, "files", "sources", ["source_id"], ["id"], ondelete="CASCADE") + op.drop_constraint("messages_agent_id_fkey", "messages", type_="foreignkey") + op.create_foreign_key(None, "messages", "agents", ["agent_id"], ["id"], ondelete="CASCADE") # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint(None, 'messages', type_='foreignkey') - op.create_foreign_key('messages_agent_id_fkey', 'messages', 'agents', ['agent_id'], ['id']) - op.drop_constraint(None, 'files', type_='foreignkey') - op.create_foreign_key('files_source_id_fkey', 'files', 'sources', ['source_id'], ['id']) + op.drop_constraint(None, "messages", type_="foreignkey") + op.create_foreign_key("messages_agent_id_fkey", "messages", "agents", ["agent_id"], ["id"]) + op.drop_constraint(None, "files", type_="foreignkey") + op.create_foreign_key("files_source_id_fkey", "files", "sources", ["source_id"], ["id"]) op.create_table( - 'passages', - sa.Column('id', sa.VARCHAR(), autoincrement=False, nullable=False), - sa.Column('text', sa.VARCHAR(), autoincrement=False, nullable=False), - sa.Column('file_id', sa.VARCHAR(), autoincrement=False, nullable=True), - sa.Column('agent_id', sa.VARCHAR(), autoincrement=False, nullable=True), - sa.Column('source_id', sa.VARCHAR(), autoincrement=False, nullable=True), - sa.Column('embedding', Vector(dim=4096), autoincrement=False, nullable=True), - sa.Column('embedding_config', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False), - sa.Column('metadata_', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False), - sa.Column('created_at', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=False), - sa.Column('updated_at', postgresql.TIMESTAMP(timezone=True), server_default=sa.text('now()'), autoincrement=False, nullable=True), - sa.Column('is_deleted', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False), - sa.Column('_created_by_id', sa.VARCHAR(), autoincrement=False, nullable=True), - sa.Column('_last_updated_by_id', sa.VARCHAR(), autoincrement=False, nullable=True), - sa.Column('organization_id', sa.VARCHAR(), autoincrement=False, nullable=False), - sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], name='passages_agent_id_fkey'), - sa.ForeignKeyConstraint(['file_id'], ['files.id'], name='passages_file_id_fkey', ondelete='CASCADE'), - sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], name='passages_organization_id_fkey'), - sa.PrimaryKeyConstraint('id', name='passages_pkey') + "passages", + sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("text", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("file_id", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("agent_id", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("source_id", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("embedding", Vector(dim=4096), autoincrement=False, nullable=True), + sa.Column("embedding_config", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False), + sa.Column("metadata_", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False), + sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=False), + sa.Column("updated_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True), + sa.Column("is_deleted", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False), + sa.Column("_created_by_id", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("_last_updated_by_id", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column("organization_id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], name="passages_agent_id_fkey"), + sa.ForeignKeyConstraint(["file_id"], ["files.id"], name="passages_file_id_fkey", ondelete="CASCADE"), + sa.ForeignKeyConstraint(["organization_id"], ["organizations.id"], name="passages_organization_id_fkey"), + sa.PrimaryKeyConstraint("id", name="passages_pkey"), ) - op.drop_index('source_passages_org_idx', table_name='source_passages') - op.drop_table('source_passages') - op.drop_index('agent_passages_org_idx', table_name='agent_passages') - op.drop_table('agent_passages') + op.drop_index("source_passages_org_idx", table_name="source_passages") + op.drop_table("source_passages") + op.drop_index("agent_passages_org_idx", table_name="agent_passages") + op.drop_table("agent_passages") # ### end Alembic commands ### diff --git a/alembic/versions/5987401b40ae_refactor_agent_memory.py b/alembic/versions/5987401b40ae_refactor_agent_memory.py index 889e9425..84e4ebe2 100644 --- a/alembic/versions/5987401b40ae_refactor_agent_memory.py +++ b/alembic/versions/5987401b40ae_refactor_agent_memory.py @@ -9,9 +9,8 @@ Create Date: 2024-11-25 14:35:00.896507 from typing import Sequence, Union import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - from alembic import op +from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision: str = "5987401b40ae" diff --git a/alembic/versions/95badb46fdf9_migrate_messages_to_the_orm.py b/alembic/versions/95badb46fdf9_migrate_messages_to_the_orm.py index 73254e39..f200e65e 100644 --- a/alembic/versions/95badb46fdf9_migrate_messages_to_the_orm.py +++ b/alembic/versions/95badb46fdf9_migrate_messages_to_the_orm.py @@ -9,9 +9,8 @@ Create Date: 2024-12-05 14:02:04.163150 from typing import Sequence, Union import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - from alembic import op +from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision: str = "95badb46fdf9" diff --git a/alembic/versions/9a505cc7eca9_create_a_baseline_migrations.py b/alembic/versions/9a505cc7eca9_create_a_baseline_migrations.py index 21f6a396..6ef44d47 100644 --- a/alembic/versions/9a505cc7eca9_create_a_baseline_migrations.py +++ b/alembic/versions/9a505cc7eca9_create_a_baseline_migrations.py @@ -8,12 +8,11 @@ Create Date: 2024-10-11 14:19:19.875656 from typing import Sequence, Union +import letta.orm import pgvector import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - -import letta.orm from alembic import op +from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision: str = "9a505cc7eca9" diff --git a/alembic/versions/a91994b9752f_add_column_to_tools_table_to_contain_.py b/alembic/versions/a91994b9752f_add_column_to_tools_table_to_contain_.py index f8da3856..157e87ef 100644 --- a/alembic/versions/a91994b9752f_add_column_to_tools_table_to_contain_.py +++ b/alembic/versions/a91994b9752f_add_column_to_tools_table_to_contain_.py @@ -9,7 +9,6 @@ Create Date: 2024-12-09 18:27:25.650079 from typing import Sequence, Union import sqlalchemy as sa - from alembic import op from letta.constants import FUNCTION_RETURN_CHAR_LIMIT diff --git a/alembic/versions/b6d7ca024aa9_add_agents_tags_table.py b/alembic/versions/b6d7ca024aa9_add_agents_tags_table.py index 2aec8a09..638bc5b5 100644 --- a/alembic/versions/b6d7ca024aa9_add_agents_tags_table.py +++ b/alembic/versions/b6d7ca024aa9_add_agents_tags_table.py @@ -9,7 +9,6 @@ Create Date: 2024-11-06 10:48:08.424108 from typing import Sequence, Union import sqlalchemy as sa - from alembic import op # revision identifiers, used by Alembic. diff --git a/alembic/versions/c5d964280dff_add_passages_orm_drop_legacy_passages_.py b/alembic/versions/c5d964280dff_add_passages_orm_drop_legacy_passages_.py index a16fdae4..fb0eafab 100644 --- a/alembic/versions/c5d964280dff_add_passages_orm_drop_legacy_passages_.py +++ b/alembic/versions/c5d964280dff_add_passages_orm_drop_legacy_passages_.py @@ -5,25 +5,26 @@ Revises: a91994b9752f Create Date: 2024-12-10 15:05:32.335519 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. -revision: str = 'c5d964280dff' -down_revision: Union[str, None] = 'a91994b9752f' +revision: str = "c5d964280dff" +down_revision: Union[str, None] = "a91994b9752f" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.add_column('passages', sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True)) - op.add_column('passages', sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False)) - op.add_column('passages', sa.Column('_created_by_id', sa.String(), nullable=True)) - op.add_column('passages', sa.Column('_last_updated_by_id', sa.String(), nullable=True)) + op.add_column("passages", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True)) + op.add_column("passages", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False)) + op.add_column("passages", sa.Column("_created_by_id", sa.String(), nullable=True)) + op.add_column("passages", sa.Column("_last_updated_by_id", sa.String(), nullable=True)) # Data migration step: op.add_column("passages", sa.Column("organization_id", sa.String(), nullable=True)) @@ -41,48 +42,32 @@ def upgrade() -> None: # Set `organization_id` as non-nullable after population op.alter_column("passages", "organization_id", nullable=False) - op.alter_column('passages', 'text', - existing_type=sa.VARCHAR(), - nullable=False) - op.alter_column('passages', 'embedding_config', - existing_type=postgresql.JSON(astext_type=sa.Text()), - nullable=False) - op.alter_column('passages', 'metadata_', - existing_type=postgresql.JSON(astext_type=sa.Text()), - nullable=False) - op.alter_column('passages', 'created_at', - existing_type=postgresql.TIMESTAMP(timezone=True), - nullable=False) - op.drop_index('passage_idx_user', table_name='passages') - op.create_foreign_key(None, 'passages', 'organizations', ['organization_id'], ['id']) - op.create_foreign_key(None, 'passages', 'agents', ['agent_id'], ['id']) - op.create_foreign_key(None, 'passages', 'files', ['file_id'], ['id'], ondelete='CASCADE') - op.drop_column('passages', 'user_id') + op.alter_column("passages", "text", existing_type=sa.VARCHAR(), nullable=False) + op.alter_column("passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False) + op.alter_column("passages", "metadata_", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False) + op.alter_column("passages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False) + op.drop_index("passage_idx_user", table_name="passages") + op.create_foreign_key(None, "passages", "organizations", ["organization_id"], ["id"]) + op.create_foreign_key(None, "passages", "agents", ["agent_id"], ["id"]) + op.create_foreign_key(None, "passages", "files", ["file_id"], ["id"], ondelete="CASCADE") + op.drop_column("passages", "user_id") # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.add_column('passages', sa.Column('user_id', sa.VARCHAR(), autoincrement=False, nullable=False)) - op.drop_constraint(None, 'passages', type_='foreignkey') - op.drop_constraint(None, 'passages', type_='foreignkey') - op.drop_constraint(None, 'passages', type_='foreignkey') - op.create_index('passage_idx_user', 'passages', ['user_id', 'agent_id', 'file_id'], unique=False) - op.alter_column('passages', 'created_at', - existing_type=postgresql.TIMESTAMP(timezone=True), - nullable=True) - op.alter_column('passages', 'metadata_', - existing_type=postgresql.JSON(astext_type=sa.Text()), - nullable=True) - op.alter_column('passages', 'embedding_config', - existing_type=postgresql.JSON(astext_type=sa.Text()), - nullable=True) - op.alter_column('passages', 'text', - existing_type=sa.VARCHAR(), - nullable=True) - op.drop_column('passages', 'organization_id') - op.drop_column('passages', '_last_updated_by_id') - op.drop_column('passages', '_created_by_id') - op.drop_column('passages', 'is_deleted') - op.drop_column('passages', 'updated_at') + op.add_column("passages", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False)) + op.drop_constraint(None, "passages", type_="foreignkey") + op.drop_constraint(None, "passages", type_="foreignkey") + op.drop_constraint(None, "passages", type_="foreignkey") + op.create_index("passage_idx_user", "passages", ["user_id", "agent_id", "file_id"], unique=False) + op.alter_column("passages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=True) + op.alter_column("passages", "metadata_", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True) + op.alter_column("passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True) + op.alter_column("passages", "text", existing_type=sa.VARCHAR(), nullable=True) + op.drop_column("passages", "organization_id") + op.drop_column("passages", "_last_updated_by_id") + op.drop_column("passages", "_created_by_id") + op.drop_column("passages", "is_deleted") + op.drop_column("passages", "updated_at") # ### end Alembic commands ### diff --git a/alembic/versions/c85a3d07c028_move_files_to_orm.py b/alembic/versions/c85a3d07c028_move_files_to_orm.py index b05d7930..7d853e6a 100644 --- a/alembic/versions/c85a3d07c028_move_files_to_orm.py +++ b/alembic/versions/c85a3d07c028_move_files_to_orm.py @@ -9,7 +9,6 @@ Create Date: 2024-11-12 13:58:57.221081 from typing import Sequence, Union import sqlalchemy as sa - from alembic import op # revision identifiers, used by Alembic. diff --git a/alembic/versions/cda66b6cb0d6_move_sources_to_orm.py b/alembic/versions/cda66b6cb0d6_move_sources_to_orm.py index f46bef6b..ae10edbd 100644 --- a/alembic/versions/cda66b6cb0d6_move_sources_to_orm.py +++ b/alembic/versions/cda66b6cb0d6_move_sources_to_orm.py @@ -9,9 +9,8 @@ Create Date: 2024-11-07 13:29:57.186107 from typing import Sequence, Union import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - from alembic import op +from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision: str = "cda66b6cb0d6" diff --git a/alembic/versions/d05669b60ebe_migrate_agents_to_orm.py b/alembic/versions/d05669b60ebe_migrate_agents_to_orm.py index d03652c8..61b24f5e 100644 --- a/alembic/versions/d05669b60ebe_migrate_agents_to_orm.py +++ b/alembic/versions/d05669b60ebe_migrate_agents_to_orm.py @@ -9,9 +9,8 @@ Create Date: 2024-12-12 10:25:31.825635 from typing import Sequence, Union import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - from alembic import op +from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision: str = "d05669b60ebe" diff --git a/alembic/versions/d14ae606614c_move_organizations_users_tools_to_orm.py b/alembic/versions/d14ae606614c_move_organizations_users_tools_to_orm.py index e8733313..5c0dab19 100644 --- a/alembic/versions/d14ae606614c_move_organizations_users_tools_to_orm.py +++ b/alembic/versions/d14ae606614c_move_organizations_users_tools_to_orm.py @@ -8,11 +8,10 @@ Create Date: 2024-11-05 15:03:12.350096 from typing import Sequence, Union -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - import letta +import sqlalchemy as sa from alembic import op +from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision: str = "d14ae606614c" diff --git a/alembic/versions/e1a625072dbf_tweak_created_at_field_for_messages.py b/alembic/versions/e1a625072dbf_tweak_created_at_field_for_messages.py index fb425db3..4fd8abd2 100644 --- a/alembic/versions/e1a625072dbf_tweak_created_at_field_for_messages.py +++ b/alembic/versions/e1a625072dbf_tweak_created_at_field_for_messages.py @@ -8,9 +8,8 @@ Create Date: 2024-12-07 14:28:27.643583 from typing import Sequence, Union -from sqlalchemy.dialects import postgresql - from alembic import op +from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision: str = "e1a625072dbf" diff --git a/alembic/versions/f7507eab4bb9_migrate_blocks_to_orm_model.py b/alembic/versions/f7507eab4bb9_migrate_blocks_to_orm_model.py index 9e7fa270..37a67d88 100644 --- a/alembic/versions/f7507eab4bb9_migrate_blocks_to_orm_model.py +++ b/alembic/versions/f7507eab4bb9_migrate_blocks_to_orm_model.py @@ -9,7 +9,6 @@ Create Date: 2024-11-18 15:40:13.149438 from typing import Sequence, Union import sqlalchemy as sa - from alembic import op # revision identifiers, used by Alembic. diff --git a/alembic/versions/f81ceea2c08d_create_sandbox_config_and_sandbox_env_.py b/alembic/versions/f81ceea2c08d_create_sandbox_config_and_sandbox_env_.py index 55332bfc..32a1b677 100644 --- a/alembic/versions/f81ceea2c08d_create_sandbox_config_and_sandbox_env_.py +++ b/alembic/versions/f81ceea2c08d_create_sandbox_config_and_sandbox_env_.py @@ -9,7 +9,6 @@ Create Date: 2024-11-14 17:51:27.263561 from typing import Sequence, Union import sqlalchemy as sa - from alembic import op # revision identifiers, used by Alembic. diff --git a/examples/swarm/simple.py b/examples/swarm/simple.py index 8e10c486..7bedb032 100644 --- a/examples/swarm/simple.py +++ b/examples/swarm/simple.py @@ -1,7 +1,6 @@ import typer -from swarm import Swarm - from letta import EmbeddingConfig, LLMConfig +from swarm import Swarm """ This is an example of how to implement the basic example provided by OpenAI for tranferring a conversation between two agents: diff --git a/examples/swarm/swarm.py b/examples/swarm/swarm.py index ef080806..40552997 100644 --- a/examples/swarm/swarm.py +++ b/examples/swarm/swarm.py @@ -2,7 +2,6 @@ import json from typing import List, Optional import typer - from letta import AgentState, EmbeddingConfig, LLMConfig, create_client from letta.schemas.agent import AgentType from letta.schemas.memory import BasicBlockMemory, Block diff --git a/examples/tool_rule_usage.py b/examples/tool_rule_usage.py index 7d04df6c..4e0193f8 100644 --- a/examples/tool_rule_usage.py +++ b/examples/tool_rule_usage.py @@ -5,8 +5,8 @@ from letta import create_client from letta.schemas.letta_message import ToolCallMessage from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule from tests.helpers.endpoints_helper import ( - assert_invoked_send_message_with_keyword, - setup_agent, + assert_invoked_send_message_with_keyword, + setup_agent, ) from tests.helpers.utils import cleanup from tests.test_model_letta_perfomance import llm_config_dir diff --git a/letta/benchmark/benchmark.py b/letta/benchmark/benchmark.py index 7109210e..b7c01ab7 100644 --- a/letta/benchmark/benchmark.py +++ b/letta/benchmark/benchmark.py @@ -5,7 +5,6 @@ import uuid from typing import Annotated, Union import typer - from letta import LocalClient, RESTClient, create_client from letta.benchmark.constants import HUMAN, PERSONA, PROMPTS, TRIES from letta.config import LettaConfig diff --git a/letta/cli/cli.py b/letta/cli/cli.py index e5a649f7..efd58df7 100644 --- a/letta/cli/cli.py +++ b/letta/cli/cli.py @@ -3,10 +3,9 @@ import sys from enum import Enum from typing import Annotated, Optional +import letta.utils as utils import questionary import typer - -import letta.utils as utils from letta import create_client from letta.agent import Agent, save_agent from letta.config import LettaConfig diff --git a/letta/cli/cli_config.py b/letta/cli/cli_config.py index 8278d553..f70b466b 100644 --- a/letta/cli/cli_config.py +++ b/letta/cli/cli_config.py @@ -5,11 +5,10 @@ from typing import Annotated, List, Optional import questionary import typer +from letta import utils from prettytable.colortable import ColorTable, Themes from tqdm import tqdm -from letta import utils - app = typer.Typer() diff --git a/letta/cli/cli_load.py b/letta/cli/cli_load.py index b27da4d8..e1d7b1b8 100644 --- a/letta/cli/cli_load.py +++ b/letta/cli/cli_load.py @@ -13,7 +13,6 @@ from typing import Annotated, List, Optional import questionary import typer - from letta import create_client from letta.data_sources.connectors import DirectoryConnector diff --git a/letta/client/client.py b/letta/client/client.py index bb6d2f0f..af81a562 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -2,9 +2,8 @@ import logging import time from typing import Callable, Dict, Generator, List, Optional, Union -import requests - import letta.utils +import requests from letta.constants import ( ADMIN_PREFIX, BASE_MEMORY_TOOLS, diff --git a/letta/client/streaming.py b/letta/client/streaming.py index a364ada6..7d7e1129 100644 --- a/letta/client/streaming.py +++ b/letta/client/streaming.py @@ -3,14 +3,13 @@ from typing import Generator import httpx from httpx_sse import SSEError, connect_sse - from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING from letta.errors import LLMError from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import ( + ReasoningMessage, ToolCallMessage, ToolReturnMessage, - ReasoningMessage, ) from letta.schemas.letta_response import LettaStreamingResponse from letta.schemas.usage import LettaUsageStatistics diff --git a/letta/client/utils.py b/letta/client/utils.py index 1ff28f8c..871269c7 100644 --- a/letta/client/utils.py +++ b/letta/client/utils.py @@ -3,12 +3,11 @@ from datetime import datetime from typing import Optional from IPython.display import HTML, display -from sqlalchemy.testing.plugin.plugin_base import warnings - from letta.local_llm.constants import ( ASSISTANT_MESSAGE_CLI_SYMBOL, INNER_THOUGHTS_CLI_SYMBOL, ) +from sqlalchemy.testing.plugin.plugin_base import warnings def pprint(messages): diff --git a/letta/data_sources/connectors.py b/letta/data_sources/connectors.py index f9fdd261..3d577810 100644 --- a/letta/data_sources/connectors.py +++ b/letta/data_sources/connectors.py @@ -1,7 +1,6 @@ from typing import Dict, Iterator, List, Tuple import typer - from letta.data_sources.connectors_helper import ( assert_all_files_exist_locally, extract_metadata_from_files, @@ -14,6 +13,7 @@ from letta.schemas.source import Source from letta.services.passage_manager import PassageManager from letta.services.source_manager import SourceManager + class DataConnector: """ Base class for data connectors that can be extended to generate files and passages from a custom data source. diff --git a/letta/embeddings.py b/letta/embeddings.py index 0d82d158..e27ee1ad 100644 --- a/letta/embeddings.py +++ b/letta/embeddings.py @@ -3,7 +3,6 @@ from typing import Any, List, Optional import numpy as np import tiktoken - from letta.constants import ( EMBEDDING_TO_TOKENIZER_DEFAULT, EMBEDDING_TO_TOKENIZER_MAP, diff --git a/letta/errors.py b/letta/errors.py index 4957139b..2c4703c0 100644 --- a/letta/errors.py +++ b/letta/errors.py @@ -52,12 +52,10 @@ class LettaConfigurationError(LettaError): class LettaAgentNotFoundError(LettaError): """Error raised when an agent is not found.""" - pass class LettaUserNotFoundError(LettaError): """Error raised when a user is not found.""" - pass class LLMError(LettaError): diff --git a/letta/functions/function_sets/extras.py b/letta/functions/function_sets/extras.py index f29f85ba..c911ffca 100644 --- a/letta/functions/function_sets/extras.py +++ b/letta/functions/function_sets/extras.py @@ -3,7 +3,6 @@ import uuid from typing import Optional import requests - from letta.constants import ( MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index 02919b2e..24eff307 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -1,8 +1,6 @@ import json from typing import List, Optional, Union -from pydantic import BaseModel, Field - from letta.schemas.enums import ToolRuleType from letta.schemas.tool_rule import ( BaseToolRule, @@ -11,6 +9,7 @@ from letta.schemas.tool_rule import ( InitToolRule, TerminalToolRule, ) +from pydantic import BaseModel, Field class ToolRuleValidationError(Exception): @@ -50,7 +49,6 @@ class ToolRulesSolver(BaseModel): assert isinstance(rule, TerminalToolRule) self.terminal_tool_rules.append(rule) - def update_tool_usage(self, tool_name: str): """Update the internal state to track the last tool called.""" self.last_tool_name = tool_name @@ -88,7 +86,7 @@ class ToolRulesSolver(BaseModel): return any(rule.tool_name == tool_name for rule in self.tool_rules) def validate_conditional_tool(self, rule: ConditionalToolRule): - ''' + """ Validate a conditional tool rule Args: @@ -96,13 +94,13 @@ class ToolRulesSolver(BaseModel): Raises: ToolRuleValidationError: If the rule is invalid - ''' + """ if len(rule.child_output_mapping) == 0: raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.") return True def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str: - ''' + """ Parse function response to determine which child tool to use based on the mapping Args: @@ -111,7 +109,7 @@ class ToolRulesSolver(BaseModel): Returns: str: The name of the child tool to use next - ''' + """ json_response = json.loads(last_function_response) function_output = json_response["message"] diff --git a/letta/interface.py b/letta/interface.py index aac10453..41276b17 100644 --- a/letta/interface.py +++ b/letta/interface.py @@ -3,7 +3,6 @@ from abc import ABC, abstractmethod from typing import List, Optional from colorama import Fore, Style, init - from letta.constants import CLI_WARNING_PREFIX from letta.local_llm.constants import ( ASSISTANT_MESSAGE_CLI_SYMBOL, diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py index 4cca920a..78980d52 100644 --- a/letta/llm_api/anthropic.py +++ b/letta/llm_api/anthropic.py @@ -102,13 +102,9 @@ def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]: formatted_tools = [] for tool in tools: formatted_tool = { - "name" : tool.function.name, - "description" : tool.function.description, - "input_schema" : tool.function.parameters or { - "type": "object", - "properties": {}, - "required": [] - } + "name": tool.function.name, + "description": tool.function.description, + "input_schema": tool.function.parameters or {"type": "object", "properties": {}, "required": []}, } formatted_tools.append(formatted_tool) @@ -346,7 +342,7 @@ def anthropic_chat_completions_request( data["tool_choice"] = { "type": "tool", # Changed from "function" to "tool" "name": anthropic_tools[0]["name"], # Directly specify name without nested "function" object - "disable_parallel_tool_use": True # Force single tool use + "disable_parallel_tool_use": True, # Force single tool use } # Move 'system' to the top level diff --git a/letta/llm_api/azure_openai.py b/letta/llm_api/azure_openai.py index e60b547b..047f2f86 100644 --- a/letta/llm_api/azure_openai.py +++ b/letta/llm_api/azure_openai.py @@ -1,7 +1,6 @@ from collections import defaultdict import requests - from letta.llm_api.helpers import make_post_request from letta.schemas.llm_config import LLMConfig from letta.schemas.openai.chat_completion_response import ChatCompletionResponse diff --git a/letta/llm_api/cohere.py b/letta/llm_api/cohere.py index 1e8b5fd6..844e618f 100644 --- a/letta/llm_api/cohere.py +++ b/letta/llm_api/cohere.py @@ -3,7 +3,6 @@ import uuid from typing import List, Optional, Union import requests - from letta.local_llm.utils import count_tokens from letta.schemas.message import Message from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool diff --git a/letta/llm_api/google_ai.py b/letta/llm_api/google_ai.py index 57071a23..f16fc447 100644 --- a/letta/llm_api/google_ai.py +++ b/letta/llm_api/google_ai.py @@ -2,7 +2,6 @@ import uuid from typing import List, Optional, Tuple import requests - from letta.constants import NON_USER_MSG_PREFIX from letta.llm_api.helpers import make_post_request from letta.local_llm.json_parser import clean_json_string_extra_backslash diff --git a/letta/llm_api/helpers.py b/letta/llm_api/helpers.py index 1244b6ff..25f77e24 100644 --- a/letta/llm_api/helpers.py +++ b/letta/llm_api/helpers.py @@ -5,7 +5,6 @@ from collections import OrderedDict from typing import Any, List, Union import requests - from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice from letta.utils import json_dumps, printd diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 578779d7..7b39b31a 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -3,7 +3,6 @@ import time from typing import List, Optional, Union import requests - from letta.constants import CLI_WARNING_PREFIX from letta.errors import LettaConfigurationError, RateLimitExceededError from letta.llm_api.anthropic import anthropic_chat_completions_request @@ -255,12 +254,7 @@ def create( tool_call = None if force_tool_call is not None: - tool_call = { - "type": "function", - "function": { - "name": force_tool_call - } - } + tool_call = {"type": "function", "function": {"name": force_tool_call}} assert functions is not None return anthropic_chat_completions_request( diff --git a/letta/llm_api/mistral.py b/letta/llm_api/mistral.py index 932cf874..b53d76ed 100644 --- a/letta/llm_api/mistral.py +++ b/letta/llm_api/mistral.py @@ -1,5 +1,4 @@ import requests - from letta.utils import printd, smart_urljoin diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index 813ae68d..407ac6ae 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -6,7 +6,6 @@ import httpx import requests from httpx_sse import connect_sse from httpx_sse._exceptions import SSEError - from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING from letta.errors import LLMError from letta.llm_api.helpers import ( diff --git a/letta/local_llm/chat_completion_proxy.py b/letta/local_llm/chat_completion_proxy.py index c6dbd4a1..9eba5ca8 100644 --- a/letta/local_llm/chat_completion_proxy.py +++ b/letta/local_llm/chat_completion_proxy.py @@ -3,7 +3,6 @@ import uuid import requests - from letta.constants import CLI_WARNING_PREFIX from letta.errors import LocalLLMConnectionError, LocalLLMError from letta.local_llm.constants import DEFAULT_WRAPPER diff --git a/letta/local_llm/grammars/gbnf_grammar_generator.py b/letta/local_llm/grammars/gbnf_grammar_generator.py index ddd62817..1cb793ef 100644 --- a/letta/local_llm/grammars/gbnf_grammar_generator.py +++ b/letta/local_llm/grammars/gbnf_grammar_generator.py @@ -19,9 +19,8 @@ from typing import ( ) from docstring_parser import parse -from pydantic import BaseModel, create_model - from letta.utils import json_dumps +from pydantic import BaseModel, create_model class PydanticDataType(Enum): diff --git a/letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py b/letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py index 19f25668..4b41d7a5 100644 --- a/letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py +++ b/letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py @@ -1,5 +1,4 @@ import yaml - from letta.utils import json_dumps, json_loads from ...errors import LLMJSONParsingError diff --git a/letta/local_llm/utils.py b/letta/local_llm/utils.py index b0529c35..796b3c57 100644 --- a/letta/local_llm/utils.py +++ b/letta/local_llm/utils.py @@ -2,15 +2,14 @@ import os import warnings from typing import List, Union -import requests -import tiktoken - import letta.local_llm.llm_chat_completion_wrappers.airoboros as airoboros import letta.local_llm.llm_chat_completion_wrappers.chatml as chatml import letta.local_llm.llm_chat_completion_wrappers.configurable_wrapper as configurable_wrapper import letta.local_llm.llm_chat_completion_wrappers.dolphin as dolphin import letta.local_llm.llm_chat_completion_wrappers.llama3 as llama3 import letta.local_llm.llm_chat_completion_wrappers.zephyr as zephyr +import requests +import tiktoken from letta.schemas.openai.chat_completion_request import Tool, ToolCall diff --git a/letta/main.py b/letta/main.py index de1b4028..bbcd56e5 100644 --- a/letta/main.py +++ b/letta/main.py @@ -2,14 +2,12 @@ import os import sys import traceback -import questionary -import requests -import typer -from rich.console import Console - import letta.agent as agent import letta.errors as errors import letta.system as system +import questionary +import requests +import typer # import benchmark from letta import create_client @@ -22,6 +20,7 @@ from letta.constants import FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE # from letta.interface import CLIInterface as interface # for printing to terminal from letta.streaming_interface import AgentRefreshStreamingInterface +from rich.console import Console # interface = interface() diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index 8a0f0c77..e083efce 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -7,7 +7,7 @@ from letta.orm.file import FileMetadata from letta.orm.job import Job from letta.orm.message import Message from letta.orm.organization import Organization -from letta.orm.passage import BasePassage, AgentPassage, SourcePassage +from letta.orm.passage import AgentPassage, BasePassage, SourcePassage from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable from letta.orm.source import Source from letta.orm.sources_agents import SourcesAgents diff --git a/letta/orm/agent.py b/letta/orm/agent.py index c4645c3e..de9dd5ce 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -1,9 +1,6 @@ import uuid from typing import TYPE_CHECKING, List, Optional -from sqlalchemy import JSON, String, UniqueConstraint -from sqlalchemy.orm import Mapped, mapped_column, relationship - from letta.orm.block import Block from letta.orm.custom_columns import ( EmbeddingConfigColumn, @@ -20,6 +17,8 @@ from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory from letta.schemas.tool_rule import ToolRule +from sqlalchemy import JSON, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship if TYPE_CHECKING: from letta.orm.agents_tags import AgentsTags diff --git a/letta/orm/agents_tags.py b/letta/orm/agents_tags.py index 76ff9011..5041f629 100644 --- a/letta/orm/agents_tags.py +++ b/letta/orm/agents_tags.py @@ -1,8 +1,7 @@ +from letta.orm.base import Base from sqlalchemy import ForeignKey, String, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship -from letta.orm.base import Base - class AgentsTags(Base): __tablename__ = "agents_tags" diff --git a/letta/orm/block.py b/letta/orm/block.py index 99cfa29b..6759596d 100644 --- a/letta/orm/block.py +++ b/letta/orm/block.py @@ -1,14 +1,13 @@ from typing import TYPE_CHECKING, Optional, Type -from sqlalchemy import JSON, BigInteger, Integer, UniqueConstraint, event -from sqlalchemy.orm import Mapped, attributes, mapped_column, relationship - from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT from letta.orm.blocks_agents import BlocksAgents from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import Human, Persona +from sqlalchemy import JSON, BigInteger, Integer, UniqueConstraint, event +from sqlalchemy.orm import Mapped, attributes, mapped_column, relationship if TYPE_CHECKING: from letta.orm import Organization diff --git a/letta/orm/blocks_agents.py b/letta/orm/blocks_agents.py index 4774783b..86a75e9a 100644 --- a/letta/orm/blocks_agents.py +++ b/letta/orm/blocks_agents.py @@ -1,8 +1,7 @@ +from letta.orm.base import Base from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column -from letta.orm.base import Base - class BlocksAgents(Base): """Agents must have one or many blocks to make up their core memory.""" diff --git a/letta/orm/custom_columns.py b/letta/orm/custom_columns.py index f53169d9..b20760f9 100644 --- a/letta/orm/custom_columns.py +++ b/letta/orm/custom_columns.py @@ -2,14 +2,18 @@ import base64 from typing import List, Union import numpy as np -from sqlalchemy import JSON -from sqlalchemy.types import BINARY, TypeDecorator - from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import ToolRuleType from letta.schemas.llm_config import LLMConfig from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction -from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule +from letta.schemas.tool_rule import ( + ChildToolRule, + ConditionalToolRule, + InitToolRule, + TerminalToolRule, +) +from sqlalchemy import JSON +from sqlalchemy.types import BINARY, TypeDecorator class EmbeddingConfigColumn(TypeDecorator): diff --git a/letta/orm/file.py b/letta/orm/file.py index 45470c6c..4216a01a 100644 --- a/letta/orm/file.py +++ b/letta/orm/file.py @@ -1,16 +1,16 @@ -from typing import TYPE_CHECKING, Optional, List - -from sqlalchemy import Integer, String -from sqlalchemy.orm import Mapped, mapped_column, relationship +from typing import TYPE_CHECKING, List, Optional from letta.orm.mixins import OrganizationMixin, SourceMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.file import FileMetadata as PydanticFileMetadata +from sqlalchemy import Integer, String +from sqlalchemy.orm import Mapped, mapped_column, relationship if TYPE_CHECKING: from letta.orm.organization import Organization - from letta.orm.source import Source from letta.orm.passage import SourcePassage + from letta.orm.source import Source + class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin): """Represents metadata for an uploaded file.""" @@ -28,4 +28,6 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin") source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin") - source_passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan") + source_passages: Mapped[List["SourcePassage"]] = relationship( + "SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan" + ) diff --git a/letta/orm/job.py b/letta/orm/job.py index d95abe44..fabdb918 100644 --- a/letta/orm/job.py +++ b/letta/orm/job.py @@ -1,13 +1,12 @@ from datetime import datetime from typing import TYPE_CHECKING, Optional -from sqlalchemy import JSON, String -from sqlalchemy.orm import Mapped, mapped_column, relationship - from letta.orm.mixins import UserMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.enums import JobStatus from letta.schemas.job import Job as PydanticJob +from sqlalchemy import JSON, String +from sqlalchemy.orm import Mapped, mapped_column, relationship if TYPE_CHECKING: from letta.orm.user import User diff --git a/letta/orm/message.py b/letta/orm/message.py index a8bbb900..6e2194f5 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -1,13 +1,12 @@ from typing import Optional -from sqlalchemy import Index -from sqlalchemy.orm import Mapped, mapped_column, relationship - from letta.orm.custom_columns import ToolCallColumn from letta.orm.mixins import AgentMixin, OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completions import ToolCall +from sqlalchemy import Index +from sqlalchemy.orm import Mapped, mapped_column, relationship class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): diff --git a/letta/orm/mixins.py b/letta/orm/mixins.py index 328772d7..3e82f4d6 100644 --- a/letta/orm/mixins.py +++ b/letta/orm/mixins.py @@ -1,11 +1,10 @@ from typing import Optional from uuid import UUID +from letta.orm.base import Base from sqlalchemy import ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column -from letta.orm.base import Base - def is_valid_uuid4(uuid_string: str) -> bool: """Check if a string is a valid UUID4.""" @@ -31,6 +30,7 @@ class UserMixin(Base): user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id")) + class AgentMixin(Base): """Mixin for models that belong to an agent.""" @@ -38,6 +38,7 @@ class AgentMixin(Base): agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE")) + class FileMixin(Base): """Mixin for models that belong to a file.""" diff --git a/letta/orm/organization.py b/letta/orm/organization.py index 9a71a09b..9bcc09b2 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -1,9 +1,8 @@ from typing import TYPE_CHECKING, List, Union -from sqlalchemy.orm import Mapped, mapped_column, relationship - from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.organization import Organization as PydanticOrganization +from sqlalchemy.orm import Mapped, mapped_column, relationship if TYPE_CHECKING: @@ -38,19 +37,11 @@ class Organization(SqlalchemyBase): agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan") messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan") source_passages: Mapped[List["SourcePassage"]] = relationship( - "SourcePassage", - back_populates="organization", - cascade="all, delete-orphan" - ) - agent_passages: Mapped[List["AgentPassage"]] = relationship( - "AgentPassage", - back_populates="organization", - cascade="all, delete-orphan" + "SourcePassage", back_populates="organization", cascade="all, delete-orphan" ) + agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="organization", cascade="all, delete-orphan") @property def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]: """Convenience property to get all passages""" return self.source_passages + self.agent_passages - - diff --git a/letta/orm/passage.py b/letta/orm/passage.py index 492c6021..8237ed61 100644 --- a/letta/orm/passage.py +++ b/letta/orm/passage.py @@ -1,8 +1,5 @@ from typing import TYPE_CHECKING -from sqlalchemy import JSON, Column, Index -from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship - from letta.config import LettaConfig from letta.constants import MAX_EMBEDDING_DIM from letta.orm.custom_columns import CommonVector, EmbeddingConfigColumn @@ -10,6 +7,8 @@ from letta.orm.mixins import AgentMixin, FileMixin, OrganizationMixin, SourceMix from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.passage import Passage as PydanticPassage from letta.settings import settings +from sqlalchemy import JSON, Column, Index +from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship config = LettaConfig() diff --git a/letta/orm/sandbox_config.py b/letta/orm/sandbox_config.py index aa8e07dc..0b0cad17 100644 --- a/letta/orm/sandbox_config.py +++ b/letta/orm/sandbox_config.py @@ -1,10 +1,5 @@ from typing import TYPE_CHECKING, Dict, List, Optional -from sqlalchemy import JSON -from sqlalchemy import Enum as SqlEnum -from sqlalchemy import String, UniqueConstraint -from sqlalchemy.orm import Mapped, mapped_column, relationship - from letta.orm.mixins import OrganizationMixin, SandboxConfigMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig @@ -12,6 +7,10 @@ from letta.schemas.sandbox_config import ( SandboxEnvironmentVariable as PydanticSandboxEnvironmentVariable, ) from letta.schemas.sandbox_config import SandboxType +from sqlalchemy import JSON +from sqlalchemy import Enum as SqlEnum +from sqlalchemy import String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship if TYPE_CHECKING: from letta.orm.organization import Organization diff --git a/letta/orm/source.py b/letta/orm/source.py index e7443ea6..838f4c78 100644 --- a/letta/orm/source.py +++ b/letta/orm/source.py @@ -1,14 +1,13 @@ from typing import TYPE_CHECKING, List, Optional -from sqlalchemy import JSON -from sqlalchemy.orm import Mapped, mapped_column, relationship - from letta.orm import FileMetadata from letta.orm.custom_columns import EmbeddingConfigColumn from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.source import Source as PydanticSource +from sqlalchemy import JSON +from sqlalchemy.orm import Mapped, mapped_column, relationship if TYPE_CHECKING: from letta.orm.agent import Agent diff --git a/letta/orm/sources_agents.py b/letta/orm/sources_agents.py index ffe8a9d0..2da5003f 100644 --- a/letta/orm/sources_agents.py +++ b/letta/orm/sources_agents.py @@ -1,8 +1,7 @@ +from letta.orm.base import Base from sqlalchemy import ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column -from letta.orm.base import Base - class SourcesAgents(Base): """Agents can have zero to many sources""" diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 6879c74b..857d0990 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -3,10 +3,6 @@ from enum import Enum from functools import wraps from typing import TYPE_CHECKING, List, Literal, Optional -from sqlalchemy import String, desc, func, or_, select -from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError -from sqlalchemy.orm import Mapped, Session, mapped_column - from letta.log import get_logger from letta.orm.base import Base, CommonSqlalchemyMetaMixins from letta.orm.errors import ( @@ -16,6 +12,9 @@ from letta.orm.errors import ( UniqueConstraintViolationError, ) from letta.orm.sqlite_functions import adapt_array +from sqlalchemy import String, desc, func, or_, select +from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError +from sqlalchemy.orm import Mapped, Session, mapped_column if TYPE_CHECKING: from pydantic import BaseModel diff --git a/letta/orm/sqlite_functions.py b/letta/orm/sqlite_functions.py index a5b741aa..50c257a0 100644 --- a/letta/orm/sqlite_functions.py +++ b/letta/orm/sqlite_functions.py @@ -1,12 +1,12 @@ +import base64 +import sqlite3 from typing import Optional, Union -import base64 import numpy as np +from letta.constants import MAX_EMBEDDING_DIM from sqlalchemy import event from sqlalchemy.engine import Engine -import sqlite3 -from letta.constants import MAX_EMBEDDING_DIM def adapt_array(arr): """ @@ -19,12 +19,13 @@ def adapt_array(arr): arr = np.array(arr, dtype=np.float32) elif not isinstance(arr, np.ndarray): raise ValueError(f"Unsupported type: {type(arr)}") - + # Convert to bytes and then base64 encode bytes_data = arr.tobytes() base64_data = base64.b64encode(bytes_data) return sqlite3.Binary(base64_data) + def convert_array(text): """ Converts binary back to numpy array @@ -38,23 +39,24 @@ def convert_array(text): # Handle both bytes and sqlite3.Binary binary_data = bytes(text) if isinstance(text, sqlite3.Binary) else text - + try: # First decode base64 decoded_data = base64.b64decode(binary_data) # Then convert to numpy array return np.frombuffer(decoded_data, dtype=np.float32) - except Exception as e: + except Exception: return None + def verify_embedding_dimension(embedding: np.ndarray, expected_dim: int = MAX_EMBEDDING_DIM) -> bool: """ Verifies that an embedding has the expected dimension - + Args: embedding: Input embedding array expected_dim: Expected embedding dimension (default: 4096) - + Returns: bool: True if dimension matches, False otherwise """ @@ -62,28 +64,27 @@ def verify_embedding_dimension(embedding: np.ndarray, expected_dim: int = MAX_EM return False return embedding.shape[0] == expected_dim + def validate_and_transform_embedding( - embedding: Union[bytes, sqlite3.Binary, list, np.ndarray], - expected_dim: int = MAX_EMBEDDING_DIM, - dtype: np.dtype = np.float32 + embedding: Union[bytes, sqlite3.Binary, list, np.ndarray], expected_dim: int = MAX_EMBEDDING_DIM, dtype: np.dtype = np.float32 ) -> Optional[np.ndarray]: """ Validates and transforms embeddings to ensure correct dimensionality. - + Args: embedding: Input embedding in various possible formats expected_dim: Expected embedding dimension (default 4096) dtype: NumPy dtype for the embedding (default float32) - + Returns: np.ndarray: Validated and transformed embedding - + Raises: ValueError: If embedding dimension doesn't match expected dimension """ if embedding is None: return None - + # Convert to numpy array based on input type if isinstance(embedding, (bytes, sqlite3.Binary)): vec = convert_array(embedding) @@ -93,48 +94,49 @@ def validate_and_transform_embedding( vec = embedding.astype(dtype) else: raise ValueError(f"Unsupported embedding type: {type(embedding)}") - + # Validate dimension if vec.shape[0] != expected_dim: - raise ValueError( - f"Invalid embedding dimension: got {vec.shape[0]}, expected {expected_dim}" - ) - + raise ValueError(f"Invalid embedding dimension: got {vec.shape[0]}, expected {expected_dim}") + return vec + def cosine_distance(embedding1, embedding2, expected_dim=MAX_EMBEDDING_DIM): """ Calculate cosine distance between two embeddings - + Args: embedding1: First embedding embedding2: Second embedding expected_dim: Expected embedding dimension (default 4096) - + Returns: float: Cosine distance """ - + if embedding1 is None or embedding2 is None: return 0.0 # Maximum distance if either embedding is None - + try: vec1 = validate_and_transform_embedding(embedding1, expected_dim) vec2 = validate_and_transform_embedding(embedding2, expected_dim) - except ValueError as e: + except ValueError: return 0.0 - + similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) distance = float(1.0 - similarity) - + return distance + @event.listens_for(Engine, "connect") def register_functions(dbapi_connection, connection_record): """Register SQLite functions""" if isinstance(dbapi_connection, sqlite3.Connection): dbapi_connection.create_function("cosine_distance", 2, cosine_distance) - + + # Register adapters and converters for numpy arrays sqlite3.register_adapter(np.ndarray, adapt_array) sqlite3.register_converter("ARRAY", convert_array) diff --git a/letta/orm/tool.py b/letta/orm/tool.py index a25c7ebb..50d47506 100644 --- a/letta/orm/tool.py +++ b/letta/orm/tool.py @@ -1,13 +1,12 @@ from typing import TYPE_CHECKING, List, Optional -from sqlalchemy import JSON, String, UniqueConstraint -from sqlalchemy.orm import Mapped, mapped_column, relationship - # TODO everything in functions should live in this model from letta.orm.enums import ToolSourceType from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.tool import Tool as PydanticTool +from sqlalchemy import JSON, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship if TYPE_CHECKING: from letta.orm.organization import Organization diff --git a/letta/orm/tools_agents.py b/letta/orm/tools_agents.py index 52c1e0a1..50668d94 100644 --- a/letta/orm/tools_agents.py +++ b/letta/orm/tools_agents.py @@ -1,8 +1,7 @@ +from letta.orm import Base from sqlalchemy import ForeignKey, String, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column -from letta.orm import Base - class ToolsAgents(Base): """Agents can have one or many tools associated with them.""" diff --git a/letta/orm/user.py b/letta/orm/user.py index 9f626b10..23c2b268 100644 --- a/letta/orm/user.py +++ b/letta/orm/user.py @@ -1,10 +1,9 @@ from typing import TYPE_CHECKING, List -from sqlalchemy.orm import Mapped, mapped_column, relationship - from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.user import User as PydanticUser +from sqlalchemy.orm import Mapped, mapped_column, relationship if TYPE_CHECKING: from letta.orm import Job, Organization diff --git a/letta/providers.py b/letta/providers.py index e8ebadfa..cb7bf72b 100644 --- a/letta/providers.py +++ b/letta/providers.py @@ -1,7 +1,5 @@ from typing import List, Optional -from pydantic import BaseModel, Field, model_validator - from letta.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW from letta.llm_api.azure_openai import ( get_azure_chat_completions_endpoint, @@ -10,6 +8,7 @@ from letta.llm_api.azure_openai import ( from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig +from pydantic import BaseModel, Field, model_validator class Provider(BaseModel): @@ -27,12 +26,11 @@ class Provider(BaseModel): def provider_tag(self) -> str: """String representation of the provider for display purposes""" raise NotImplementedError - + def get_handle(self, model_name: str) -> str: return f"{self.name}/{model_name}" - class LettaProvider(Provider): name: str = "letta" @@ -44,7 +42,7 @@ class LettaProvider(Provider): model_endpoint_type="openai", model_endpoint="https://inference.memgpt.ai", context_window=16384, - handle=self.get_handle("letta-free") + handle=self.get_handle("letta-free"), ) ] @@ -56,7 +54,7 @@ class LettaProvider(Provider): embedding_endpoint="https://embeddings.memgpt.ai", embedding_dim=1024, embedding_chunk_size=300, - handle=self.get_handle("letta-free") + handle=self.get_handle("letta-free"), ) ] @@ -121,7 +119,13 @@ class OpenAIProvider(Provider): # continue configs.append( - LLMConfig(model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size, handle=self.get_handle(model_name)) + LLMConfig( + model=model_name, + model_endpoint_type="openai", + model_endpoint=self.base_url, + context_window=context_window_size, + handle=self.get_handle(model_name), + ) ) # for OpenAI, sort in reverse order @@ -141,7 +145,7 @@ class OpenAIProvider(Provider): embedding_endpoint="https://api.openai.com/v1", embedding_dim=1536, embedding_chunk_size=300, - handle=self.get_handle("text-embedding-ada-002") + handle=self.get_handle("text-embedding-ada-002"), ) ] @@ -170,7 +174,7 @@ class AnthropicProvider(Provider): model_endpoint_type="anthropic", model_endpoint=self.base_url, context_window=model["context_window"], - handle=self.get_handle(model["name"]) + handle=self.get_handle(model["name"]), ) ) return configs @@ -203,7 +207,7 @@ class MistralProvider(Provider): model_endpoint_type="openai", model_endpoint=self.base_url, context_window=model["max_context_length"], - handle=self.get_handle(model["id"]) + handle=self.get_handle(model["id"]), ) ) @@ -259,7 +263,7 @@ class OllamaProvider(OpenAIProvider): model_endpoint=self.base_url, model_wrapper=self.default_prompt_formatter, context_window=context_window, - handle=self.get_handle(model["name"]) + handle=self.get_handle(model["name"]), ) ) return configs @@ -335,7 +339,7 @@ class OllamaProvider(OpenAIProvider): embedding_endpoint=self.base_url, embedding_dim=embedding_dim, embedding_chunk_size=300, - handle=self.get_handle(model["name"]) + handle=self.get_handle(model["name"]), ) ) return configs @@ -356,7 +360,11 @@ class GroqProvider(OpenAIProvider): continue configs.append( LLMConfig( - model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"], handle=self.get_handle(model["id"]) + model=model["id"], + model_endpoint_type="groq", + model_endpoint=self.base_url, + context_window=model["context_window"], + handle=self.get_handle(model["id"]), ) ) return configs @@ -424,7 +432,7 @@ class TogetherProvider(OpenAIProvider): model_endpoint=self.base_url, model_wrapper=self.default_prompt_formatter, context_window=context_window_size, - handle=self.get_handle(model_name) + handle=self.get_handle(model_name), ) ) @@ -505,7 +513,7 @@ class GoogleAIProvider(Provider): model_endpoint_type="google_ai", model_endpoint=self.base_url, context_window=self.get_model_context_window(model), - handle=self.get_handle(model) + handle=self.get_handle(model), ) ) return configs @@ -529,7 +537,7 @@ class GoogleAIProvider(Provider): embedding_endpoint=self.base_url, embedding_dim=768, embedding_chunk_size=300, # NOTE: max is 2048 - handle=self.get_handle(model) + handle=self.get_handle(model), ) ) return configs @@ -570,7 +578,8 @@ class AzureProvider(Provider): context_window_size = self.get_model_context_window(model_name) model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version) configs.append( - LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size), handle=self.get_handle(model_name) + LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size), + handle=self.get_handle(model_name), ) return configs @@ -591,7 +600,7 @@ class AzureProvider(Provider): embedding_endpoint=model_endpoint, embedding_dim=768, embedding_chunk_size=300, # NOTE: max is 2048 - handle=self.get_handle(model_name) + handle=self.get_handle(model_name), ) ) return configs @@ -625,7 +634,7 @@ class VLLMChatCompletionsProvider(Provider): model_endpoint_type="openai", model_endpoint=self.base_url, context_window=model["max_model_len"], - handle=self.get_handle(model["id"]) + handle=self.get_handle(model["id"]), ) ) return configs @@ -658,7 +667,7 @@ class VLLMCompletionsProvider(Provider): model_endpoint=self.base_url, model_wrapper=self.default_prompt_formatter, context_window=model["max_model_len"], - handle=self.get_handle(model["id"]) + handle=self.get_handle(model["id"]), ) ) return configs diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 03d40350..06fe5860 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -1,8 +1,6 @@ from enum import Enum from typing import Dict, List, Optional -from pydantic import BaseModel, Field, field_validator - from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE from letta.schemas.block import CreateBlock from letta.schemas.embedding_config import EmbeddingConfig @@ -15,6 +13,7 @@ from letta.schemas.source import Source from letta.schemas.tool import Tool from letta.schemas.tool_rule import ToolRule from letta.utils import create_random_username +from pydantic import BaseModel, Field, field_validator class AgentType(str, Enum): diff --git a/letta/schemas/block.py b/letta/schemas/block.py index 25e84b7d..bd905bd2 100644 --- a/letta/schemas/block.py +++ b/letta/schemas/block.py @@ -1,10 +1,9 @@ from typing import Optional -from pydantic import BaseModel, Field, model_validator -from typing_extensions import Self - from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT from letta.schemas.letta_base import LettaBase +from pydantic import BaseModel, Field, model_validator +from typing_extensions import Self # block of the LLM context diff --git a/letta/schemas/file.py b/letta/schemas/file.py index b43eb64c..0671bc3e 100644 --- a/letta/schemas/file.py +++ b/letta/schemas/file.py @@ -1,9 +1,8 @@ from datetime import datetime from typing import Optional -from pydantic import Field - from letta.schemas.letta_base import LettaBase +from pydantic import Field class FileMetadataBase(LettaBase): diff --git a/letta/schemas/job.py b/letta/schemas/job.py index 17c2b98d..b767b567 100644 --- a/letta/schemas/job.py +++ b/letta/schemas/job.py @@ -1,10 +1,9 @@ from datetime import datetime from typing import Optional -from pydantic import Field - from letta.schemas.enums import JobStatus from letta.schemas.letta_base import OrmMetadataBase +from pydantic import Field class JobBase(OrmMetadataBase): diff --git a/letta/schemas/letta_request.py b/letta/schemas/letta_request.py index 123d817c..73f956ff 100644 --- a/letta/schemas/letta_request.py +++ b/letta/schemas/letta_request.py @@ -1,9 +1,8 @@ from typing import List -from pydantic import BaseModel, Field - from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.message import MessageCreate +from pydantic import BaseModel, Field class LettaRequest(BaseModel): diff --git a/letta/schemas/letta_response.py b/letta/schemas/letta_response.py index c6a1e8be..7084c3ba 100644 --- a/letta/schemas/letta_response.py +++ b/letta/schemas/letta_response.py @@ -3,12 +3,11 @@ import json import re from typing import List, Union -from pydantic import BaseModel, Field - from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import LettaMessage, LettaMessageUnion from letta.schemas.usage import LettaUsageStatistics from letta.utils import json_dumps +from pydantic import BaseModel, Field # TODO: consider moving into own file diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 74bb8135..24022ece 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -4,8 +4,6 @@ import warnings from datetime import datetime, timezone from typing import List, Literal, Optional -from pydantic import BaseModel, Field, field_validator - from letta.constants import ( DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, @@ -16,16 +14,15 @@ from letta.schemas.enums import MessageRole from letta.schemas.letta_base import OrmMetadataBase from letta.schemas.letta_message import ( AssistantMessage, - ToolCall as LettaToolCall, - ToolCallMessage, - ToolReturnMessage, - ReasoningMessage, LettaMessage, + ReasoningMessage, SystemMessage, - UserMessage, ) +from letta.schemas.letta_message import ToolCall as LettaToolCall +from letta.schemas.letta_message import ToolCallMessage, ToolReturnMessage, UserMessage from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction from letta.utils import get_utc_time, is_utc_datetime, json_dumps +from pydantic import BaseModel, Field, field_validator def add_inner_thoughts_to_tool_call( diff --git a/letta/schemas/organization.py b/letta/schemas/organization.py index 35784ad0..6220b987 100644 --- a/letta/schemas/organization.py +++ b/letta/schemas/organization.py @@ -1,10 +1,9 @@ from datetime import datetime from typing import Optional -from pydantic import Field - from letta.schemas.letta_base import LettaBase from letta.utils import create_random_username, get_utc_time +from pydantic import Field class OrganizationBase(LettaBase): diff --git a/letta/schemas/passage.py b/letta/schemas/passage.py index c1ec13be..c613c2e1 100644 --- a/letta/schemas/passage.py +++ b/letta/schemas/passage.py @@ -1,12 +1,11 @@ from datetime import datetime from typing import Dict, List, Optional -from pydantic import Field, field_validator - from letta.constants import MAX_EMBEDDING_DIM from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.letta_base import OrmMetadataBase from letta.utils import get_utc_time +from pydantic import Field, field_validator class PassageBase(OrmMetadataBase): diff --git a/letta/schemas/sandbox_config.py b/letta/schemas/sandbox_config.py index f86233fa..993711cc 100644 --- a/letta/schemas/sandbox_config.py +++ b/letta/schemas/sandbox_config.py @@ -3,11 +3,10 @@ import json from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union -from pydantic import BaseModel, Field, model_validator - from letta.schemas.agent import AgentState from letta.schemas.letta_base import LettaBase, OrmMetadataBase from letta.settings import tool_settings +from pydantic import BaseModel, Field, model_validator # Sandbox Config diff --git a/letta/schemas/source.py b/letta/schemas/source.py index 0a458dfd..d6a14d79 100644 --- a/letta/schemas/source.py +++ b/letta/schemas/source.py @@ -1,10 +1,9 @@ from datetime import datetime from typing import Optional -from pydantic import Field - from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.letta_base import LettaBase +from pydantic import Field class BaseSource(LettaBase): diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index 997965ab..2837ee7f 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -1,7 +1,5 @@ from typing import Dict, List, Optional -from pydantic import Field, model_validator - from letta.constants import FUNCTION_RETURN_CHAR_LIMIT from letta.functions.functions import derive_openai_json_schema from letta.functions.helpers import ( @@ -11,6 +9,7 @@ from letta.functions.helpers import ( from letta.functions.schema_generator import generate_schema_from_args_schema_v2 from letta.schemas.letta_base import LettaBase from letta.schemas.openai.chat_completions import ToolCall +from pydantic import Field, model_validator class BaseTool(LettaBase): diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py index 259e5452..e88ff290 100644 --- a/letta/schemas/tool_rule.py +++ b/letta/schemas/tool_rule.py @@ -1,9 +1,8 @@ from typing import Any, Dict, List, Optional, Union -from pydantic import Field - from letta.schemas.enums import ToolRuleType from letta.schemas.letta_base import LettaBase +from pydantic import Field class BaseToolRule(LettaBase): @@ -25,6 +24,7 @@ class ConditionalToolRule(BaseToolRule): """ A ToolRule that conditionally maps to different child tools based on the output. """ + type: ToolRuleType = ToolRuleType.conditional default_child: Optional[str] = Field(None, description="The default child tool to be called. If None, any tool can be called.") child_output_mapping: Dict[Any, str] = Field(..., description="The output case to check for mapping") diff --git a/letta/schemas/usage.py b/letta/schemas/usage.py index 53cda8b2..d317cc5b 100644 --- a/letta/schemas/usage.py +++ b/letta/schemas/usage.py @@ -1,4 +1,5 @@ from typing import Literal + from pydantic import BaseModel, Field @@ -12,6 +13,7 @@ class LettaUsageStatistics(BaseModel): total_tokens (int): The total number of tokens processed by the agent. step_count (int): The number of steps taken by the agent. """ + message_type: Literal["usage_statistics"] = "usage_statistics" completion_tokens: int = Field(0, description="The number of tokens generated by the agent.") prompt_tokens: int = Field(0, description="The number of tokens in the prompt.") diff --git a/letta/schemas/user.py b/letta/schemas/user.py index 59a4594e..b2bcb933 100644 --- a/letta/schemas/user.py +++ b/letta/schemas/user.py @@ -1,10 +1,9 @@ from datetime import datetime from typing import Optional -from pydantic import Field - from letta.schemas.letta_base import LettaBase from letta.services.organization_manager import OrganizationManager +from pydantic import Field class UserBase(LettaBase): diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index 8cb9b27e..425d9bad 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -8,9 +8,6 @@ from typing import Optional import uvicorn from fastapi import FastAPI, Request from fastapi.responses import JSONResponse -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.middleware.cors import CORSMiddleware - from letta.__init__ import __version__ from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX from letta.errors import LettaAgentNotFoundError, LettaUserNotFoundError @@ -47,6 +44,8 @@ from letta.server.rest_api.routers.v1.users import ( from letta.server.rest_api.static_files import mount_static_files from letta.server.server import SyncServer from letta.settings import settings +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.cors import CORSMiddleware # TODO(ethan) # NOTE(charles): @ethan I had to add this to get the global as the bottom to work diff --git a/letta/server/rest_api/auth/index.py b/letta/server/rest_api/auth/index.py index 28d22435..7a29f6b4 100644 --- a/letta/server/rest_api/auth/index.py +++ b/letta/server/rest_api/auth/index.py @@ -2,11 +2,10 @@ from typing import Optional from uuid import UUID from fastapi import APIRouter -from pydantic import BaseModel, Field - from letta.log import get_logger from letta.server.rest_api.interface import QueuingInterface from letta.server.server import SyncServer +from pydantic import BaseModel, Field logger = get_logger(__name__) router = APIRouter() diff --git a/letta/server/rest_api/auth_token.py b/letta/server/rest_api/auth_token.py index 40e26d80..72387bbd 100644 --- a/letta/server/rest_api/auth_token.py +++ b/letta/server/rest_api/auth_token.py @@ -2,7 +2,6 @@ import uuid from fastapi import Depends, HTTPException from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer - from letta.server.server import SyncServer security = HTTPBearer() diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index 1e68ce6e..ba06e1d6 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -12,14 +12,14 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import ( AssistantMessage, + LegacyFunctionCallMessage, + LegacyLettaMessage, + LettaMessage, + ReasoningMessage, ToolCall, ToolCallDelta, ToolCallMessage, ToolReturnMessage, - ReasoningMessage, - LegacyFunctionCallMessage, - LegacyLettaMessage, - LettaMessage, ) from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse diff --git a/letta/server/rest_api/routers/openai/assistants/assistants.py b/letta/server/rest_api/routers/openai/assistants/assistants.py index 2b646f93..59e57cdf 100644 --- a/letta/server/rest_api/routers/openai/assistants/assistants.py +++ b/letta/server/rest_api/routers/openai/assistants/assistants.py @@ -1,7 +1,6 @@ from typing import List from fastapi import APIRouter, Body, HTTPException, Path, Query - from letta.constants import DEFAULT_PRESET from letta.schemas.openai.openai import AssistantFile, OpenAIAssistant from letta.server.rest_api.routers.openai.assistants.schemas import ( diff --git a/letta/server/rest_api/routers/openai/assistants/schemas.py b/letta/server/rest_api/routers/openai/assistants/schemas.py index b3cbf389..5a6649eb 100644 --- a/letta/server/rest_api/routers/openai/assistants/schemas.py +++ b/letta/server/rest_api/routers/openai/assistants/schemas.py @@ -1,7 +1,5 @@ from typing import List, Optional -from pydantic import BaseModel, Field - from letta.schemas.openai.openai import ( MessageRoleType, OpenAIMessage, @@ -9,6 +7,7 @@ from letta.schemas.openai.openai import ( ToolCall, ToolCallOutput, ) +from pydantic import BaseModel, Field class CreateAssistantRequest(BaseModel): diff --git a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py index deabcaf5..72b647db 100644 --- a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +++ b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py @@ -2,9 +2,8 @@ import json from typing import TYPE_CHECKING, Optional from fastapi import APIRouter, Body, Depends, Header, HTTPException - from letta.schemas.enums import MessageRole -from letta.schemas.letta_message import ToolCall, LettaMessage +from letta.schemas.letta_message import LettaMessage, ToolCall from letta.schemas.openai.chat_completion_request import ChatCompletionRequest from letta.schemas.openai.chat_completion_response import ( ChatCompletionResponse, diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 640a17d3..2b7f7804 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -14,8 +14,6 @@ from fastapi import ( status, ) from fastapi.responses import JSONResponse, StreamingResponse -from pydantic import Field - from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.log import get_logger from letta.orm.errors import NoResultFound @@ -49,6 +47,7 @@ from letta.schemas.user import User from letta.server.rest_api.interface import StreamingServerInterface from letta.server.rest_api.utils import get_letta_server, sse_async_generator from letta.server.server import SyncServer +from pydantic import Field # These can be forward refs, but because Fastapi needs them at runtime the must be imported normally diff --git a/letta/server/rest_api/routers/v1/blocks.py b/letta/server/rest_api/routers/v1/blocks.py index d9213233..9d2a937d 100644 --- a/letta/server/rest_api/routers/v1/blocks.py +++ b/letta/server/rest_api/routers/v1/blocks.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING, List, Optional from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, Response - from letta.orm.errors import NoResultFound from letta.schemas.block import Block, BlockUpdate, CreateBlock from letta.server.rest_api.utils import get_letta_server diff --git a/letta/server/rest_api/routers/v1/health.py b/letta/server/rest_api/routers/v1/health.py index 99fce66d..81bb41e0 100644 --- a/letta/server/rest_api/routers/v1/health.py +++ b/letta/server/rest_api/routers/v1/health.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING from fastapi import APIRouter - from letta.cli.cli import version from letta.schemas.health import Health diff --git a/letta/server/rest_api/routers/v1/jobs.py b/letta/server/rest_api/routers/v1/jobs.py index 4245d2f9..bec22625 100644 --- a/letta/server/rest_api/routers/v1/jobs.py +++ b/letta/server/rest_api/routers/v1/jobs.py @@ -1,7 +1,6 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Header, HTTPException, Query - from letta.orm.errors import NoResultFound from letta.schemas.enums import JobStatus from letta.schemas.job import Job diff --git a/letta/server/rest_api/routers/v1/llms.py b/letta/server/rest_api/routers/v1/llms.py index 4536ae49..f0bf9f42 100644 --- a/letta/server/rest_api/routers/v1/llms.py +++ b/letta/server/rest_api/routers/v1/llms.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING, List from fastapi import APIRouter, Depends - from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig from letta.server.rest_api.utils import get_letta_server diff --git a/letta/server/rest_api/routers/v1/organizations.py b/letta/server/rest_api/routers/v1/organizations.py index 2f4cdb1b..8fbead6d 100644 --- a/letta/server/rest_api/routers/v1/organizations.py +++ b/letta/server/rest_api/routers/v1/organizations.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING, List, Optional from fastapi import APIRouter, Body, Depends, HTTPException, Query - from letta.schemas.organization import Organization, OrganizationCreate from letta.server.rest_api.utils import get_letta_server diff --git a/letta/server/rest_api/routers/v1/sandbox_configs.py b/letta/server/rest_api/routers/v1/sandbox_configs.py index bf06bae7..17dcbb4f 100644 --- a/letta/server/rest_api/routers/v1/sandbox_configs.py +++ b/letta/server/rest_api/routers/v1/sandbox_configs.py @@ -1,7 +1,6 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query - from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index fb48d125..c51fee68 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -11,7 +11,6 @@ from fastapi import ( Query, UploadFile, ) - from letta.schemas.file import FileMetadata from letta.schemas.job import Job from letta.schemas.passage import Passage diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index ffc2b212..bb297a36 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -6,7 +6,6 @@ from composio.client.enums.base import EnumStringNotFound from composio.exceptions import ApiKeyNotProvidedError, ComposioSDKError from composio.tools.base.abs import InvalidClassDefinition from fastapi import APIRouter, Body, Depends, Header, HTTPException - from letta.errors import LettaToolCreateError from letta.orm.errors import UniqueConstraintViolationError from letta.schemas.letta_message import ToolReturnMessage diff --git a/letta/server/rest_api/routers/v1/users.py b/letta/server/rest_api/routers/v1/users.py index 27a2feeb..2dfeb06b 100644 --- a/letta/server/rest_api/routers/v1/users.py +++ b/letta/server/rest_api/routers/v1/users.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING, List, Optional from fastapi import APIRouter, Body, Depends, HTTPException, Query - from letta.schemas.user import User, UserCreate, UserUpdate from letta.server.rest_api.utils import get_letta_server diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 86a88990..650f0e3e 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -6,12 +6,11 @@ from enum import Enum from typing import AsyncGenerator, Optional, Union from fastapi import Header -from pydantic import BaseModel - from letta.errors import ContextWindowExceededError, RateLimitExceededError from letta.schemas.usage import LettaUsageStatistics from letta.server.rest_api.interface import StreamingServerInterface from letta.server.server import SyncServer +from pydantic import BaseModel # from letta.orm.user import User # from letta.orm.utilities import get_db_session @@ -102,6 +101,7 @@ def get_user_id(user_id: Optional[str] = Header(None, alias="user_id")) -> Optio def get_current_interface() -> StreamingServerInterface: return StreamingServerInterface + def log_error_to_sentry(e): import traceback diff --git a/letta/server/server.py b/letta/server/server.py index 85aee52b..17bc5771 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -7,13 +7,12 @@ from abc import abstractmethod from datetime import datetime from typing import Callable, List, Optional, Tuple, Union -from composio.client import Composio -from composio.client.collections import ActionModel, AppModel -from fastapi import HTTPException - import letta.constants as constants import letta.server.utils as server_utils import letta.system as system +from composio.client import Composio +from composio.client.collections import ActionModel, AppModel +from fastapi import HTTPException from letta.agent import Agent, save_agent from letta.chat_only_agent import ChatOnlyAgent from letta.credentials import LettaCredentials @@ -139,17 +138,16 @@ class Server(object): from contextlib import contextmanager +from letta.config import LettaConfig + +# NOTE: hack to see if single session management works +from letta.settings import model_settings, settings, tool_settings from rich.console import Console from rich.panel import Panel from rich.text import Text from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from letta.config import LettaConfig - -# NOTE: hack to see if single session management works -from letta.settings import model_settings, settings, tool_settings - config = LettaConfig.load() diff --git a/letta/server/ws_api/example_client.py b/letta/server/ws_api/example_client.py index a7fc57b5..fdfdd660 100644 --- a/letta/server/ws_api/example_client.py +++ b/letta/server/ws_api/example_client.py @@ -1,8 +1,7 @@ import asyncio -import websockets - import letta.server.ws_api.protocol as protocol +import websockets from letta.server.constants import WS_CLIENT_TIMEOUT, WS_DEFAULT_PORT from letta.server.utils import condition_to_stop_receiving, print_server_response diff --git a/letta/server/ws_api/server.py b/letta/server/ws_api/server.py index e2408dda..b3dcc2ba 100644 --- a/letta/server/ws_api/server.py +++ b/letta/server/ws_api/server.py @@ -3,9 +3,8 @@ import signal import sys import traceback -import websockets - import letta.server.ws_api.protocol as protocol +import websockets from letta.server.constants import WS_DEFAULT_PORT from letta.server.server import SyncServer from letta.server.ws_api.interface import SyncWebSocketInterface diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 4e6b80ec..b1040f31 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -2,8 +2,6 @@ from datetime import datetime from typing import Dict, List, Optional import numpy as np -from sqlalchemy import Select, func, literal, select, union_all - from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM from letta.embeddings import embedding_model from letta.log import get_logger @@ -40,6 +38,7 @@ from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager from letta.settings import settings from letta.utils import enforce_types, get_utc_time, united_diff +from sqlalchemy import Select, func, literal, select, union_all logger = get_logger(__name__) diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index d8554063..f80e0160 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -1,21 +1,15 @@ -from typing import List, Optional from datetime import datetime -import numpy as np +from typing import List, Optional -from sqlalchemy import select, union_all, literal - -from letta.constants import MAX_EMBEDDING_DIM from letta.embeddings import embedding_model, parse_and_chunk_text from letta.orm.errors import NoResultFound from letta.orm.passage import AgentPassage, SourcePassage from letta.schemas.agent import AgentState -from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.user import User as PydanticUser from letta.utils import enforce_types - class PassageManager: """Manager class to handle business logic related to Passages.""" diff --git a/letta/settings.py b/letta/settings.py index 1b6ba44b..3b9af801 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -1,11 +1,10 @@ from pathlib import Path from typing import Optional +from letta.local_llm.constants import DEFAULT_WRAPPER_NAME from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict -from letta.local_llm.constants import DEFAULT_WRAPPER_NAME - class ToolSettings(BaseSettings): composio_api_key: Optional[str] = None diff --git a/letta/streaming_interface.py b/letta/streaming_interface.py index e21e5e73..f7986530 100644 --- a/letta/streaming_interface.py +++ b/letta/streaming_interface.py @@ -3,11 +3,6 @@ from abc import ABC, abstractmethod from datetime import datetime from typing import List, Optional -# from colorama import Fore, Style, init -from rich.console import Console -from rich.live import Live -from rich.markup import escape - from letta.interface import CLIInterface from letta.local_llm.constants import ( ASSISTANT_MESSAGE_CLI_SYMBOL, @@ -19,6 +14,11 @@ from letta.schemas.openai.chat_completion_response import ( ChatCompletionResponse, ) +# from colorama import Fore, Style, init +from rich.console import Console +from rich.live import Live +from rich.markup import escape + # init(autoreset=True) # DEBUG = True # puts full message outputs in the terminal diff --git a/letta/utils.py b/letta/utils.py index 4be8a543..f8015964 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -19,11 +19,9 @@ from typing import List, Union, _GenericAlias, get_args, get_origin, get_type_hi from urllib.parse import urljoin, urlparse import demjson3 as demjson +import letta import pytz import tiktoken -from pathvalidate import sanitize_filename as pathvalidate_sanitize_filename - -import letta from letta.constants import ( CLI_WARNING_PREFIX, CORE_MEMORY_HUMAN_CHAR_LIMIT, @@ -34,6 +32,7 @@ from letta.constants import ( TOOL_CALL_ID_MAX_LEN, ) from letta.schemas.openai.chat_completion_response import ChatCompletionResponse +from pathvalidate import sanitize_filename as pathvalidate_sanitize_filename DEBUG = False if "LOG_LEVEL" in os.environ: @@ -1120,6 +1119,7 @@ def sanitize_filename(filename: str) -> str: # Return the sanitized filename return sanitized_filename + def get_friendly_error_msg(function_name: str, exception_name: str, exception_message: str): from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT diff --git a/locust_test.py b/locust_test.py index 570e6eef..bf32fb48 100644 --- a/locust_test.py +++ b/locust_test.py @@ -1,8 +1,6 @@ import random import string -from locust import HttpUser, between, task - from letta.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA from letta.schemas.agent import AgentState, CreateAgent from letta.schemas.letta_request import LettaRequest @@ -10,6 +8,7 @@ from letta.schemas.letta_response import LettaResponse from letta.schemas.memory import ChatMemory from letta.schemas.message import MessageCreate, MessageRole from letta.utils import get_human_text, get_persona_text +from locust import HttpUser, between, task class LettaUser(HttpUser): diff --git a/paper_experiments/doc_qa_task/doc_qa.py b/paper_experiments/doc_qa_task/doc_qa.py index e07060d1..205fa0d0 100644 --- a/paper_experiments/doc_qa_task/doc_qa.py +++ b/paper_experiments/doc_qa_task/doc_qa.py @@ -23,9 +23,6 @@ import uuid from typing import List from icml_experiments.utils import get_experiment_config, load_gzipped_file -from openai import OpenAI -from tqdm import tqdm - from letta import utils from letta.agent_store.storage import StorageConnector, TableType from letta.cli.cli_config import delete @@ -33,6 +30,8 @@ from letta.config import LettaConfig from letta.credentials import LettaCredentials from letta.embeddings import embedding_model from letta.utils import count_tokens +from openai import OpenAI +from tqdm import tqdm DATA_SOURCE_NAME = "wikipedia" DOC_QA_PERSONA = "You are Letta DOC-QA bot. Your job is to answer questions about documents that are stored in your archival memory. The answer to the users question will ALWAYS be in your archival memory, so remember to keep searching if you can't find the answer. Answer the questions as if though the year is 2018." # TODO decide on a good persona/human diff --git a/paper_experiments/doc_qa_task/llm_judge_doc_qa.py b/paper_experiments/doc_qa_task/llm_judge_doc_qa.py index c6ff6cfe..18d907a3 100644 --- a/paper_experiments/doc_qa_task/llm_judge_doc_qa.py +++ b/paper_experiments/doc_qa_task/llm_judge_doc_qa.py @@ -2,11 +2,10 @@ import argparse import json import re +from letta.credentials import LettaCredentials from openai import OpenAI from tqdm import tqdm -from letta.credentials import LettaCredentials - # Note: did not end up using since no cases of cheating were observed # CHEATING_PROMPT = \ # """ diff --git a/paper_experiments/doc_qa_task/load_wikipedia_embeddings.py b/paper_experiments/doc_qa_task/load_wikipedia_embeddings.py index 94b98143..e80e7e9b 100644 --- a/paper_experiments/doc_qa_task/load_wikipedia_embeddings.py +++ b/paper_experiments/doc_qa_task/load_wikipedia_embeddings.py @@ -8,11 +8,10 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from absl import app, flags from icml_experiments.utils import get_experiment_config -from tqdm import tqdm - from letta.agent_store.storage import StorageConnector, TableType from letta.cli.cli_config import delete from letta.data_types import Passage +from tqdm import tqdm # Create an empty list to store the JSON objects source_name = "wikipedia" diff --git a/paper_experiments/nested_kv_task/nested_kv.py b/paper_experiments/nested_kv_task/nested_kv.py index 04c95ac5..64fa06a2 100644 --- a/paper_experiments/nested_kv_task/nested_kv.py +++ b/paper_experiments/nested_kv_task/nested_kv.py @@ -29,11 +29,10 @@ from typing import Optional import openai from icml_experiments.utils import get_experiment_config, load_gzipped_file -from tqdm import tqdm - from letta import utils from letta.cli.cli_config import delete from letta.config import LettaConfig +from tqdm import tqdm # TODO: update personas NESTED_PERSONA = "You are Letta DOC-QA bot. Your job is to answer questions about documents that are stored in your archival memory. The answer to the users question will ALWAYS be in your archival memory, so remember to keep searching if you can't find the answer. DO NOT STOP SEARCHING UNTIL YOU VERIFY THAT THE VALUE IS NOT A KEY. Do not stop making nested lookups until this condition is met." # TODO decide on a good persona/human diff --git a/project.json b/project.json index 9fbcd3a9..18b70617 100644 --- a/project.json +++ b/project.json @@ -47,6 +47,13 @@ "cwd": "apps/core" } }, + "lint": { + "executor": "@nxlv/python:run-commands", + "options": { + "command": "poetry run isort --profile black . && poetry run black . && poetry run autoflake --remove-all-unused-imports --remove-unused-variables --in-place --recursive --ignore-init-module-imports .", + "cwd": "apps/core" + } + }, "database:migrate": { "executor": "@nxlv/python:run-commands", "options": { diff --git a/scripts/migrate_tools.py b/scripts/migrate_tools.py index 53578c69..91e16071 100644 --- a/scripts/migrate_tools.py +++ b/scripts/migrate_tools.py @@ -1,8 +1,7 @@ -from tqdm import tqdm - from letta.schemas.user import User from letta.services.organization_manager import OrganizationManager from letta.services.tool_manager import ToolManager +from tqdm import tqdm def deprecated_tool(): diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index 44aad0d0..289b023a 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -194,16 +194,13 @@ def test_check_tool_rules_with_different_models(mock_e2b_api_key_none): "tests/configs/llm_model_configs/openai-gpt-3.5-turbo.json", "tests/configs/llm_model_configs/openai-gpt-4o.json", ] - + # Create two test tools t1_name = "first_secret_word" t2_name = "second_secret_word" t1 = client.create_or_update_tool(first_secret_word, name=t1_name) t2 = client.create_or_update_tool(second_secret_word, name=t2_name) - tool_rules = [ - InitToolRule(tool_name=t1_name), - InitToolRule(tool_name=t2_name) - ] + tool_rules = [InitToolRule(tool_name=t1_name), InitToolRule(tool_name=t2_name)] tools = [t1, t2] for config_file in config_files: @@ -212,34 +209,26 @@ def test_check_tool_rules_with_different_models(mock_e2b_api_key_none): if "gpt-4o" in config_file: # Structured output model (should work with multiple init tools) - agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, - tool_ids=[t.id for t in tools], - tool_rules=tool_rules) + agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) assert agent_state is not None else: # Non-structured output model (should raise error with multiple init tools) with pytest.raises(ValueError, match="Multiple initial tools are not supported for non-structured models"): - setup_agent(client, config_file, agent_uuid=agent_uuid, - tool_ids=[t.id for t in tools], - tool_rules=tool_rules) - + setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) + # Cleanup cleanup(client=client, agent_uuid=agent_uuid) # Create tool rule with single initial tool t3_name = "third_secret_word" t3 = client.create_or_update_tool(third_secret_word, name=t3_name) - tool_rules = [ - InitToolRule(tool_name=t3_name) - ] + tool_rules = [InitToolRule(tool_name=t3_name)] tools = [t3] for config_file in config_files: agent_uuid = str(uuid.uuid4()) # Structured output model (should work with single init tool) - agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, - tool_ids=[t.id for t in tools], - tool_rules=tool_rules) + agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) assert agent_state is not None cleanup(client=client, agent_uuid=agent_uuid) @@ -257,7 +246,7 @@ def test_claude_initial_tool_rule_enforced(mock_e2b_api_key_none): tool_rules = [ InitToolRule(tool_name=t1_name), ChildToolRule(tool_name=t1_name, children=[t2_name]), - TerminalToolRule(tool_name=t2_name) + TerminalToolRule(tool_name=t2_name), ] tools = [t1, t2] @@ -265,7 +254,9 @@ def test_claude_initial_tool_rule_enforced(mock_e2b_api_key_none): anthropic_config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json" for i in range(3): agent_uuid = str(uuid.uuid4()) - agent_state = setup_agent(client, anthropic_config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) + agent_state = setup_agent( + client, anthropic_config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules + ) response = client.user_message(agent_id=agent_state.id, message="What is the second secret word?") assert_sanity_checks(response) @@ -289,9 +280,10 @@ def test_claude_initial_tool_rule_enforced(mock_e2b_api_key_none): # Implement exponential backoff with initial time of 10 seconds if i < 2: - backoff_time = 10 * (2 ** i) + backoff_time = 10 * (2**i) time.sleep(backoff_time) + @pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely def test_agent_no_structured_output_with_one_child_tool(mock_e2b_api_key_none): client = create_client() @@ -389,7 +381,7 @@ def test_agent_conditional_tool_easy(mock_e2b_api_key_none): default_child=coin_flip_name, child_output_mapping={ "hj2hwibbqm": secret_word_tool, - } + }, ), TerminalToolRule(tool_name=secret_word_tool), ] @@ -425,7 +417,6 @@ def test_agent_conditional_tool_easy(mock_e2b_api_key_none): cleanup(client=client, agent_uuid=agent_uuid) - @pytest.mark.timeout(90) # Longer timeout since this test has more steps def test_agent_conditional_tool_hard(mock_e2b_api_key_none): """ @@ -450,7 +441,7 @@ def test_agent_conditional_tool_hard(mock_e2b_api_key_none): final_tool = "fourth_secret_word" play_game_tool = client.create_or_update_tool(can_play_game, name=play_game) flip_coin_tool = client.create_or_update_tool(flip_coin_hard, name=coin_flip_name) - reveal_secret = client.create_or_update_tool(fourth_secret_word, name=final_tool) + reveal_secret = client.create_or_update_tool(fourth_secret_word, name=final_tool) # Make tool rules - chain them together with conditional rules tool_rules = [ @@ -458,16 +449,10 @@ def test_agent_conditional_tool_hard(mock_e2b_api_key_none): ConditionalToolRule( tool_name=play_game, default_child=play_game, # Keep trying if we can't play - child_output_mapping={ - True: coin_flip_name # Only allow access when can_play_game returns True - } + child_output_mapping={True: coin_flip_name}, # Only allow access when can_play_game returns True ), ConditionalToolRule( - tool_name=coin_flip_name, - default_child=coin_flip_name, - child_output_mapping={ - "hj2hwibbqm": final_tool, "START_OVER": play_game - } + tool_name=coin_flip_name, default_child=coin_flip_name, child_output_mapping={"hj2hwibbqm": final_tool, "START_OVER": play_game} ), TerminalToolRule(tool_name=final_tool), ] @@ -475,13 +460,7 @@ def test_agent_conditional_tool_hard(mock_e2b_api_key_none): # Setup agent with all tools tools = [play_game_tool, flip_coin_tool, reveal_secret] config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json" - agent_state = setup_agent( - client, - config_file, - agent_uuid=agent_uuid, - tool_ids=[t.id for t in tools], - tool_rules=tool_rules - ) + agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) # Ask agent to try to get all secret words response = client.user_message(agent_id=agent_state.id, message="hi") @@ -520,7 +499,7 @@ def test_agent_conditional_tool_without_default_child(mock_e2b_api_key_none): Test the agent with a conditional tool that allows any child tool to be called if a function returns None. Tool Flow: - + return_none | v @@ -541,27 +520,16 @@ def test_agent_conditional_tool_without_default_child(mock_e2b_api_key_none): ConditionalToolRule( tool_name=tool_name, default_child=None, # Allow any tool to be called if output doesn't match - child_output_mapping={ - "anything but none": "first_secret_word" - } - ) + child_output_mapping={"anything but none": "first_secret_word"}, + ), ] tools = [tool, secret_word] # Setup agent with all tools - agent_state = setup_agent( - client, - config_file, - agent_uuid=agent_uuid, - tool_ids=[t.id for t in tools], - tool_rules=tool_rules - ) + agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) # Ask agent to try different tools based on the game output - response = client.user_message( - agent_id=agent_state.id, - message="call a function, any function. then call send_message" - ) + response = client.user_message(agent_id=agent_state.id, message="call a function, any function. then call send_message") # Make checks assert_sanity_checks(response) @@ -613,18 +581,14 @@ def test_agent_reload_remembers_function_response(mock_e2b_api_key_none): ConditionalToolRule( tool_name=flip_coin_name, default_child=flip_coin_name, # Allow any tool to be called if output doesn't match - child_output_mapping={ - "hj2hwibbqm": secret_word - } + child_output_mapping={"hj2hwibbqm": secret_word}, ), - TerminalToolRule(tool_name=secret_word) + TerminalToolRule(tool_name=secret_word), ] tools = [flip_coin_tool, secret_word_tool] # Setup initial agent - agent_state = setup_agent( - client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules - ) + agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) # Call flip_coin first response = client.user_message(agent_id=agent_state.id, message="flip a coin") @@ -643,4 +607,4 @@ def test_agent_reload_remembers_function_response(mock_e2b_api_key_none): assert reloaded_agent.last_function_response is not None print(f"Got successful response from client: \n\n{response}") - cleanup(client=client, agent_uuid=agent_uuid) \ No newline at end of file + cleanup(client=client, agent_uuid=agent_uuid) diff --git a/tests/integration_test_offline_memory_agent.py b/tests/integration_test_offline_memory_agent.py index 15d4161d..5803e820 100644 --- a/tests/integration_test_offline_memory_agent.py +++ b/tests/integration_test_offline_memory_agent.py @@ -1,5 +1,4 @@ import pytest - from letta import BasicBlockMemory from letta.client.client import Block, create_client from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA diff --git a/tests/integration_test_summarizer.py b/tests/integration_test_summarizer.py index b4de0043..ce6c41b2 100644 --- a/tests/integration_test_summarizer.py +++ b/tests/integration_test_summarizer.py @@ -4,7 +4,6 @@ import uuid from typing import List import pytest - from letta import create_client from letta.agent import Agent from letta.client.client import LocalClient diff --git a/tests/integration_test_tool_execution_sandbox.py b/tests/integration_test_tool_execution_sandbox.py index 299e1e96..0938fdc8 100644 --- a/tests/integration_test_tool_execution_sandbox.py +++ b/tests/integration_test_tool_execution_sandbox.py @@ -5,8 +5,6 @@ from pathlib import Path from unittest.mock import patch import pytest -from sqlalchemy import delete - from letta import create_client from letta.functions.function_sets.base import core_memory_append, core_memory_replace from letta.orm import SandboxConfig, SandboxEnvironmentVariable @@ -31,6 +29,7 @@ from letta.services.tool_execution_sandbox import ToolExecutionSandbox from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.settings import tool_settings +from sqlalchemy import delete from tests.helpers.utils import create_tool_from_func # Constants diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index 5b5bec6f..007d2bf9 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -1,6 +1,5 @@ -import pytest - import letta.functions.function_sets.base as base_functions +import pytest from letta import LocalClient, create_client from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig diff --git a/tests/test_cli.py b/tests/test_cli.py index 7b2ffae1..54faa804 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -4,7 +4,6 @@ import sys import pexpect import pytest - from letta.local_llm.constants import ( ASSISTANT_MESSAGE_CLI_SYMBOL, INNER_THOUGHTS_CLI_SYMBOL, diff --git a/tests/test_client.py b/tests/test_client.py index ac0f4f18..71f7216c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,8 +8,6 @@ from typing import List, Union import pytest from dotenv import load_dotenv -from sqlalchemy import delete - from letta import LocalClient, RESTClient, create_client from letta.orm import SandboxConfig, SandboxEnvironmentVariable from letta.schemas.agent import AgentState @@ -20,6 +18,7 @@ from letta.schemas.letta_message import ToolReturnMessage from letta.schemas.llm_config import LLMConfig from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxType from letta.utils import create_random_username +from sqlalchemy import delete # Constants SERVER_PORT = 8283 @@ -341,7 +340,9 @@ def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState): def test_send_system_message(client: Union[LocalClient, RESTClient], agent: AgentState): """Important unit test since the Letta API exposes sending system messages, but some backends don't natively support it (eg Anthropic)""" - send_system_message_response = client.send_message(agent_id=agent.id, message="Event occurred: The user just logged off.", role="system") + send_system_message_response = client.send_message( + agent_id=agent.id, message="Event occurred: The user just logged off.", role="system" + ) assert send_system_message_response, "Sending message failed" @@ -390,7 +391,7 @@ def test_function_always_error(client: Union[LocalClient, RESTClient]): """ Always throw an error. """ - return 5/0 + return 5 / 0 tool = client.create_or_update_tool(func=always_error) agent = client.create_agent(tool_ids=[tool.id]) @@ -410,8 +411,8 @@ def test_function_always_error(client: Union[LocalClient, RESTClient]): assert response_message.tool_return == "Error executing function always_error: ZeroDivisionError: division by zero" else: response_json = json.loads(response_message.tool_return) - assert response_json['status'] == "Failed" - assert response_json['message'] == "Error executing function always_error: ZeroDivisionError: division by zero" + assert response_json["status"] == "Failed" + assert response_json["message"] == "Error executing function always_error: ZeroDivisionError: division by zero" client.delete_agent(agent_id=agent.id) diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 3d907fa3..a635a28e 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -7,8 +7,6 @@ from typing import List, Union import pytest from dotenv import load_dotenv -from sqlalchemy import delete - from letta import create_client from letta.client.client import LocalClient, RESTClient from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_PRESET @@ -34,6 +32,7 @@ from letta.services.organization_manager import OrganizationManager from letta.services.user_manager import UserManager from letta.settings import model_settings from letta.utils import get_utc_time +from sqlalchemy import delete from tests.helpers.client_helper import upload_file_using_client # from tests.utils import create_config diff --git a/tests/test_local_client.py b/tests/test_local_client.py index da5e533c..8b0c7617 100644 --- a/tests/test_local_client.py +++ b/tests/test_local_client.py @@ -1,7 +1,6 @@ import uuid import pytest - from letta import create_client from letta.client.client import LocalClient from letta.schemas.agent import AgentState diff --git a/tests/test_managers.py b/tests/test_managers.py index 388d477c..5dac4c89 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -3,9 +3,6 @@ import time from datetime import datetime, timedelta import pytest -from sqlalchemy import delete -from sqlalchemy.exc import IntegrityError - from letta.config import LettaConfig from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS from letta.embeddings import embedding_model @@ -63,6 +60,8 @@ from letta.server.server import SyncServer from letta.services.block_manager import BlockManager from letta.services.organization_manager import OrganizationManager from letta.settings import tool_settings +from sqlalchemy import delete +from sqlalchemy.exc import IntegrityError from tests.helpers.utils import comprehensive_agent_checks DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig( diff --git a/tests/test_server.py b/tests/test_server.py index 4775ed91..631ec1f3 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -3,9 +3,8 @@ import uuid import warnings from typing import List, Tuple -import pytest - import letta.utils as utils +import pytest from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS from letta.schemas.block import CreateBlock from letta.schemas.enums import MessageRole diff --git a/tests/test_stream_buffer_readers.py b/tests/test_stream_buffer_readers.py index 9a0bb5e8..92335cc7 100644 --- a/tests/test_stream_buffer_readers.py +++ b/tests/test_stream_buffer_readers.py @@ -1,7 +1,6 @@ import json import pytest - from letta.streaming_utils import JSONInnerThoughtsExtractor diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py index c524d53a..3504ed98 100644 --- a/tests/test_tool_rule_solver.py +++ b/tests/test_tool_rule_solver.py @@ -1,12 +1,11 @@ import pytest - from letta.helpers import ToolRulesSolver from letta.helpers.tool_rule_solver import ToolRuleValidationError from letta.schemas.tool_rule import ( ChildToolRule, ConditionalToolRule, InitToolRule, - TerminalToolRule + TerminalToolRule, ) # Constants for tool names used in the tests @@ -113,11 +112,7 @@ def test_conditional_tool_rule(): # Setup: Define a conditional tool rule init_rule = InitToolRule(tool_name=START_TOOL) terminal_rule = TerminalToolRule(tool_name=END_TOOL) - rule = ConditionalToolRule( - tool_name=START_TOOL, - default_child=None, - child_output_mapping={True: END_TOOL, False: START_TOOL} - ) + rule = ConditionalToolRule(tool_name=START_TOOL, default_child=None, child_output_mapping={True: END_TOOL, False: START_TOOL}) solver = ToolRulesSolver(tool_rules=[init_rule, rule, terminal_rule]) # Action & Assert: Verify the rule properties @@ -126,8 +121,12 @@ def test_conditional_tool_rule(): # Step 2: After using 'start_tool' solver.update_tool_usage(START_TOOL) - assert solver.get_allowed_tool_names(last_function_response='{"message": "true"}') == [END_TOOL], "After 'start_tool' returns true, should allow 'end_tool'" - assert solver.get_allowed_tool_names(last_function_response='{"message": "false"}') == [START_TOOL], "After 'start_tool' returns false, should allow 'start_tool'" + assert solver.get_allowed_tool_names(last_function_response='{"message": "true"}') == [ + END_TOOL + ], "After 'start_tool' returns true, should allow 'end_tool'" + assert solver.get_allowed_tool_names(last_function_response='{"message": "false"}') == [ + START_TOOL + ], "After 'start_tool' returns false, should allow 'start_tool'" # Step 3: After using 'end_tool' assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as terminal" @@ -137,11 +136,7 @@ def test_invalid_conditional_tool_rule(): # Setup: Define an invalid conditional tool rule init_rule = InitToolRule(tool_name=START_TOOL) terminal_rule = TerminalToolRule(tool_name=END_TOOL) - invalid_rule_1 = ConditionalToolRule( - tool_name=START_TOOL, - default_child=END_TOOL, - child_output_mapping={} - ) + invalid_rule_1 = ConditionalToolRule(tool_name=START_TOOL, default_child=END_TOOL, child_output_mapping={}) # Test 1: Missing child output mapping with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have at least one child tool."): diff --git a/tests/test_tool_schema_parsing.py b/tests/test_tool_schema_parsing.py index f6738a06..6eed18f1 100644 --- a/tests/test_tool_schema_parsing.py +++ b/tests/test_tool_schema_parsing.py @@ -2,7 +2,6 @@ import json import os import pytest - from letta.functions.functions import derive_openai_json_schema from letta.llm_api.helpers import convert_to_structured_output, make_post_request diff --git a/tests/test_utils.py b/tests/test_utils.py index 904e903e..f8fd42dd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,4 @@ import pytest - from letta.constants import MAX_FILENAME_LENGTH from letta.utils import sanitize_filename diff --git a/tests/test_v1_routes.py b/tests/test_v1_routes.py index 2865bb2e..266d8dcb 100644 --- a/tests/test_v1_routes.py +++ b/tests/test_v1_routes.py @@ -8,7 +8,6 @@ from composio.client.collections import ( AppModel, ) from fastapi.testclient import TestClient - from letta.schemas.tool import ToolCreate, ToolUpdate from letta.server.rest_api.app import app from letta.server.rest_api.utils import get_letta_server diff --git a/tests/test_vector_embeddings.py b/tests/test_vector_embeddings.py index e65e6b9b..8640f628 100644 --- a/tests/test_vector_embeddings.py +++ b/tests/test_vector_embeddings.py @@ -1,5 +1,4 @@ import numpy as np - from letta.orm.sqlalchemy_base import adapt_array from letta.orm.sqlite_functions import convert_array, verify_embedding_dimension diff --git a/tests/utils.py b/tests/utils.py index 19a05a09..3ce1cb62 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,7 +5,6 @@ from importlib import util from typing import Dict, Iterator, List, Tuple import requests - from letta.config import LettaConfig from letta.data_sources.connectors import DataConnector from letta.schemas.file import FileMetadata