Merge branch 'main' into matt/let-649-fix-updating-agent-refresh-blocks
This commit is contained in:
@@ -1,12 +1,11 @@
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
from alembic import context
|
||||
from letta.config import LettaConfig
|
||||
from letta.orm import Base
|
||||
from letta.settings import settings
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
letta_config = LettaConfig.load()
|
||||
|
||||
|
||||
@@ -5,40 +5,44 @@ Revises: 3c683a662c82
|
||||
Create Date: 2024-12-05 16:46:51.258831
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '08b2f8225812'
|
||||
down_revision: Union[str, None] = '3c683a662c82'
|
||||
revision: str = "08b2f8225812"
|
||||
down_revision: Union[str, None] = "3c683a662c82"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tools_agents',
|
||||
sa.Column('agent_id', sa.String(), nullable=False),
|
||||
sa.Column('tool_id', sa.String(), nullable=False),
|
||||
sa.Column('tool_name', sa.String(), nullable=False),
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
||||
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False),
|
||||
sa.Column('_created_by_id', sa.String(), nullable=True),
|
||||
sa.Column('_last_updated_by_id', sa.String(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ),
|
||||
sa.ForeignKeyConstraint(['tool_id'], ['tools.id'], name='fk_tool_id'),
|
||||
sa.PrimaryKeyConstraint('agent_id', 'tool_id', 'tool_name', 'id'),
|
||||
sa.UniqueConstraint('agent_id', 'tool_name', name='unique_tool_per_agent')
|
||||
op.create_table(
|
||||
"tools_agents",
|
||||
sa.Column("agent_id", sa.String(), nullable=False),
|
||||
sa.Column("tool_id", sa.String(), nullable=False),
|
||||
sa.Column("tool_name", sa.String(), nullable=False),
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["agent_id"],
|
||||
["agents.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(["tool_id"], ["tools.id"], name="fk_tool_id"),
|
||||
sa.PrimaryKeyConstraint("agent_id", "tool_id", "tool_name", "id"),
|
||||
sa.UniqueConstraint("agent_id", "tool_name", name="unique_tool_per_agent"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('tools_agents')
|
||||
op.drop_table("tools_agents")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -9,7 +9,6 @@ Create Date: 2024-11-22 15:42:47.209229
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@@ -9,9 +9,8 @@ Create Date: 2024-12-04 15:59:41.708396
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "3c683a662c82"
|
||||
|
||||
@@ -9,7 +9,6 @@ Create Date: 2024-12-13 17:19:55.796210
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@@ -5,18 +5,18 @@ Revises: 4e88e702f85e
|
||||
Create Date: 2024-12-14 17:23:08.772554
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
from pgvector.sqlalchemy import Vector
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from letta.orm.custom_columns import EmbeddingConfigColumn
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from letta.orm.custom_columns import EmbeddingConfigColumn
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '54dec07619c4'
|
||||
down_revision: Union[str, None] = '4e88e702f85e'
|
||||
revision: str = "54dec07619c4"
|
||||
down_revision: Union[str, None] = "4e88e702f85e"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
@@ -24,82 +24,88 @@ depends_on: Union[str, Sequence[str], None] = None
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
'agent_passages',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('text', sa.String(), nullable=False),
|
||||
sa.Column('embedding_config', EmbeddingConfigColumn(), nullable=False),
|
||||
sa.Column('metadata_', sa.JSON(), nullable=False),
|
||||
sa.Column('embedding', Vector(dim=4096), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
||||
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False),
|
||||
sa.Column('_created_by_id', sa.String(), nullable=True),
|
||||
sa.Column('_last_updated_by_id', sa.String(), nullable=True),
|
||||
sa.Column('organization_id', sa.String(), nullable=False),
|
||||
sa.Column('agent_id', sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
"agent_passages",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("text", sa.String(), nullable=False),
|
||||
sa.Column("embedding_config", EmbeddingConfigColumn(), nullable=False),
|
||||
sa.Column("metadata_", sa.JSON(), nullable=False),
|
||||
sa.Column("embedding", Vector(dim=4096), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.Column("agent_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index('agent_passages_org_idx', 'agent_passages', ['organization_id'], unique=False)
|
||||
op.create_index("agent_passages_org_idx", "agent_passages", ["organization_id"], unique=False)
|
||||
op.create_table(
|
||||
'source_passages',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('text', sa.String(), nullable=False),
|
||||
sa.Column('embedding_config', EmbeddingConfigColumn(), nullable=False),
|
||||
sa.Column('metadata_', sa.JSON(), nullable=False),
|
||||
sa.Column('embedding', Vector(dim=4096), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
||||
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False),
|
||||
sa.Column('_created_by_id', sa.String(), nullable=True),
|
||||
sa.Column('_last_updated_by_id', sa.String(), nullable=True),
|
||||
sa.Column('organization_id', sa.String(), nullable=False),
|
||||
sa.Column('file_id', sa.String(), nullable=True),
|
||||
sa.Column('source_id', sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['file_id'], ['files.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.ForeignKeyConstraint(['source_id'], ['sources.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
"source_passages",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("text", sa.String(), nullable=False),
|
||||
sa.Column("embedding_config", EmbeddingConfigColumn(), nullable=False),
|
||||
sa.Column("metadata_", sa.JSON(), nullable=False),
|
||||
sa.Column("embedding", Vector(dim=4096), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.Column("file_id", sa.String(), nullable=True),
|
||||
sa.Column("source_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["file_id"], ["files.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(["source_id"], ["sources.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index('source_passages_org_idx', 'source_passages', ['organization_id'], unique=False)
|
||||
op.drop_table('passages')
|
||||
op.drop_constraint('files_source_id_fkey', 'files', type_='foreignkey')
|
||||
op.create_foreign_key(None, 'files', 'sources', ['source_id'], ['id'], ondelete='CASCADE')
|
||||
op.drop_constraint('messages_agent_id_fkey', 'messages', type_='foreignkey')
|
||||
op.create_foreign_key(None, 'messages', 'agents', ['agent_id'], ['id'], ondelete='CASCADE')
|
||||
op.create_index("source_passages_org_idx", "source_passages", ["organization_id"], unique=False)
|
||||
op.drop_table("passages")
|
||||
op.drop_constraint("files_source_id_fkey", "files", type_="foreignkey")
|
||||
op.create_foreign_key(None, "files", "sources", ["source_id"], ["id"], ondelete="CASCADE")
|
||||
op.drop_constraint("messages_agent_id_fkey", "messages", type_="foreignkey")
|
||||
op.create_foreign_key(None, "messages", "agents", ["agent_id"], ["id"], ondelete="CASCADE")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint(None, 'messages', type_='foreignkey')
|
||||
op.create_foreign_key('messages_agent_id_fkey', 'messages', 'agents', ['agent_id'], ['id'])
|
||||
op.drop_constraint(None, 'files', type_='foreignkey')
|
||||
op.create_foreign_key('files_source_id_fkey', 'files', 'sources', ['source_id'], ['id'])
|
||||
op.drop_constraint(None, "messages", type_="foreignkey")
|
||||
op.create_foreign_key("messages_agent_id_fkey", "messages", "agents", ["agent_id"], ["id"])
|
||||
op.drop_constraint(None, "files", type_="foreignkey")
|
||||
op.create_foreign_key("files_source_id_fkey", "files", "sources", ["source_id"], ["id"])
|
||||
op.create_table(
|
||||
'passages',
|
||||
sa.Column('id', sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.Column('text', sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.Column('file_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column('agent_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column('source_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column('embedding', Vector(dim=4096), autoincrement=False, nullable=True),
|
||||
sa.Column('embedding_config', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
||||
sa.Column('metadata_', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
||||
sa.Column('created_at', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=False),
|
||||
sa.Column('updated_at', postgresql.TIMESTAMP(timezone=True), server_default=sa.text('now()'), autoincrement=False, nullable=True),
|
||||
sa.Column('is_deleted', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False),
|
||||
sa.Column('_created_by_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column('_last_updated_by_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column('organization_id', sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], name='passages_agent_id_fkey'),
|
||||
sa.ForeignKeyConstraint(['file_id'], ['files.id'], name='passages_file_id_fkey', ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], name='passages_organization_id_fkey'),
|
||||
sa.PrimaryKeyConstraint('id', name='passages_pkey')
|
||||
"passages",
|
||||
sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.Column("text", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.Column("file_id", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column("agent_id", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column("source_id", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column("embedding", Vector(dim=4096), autoincrement=False, nullable=True),
|
||||
sa.Column("embedding_config", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
||||
sa.Column("metadata_", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
||||
sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=False),
|
||||
sa.Column("updated_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True),
|
||||
sa.Column("is_deleted", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False),
|
||||
sa.Column("_created_by_id", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column("organization_id", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], name="passages_agent_id_fkey"),
|
||||
sa.ForeignKeyConstraint(["file_id"], ["files.id"], name="passages_file_id_fkey", ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["organization_id"], ["organizations.id"], name="passages_organization_id_fkey"),
|
||||
sa.PrimaryKeyConstraint("id", name="passages_pkey"),
|
||||
)
|
||||
op.drop_index('source_passages_org_idx', table_name='source_passages')
|
||||
op.drop_table('source_passages')
|
||||
op.drop_index('agent_passages_org_idx', table_name='agent_passages')
|
||||
op.drop_table('agent_passages')
|
||||
op.drop_index("source_passages_org_idx", table_name="source_passages")
|
||||
op.drop_table("source_passages")
|
||||
op.drop_index("agent_passages_org_idx", table_name="agent_passages")
|
||||
op.drop_table("agent_passages")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -9,9 +9,8 @@ Create Date: 2024-11-25 14:35:00.896507
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "5987401b40ae"
|
||||
|
||||
@@ -9,9 +9,8 @@ Create Date: 2024-12-05 14:02:04.163150
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "95badb46fdf9"
|
||||
|
||||
@@ -8,12 +8,11 @@ Create Date: 2024-10-11 14:19:19.875656
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import letta.orm
|
||||
import pgvector
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import letta.orm
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "9a505cc7eca9"
|
||||
|
||||
@@ -9,7 +9,6 @@ Create Date: 2024-12-09 18:27:25.650079
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
from letta.constants import FUNCTION_RETURN_CHAR_LIMIT
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ Create Date: 2024-11-06 10:48:08.424108
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@@ -5,25 +5,26 @@ Revises: a91994b9752f
|
||||
Create Date: 2024-12-10 15:05:32.335519
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'c5d964280dff'
|
||||
down_revision: Union[str, None] = 'a91994b9752f'
|
||||
revision: str = "c5d964280dff"
|
||||
down_revision: Union[str, None] = "a91994b9752f"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('passages', sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True))
|
||||
op.add_column('passages', sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False))
|
||||
op.add_column('passages', sa.Column('_created_by_id', sa.String(), nullable=True))
|
||||
op.add_column('passages', sa.Column('_last_updated_by_id', sa.String(), nullable=True))
|
||||
op.add_column("passages", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
|
||||
op.add_column("passages", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
|
||||
op.add_column("passages", sa.Column("_created_by_id", sa.String(), nullable=True))
|
||||
op.add_column("passages", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
|
||||
|
||||
# Data migration step:
|
||||
op.add_column("passages", sa.Column("organization_id", sa.String(), nullable=True))
|
||||
@@ -41,48 +42,32 @@ def upgrade() -> None:
|
||||
# Set `organization_id` as non-nullable after population
|
||||
op.alter_column("passages", "organization_id", nullable=False)
|
||||
|
||||
op.alter_column('passages', 'text',
|
||||
existing_type=sa.VARCHAR(),
|
||||
nullable=False)
|
||||
op.alter_column('passages', 'embedding_config',
|
||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||
nullable=False)
|
||||
op.alter_column('passages', 'metadata_',
|
||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||
nullable=False)
|
||||
op.alter_column('passages', 'created_at',
|
||||
existing_type=postgresql.TIMESTAMP(timezone=True),
|
||||
nullable=False)
|
||||
op.drop_index('passage_idx_user', table_name='passages')
|
||||
op.create_foreign_key(None, 'passages', 'organizations', ['organization_id'], ['id'])
|
||||
op.create_foreign_key(None, 'passages', 'agents', ['agent_id'], ['id'])
|
||||
op.create_foreign_key(None, 'passages', 'files', ['file_id'], ['id'], ondelete='CASCADE')
|
||||
op.drop_column('passages', 'user_id')
|
||||
op.alter_column("passages", "text", existing_type=sa.VARCHAR(), nullable=False)
|
||||
op.alter_column("passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
|
||||
op.alter_column("passages", "metadata_", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
|
||||
op.alter_column("passages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False)
|
||||
op.drop_index("passage_idx_user", table_name="passages")
|
||||
op.create_foreign_key(None, "passages", "organizations", ["organization_id"], ["id"])
|
||||
op.create_foreign_key(None, "passages", "agents", ["agent_id"], ["id"])
|
||||
op.create_foreign_key(None, "passages", "files", ["file_id"], ["id"], ondelete="CASCADE")
|
||||
op.drop_column("passages", "user_id")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('passages', sa.Column('user_id', sa.VARCHAR(), autoincrement=False, nullable=False))
|
||||
op.drop_constraint(None, 'passages', type_='foreignkey')
|
||||
op.drop_constraint(None, 'passages', type_='foreignkey')
|
||||
op.drop_constraint(None, 'passages', type_='foreignkey')
|
||||
op.create_index('passage_idx_user', 'passages', ['user_id', 'agent_id', 'file_id'], unique=False)
|
||||
op.alter_column('passages', 'created_at',
|
||||
existing_type=postgresql.TIMESTAMP(timezone=True),
|
||||
nullable=True)
|
||||
op.alter_column('passages', 'metadata_',
|
||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||
nullable=True)
|
||||
op.alter_column('passages', 'embedding_config',
|
||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||
nullable=True)
|
||||
op.alter_column('passages', 'text',
|
||||
existing_type=sa.VARCHAR(),
|
||||
nullable=True)
|
||||
op.drop_column('passages', 'organization_id')
|
||||
op.drop_column('passages', '_last_updated_by_id')
|
||||
op.drop_column('passages', '_created_by_id')
|
||||
op.drop_column('passages', 'is_deleted')
|
||||
op.drop_column('passages', 'updated_at')
|
||||
op.add_column("passages", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False))
|
||||
op.drop_constraint(None, "passages", type_="foreignkey")
|
||||
op.drop_constraint(None, "passages", type_="foreignkey")
|
||||
op.drop_constraint(None, "passages", type_="foreignkey")
|
||||
op.create_index("passage_idx_user", "passages", ["user_id", "agent_id", "file_id"], unique=False)
|
||||
op.alter_column("passages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=True)
|
||||
op.alter_column("passages", "metadata_", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
|
||||
op.alter_column("passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
|
||||
op.alter_column("passages", "text", existing_type=sa.VARCHAR(), nullable=True)
|
||||
op.drop_column("passages", "organization_id")
|
||||
op.drop_column("passages", "_last_updated_by_id")
|
||||
op.drop_column("passages", "_created_by_id")
|
||||
op.drop_column("passages", "is_deleted")
|
||||
op.drop_column("passages", "updated_at")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -9,7 +9,6 @@ Create Date: 2024-11-12 13:58:57.221081
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@@ -9,9 +9,8 @@ Create Date: 2024-11-07 13:29:57.186107
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "cda66b6cb0d6"
|
||||
|
||||
@@ -9,9 +9,8 @@ Create Date: 2024-12-12 10:25:31.825635
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d05669b60ebe"
|
||||
|
||||
@@ -8,11 +8,10 @@ Create Date: 2024-11-05 15:03:12.350096
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import letta
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d14ae606614c"
|
||||
|
||||
@@ -8,9 +8,8 @@ Create Date: 2024-12-07 14:28:27.643583
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "e1a625072dbf"
|
||||
|
||||
@@ -9,7 +9,6 @@ Create Date: 2024-11-18 15:40:13.149438
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@@ -9,7 +9,6 @@ Create Date: 2024-11-14 17:51:27.263561
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import typer
|
||||
from swarm import Swarm
|
||||
|
||||
from letta import EmbeddingConfig, LLMConfig
|
||||
from swarm import Swarm
|
||||
|
||||
"""
|
||||
This is an example of how to implement the basic example provided by OpenAI for tranferring a conversation between two agents:
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
from typing import List, Optional
|
||||
|
||||
import typer
|
||||
|
||||
from letta import AgentState, EmbeddingConfig, LLMConfig, create_client
|
||||
from letta.schemas.agent import AgentType
|
||||
from letta.schemas.memory import BasicBlockMemory, Block
|
||||
|
||||
@@ -5,8 +5,8 @@ from letta import create_client
|
||||
from letta.schemas.letta_message import ToolCallMessage
|
||||
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
||||
from tests.helpers.endpoints_helper import (
|
||||
assert_invoked_send_message_with_keyword,
|
||||
setup_agent,
|
||||
assert_invoked_send_message_with_keyword,
|
||||
setup_agent,
|
||||
)
|
||||
from tests.helpers.utils import cleanup
|
||||
from tests.test_model_letta_perfomance import llm_config_dir
|
||||
|
||||
@@ -5,7 +5,6 @@ import uuid
|
||||
from typing import Annotated, Union
|
||||
|
||||
import typer
|
||||
|
||||
from letta import LocalClient, RESTClient, create_client
|
||||
from letta.benchmark.constants import HUMAN, PERSONA, PROMPTS, TRIES
|
||||
from letta.config import LettaConfig
|
||||
|
||||
@@ -3,10 +3,9 @@ import sys
|
||||
from enum import Enum
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import letta.utils as utils
|
||||
import questionary
|
||||
import typer
|
||||
|
||||
import letta.utils as utils
|
||||
from letta import create_client
|
||||
from letta.agent import Agent, save_agent
|
||||
from letta.config import LettaConfig
|
||||
|
||||
@@ -5,11 +5,10 @@ from typing import Annotated, List, Optional
|
||||
|
||||
import questionary
|
||||
import typer
|
||||
from letta import utils
|
||||
from prettytable.colortable import ColorTable, Themes
|
||||
from tqdm import tqdm
|
||||
|
||||
from letta import utils
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from typing import Annotated, List, Optional
|
||||
|
||||
import questionary
|
||||
import typer
|
||||
|
||||
from letta import create_client
|
||||
from letta.data_sources.connectors import DirectoryConnector
|
||||
|
||||
|
||||
@@ -2,9 +2,8 @@ import logging
|
||||
import time
|
||||
from typing import Callable, Dict, Generator, List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
import letta.utils
|
||||
import requests
|
||||
from letta.constants import (
|
||||
ADMIN_PREFIX,
|
||||
BASE_MEMORY_TOOLS,
|
||||
|
||||
@@ -3,14 +3,13 @@ from typing import Generator
|
||||
|
||||
import httpx
|
||||
from httpx_sse import SSEError, connect_sse
|
||||
|
||||
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
|
||||
from letta.errors import LLMError
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import (
|
||||
ReasoningMessage,
|
||||
ToolCallMessage,
|
||||
ToolReturnMessage,
|
||||
ReasoningMessage,
|
||||
)
|
||||
from letta.schemas.letta_response import LettaStreamingResponse
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
@@ -3,12 +3,11 @@ from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from IPython.display import HTML, display
|
||||
from sqlalchemy.testing.plugin.plugin_base import warnings
|
||||
|
||||
from letta.local_llm.constants import (
|
||||
ASSISTANT_MESSAGE_CLI_SYMBOL,
|
||||
INNER_THOUGHTS_CLI_SYMBOL,
|
||||
)
|
||||
from sqlalchemy.testing.plugin.plugin_base import warnings
|
||||
|
||||
|
||||
def pprint(messages):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Dict, Iterator, List, Tuple
|
||||
|
||||
import typer
|
||||
|
||||
from letta.data_sources.connectors_helper import (
|
||||
assert_all_files_exist_locally,
|
||||
extract_metadata_from_files,
|
||||
@@ -14,6 +13,7 @@ from letta.schemas.source import Source
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
|
||||
|
||||
class DataConnector:
|
||||
"""
|
||||
Base class for data connectors that can be extended to generate files and passages from a custom data source.
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
|
||||
from letta.constants import (
|
||||
EMBEDDING_TO_TOKENIZER_DEFAULT,
|
||||
EMBEDDING_TO_TOKENIZER_MAP,
|
||||
|
||||
@@ -52,12 +52,10 @@ class LettaConfigurationError(LettaError):
|
||||
|
||||
class LettaAgentNotFoundError(LettaError):
|
||||
"""Error raised when an agent is not found."""
|
||||
pass
|
||||
|
||||
|
||||
class LettaUserNotFoundError(LettaError):
|
||||
"""Error raised when a user is not found."""
|
||||
pass
|
||||
|
||||
|
||||
class LLMError(LettaError):
|
||||
|
||||
@@ -3,7 +3,6 @@ import uuid
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from letta.constants import (
|
||||
MESSAGE_CHATGPT_FUNCTION_MODEL,
|
||||
MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE,
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import json
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.tool_rule import (
|
||||
BaseToolRule,
|
||||
@@ -11,6 +9,7 @@ from letta.schemas.tool_rule import (
|
||||
InitToolRule,
|
||||
TerminalToolRule,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ToolRuleValidationError(Exception):
|
||||
@@ -50,7 +49,6 @@ class ToolRulesSolver(BaseModel):
|
||||
assert isinstance(rule, TerminalToolRule)
|
||||
self.terminal_tool_rules.append(rule)
|
||||
|
||||
|
||||
def update_tool_usage(self, tool_name: str):
|
||||
"""Update the internal state to track the last tool called."""
|
||||
self.last_tool_name = tool_name
|
||||
@@ -88,7 +86,7 @@ class ToolRulesSolver(BaseModel):
|
||||
return any(rule.tool_name == tool_name for rule in self.tool_rules)
|
||||
|
||||
def validate_conditional_tool(self, rule: ConditionalToolRule):
|
||||
'''
|
||||
"""
|
||||
Validate a conditional tool rule
|
||||
|
||||
Args:
|
||||
@@ -96,13 +94,13 @@ class ToolRulesSolver(BaseModel):
|
||||
|
||||
Raises:
|
||||
ToolRuleValidationError: If the rule is invalid
|
||||
'''
|
||||
"""
|
||||
if len(rule.child_output_mapping) == 0:
|
||||
raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.")
|
||||
return True
|
||||
|
||||
def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str:
|
||||
'''
|
||||
"""
|
||||
Parse function response to determine which child tool to use based on the mapping
|
||||
|
||||
Args:
|
||||
@@ -111,7 +109,7 @@ class ToolRulesSolver(BaseModel):
|
||||
|
||||
Returns:
|
||||
str: The name of the child tool to use next
|
||||
'''
|
||||
"""
|
||||
json_response = json.loads(last_function_response)
|
||||
function_output = json_response["message"]
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from colorama import Fore, Style, init
|
||||
|
||||
from letta.constants import CLI_WARNING_PREFIX
|
||||
from letta.local_llm.constants import (
|
||||
ASSISTANT_MESSAGE_CLI_SYMBOL,
|
||||
|
||||
@@ -102,13 +102,9 @@ def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]:
|
||||
formatted_tools = []
|
||||
for tool in tools:
|
||||
formatted_tool = {
|
||||
"name" : tool.function.name,
|
||||
"description" : tool.function.description,
|
||||
"input_schema" : tool.function.parameters or {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
"name": tool.function.name,
|
||||
"description": tool.function.description,
|
||||
"input_schema": tool.function.parameters or {"type": "object", "properties": {}, "required": []},
|
||||
}
|
||||
formatted_tools.append(formatted_tool)
|
||||
|
||||
@@ -346,7 +342,7 @@ def anthropic_chat_completions_request(
|
||||
data["tool_choice"] = {
|
||||
"type": "tool", # Changed from "function" to "tool"
|
||||
"name": anthropic_tools[0]["name"], # Directly specify name without nested "function" object
|
||||
"disable_parallel_tool_use": True # Force single tool use
|
||||
"disable_parallel_tool_use": True, # Force single tool use
|
||||
}
|
||||
|
||||
# Move 'system' to the top level
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import requests
|
||||
|
||||
from letta.llm_api.helpers import make_post_request
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
|
||||
@@ -3,7 +3,6 @@ import uuid
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from letta.local_llm.utils import count_tokens
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
|
||||
|
||||
@@ -2,7 +2,6 @@ import uuid
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from letta.constants import NON_USER_MSG_PREFIX
|
||||
from letta.llm_api.helpers import make_post_request
|
||||
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
||||
|
||||
@@ -5,7 +5,6 @@ from collections import OrderedDict
|
||||
from typing import Any, List, Union
|
||||
|
||||
import requests
|
||||
|
||||
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice
|
||||
from letta.utils import json_dumps, printd
|
||||
|
||||
@@ -3,7 +3,6 @@ import time
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from letta.constants import CLI_WARNING_PREFIX
|
||||
from letta.errors import LettaConfigurationError, RateLimitExceededError
|
||||
from letta.llm_api.anthropic import anthropic_chat_completions_request
|
||||
@@ -255,12 +254,7 @@ def create(
|
||||
|
||||
tool_call = None
|
||||
if force_tool_call is not None:
|
||||
tool_call = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": force_tool_call
|
||||
}
|
||||
}
|
||||
tool_call = {"type": "function", "function": {"name": force_tool_call}}
|
||||
assert functions is not None
|
||||
|
||||
return anthropic_chat_completions_request(
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import requests
|
||||
|
||||
from letta.utils import printd, smart_urljoin
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import httpx
|
||||
import requests
|
||||
from httpx_sse import connect_sse
|
||||
from httpx_sse._exceptions import SSEError
|
||||
|
||||
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
|
||||
from letta.errors import LLMError
|
||||
from letta.llm_api.helpers import (
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
|
||||
from letta.constants import CLI_WARNING_PREFIX
|
||||
from letta.errors import LocalLLMConnectionError, LocalLLMError
|
||||
from letta.local_llm.constants import DEFAULT_WRAPPER
|
||||
|
||||
@@ -19,9 +19,8 @@ from typing import (
|
||||
)
|
||||
|
||||
from docstring_parser import parse
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
from letta.utils import json_dumps
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
|
||||
class PydanticDataType(Enum):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import yaml
|
||||
|
||||
from letta.utils import json_dumps, json_loads
|
||||
|
||||
from ...errors import LLMJSONParsingError
|
||||
|
||||
@@ -2,15 +2,14 @@ import os
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
|
||||
import requests
|
||||
import tiktoken
|
||||
|
||||
import letta.local_llm.llm_chat_completion_wrappers.airoboros as airoboros
|
||||
import letta.local_llm.llm_chat_completion_wrappers.chatml as chatml
|
||||
import letta.local_llm.llm_chat_completion_wrappers.configurable_wrapper as configurable_wrapper
|
||||
import letta.local_llm.llm_chat_completion_wrappers.dolphin as dolphin
|
||||
import letta.local_llm.llm_chat_completion_wrappers.llama3 as llama3
|
||||
import letta.local_llm.llm_chat_completion_wrappers.zephyr as zephyr
|
||||
import requests
|
||||
import tiktoken
|
||||
from letta.schemas.openai.chat_completion_request import Tool, ToolCall
|
||||
|
||||
|
||||
|
||||
@@ -2,14 +2,12 @@ import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import questionary
|
||||
import requests
|
||||
import typer
|
||||
from rich.console import Console
|
||||
|
||||
import letta.agent as agent
|
||||
import letta.errors as errors
|
||||
import letta.system as system
|
||||
import questionary
|
||||
import requests
|
||||
import typer
|
||||
|
||||
# import benchmark
|
||||
from letta import create_client
|
||||
@@ -22,6 +20,7 @@ from letta.constants import FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE
|
||||
|
||||
# from letta.interface import CLIInterface as interface # for printing to terminal
|
||||
from letta.streaming_interface import AgentRefreshStreamingInterface
|
||||
from rich.console import Console
|
||||
|
||||
# interface = interface()
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from letta.orm.file import FileMetadata
|
||||
from letta.orm.job import Job
|
||||
from letta.orm.message import Message
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.passage import BasePassage, AgentPassage, SourcePassage
|
||||
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
|
||||
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.orm.source import Source
|
||||
from letta.orm.sources_agents import SourcesAgents
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from sqlalchemy import JSON, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.block import Block
|
||||
from letta.orm.custom_columns import (
|
||||
EmbeddingConfigColumn,
|
||||
@@ -20,6 +17,8 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
from sqlalchemy import JSON, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from letta.orm.base import Base
|
||||
from sqlalchemy import ForeignKey, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.base import Base
|
||||
|
||||
|
||||
class AgentsTags(Base):
|
||||
__tablename__ = "agents_tags"
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
from typing import TYPE_CHECKING, Optional, Type
|
||||
|
||||
from sqlalchemy import JSON, BigInteger, Integer, UniqueConstraint, event
|
||||
from sqlalchemy.orm import Mapped, attributes, mapped_column, relationship
|
||||
|
||||
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
|
||||
from letta.orm.blocks_agents import BlocksAgents
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import Human, Persona
|
||||
from sqlalchemy import JSON, BigInteger, Integer, UniqueConstraint, event
|
||||
from sqlalchemy.orm import Mapped, attributes, mapped_column, relationship
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm import Organization
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from letta.orm.base import Base
|
||||
from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.orm.base import Base
|
||||
|
||||
|
||||
class BlocksAgents(Base):
|
||||
"""Agents must have one or many blocks to make up their core memory."""
|
||||
|
||||
@@ -2,14 +2,18 @@ import base64
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
from sqlalchemy import JSON
|
||||
from sqlalchemy.types import BINARY, TypeDecorator
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
||||
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.schemas.tool_rule import (
|
||||
ChildToolRule,
|
||||
ConditionalToolRule,
|
||||
InitToolRule,
|
||||
TerminalToolRule,
|
||||
)
|
||||
from sqlalchemy import JSON
|
||||
from sqlalchemy.types import BINARY, TypeDecorator
|
||||
|
||||
|
||||
class EmbeddingConfigColumn(TypeDecorator):
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
|
||||
from sqlalchemy import Integer, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin, SourceMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from sqlalchemy import Integer, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.source import Source
|
||||
from letta.orm.passage import SourcePassage
|
||||
from letta.orm.source import Source
|
||||
|
||||
|
||||
class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
|
||||
"""Represents metadata for an uploaded file."""
|
||||
@@ -28,4 +28,6 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin")
|
||||
source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin")
|
||||
source_passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan")
|
||||
source_passages: Mapped[List["SourcePassage"]] = relationship(
|
||||
"SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from sqlalchemy import JSON, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import UserMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from sqlalchemy import JSON, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.user import User
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.custom_columns import ToolCallColumn
|
||||
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completions import ToolCall
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
|
||||
class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from letta.orm.base import Base
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.orm.base import Base
|
||||
|
||||
|
||||
def is_valid_uuid4(uuid_string: str) -> bool:
|
||||
"""Check if a string is a valid UUID4."""
|
||||
@@ -31,6 +30,7 @@ class UserMixin(Base):
|
||||
|
||||
user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"))
|
||||
|
||||
|
||||
class AgentMixin(Base):
|
||||
"""Mixin for models that belong to an agent."""
|
||||
|
||||
@@ -38,6 +38,7 @@ class AgentMixin(Base):
|
||||
|
||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"))
|
||||
|
||||
|
||||
class FileMixin(Base):
|
||||
"""Mixin for models that belong to a file."""
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.organization import Organization as PydanticOrganization
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -38,19 +37,11 @@ class Organization(SqlalchemyBase):
|
||||
agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
|
||||
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
|
||||
source_passages: Mapped[List["SourcePassage"]] = relationship(
|
||||
"SourcePassage",
|
||||
back_populates="organization",
|
||||
cascade="all, delete-orphan"
|
||||
)
|
||||
agent_passages: Mapped[List["AgentPassage"]] = relationship(
|
||||
"AgentPassage",
|
||||
back_populates="organization",
|
||||
cascade="all, delete-orphan"
|
||||
"SourcePassage", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="organization", cascade="all, delete-orphan")
|
||||
|
||||
@property
|
||||
def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]:
|
||||
"""Convenience property to get all passages"""
|
||||
return self.source_passages + self.agent_passages
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import JSON, Column, Index
|
||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.orm.custom_columns import CommonVector, EmbeddingConfigColumn
|
||||
@@ -10,6 +7,8 @@ from letta.orm.mixins import AgentMixin, FileMixin, OrganizationMixin, SourceMix
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.settings import settings
|
||||
from sqlalchemy import JSON, Column, Index
|
||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship
|
||||
|
||||
config = LettaConfig()
|
||||
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import JSON
|
||||
from sqlalchemy import Enum as SqlEnum
|
||||
from sqlalchemy import String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin, SandboxConfigMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
|
||||
@@ -12,6 +7,10 @@ from letta.schemas.sandbox_config import (
|
||||
SandboxEnvironmentVariable as PydanticSandboxEnvironmentVariable,
|
||||
)
|
||||
from letta.schemas.sandbox_config import SandboxType
|
||||
from sqlalchemy import JSON
|
||||
from sqlalchemy import Enum as SqlEnum
|
||||
from sqlalchemy import String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.organization import Organization
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from sqlalchemy import JSON
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm import FileMetadata
|
||||
from letta.orm.custom_columns import EmbeddingConfigColumn
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from sqlalchemy import JSON
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.agent import Agent
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from letta.orm.base import Base
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.orm.base import Base
|
||||
|
||||
|
||||
class SourcesAgents(Base):
|
||||
"""Agents can have zero to many sources"""
|
||||
|
||||
@@ -3,10 +3,6 @@ from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||
|
||||
from sqlalchemy import String, desc, func, or_, select
|
||||
from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
||||
from letta.orm.errors import (
|
||||
@@ -16,6 +12,9 @@ from letta.orm.errors import (
|
||||
UniqueConstraintViolationError,
|
||||
)
|
||||
from letta.orm.sqlite_functions import adapt_array
|
||||
from sqlalchemy import String, desc, func, or_, select
|
||||
from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import base64
|
||||
import sqlite3
|
||||
from typing import Optional, Union
|
||||
|
||||
import base64
|
||||
import numpy as np
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.engine import Engine
|
||||
import sqlite3
|
||||
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
|
||||
def adapt_array(arr):
|
||||
"""
|
||||
@@ -19,12 +19,13 @@ def adapt_array(arr):
|
||||
arr = np.array(arr, dtype=np.float32)
|
||||
elif not isinstance(arr, np.ndarray):
|
||||
raise ValueError(f"Unsupported type: {type(arr)}")
|
||||
|
||||
|
||||
# Convert to bytes and then base64 encode
|
||||
bytes_data = arr.tobytes()
|
||||
base64_data = base64.b64encode(bytes_data)
|
||||
return sqlite3.Binary(base64_data)
|
||||
|
||||
|
||||
def convert_array(text):
|
||||
"""
|
||||
Converts binary back to numpy array
|
||||
@@ -38,23 +39,24 @@ def convert_array(text):
|
||||
|
||||
# Handle both bytes and sqlite3.Binary
|
||||
binary_data = bytes(text) if isinstance(text, sqlite3.Binary) else text
|
||||
|
||||
|
||||
try:
|
||||
# First decode base64
|
||||
decoded_data = base64.b64decode(binary_data)
|
||||
# Then convert to numpy array
|
||||
return np.frombuffer(decoded_data, dtype=np.float32)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def verify_embedding_dimension(embedding: np.ndarray, expected_dim: int = MAX_EMBEDDING_DIM) -> bool:
|
||||
"""
|
||||
Verifies that an embedding has the expected dimension
|
||||
|
||||
|
||||
Args:
|
||||
embedding: Input embedding array
|
||||
expected_dim: Expected embedding dimension (default: 4096)
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if dimension matches, False otherwise
|
||||
"""
|
||||
@@ -62,28 +64,27 @@ def verify_embedding_dimension(embedding: np.ndarray, expected_dim: int = MAX_EM
|
||||
return False
|
||||
return embedding.shape[0] == expected_dim
|
||||
|
||||
|
||||
def validate_and_transform_embedding(
|
||||
embedding: Union[bytes, sqlite3.Binary, list, np.ndarray],
|
||||
expected_dim: int = MAX_EMBEDDING_DIM,
|
||||
dtype: np.dtype = np.float32
|
||||
embedding: Union[bytes, sqlite3.Binary, list, np.ndarray], expected_dim: int = MAX_EMBEDDING_DIM, dtype: np.dtype = np.float32
|
||||
) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Validates and transforms embeddings to ensure correct dimensionality.
|
||||
|
||||
|
||||
Args:
|
||||
embedding: Input embedding in various possible formats
|
||||
expected_dim: Expected embedding dimension (default 4096)
|
||||
dtype: NumPy dtype for the embedding (default float32)
|
||||
|
||||
|
||||
Returns:
|
||||
np.ndarray: Validated and transformed embedding
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If embedding dimension doesn't match expected dimension
|
||||
"""
|
||||
if embedding is None:
|
||||
return None
|
||||
|
||||
|
||||
# Convert to numpy array based on input type
|
||||
if isinstance(embedding, (bytes, sqlite3.Binary)):
|
||||
vec = convert_array(embedding)
|
||||
@@ -93,48 +94,49 @@ def validate_and_transform_embedding(
|
||||
vec = embedding.astype(dtype)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding type: {type(embedding)}")
|
||||
|
||||
|
||||
# Validate dimension
|
||||
if vec.shape[0] != expected_dim:
|
||||
raise ValueError(
|
||||
f"Invalid embedding dimension: got {vec.shape[0]}, expected {expected_dim}"
|
||||
)
|
||||
|
||||
raise ValueError(f"Invalid embedding dimension: got {vec.shape[0]}, expected {expected_dim}")
|
||||
|
||||
return vec
|
||||
|
||||
|
||||
def cosine_distance(embedding1, embedding2, expected_dim=MAX_EMBEDDING_DIM):
|
||||
"""
|
||||
Calculate cosine distance between two embeddings
|
||||
|
||||
|
||||
Args:
|
||||
embedding1: First embedding
|
||||
embedding2: Second embedding
|
||||
expected_dim: Expected embedding dimension (default 4096)
|
||||
|
||||
|
||||
Returns:
|
||||
float: Cosine distance
|
||||
"""
|
||||
|
||||
|
||||
if embedding1 is None or embedding2 is None:
|
||||
return 0.0 # Maximum distance if either embedding is None
|
||||
|
||||
|
||||
try:
|
||||
vec1 = validate_and_transform_embedding(embedding1, expected_dim)
|
||||
vec2 = validate_and_transform_embedding(embedding2, expected_dim)
|
||||
except ValueError as e:
|
||||
except ValueError:
|
||||
return 0.0
|
||||
|
||||
|
||||
similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
||||
distance = float(1.0 - similarity)
|
||||
|
||||
|
||||
return distance
|
||||
|
||||
|
||||
@event.listens_for(Engine, "connect")
|
||||
def register_functions(dbapi_connection, connection_record):
|
||||
"""Register SQLite functions"""
|
||||
if isinstance(dbapi_connection, sqlite3.Connection):
|
||||
dbapi_connection.create_function("cosine_distance", 2, cosine_distance)
|
||||
|
||||
|
||||
|
||||
# Register adapters and converters for numpy arrays
|
||||
sqlite3.register_adapter(np.ndarray, adapt_array)
|
||||
sqlite3.register_converter("ARRAY", convert_array)
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from sqlalchemy import JSON, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
# TODO everything in functions should live in this model
|
||||
from letta.orm.enums import ToolSourceType
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from sqlalchemy import JSON, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.organization import Organization
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from letta.orm import Base
|
||||
from sqlalchemy import ForeignKey, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.orm import Base
|
||||
|
||||
|
||||
class ToolsAgents(Base):
|
||||
"""Agents can have one or many tools associated with them."""
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm import Job, Organization
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from letta.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
|
||||
from letta.llm_api.azure_openai import (
|
||||
get_azure_chat_completions_endpoint,
|
||||
@@ -10,6 +8,7 @@ from letta.llm_api.azure_openai import (
|
||||
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class Provider(BaseModel):
|
||||
@@ -27,12 +26,11 @@ class Provider(BaseModel):
|
||||
def provider_tag(self) -> str:
|
||||
"""String representation of the provider for display purposes"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_handle(self, model_name: str) -> str:
|
||||
return f"{self.name}/{model_name}"
|
||||
|
||||
|
||||
|
||||
class LettaProvider(Provider):
|
||||
|
||||
name: str = "letta"
|
||||
@@ -44,7 +42,7 @@ class LettaProvider(Provider):
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://inference.memgpt.ai",
|
||||
context_window=16384,
|
||||
handle=self.get_handle("letta-free")
|
||||
handle=self.get_handle("letta-free"),
|
||||
)
|
||||
]
|
||||
|
||||
@@ -56,7 +54,7 @@ class LettaProvider(Provider):
|
||||
embedding_endpoint="https://embeddings.memgpt.ai",
|
||||
embedding_dim=1024,
|
||||
embedding_chunk_size=300,
|
||||
handle=self.get_handle("letta-free")
|
||||
handle=self.get_handle("letta-free"),
|
||||
)
|
||||
]
|
||||
|
||||
@@ -121,7 +119,13 @@ class OpenAIProvider(Provider):
|
||||
# continue
|
||||
|
||||
configs.append(
|
||||
LLMConfig(model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size, handle=self.get_handle(model_name))
|
||||
LLMConfig(
|
||||
model=model_name,
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
)
|
||||
)
|
||||
|
||||
# for OpenAI, sort in reverse order
|
||||
@@ -141,7 +145,7 @@ class OpenAIProvider(Provider):
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=1536,
|
||||
embedding_chunk_size=300,
|
||||
handle=self.get_handle("text-embedding-ada-002")
|
||||
handle=self.get_handle("text-embedding-ada-002"),
|
||||
)
|
||||
]
|
||||
|
||||
@@ -170,7 +174,7 @@ class AnthropicProvider(Provider):
|
||||
model_endpoint_type="anthropic",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=model["context_window"],
|
||||
handle=self.get_handle(model["name"])
|
||||
handle=self.get_handle(model["name"]),
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -203,7 +207,7 @@ class MistralProvider(Provider):
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=model["max_context_length"],
|
||||
handle=self.get_handle(model["id"])
|
||||
handle=self.get_handle(model["id"]),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -259,7 +263,7 @@ class OllamaProvider(OpenAIProvider):
|
||||
model_endpoint=self.base_url,
|
||||
model_wrapper=self.default_prompt_formatter,
|
||||
context_window=context_window,
|
||||
handle=self.get_handle(model["name"])
|
||||
handle=self.get_handle(model["name"]),
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -335,7 +339,7 @@ class OllamaProvider(OpenAIProvider):
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=embedding_dim,
|
||||
embedding_chunk_size=300,
|
||||
handle=self.get_handle(model["name"])
|
||||
handle=self.get_handle(model["name"]),
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -356,7 +360,11 @@ class GroqProvider(OpenAIProvider):
|
||||
continue
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"], handle=self.get_handle(model["id"])
|
||||
model=model["id"],
|
||||
model_endpoint_type="groq",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=model["context_window"],
|
||||
handle=self.get_handle(model["id"]),
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -424,7 +432,7 @@ class TogetherProvider(OpenAIProvider):
|
||||
model_endpoint=self.base_url,
|
||||
model_wrapper=self.default_prompt_formatter,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name)
|
||||
handle=self.get_handle(model_name),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -505,7 +513,7 @@ class GoogleAIProvider(Provider):
|
||||
model_endpoint_type="google_ai",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=self.get_model_context_window(model),
|
||||
handle=self.get_handle(model)
|
||||
handle=self.get_handle(model),
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -529,7 +537,7 @@ class GoogleAIProvider(Provider):
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=768,
|
||||
embedding_chunk_size=300, # NOTE: max is 2048
|
||||
handle=self.get_handle(model)
|
||||
handle=self.get_handle(model),
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -570,7 +578,8 @@ class AzureProvider(Provider):
|
||||
context_window_size = self.get_model_context_window(model_name)
|
||||
model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version)
|
||||
configs.append(
|
||||
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size), handle=self.get_handle(model_name)
|
||||
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size),
|
||||
handle=self.get_handle(model_name),
|
||||
)
|
||||
return configs
|
||||
|
||||
@@ -591,7 +600,7 @@ class AzureProvider(Provider):
|
||||
embedding_endpoint=model_endpoint,
|
||||
embedding_dim=768,
|
||||
embedding_chunk_size=300, # NOTE: max is 2048
|
||||
handle=self.get_handle(model_name)
|
||||
handle=self.get_handle(model_name),
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -625,7 +634,7 @@ class VLLMChatCompletionsProvider(Provider):
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=model["max_model_len"],
|
||||
handle=self.get_handle(model["id"])
|
||||
handle=self.get_handle(model["id"]),
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@@ -658,7 +667,7 @@ class VLLMCompletionsProvider(Provider):
|
||||
model_endpoint=self.base_url,
|
||||
model_wrapper=self.default_prompt_formatter,
|
||||
context_window=model["max_model_len"],
|
||||
handle=self.get_handle(model["id"])
|
||||
handle=self.get_handle(model["id"]),
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -15,6 +13,7 @@ from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
from letta.utils import create_random_username
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
# block of the LLM context
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class FileMetadataBase(LettaBase):
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class JobBase(OrmMetadataBase):
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.schemas.message import MessageCreate
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class LettaRequest(BaseModel):
|
||||
|
||||
@@ -3,12 +3,11 @@ import json
|
||||
import re
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import LettaMessage, LettaMessageUnion
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.utils import json_dumps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# TODO: consider moving into own file
|
||||
|
||||
|
||||
@@ -4,8 +4,6 @@ import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from letta.constants import (
|
||||
DEFAULT_MESSAGE_TOOL,
|
||||
DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
@@ -16,16 +14,15 @@ from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.letta_message import (
|
||||
AssistantMessage,
|
||||
ToolCall as LettaToolCall,
|
||||
ToolCallMessage,
|
||||
ToolReturnMessage,
|
||||
ReasoningMessage,
|
||||
LettaMessage,
|
||||
ReasoningMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from letta.schemas.letta_message import ToolCall as LettaToolCall
|
||||
from letta.schemas.letta_message import ToolCallMessage, ToolReturnMessage, UserMessage
|
||||
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
||||
from letta.utils import get_utc_time, is_utc_datetime, json_dumps
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
def add_inner_thoughts_to_tool_call(
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.utils import create_random_username, get_utc_time
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class OrganizationBase(LettaBase):
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.utils import get_utc_time
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
|
||||
class PassageBase(OrmMetadataBase):
|
||||
|
||||
@@ -3,11 +3,10 @@ import json
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.letta_base import LettaBase, OrmMetadataBase
|
||||
from letta.settings import tool_settings
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
# Sandbox Config
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class BaseSource(LettaBase):
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from letta.constants import FUNCTION_RETURN_CHAR_LIMIT
|
||||
from letta.functions.functions import derive_openai_json_schema
|
||||
from letta.functions.helpers import (
|
||||
@@ -11,6 +9,7 @@ from letta.functions.helpers import (
|
||||
from letta.functions.schema_generator import generate_schema_from_args_schema_v2
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.openai.chat_completions import ToolCall
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
|
||||
class BaseTool(LettaBase):
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class BaseToolRule(LettaBase):
|
||||
@@ -25,6 +24,7 @@ class ConditionalToolRule(BaseToolRule):
|
||||
"""
|
||||
A ToolRule that conditionally maps to different child tools based on the output.
|
||||
"""
|
||||
|
||||
type: ToolRuleType = ToolRuleType.conditional
|
||||
default_child: Optional[str] = Field(None, description="The default child tool to be called. If None, any tool can be called.")
|
||||
child_output_mapping: Dict[Any, str] = Field(..., description="The output case to check for mapping")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -12,6 +13,7 @@ class LettaUsageStatistics(BaseModel):
|
||||
total_tokens (int): The total number of tokens processed by the agent.
|
||||
step_count (int): The number of steps taken by the agent.
|
||||
"""
|
||||
|
||||
message_type: Literal["usage_statistics"] = "usage_statistics"
|
||||
completion_tokens: int = Field(0, description="The number of tokens generated by the agent.")
|
||||
prompt_tokens: int = Field(0, description="The number of tokens in the prompt.")
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class UserBase(LettaBase):
|
||||
|
||||
@@ -8,9 +8,6 @@ from typing import Optional
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from letta.__init__ import __version__
|
||||
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
|
||||
from letta.errors import LettaAgentNotFoundError, LettaUserNotFoundError
|
||||
@@ -47,6 +44,8 @@ from letta.server.rest_api.routers.v1.users import (
|
||||
from letta.server.rest_api.static_files import mount_static_files
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import settings
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
# TODO(ethan)
|
||||
# NOTE(charles): @ethan I had to add this to get the global as the bottom to work
|
||||
|
||||
@@ -2,11 +2,10 @@ from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.server.rest_api.interface import QueuingInterface
|
||||
from letta.server.server import SyncServer
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = get_logger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
@@ -2,7 +2,6 @@ import uuid
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
@@ -12,14 +12,14 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import (
|
||||
AssistantMessage,
|
||||
LegacyFunctionCallMessage,
|
||||
LegacyLettaMessage,
|
||||
LettaMessage,
|
||||
ReasoningMessage,
|
||||
ToolCall,
|
||||
ToolCallDelta,
|
||||
ToolCallMessage,
|
||||
ToolReturnMessage,
|
||||
ReasoningMessage,
|
||||
LegacyFunctionCallMessage,
|
||||
LegacyLettaMessage,
|
||||
LettaMessage,
|
||||
)
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Path, Query
|
||||
|
||||
from letta.constants import DEFAULT_PRESET
|
||||
from letta.schemas.openai.openai import AssistantFile, OpenAIAssistant
|
||||
from letta.server.rest_api.routers.openai.assistants.schemas import (
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.openai.openai import (
|
||||
MessageRoleType,
|
||||
OpenAIMessage,
|
||||
@@ -9,6 +7,7 @@ from letta.schemas.openai.openai import (
|
||||
ToolCall,
|
||||
ToolCallOutput,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CreateAssistantRequest(BaseModel):
|
||||
|
||||
@@ -2,9 +2,8 @@ import json
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException
|
||||
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import ToolCall, LettaMessage
|
||||
from letta.schemas.letta_message import LettaMessage, ToolCall
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
from letta.schemas.openai.chat_completion_response import (
|
||||
ChatCompletionResponse,
|
||||
|
||||
@@ -14,8 +14,6 @@ from fastapi import (
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import Field
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
@@ -49,6 +47,7 @@ from letta.schemas.user import User
|
||||
from letta.server.rest_api.interface import StreamingServerInterface
|
||||
from letta.server.rest_api.utils import get_letta_server, sse_async_generator
|
||||
from letta.server.server import SyncServer
|
||||
from pydantic import Field
|
||||
|
||||
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, Response
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.block import Block, BlockUpdate, CreateBlock
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from letta.cli.cli import version
|
||||
from letta.schemas.health import Health
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.job import Job
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
|
||||
from letta.schemas.organization import Organization, OrganizationCreate
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
|
||||
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate
|
||||
from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar
|
||||
|
||||
@@ -11,7 +11,6 @@ from fastapi import (
|
||||
Query,
|
||||
UploadFile,
|
||||
)
|
||||
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.job import Job
|
||||
from letta.schemas.passage import Passage
|
||||
|
||||
@@ -6,7 +6,6 @@ from composio.client.enums.base import EnumStringNotFound
|
||||
from composio.exceptions import ApiKeyNotProvidedError, ComposioSDKError
|
||||
from composio.tools.base.abs import InvalidClassDefinition
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException
|
||||
|
||||
from letta.errors import LettaToolCreateError
|
||||
from letta.orm.errors import UniqueConstraintViolationError
|
||||
from letta.schemas.letta_message import ToolReturnMessage
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user