Merge branch 'main' into matt/let-649-fix-updating-agent-refresh-blocks

This commit is contained in:
Shubham Naik
2024-12-23 15:15:49 -08:00
committed by GitHub
135 changed files with 395 additions and 553 deletions

View File

@@ -1,12 +1,11 @@
import os
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
from letta.config import LettaConfig
from letta.orm import Base
from letta.settings import settings
from sqlalchemy import engine_from_config, pool
letta_config = LettaConfig.load()

View File

@@ -5,40 +5,44 @@ Revises: 3c683a662c82
Create Date: 2024-12-05 16:46:51.258831
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '08b2f8225812'
down_revision: Union[str, None] = '3c683a662c82'
revision: str = "08b2f8225812"
down_revision: Union[str, None] = "3c683a662c82"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tools_agents',
sa.Column('agent_id', sa.String(), nullable=False),
sa.Column('tool_id', sa.String(), nullable=False),
sa.Column('tool_name', sa.String(), nullable=False),
sa.Column('id', sa.String(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False),
sa.Column('_created_by_id', sa.String(), nullable=True),
sa.Column('_last_updated_by_id', sa.String(), nullable=True),
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ),
sa.ForeignKeyConstraint(['tool_id'], ['tools.id'], name='fk_tool_id'),
sa.PrimaryKeyConstraint('agent_id', 'tool_id', 'tool_name', 'id'),
sa.UniqueConstraint('agent_id', 'tool_name', name='unique_tool_per_agent')
op.create_table(
"tools_agents",
sa.Column("agent_id", sa.String(), nullable=False),
sa.Column("tool_id", sa.String(), nullable=False),
sa.Column("tool_name", sa.String(), nullable=False),
sa.Column("id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.ForeignKeyConstraint(
["agent_id"],
["agents.id"],
),
sa.ForeignKeyConstraint(["tool_id"], ["tools.id"], name="fk_tool_id"),
sa.PrimaryKeyConstraint("agent_id", "tool_id", "tool_name", "id"),
sa.UniqueConstraint("agent_id", "tool_name", name="unique_tool_per_agent"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('tools_agents')
op.drop_table("tools_agents")
# ### end Alembic commands ###

View File

@@ -9,7 +9,6 @@ Create Date: 2024-11-22 15:42:47.209229
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.

View File

@@ -9,9 +9,8 @@ Create Date: 2024-12-04 15:59:41.708396
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "3c683a662c82"

View File

@@ -9,7 +9,6 @@ Create Date: 2024-12-13 17:19:55.796210
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.

View File

@@ -5,18 +5,18 @@ Revises: 4e88e702f85e
Create Date: 2024-12-14 17:23:08.772554
"""
from typing import Sequence, Union
from alembic import op
from pgvector.sqlalchemy import Vector
import sqlalchemy as sa
from alembic import op
from letta.orm.custom_columns import EmbeddingConfigColumn
from pgvector.sqlalchemy import Vector
from sqlalchemy.dialects import postgresql
from letta.orm.custom_columns import EmbeddingConfigColumn
# revision identifiers, used by Alembic.
revision: str = '54dec07619c4'
down_revision: Union[str, None] = '4e88e702f85e'
revision: str = "54dec07619c4"
down_revision: Union[str, None] = "4e88e702f85e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -24,82 +24,88 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
'agent_passages',
sa.Column('id', sa.String(), nullable=False),
sa.Column('text', sa.String(), nullable=False),
sa.Column('embedding_config', EmbeddingConfigColumn(), nullable=False),
sa.Column('metadata_', sa.JSON(), nullable=False),
sa.Column('embedding', Vector(dim=4096), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False),
sa.Column('_created_by_id', sa.String(), nullable=True),
sa.Column('_last_updated_by_id', sa.String(), nullable=True),
sa.Column('organization_id', sa.String(), nullable=False),
sa.Column('agent_id', sa.String(), nullable=False),
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.PrimaryKeyConstraint('id')
"agent_passages",
sa.Column("id", sa.String(), nullable=False),
sa.Column("text", sa.String(), nullable=False),
sa.Column("embedding_config", EmbeddingConfigColumn(), nullable=False),
sa.Column("metadata_", sa.JSON(), nullable=False),
sa.Column("embedding", Vector(dim=4096), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.Column("agent_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index('agent_passages_org_idx', 'agent_passages', ['organization_id'], unique=False)
op.create_index("agent_passages_org_idx", "agent_passages", ["organization_id"], unique=False)
op.create_table(
'source_passages',
sa.Column('id', sa.String(), nullable=False),
sa.Column('text', sa.String(), nullable=False),
sa.Column('embedding_config', EmbeddingConfigColumn(), nullable=False),
sa.Column('metadata_', sa.JSON(), nullable=False),
sa.Column('embedding', Vector(dim=4096), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False),
sa.Column('_created_by_id', sa.String(), nullable=True),
sa.Column('_last_updated_by_id', sa.String(), nullable=True),
sa.Column('organization_id', sa.String(), nullable=False),
sa.Column('file_id', sa.String(), nullable=True),
sa.Column('source_id', sa.String(), nullable=False),
sa.ForeignKeyConstraint(['file_id'], ['files.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.ForeignKeyConstraint(['source_id'], ['sources.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
"source_passages",
sa.Column("id", sa.String(), nullable=False),
sa.Column("text", sa.String(), nullable=False),
sa.Column("embedding_config", EmbeddingConfigColumn(), nullable=False),
sa.Column("metadata_", sa.JSON(), nullable=False),
sa.Column("embedding", Vector(dim=4096), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.Column("file_id", sa.String(), nullable=True),
sa.Column("source_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(["file_id"], ["files.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.ForeignKeyConstraint(["source_id"], ["sources.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index('source_passages_org_idx', 'source_passages', ['organization_id'], unique=False)
op.drop_table('passages')
op.drop_constraint('files_source_id_fkey', 'files', type_='foreignkey')
op.create_foreign_key(None, 'files', 'sources', ['source_id'], ['id'], ondelete='CASCADE')
op.drop_constraint('messages_agent_id_fkey', 'messages', type_='foreignkey')
op.create_foreign_key(None, 'messages', 'agents', ['agent_id'], ['id'], ondelete='CASCADE')
op.create_index("source_passages_org_idx", "source_passages", ["organization_id"], unique=False)
op.drop_table("passages")
op.drop_constraint("files_source_id_fkey", "files", type_="foreignkey")
op.create_foreign_key(None, "files", "sources", ["source_id"], ["id"], ondelete="CASCADE")
op.drop_constraint("messages_agent_id_fkey", "messages", type_="foreignkey")
op.create_foreign_key(None, "messages", "agents", ["agent_id"], ["id"], ondelete="CASCADE")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, 'messages', type_='foreignkey')
op.create_foreign_key('messages_agent_id_fkey', 'messages', 'agents', ['agent_id'], ['id'])
op.drop_constraint(None, 'files', type_='foreignkey')
op.create_foreign_key('files_source_id_fkey', 'files', 'sources', ['source_id'], ['id'])
op.drop_constraint(None, "messages", type_="foreignkey")
op.create_foreign_key("messages_agent_id_fkey", "messages", "agents", ["agent_id"], ["id"])
op.drop_constraint(None, "files", type_="foreignkey")
op.create_foreign_key("files_source_id_fkey", "files", "sources", ["source_id"], ["id"])
op.create_table(
'passages',
sa.Column('id', sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column('text', sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column('file_id', sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column('agent_id', sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column('source_id', sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column('embedding', Vector(dim=4096), autoincrement=False, nullable=True),
sa.Column('embedding_config', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
sa.Column('metadata_', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
sa.Column('created_at', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=False),
sa.Column('updated_at', postgresql.TIMESTAMP(timezone=True), server_default=sa.text('now()'), autoincrement=False, nullable=True),
sa.Column('is_deleted', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False),
sa.Column('_created_by_id', sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column('_last_updated_by_id', sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column('organization_id', sa.VARCHAR(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], name='passages_agent_id_fkey'),
sa.ForeignKeyConstraint(['file_id'], ['files.id'], name='passages_file_id_fkey', ondelete='CASCADE'),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], name='passages_organization_id_fkey'),
sa.PrimaryKeyConstraint('id', name='passages_pkey')
"passages",
sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column("text", sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column("file_id", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("agent_id", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("source_id", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("embedding", Vector(dim=4096), autoincrement=False, nullable=True),
sa.Column("embedding_config", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
sa.Column("metadata_", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=False),
sa.Column("updated_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True),
sa.Column("is_deleted", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False),
sa.Column("_created_by_id", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("_last_updated_by_id", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("organization_id", sa.VARCHAR(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], name="passages_agent_id_fkey"),
sa.ForeignKeyConstraint(["file_id"], ["files.id"], name="passages_file_id_fkey", ondelete="CASCADE"),
sa.ForeignKeyConstraint(["organization_id"], ["organizations.id"], name="passages_organization_id_fkey"),
sa.PrimaryKeyConstraint("id", name="passages_pkey"),
)
op.drop_index('source_passages_org_idx', table_name='source_passages')
op.drop_table('source_passages')
op.drop_index('agent_passages_org_idx', table_name='agent_passages')
op.drop_table('agent_passages')
op.drop_index("source_passages_org_idx", table_name="source_passages")
op.drop_table("source_passages")
op.drop_index("agent_passages_org_idx", table_name="agent_passages")
op.drop_table("agent_passages")
# ### end Alembic commands ###

View File

@@ -9,9 +9,8 @@ Create Date: 2024-11-25 14:35:00.896507
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "5987401b40ae"

View File

@@ -9,9 +9,8 @@ Create Date: 2024-12-05 14:02:04.163150
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "95badb46fdf9"

View File

@@ -8,12 +8,11 @@ Create Date: 2024-10-11 14:19:19.875656
from typing import Sequence, Union
import letta.orm
import pgvector
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
import letta.orm
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "9a505cc7eca9"

View File

@@ -9,7 +9,6 @@ Create Date: 2024-12-09 18:27:25.650079
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from letta.constants import FUNCTION_RETURN_CHAR_LIMIT

View File

@@ -9,7 +9,6 @@ Create Date: 2024-11-06 10:48:08.424108
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.

View File

@@ -5,25 +5,26 @@ Revises: a91994b9752f
Create Date: 2024-12-10 15:05:32.335519
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = 'c5d964280dff'
down_revision: Union[str, None] = 'a91994b9752f'
revision: str = "c5d964280dff"
down_revision: Union[str, None] = "a91994b9752f"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('passages', sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True))
op.add_column('passages', sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False))
op.add_column('passages', sa.Column('_created_by_id', sa.String(), nullable=True))
op.add_column('passages', sa.Column('_last_updated_by_id', sa.String(), nullable=True))
op.add_column("passages", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
op.add_column("passages", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
op.add_column("passages", sa.Column("_created_by_id", sa.String(), nullable=True))
op.add_column("passages", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
# Data migration step:
op.add_column("passages", sa.Column("organization_id", sa.String(), nullable=True))
@@ -41,48 +42,32 @@ def upgrade() -> None:
# Set `organization_id` as non-nullable after population
op.alter_column("passages", "organization_id", nullable=False)
op.alter_column('passages', 'text',
existing_type=sa.VARCHAR(),
nullable=False)
op.alter_column('passages', 'embedding_config',
existing_type=postgresql.JSON(astext_type=sa.Text()),
nullable=False)
op.alter_column('passages', 'metadata_',
existing_type=postgresql.JSON(astext_type=sa.Text()),
nullable=False)
op.alter_column('passages', 'created_at',
existing_type=postgresql.TIMESTAMP(timezone=True),
nullable=False)
op.drop_index('passage_idx_user', table_name='passages')
op.create_foreign_key(None, 'passages', 'organizations', ['organization_id'], ['id'])
op.create_foreign_key(None, 'passages', 'agents', ['agent_id'], ['id'])
op.create_foreign_key(None, 'passages', 'files', ['file_id'], ['id'], ondelete='CASCADE')
op.drop_column('passages', 'user_id')
op.alter_column("passages", "text", existing_type=sa.VARCHAR(), nullable=False)
op.alter_column("passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
op.alter_column("passages", "metadata_", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
op.alter_column("passages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False)
op.drop_index("passage_idx_user", table_name="passages")
op.create_foreign_key(None, "passages", "organizations", ["organization_id"], ["id"])
op.create_foreign_key(None, "passages", "agents", ["agent_id"], ["id"])
op.create_foreign_key(None, "passages", "files", ["file_id"], ["id"], ondelete="CASCADE")
op.drop_column("passages", "user_id")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('passages', sa.Column('user_id', sa.VARCHAR(), autoincrement=False, nullable=False))
op.drop_constraint(None, 'passages', type_='foreignkey')
op.drop_constraint(None, 'passages', type_='foreignkey')
op.drop_constraint(None, 'passages', type_='foreignkey')
op.create_index('passage_idx_user', 'passages', ['user_id', 'agent_id', 'file_id'], unique=False)
op.alter_column('passages', 'created_at',
existing_type=postgresql.TIMESTAMP(timezone=True),
nullable=True)
op.alter_column('passages', 'metadata_',
existing_type=postgresql.JSON(astext_type=sa.Text()),
nullable=True)
op.alter_column('passages', 'embedding_config',
existing_type=postgresql.JSON(astext_type=sa.Text()),
nullable=True)
op.alter_column('passages', 'text',
existing_type=sa.VARCHAR(),
nullable=True)
op.drop_column('passages', 'organization_id')
op.drop_column('passages', '_last_updated_by_id')
op.drop_column('passages', '_created_by_id')
op.drop_column('passages', 'is_deleted')
op.drop_column('passages', 'updated_at')
op.add_column("passages", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False))
op.drop_constraint(None, "passages", type_="foreignkey")
op.drop_constraint(None, "passages", type_="foreignkey")
op.drop_constraint(None, "passages", type_="foreignkey")
op.create_index("passage_idx_user", "passages", ["user_id", "agent_id", "file_id"], unique=False)
op.alter_column("passages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=True)
op.alter_column("passages", "metadata_", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
op.alter_column("passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
op.alter_column("passages", "text", existing_type=sa.VARCHAR(), nullable=True)
op.drop_column("passages", "organization_id")
op.drop_column("passages", "_last_updated_by_id")
op.drop_column("passages", "_created_by_id")
op.drop_column("passages", "is_deleted")
op.drop_column("passages", "updated_at")
# ### end Alembic commands ###

View File

@@ -9,7 +9,6 @@ Create Date: 2024-11-12 13:58:57.221081
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.

View File

@@ -9,9 +9,8 @@ Create Date: 2024-11-07 13:29:57.186107
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "cda66b6cb0d6"

View File

@@ -9,9 +9,8 @@ Create Date: 2024-12-12 10:25:31.825635
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "d05669b60ebe"

View File

@@ -8,11 +8,10 @@ Create Date: 2024-11-05 15:03:12.350096
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
import letta
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "d14ae606614c"

View File

@@ -8,9 +8,8 @@ Create Date: 2024-12-07 14:28:27.643583
from typing import Sequence, Union
from sqlalchemy.dialects import postgresql
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "e1a625072dbf"

View File

@@ -9,7 +9,6 @@ Create Date: 2024-11-18 15:40:13.149438
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.

View File

@@ -9,7 +9,6 @@ Create Date: 2024-11-14 17:51:27.263561
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.

View File

@@ -1,7 +1,6 @@
import typer
from swarm import Swarm
from letta import EmbeddingConfig, LLMConfig
from swarm import Swarm
"""
This is an example of how to implement the basic example provided by OpenAI for tranferring a conversation between two agents:

View File

@@ -2,7 +2,6 @@ import json
from typing import List, Optional
import typer
from letta import AgentState, EmbeddingConfig, LLMConfig, create_client
from letta.schemas.agent import AgentType
from letta.schemas.memory import BasicBlockMemory, Block

View File

@@ -5,8 +5,8 @@ from letta import create_client
from letta.schemas.letta_message import ToolCallMessage
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from tests.helpers.endpoints_helper import (
assert_invoked_send_message_with_keyword,
setup_agent,
assert_invoked_send_message_with_keyword,
setup_agent,
)
from tests.helpers.utils import cleanup
from tests.test_model_letta_perfomance import llm_config_dir

View File

@@ -5,7 +5,6 @@ import uuid
from typing import Annotated, Union
import typer
from letta import LocalClient, RESTClient, create_client
from letta.benchmark.constants import HUMAN, PERSONA, PROMPTS, TRIES
from letta.config import LettaConfig

View File

@@ -3,10 +3,9 @@ import sys
from enum import Enum
from typing import Annotated, Optional
import letta.utils as utils
import questionary
import typer
import letta.utils as utils
from letta import create_client
from letta.agent import Agent, save_agent
from letta.config import LettaConfig

View File

@@ -5,11 +5,10 @@ from typing import Annotated, List, Optional
import questionary
import typer
from letta import utils
from prettytable.colortable import ColorTable, Themes
from tqdm import tqdm
from letta import utils
app = typer.Typer()

View File

@@ -13,7 +13,6 @@ from typing import Annotated, List, Optional
import questionary
import typer
from letta import create_client
from letta.data_sources.connectors import DirectoryConnector

View File

@@ -2,9 +2,8 @@ import logging
import time
from typing import Callable, Dict, Generator, List, Optional, Union
import requests
import letta.utils
import requests
from letta.constants import (
ADMIN_PREFIX,
BASE_MEMORY_TOOLS,

View File

@@ -3,14 +3,13 @@ from typing import Generator
import httpx
from httpx_sse import SSEError, connect_sse
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from letta.errors import LLMError
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import (
ReasoningMessage,
ToolCallMessage,
ToolReturnMessage,
ReasoningMessage,
)
from letta.schemas.letta_response import LettaStreamingResponse
from letta.schemas.usage import LettaUsageStatistics

View File

@@ -3,12 +3,11 @@ from datetime import datetime
from typing import Optional
from IPython.display import HTML, display
from sqlalchemy.testing.plugin.plugin_base import warnings
from letta.local_llm.constants import (
ASSISTANT_MESSAGE_CLI_SYMBOL,
INNER_THOUGHTS_CLI_SYMBOL,
)
from sqlalchemy.testing.plugin.plugin_base import warnings
def pprint(messages):

View File

@@ -1,7 +1,6 @@
from typing import Dict, Iterator, List, Tuple
import typer
from letta.data_sources.connectors_helper import (
assert_all_files_exist_locally,
extract_metadata_from_files,
@@ -14,6 +13,7 @@ from letta.schemas.source import Source
from letta.services.passage_manager import PassageManager
from letta.services.source_manager import SourceManager
class DataConnector:
"""
Base class for data connectors that can be extended to generate files and passages from a custom data source.

View File

@@ -3,7 +3,6 @@ from typing import Any, List, Optional
import numpy as np
import tiktoken
from letta.constants import (
EMBEDDING_TO_TOKENIZER_DEFAULT,
EMBEDDING_TO_TOKENIZER_MAP,

View File

@@ -52,12 +52,10 @@ class LettaConfigurationError(LettaError):
class LettaAgentNotFoundError(LettaError):
"""Error raised when an agent is not found."""
pass
class LettaUserNotFoundError(LettaError):
"""Error raised when a user is not found."""
pass
class LLMError(LettaError):

View File

@@ -3,7 +3,6 @@ import uuid
from typing import Optional
import requests
from letta.constants import (
MESSAGE_CHATGPT_FUNCTION_MODEL,
MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE,

View File

@@ -1,8 +1,6 @@
import json
from typing import List, Optional, Union
from pydantic import BaseModel, Field
from letta.schemas.enums import ToolRuleType
from letta.schemas.tool_rule import (
BaseToolRule,
@@ -11,6 +9,7 @@ from letta.schemas.tool_rule import (
InitToolRule,
TerminalToolRule,
)
from pydantic import BaseModel, Field
class ToolRuleValidationError(Exception):
@@ -50,7 +49,6 @@ class ToolRulesSolver(BaseModel):
assert isinstance(rule, TerminalToolRule)
self.terminal_tool_rules.append(rule)
def update_tool_usage(self, tool_name: str):
"""Update the internal state to track the last tool called."""
self.last_tool_name = tool_name
@@ -88,7 +86,7 @@ class ToolRulesSolver(BaseModel):
return any(rule.tool_name == tool_name for rule in self.tool_rules)
def validate_conditional_tool(self, rule: ConditionalToolRule):
'''
"""
Validate a conditional tool rule
Args:
@@ -96,13 +94,13 @@ class ToolRulesSolver(BaseModel):
Raises:
ToolRuleValidationError: If the rule is invalid
'''
"""
if len(rule.child_output_mapping) == 0:
raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.")
return True
def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str:
'''
"""
Parse function response to determine which child tool to use based on the mapping
Args:
@@ -111,7 +109,7 @@ class ToolRulesSolver(BaseModel):
Returns:
str: The name of the child tool to use next
'''
"""
json_response = json.loads(last_function_response)
function_output = json_response["message"]

View File

@@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
from typing import List, Optional
from colorama import Fore, Style, init
from letta.constants import CLI_WARNING_PREFIX
from letta.local_llm.constants import (
ASSISTANT_MESSAGE_CLI_SYMBOL,

View File

@@ -102,13 +102,9 @@ def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]:
formatted_tools = []
for tool in tools:
formatted_tool = {
"name" : tool.function.name,
"description" : tool.function.description,
"input_schema" : tool.function.parameters or {
"type": "object",
"properties": {},
"required": []
}
"name": tool.function.name,
"description": tool.function.description,
"input_schema": tool.function.parameters or {"type": "object", "properties": {}, "required": []},
}
formatted_tools.append(formatted_tool)
@@ -346,7 +342,7 @@ def anthropic_chat_completions_request(
data["tool_choice"] = {
"type": "tool", # Changed from "function" to "tool"
"name": anthropic_tools[0]["name"], # Directly specify name without nested "function" object
"disable_parallel_tool_use": True # Force single tool use
"disable_parallel_tool_use": True, # Force single tool use
}
# Move 'system' to the top level

View File

@@ -1,7 +1,6 @@
from collections import defaultdict
import requests
from letta.llm_api.helpers import make_post_request
from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse

View File

@@ -3,7 +3,6 @@ import uuid
from typing import List, Optional, Union
import requests
from letta.local_llm.utils import count_tokens
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool

View File

@@ -2,7 +2,6 @@ import uuid
from typing import List, Optional, Tuple
import requests
from letta.constants import NON_USER_MSG_PREFIX
from letta.llm_api.helpers import make_post_request
from letta.local_llm.json_parser import clean_json_string_extra_backslash

View File

@@ -5,7 +5,6 @@ from collections import OrderedDict
from typing import Any, List, Union
import requests
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice
from letta.utils import json_dumps, printd

View File

@@ -3,7 +3,6 @@ import time
from typing import List, Optional, Union
import requests
from letta.constants import CLI_WARNING_PREFIX
from letta.errors import LettaConfigurationError, RateLimitExceededError
from letta.llm_api.anthropic import anthropic_chat_completions_request
@@ -255,12 +254,7 @@ def create(
tool_call = None
if force_tool_call is not None:
tool_call = {
"type": "function",
"function": {
"name": force_tool_call
}
}
tool_call = {"type": "function", "function": {"name": force_tool_call}}
assert functions is not None
return anthropic_chat_completions_request(

View File

@@ -1,5 +1,4 @@
import requests
from letta.utils import printd, smart_urljoin

View File

@@ -6,7 +6,6 @@ import httpx
import requests
from httpx_sse import connect_sse
from httpx_sse._exceptions import SSEError
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from letta.errors import LLMError
from letta.llm_api.helpers import (

View File

@@ -3,7 +3,6 @@
import uuid
import requests
from letta.constants import CLI_WARNING_PREFIX
from letta.errors import LocalLLMConnectionError, LocalLLMError
from letta.local_llm.constants import DEFAULT_WRAPPER

View File

@@ -19,9 +19,8 @@ from typing import (
)
from docstring_parser import parse
from pydantic import BaseModel, create_model
from letta.utils import json_dumps
from pydantic import BaseModel, create_model
class PydanticDataType(Enum):

View File

@@ -1,5 +1,4 @@
import yaml
from letta.utils import json_dumps, json_loads
from ...errors import LLMJSONParsingError

View File

@@ -2,15 +2,14 @@ import os
import warnings
from typing import List, Union
import requests
import tiktoken
import letta.local_llm.llm_chat_completion_wrappers.airoboros as airoboros
import letta.local_llm.llm_chat_completion_wrappers.chatml as chatml
import letta.local_llm.llm_chat_completion_wrappers.configurable_wrapper as configurable_wrapper
import letta.local_llm.llm_chat_completion_wrappers.dolphin as dolphin
import letta.local_llm.llm_chat_completion_wrappers.llama3 as llama3
import letta.local_llm.llm_chat_completion_wrappers.zephyr as zephyr
import requests
import tiktoken
from letta.schemas.openai.chat_completion_request import Tool, ToolCall

View File

@@ -2,14 +2,12 @@ import os
import sys
import traceback
import questionary
import requests
import typer
from rich.console import Console
import letta.agent as agent
import letta.errors as errors
import letta.system as system
import questionary
import requests
import typer
# import benchmark
from letta import create_client
@@ -22,6 +20,7 @@ from letta.constants import FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE
# from letta.interface import CLIInterface as interface # for printing to terminal
from letta.streaming_interface import AgentRefreshStreamingInterface
from rich.console import Console
# interface = interface()

View File

@@ -7,7 +7,7 @@ from letta.orm.file import FileMetadata
from letta.orm.job import Job
from letta.orm.message import Message
from letta.orm.organization import Organization
from letta.orm.passage import BasePassage, AgentPassage, SourcePassage
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
from letta.orm.source import Source
from letta.orm.sources_agents import SourcesAgents

View File

@@ -1,9 +1,6 @@
import uuid
from typing import TYPE_CHECKING, List, Optional
from sqlalchemy import JSON, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.block import Block
from letta.orm.custom_columns import (
EmbeddingConfigColumn,
@@ -20,6 +17,8 @@ from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
from letta.schemas.tool_rule import ToolRule
from sqlalchemy import JSON, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
if TYPE_CHECKING:
from letta.orm.agents_tags import AgentsTags

View File

@@ -1,8 +1,7 @@
from letta.orm.base import Base
from sqlalchemy import ForeignKey, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.base import Base
class AgentsTags(Base):
__tablename__ = "agents_tags"

View File

@@ -1,14 +1,13 @@
from typing import TYPE_CHECKING, Optional, Type
from sqlalchemy import JSON, BigInteger, Integer, UniqueConstraint, event
from sqlalchemy.orm import Mapped, attributes, mapped_column, relationship
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
from letta.orm.blocks_agents import BlocksAgents
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import Human, Persona
from sqlalchemy import JSON, BigInteger, Integer, UniqueConstraint, event
from sqlalchemy.orm import Mapped, attributes, mapped_column, relationship
if TYPE_CHECKING:
from letta.orm import Organization

View File

@@ -1,8 +1,7 @@
from letta.orm.base import Base
from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm.base import Base
class BlocksAgents(Base):
"""Agents must have one or many blocks to make up their core memory."""

View File

@@ -2,14 +2,18 @@ import base64
from typing import List, Union
import numpy as np
from sqlalchemy import JSON
from sqlalchemy.types import BINARY, TypeDecorator
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ToolRuleType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
from letta.schemas.tool_rule import (
ChildToolRule,
ConditionalToolRule,
InitToolRule,
TerminalToolRule,
)
from sqlalchemy import JSON
from sqlalchemy.types import BINARY, TypeDecorator
class EmbeddingConfigColumn(TypeDecorator):

View File

@@ -1,16 +1,16 @@
from typing import TYPE_CHECKING, Optional, List
from sqlalchemy import Integer, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from typing import TYPE_CHECKING, List, Optional
from letta.orm.mixins import OrganizationMixin, SourceMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.file import FileMetadata as PydanticFileMetadata
from sqlalchemy import Integer, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
if TYPE_CHECKING:
from letta.orm.organization import Organization
from letta.orm.source import Source
from letta.orm.passage import SourcePassage
from letta.orm.source import Source
class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
"""Represents metadata for an uploaded file."""
@@ -28,4 +28,6 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin")
source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin")
source_passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan")
source_passages: Mapped[List["SourcePassage"]] = relationship(
"SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan"
)

View File

@@ -1,13 +1,12 @@
from datetime import datetime
from typing import TYPE_CHECKING, Optional
from sqlalchemy import JSON, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import UserMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.enums import JobStatus
from letta.schemas.job import Job as PydanticJob
from sqlalchemy import JSON, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
if TYPE_CHECKING:
from letta.orm.user import User

View File

@@ -1,13 +1,12 @@
from typing import Optional
from sqlalchemy import Index
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.custom_columns import ToolCallColumn
from letta.orm.mixins import AgentMixin, OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completions import ToolCall
from sqlalchemy import Index
from sqlalchemy.orm import Mapped, mapped_column, relationship
class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):

View File

@@ -1,11 +1,10 @@
from typing import Optional
from uuid import UUID
from letta.orm.base import Base
from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm.base import Base
def is_valid_uuid4(uuid_string: str) -> bool:
"""Check if a string is a valid UUID4."""
@@ -31,6 +30,7 @@ class UserMixin(Base):
user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"))
class AgentMixin(Base):
"""Mixin for models that belong to an agent."""
@@ -38,6 +38,7 @@ class AgentMixin(Base):
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"))
class FileMixin(Base):
"""Mixin for models that belong to a file."""

View File

@@ -1,9 +1,8 @@
from typing import TYPE_CHECKING, List, Union
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.organization import Organization as PydanticOrganization
from sqlalchemy.orm import Mapped, mapped_column, relationship
if TYPE_CHECKING:
@@ -38,19 +37,11 @@ class Organization(SqlalchemyBase):
agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
source_passages: Mapped[List["SourcePassage"]] = relationship(
"SourcePassage",
back_populates="organization",
cascade="all, delete-orphan"
)
agent_passages: Mapped[List["AgentPassage"]] = relationship(
"AgentPassage",
back_populates="organization",
cascade="all, delete-orphan"
"SourcePassage", back_populates="organization", cascade="all, delete-orphan"
)
agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="organization", cascade="all, delete-orphan")
@property
def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]:
"""Convenience property to get all passages"""
return self.source_passages + self.agent_passages

View File

@@ -1,8 +1,5 @@
from typing import TYPE_CHECKING
from sqlalchemy import JSON, Column, Index
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship
from letta.config import LettaConfig
from letta.constants import MAX_EMBEDDING_DIM
from letta.orm.custom_columns import CommonVector, EmbeddingConfigColumn
@@ -10,6 +7,8 @@ from letta.orm.mixins import AgentMixin, FileMixin, OrganizationMixin, SourceMix
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.passage import Passage as PydanticPassage
from letta.settings import settings
from sqlalchemy import JSON, Column, Index
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship
config = LettaConfig()

View File

@@ -1,10 +1,5 @@
from typing import TYPE_CHECKING, Dict, List, Optional
from sqlalchemy import JSON
from sqlalchemy import Enum as SqlEnum
from sqlalchemy import String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin, SandboxConfigMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
@@ -12,6 +7,10 @@ from letta.schemas.sandbox_config import (
SandboxEnvironmentVariable as PydanticSandboxEnvironmentVariable,
)
from letta.schemas.sandbox_config import SandboxType
from sqlalchemy import JSON
from sqlalchemy import Enum as SqlEnum
from sqlalchemy import String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
if TYPE_CHECKING:
from letta.orm.organization import Organization

View File

@@ -1,14 +1,13 @@
from typing import TYPE_CHECKING, List, Optional
from sqlalchemy import JSON
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm import FileMetadata
from letta.orm.custom_columns import EmbeddingConfigColumn
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.source import Source as PydanticSource
from sqlalchemy import JSON
from sqlalchemy.orm import Mapped, mapped_column, relationship
if TYPE_CHECKING:
from letta.orm.agent import Agent

View File

@@ -1,8 +1,7 @@
from letta.orm.base import Base
from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm.base import Base
class SourcesAgents(Base):
"""Agents can have zero to many sources"""

View File

@@ -3,10 +3,6 @@ from enum import Enum
from functools import wraps
from typing import TYPE_CHECKING, List, Literal, Optional
from sqlalchemy import String, desc, func, or_, select
from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError
from sqlalchemy.orm import Mapped, Session, mapped_column
from letta.log import get_logger
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
from letta.orm.errors import (
@@ -16,6 +12,9 @@ from letta.orm.errors import (
UniqueConstraintViolationError,
)
from letta.orm.sqlite_functions import adapt_array
from sqlalchemy import String, desc, func, or_, select
from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError
from sqlalchemy.orm import Mapped, Session, mapped_column
if TYPE_CHECKING:
from pydantic import BaseModel

View File

@@ -1,12 +1,12 @@
import base64
import sqlite3
from typing import Optional, Union
import base64
import numpy as np
from letta.constants import MAX_EMBEDDING_DIM
from sqlalchemy import event
from sqlalchemy.engine import Engine
import sqlite3
from letta.constants import MAX_EMBEDDING_DIM
def adapt_array(arr):
"""
@@ -19,12 +19,13 @@ def adapt_array(arr):
arr = np.array(arr, dtype=np.float32)
elif not isinstance(arr, np.ndarray):
raise ValueError(f"Unsupported type: {type(arr)}")
# Convert to bytes and then base64 encode
bytes_data = arr.tobytes()
base64_data = base64.b64encode(bytes_data)
return sqlite3.Binary(base64_data)
def convert_array(text):
"""
Converts binary back to numpy array
@@ -38,23 +39,24 @@ def convert_array(text):
# Handle both bytes and sqlite3.Binary
binary_data = bytes(text) if isinstance(text, sqlite3.Binary) else text
try:
# First decode base64
decoded_data = base64.b64decode(binary_data)
# Then convert to numpy array
return np.frombuffer(decoded_data, dtype=np.float32)
except Exception as e:
except Exception:
return None
def verify_embedding_dimension(embedding: np.ndarray, expected_dim: int = MAX_EMBEDDING_DIM) -> bool:
"""
Verifies that an embedding has the expected dimension
Args:
embedding: Input embedding array
expected_dim: Expected embedding dimension (default: 4096)
Returns:
bool: True if dimension matches, False otherwise
"""
@@ -62,28 +64,27 @@ def verify_embedding_dimension(embedding: np.ndarray, expected_dim: int = MAX_EM
return False
return embedding.shape[0] == expected_dim
def validate_and_transform_embedding(
embedding: Union[bytes, sqlite3.Binary, list, np.ndarray],
expected_dim: int = MAX_EMBEDDING_DIM,
dtype: np.dtype = np.float32
embedding: Union[bytes, sqlite3.Binary, list, np.ndarray], expected_dim: int = MAX_EMBEDDING_DIM, dtype: np.dtype = np.float32
) -> Optional[np.ndarray]:
"""
Validates and transforms embeddings to ensure correct dimensionality.
Args:
embedding: Input embedding in various possible formats
expected_dim: Expected embedding dimension (default 4096)
dtype: NumPy dtype for the embedding (default float32)
Returns:
np.ndarray: Validated and transformed embedding
Raises:
ValueError: If embedding dimension doesn't match expected dimension
"""
if embedding is None:
return None
# Convert to numpy array based on input type
if isinstance(embedding, (bytes, sqlite3.Binary)):
vec = convert_array(embedding)
@@ -93,48 +94,49 @@ def validate_and_transform_embedding(
vec = embedding.astype(dtype)
else:
raise ValueError(f"Unsupported embedding type: {type(embedding)}")
# Validate dimension
if vec.shape[0] != expected_dim:
raise ValueError(
f"Invalid embedding dimension: got {vec.shape[0]}, expected {expected_dim}"
)
raise ValueError(f"Invalid embedding dimension: got {vec.shape[0]}, expected {expected_dim}")
return vec
def cosine_distance(embedding1, embedding2, expected_dim=MAX_EMBEDDING_DIM):
"""
Calculate cosine distance between two embeddings
Args:
embedding1: First embedding
embedding2: Second embedding
expected_dim: Expected embedding dimension (default 4096)
Returns:
float: Cosine distance
"""
if embedding1 is None or embedding2 is None:
return 0.0 # Maximum distance if either embedding is None
try:
vec1 = validate_and_transform_embedding(embedding1, expected_dim)
vec2 = validate_and_transform_embedding(embedding2, expected_dim)
except ValueError as e:
except ValueError:
return 0.0
similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
distance = float(1.0 - similarity)
return distance
@event.listens_for(Engine, "connect")
def register_functions(dbapi_connection, connection_record):
"""Register SQLite functions"""
if isinstance(dbapi_connection, sqlite3.Connection):
dbapi_connection.create_function("cosine_distance", 2, cosine_distance)
# Register adapters and converters for numpy arrays
sqlite3.register_adapter(np.ndarray, adapt_array)
sqlite3.register_converter("ARRAY", convert_array)

View File

@@ -1,13 +1,12 @@
from typing import TYPE_CHECKING, List, Optional
from sqlalchemy import JSON, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
# TODO everything in functions should live in this model
from letta.orm.enums import ToolSourceType
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.tool import Tool as PydanticTool
from sqlalchemy import JSON, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
if TYPE_CHECKING:
from letta.orm.organization import Organization

View File

@@ -1,8 +1,7 @@
from letta.orm import Base
from sqlalchemy import ForeignKey, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm import Base
class ToolsAgents(Base):
"""Agents can have one or many tools associated with them."""

View File

@@ -1,10 +1,9 @@
from typing import TYPE_CHECKING, List
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.user import User as PydanticUser
from sqlalchemy.orm import Mapped, mapped_column, relationship
if TYPE_CHECKING:
from letta.orm import Job, Organization

View File

@@ -1,7 +1,5 @@
from typing import List, Optional
from pydantic import BaseModel, Field, model_validator
from letta.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
from letta.llm_api.azure_openai import (
get_azure_chat_completions_endpoint,
@@ -10,6 +8,7 @@ from letta.llm_api.azure_openai import (
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from pydantic import BaseModel, Field, model_validator
class Provider(BaseModel):
@@ -27,12 +26,11 @@ class Provider(BaseModel):
def provider_tag(self) -> str:
"""String representation of the provider for display purposes"""
raise NotImplementedError
def get_handle(self, model_name: str) -> str:
return f"{self.name}/{model_name}"
class LettaProvider(Provider):
name: str = "letta"
@@ -44,7 +42,7 @@ class LettaProvider(Provider):
model_endpoint_type="openai",
model_endpoint="https://inference.memgpt.ai",
context_window=16384,
handle=self.get_handle("letta-free")
handle=self.get_handle("letta-free"),
)
]
@@ -56,7 +54,7 @@ class LettaProvider(Provider):
embedding_endpoint="https://embeddings.memgpt.ai",
embedding_dim=1024,
embedding_chunk_size=300,
handle=self.get_handle("letta-free")
handle=self.get_handle("letta-free"),
)
]
@@ -121,7 +119,13 @@ class OpenAIProvider(Provider):
# continue
configs.append(
LLMConfig(model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size, handle=self.get_handle(model_name))
LLMConfig(
model=model_name,
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=context_window_size,
handle=self.get_handle(model_name),
)
)
# for OpenAI, sort in reverse order
@@ -141,7 +145,7 @@ class OpenAIProvider(Provider):
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
embedding_chunk_size=300,
handle=self.get_handle("text-embedding-ada-002")
handle=self.get_handle("text-embedding-ada-002"),
)
]
@@ -170,7 +174,7 @@ class AnthropicProvider(Provider):
model_endpoint_type="anthropic",
model_endpoint=self.base_url,
context_window=model["context_window"],
handle=self.get_handle(model["name"])
handle=self.get_handle(model["name"]),
)
)
return configs
@@ -203,7 +207,7 @@ class MistralProvider(Provider):
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=model["max_context_length"],
handle=self.get_handle(model["id"])
handle=self.get_handle(model["id"]),
)
)
@@ -259,7 +263,7 @@ class OllamaProvider(OpenAIProvider):
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=context_window,
handle=self.get_handle(model["name"])
handle=self.get_handle(model["name"]),
)
)
return configs
@@ -335,7 +339,7 @@ class OllamaProvider(OpenAIProvider):
embedding_endpoint=self.base_url,
embedding_dim=embedding_dim,
embedding_chunk_size=300,
handle=self.get_handle(model["name"])
handle=self.get_handle(model["name"]),
)
)
return configs
@@ -356,7 +360,11 @@ class GroqProvider(OpenAIProvider):
continue
configs.append(
LLMConfig(
model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"], handle=self.get_handle(model["id"])
model=model["id"],
model_endpoint_type="groq",
model_endpoint=self.base_url,
context_window=model["context_window"],
handle=self.get_handle(model["id"]),
)
)
return configs
@@ -424,7 +432,7 @@ class TogetherProvider(OpenAIProvider):
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=context_window_size,
handle=self.get_handle(model_name)
handle=self.get_handle(model_name),
)
)
@@ -505,7 +513,7 @@ class GoogleAIProvider(Provider):
model_endpoint_type="google_ai",
model_endpoint=self.base_url,
context_window=self.get_model_context_window(model),
handle=self.get_handle(model)
handle=self.get_handle(model),
)
)
return configs
@@ -529,7 +537,7 @@ class GoogleAIProvider(Provider):
embedding_endpoint=self.base_url,
embedding_dim=768,
embedding_chunk_size=300, # NOTE: max is 2048
handle=self.get_handle(model)
handle=self.get_handle(model),
)
)
return configs
@@ -570,7 +578,8 @@ class AzureProvider(Provider):
context_window_size = self.get_model_context_window(model_name)
model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version)
configs.append(
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size), handle=self.get_handle(model_name)
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size),
handle=self.get_handle(model_name),
)
return configs
@@ -591,7 +600,7 @@ class AzureProvider(Provider):
embedding_endpoint=model_endpoint,
embedding_dim=768,
embedding_chunk_size=300, # NOTE: max is 2048
handle=self.get_handle(model_name)
handle=self.get_handle(model_name),
)
)
return configs
@@ -625,7 +634,7 @@ class VLLMChatCompletionsProvider(Provider):
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=model["max_model_len"],
handle=self.get_handle(model["id"])
handle=self.get_handle(model["id"]),
)
)
return configs
@@ -658,7 +667,7 @@ class VLLMCompletionsProvider(Provider):
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=model["max_model_len"],
handle=self.get_handle(model["id"])
handle=self.get_handle(model["id"]),
)
)
return configs

View File

@@ -1,8 +1,6 @@
from enum import Enum
from typing import Dict, List, Optional
from pydantic import BaseModel, Field, field_validator
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
from letta.schemas.block import CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
@@ -15,6 +13,7 @@ from letta.schemas.source import Source
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import ToolRule
from letta.utils import create_random_username
from pydantic import BaseModel, Field, field_validator
class AgentType(str, Enum):

View File

@@ -1,10 +1,9 @@
from typing import Optional
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
from letta.schemas.letta_base import LettaBase
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
# block of the LLM context

View File

@@ -1,9 +1,8 @@
from datetime import datetime
from typing import Optional
from pydantic import Field
from letta.schemas.letta_base import LettaBase
from pydantic import Field
class FileMetadataBase(LettaBase):

View File

@@ -1,10 +1,9 @@
from datetime import datetime
from typing import Optional
from pydantic import Field
from letta.schemas.enums import JobStatus
from letta.schemas.letta_base import OrmMetadataBase
from pydantic import Field
class JobBase(OrmMetadataBase):

View File

@@ -1,9 +1,8 @@
from typing import List
from pydantic import BaseModel, Field
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.schemas.message import MessageCreate
from pydantic import BaseModel, Field
class LettaRequest(BaseModel):

View File

@@ -3,12 +3,11 @@ import json
import re
from typing import List, Union
from pydantic import BaseModel, Field
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import LettaMessage, LettaMessageUnion
from letta.schemas.usage import LettaUsageStatistics
from letta.utils import json_dumps
from pydantic import BaseModel, Field
# TODO: consider moving into own file

View File

@@ -4,8 +4,6 @@ import warnings
from datetime import datetime, timezone
from typing import List, Literal, Optional
from pydantic import BaseModel, Field, field_validator
from letta.constants import (
DEFAULT_MESSAGE_TOOL,
DEFAULT_MESSAGE_TOOL_KWARG,
@@ -16,16 +14,15 @@ from letta.schemas.enums import MessageRole
from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.letta_message import (
AssistantMessage,
ToolCall as LettaToolCall,
ToolCallMessage,
ToolReturnMessage,
ReasoningMessage,
LettaMessage,
ReasoningMessage,
SystemMessage,
UserMessage,
)
from letta.schemas.letta_message import ToolCall as LettaToolCall
from letta.schemas.letta_message import ToolCallMessage, ToolReturnMessage, UserMessage
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from letta.utils import get_utc_time, is_utc_datetime, json_dumps
from pydantic import BaseModel, Field, field_validator
def add_inner_thoughts_to_tool_call(

View File

@@ -1,10 +1,9 @@
from datetime import datetime
from typing import Optional
from pydantic import Field
from letta.schemas.letta_base import LettaBase
from letta.utils import create_random_username, get_utc_time
from pydantic import Field
class OrganizationBase(LettaBase):

View File

@@ -1,12 +1,11 @@
from datetime import datetime
from typing import Dict, List, Optional
from pydantic import Field, field_validator
from letta.constants import MAX_EMBEDDING_DIM
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_base import OrmMetadataBase
from letta.utils import get_utc_time
from pydantic import Field, field_validator
class PassageBase(OrmMetadataBase):

View File

@@ -3,11 +3,10 @@ import json
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, model_validator
from letta.schemas.agent import AgentState
from letta.schemas.letta_base import LettaBase, OrmMetadataBase
from letta.settings import tool_settings
from pydantic import BaseModel, Field, model_validator
# Sandbox Config

View File

@@ -1,10 +1,9 @@
from datetime import datetime
from typing import Optional
from pydantic import Field
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_base import LettaBase
from pydantic import Field
class BaseSource(LettaBase):

View File

@@ -1,7 +1,5 @@
from typing import Dict, List, Optional
from pydantic import Field, model_validator
from letta.constants import FUNCTION_RETURN_CHAR_LIMIT
from letta.functions.functions import derive_openai_json_schema
from letta.functions.helpers import (
@@ -11,6 +9,7 @@ from letta.functions.helpers import (
from letta.functions.schema_generator import generate_schema_from_args_schema_v2
from letta.schemas.letta_base import LettaBase
from letta.schemas.openai.chat_completions import ToolCall
from pydantic import Field, model_validator
class BaseTool(LettaBase):

View File

@@ -1,9 +1,8 @@
from typing import Any, Dict, List, Optional, Union
from pydantic import Field
from letta.schemas.enums import ToolRuleType
from letta.schemas.letta_base import LettaBase
from pydantic import Field
class BaseToolRule(LettaBase):
@@ -25,6 +24,7 @@ class ConditionalToolRule(BaseToolRule):
"""
A ToolRule that conditionally maps to different child tools based on the output.
"""
type: ToolRuleType = ToolRuleType.conditional
default_child: Optional[str] = Field(None, description="The default child tool to be called. If None, any tool can be called.")
child_output_mapping: Dict[Any, str] = Field(..., description="The output case to check for mapping")

View File

@@ -1,4 +1,5 @@
from typing import Literal
from pydantic import BaseModel, Field
@@ -12,6 +13,7 @@ class LettaUsageStatistics(BaseModel):
total_tokens (int): The total number of tokens processed by the agent.
step_count (int): The number of steps taken by the agent.
"""
message_type: Literal["usage_statistics"] = "usage_statistics"
completion_tokens: int = Field(0, description="The number of tokens generated by the agent.")
prompt_tokens: int = Field(0, description="The number of tokens in the prompt.")

View File

@@ -1,10 +1,9 @@
from datetime import datetime
from typing import Optional
from pydantic import Field
from letta.schemas.letta_base import LettaBase
from letta.services.organization_manager import OrganizationManager
from pydantic import Field
class UserBase(LettaBase):

View File

@@ -8,9 +8,6 @@ from typing import Optional
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from letta.__init__ import __version__
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
from letta.errors import LettaAgentNotFoundError, LettaUserNotFoundError
@@ -47,6 +44,8 @@ from letta.server.rest_api.routers.v1.users import (
from letta.server.rest_api.static_files import mount_static_files
from letta.server.server import SyncServer
from letta.settings import settings
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
# TODO(ethan)
# NOTE(charles): @ethan I had to add this to get the global as the bottom to work

View File

@@ -2,11 +2,10 @@ from typing import Optional
from uuid import UUID
from fastapi import APIRouter
from pydantic import BaseModel, Field
from letta.log import get_logger
from letta.server.rest_api.interface import QueuingInterface
from letta.server.server import SyncServer
from pydantic import BaseModel, Field
logger = get_logger(__name__)
router = APIRouter()

View File

@@ -2,7 +2,6 @@ import uuid
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from letta.server.server import SyncServer
security = HTTPBearer()

View File

@@ -12,14 +12,14 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import (
AssistantMessage,
LegacyFunctionCallMessage,
LegacyLettaMessage,
LettaMessage,
ReasoningMessage,
ToolCall,
ToolCallDelta,
ToolCallMessage,
ToolReturnMessage,
ReasoningMessage,
LegacyFunctionCallMessage,
LegacyLettaMessage,
LettaMessage,
)
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse

View File

@@ -1,7 +1,6 @@
from typing import List
from fastapi import APIRouter, Body, HTTPException, Path, Query
from letta.constants import DEFAULT_PRESET
from letta.schemas.openai.openai import AssistantFile, OpenAIAssistant
from letta.server.rest_api.routers.openai.assistants.schemas import (

View File

@@ -1,7 +1,5 @@
from typing import List, Optional
from pydantic import BaseModel, Field
from letta.schemas.openai.openai import (
MessageRoleType,
OpenAIMessage,
@@ -9,6 +7,7 @@ from letta.schemas.openai.openai import (
ToolCall,
ToolCallOutput,
)
from pydantic import BaseModel, Field
class CreateAssistantRequest(BaseModel):

View File

@@ -2,9 +2,8 @@ import json
from typing import TYPE_CHECKING, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import ToolCall, LettaMessage
from letta.schemas.letta_message import LettaMessage, ToolCall
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
from letta.schemas.openai.chat_completion_response import (
ChatCompletionResponse,

View File

@@ -14,8 +14,6 @@ from fastapi import (
status,
)
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import Field
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.log import get_logger
from letta.orm.errors import NoResultFound
@@ -49,6 +47,7 @@ from letta.schemas.user import User
from letta.server.rest_api.interface import StreamingServerInterface
from letta.server.rest_api.utils import get_letta_server, sse_async_generator
from letta.server.server import SyncServer
from pydantic import Field
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally

View File

@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, Response
from letta.orm.errors import NoResultFound
from letta.schemas.block import Block, BlockUpdate, CreateBlock
from letta.server.rest_api.utils import get_letta_server

View File

@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING
from fastapi import APIRouter
from letta.cli.cli import version
from letta.schemas.health import Health

View File

@@ -1,7 +1,6 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, Header, HTTPException, Query
from letta.orm.errors import NoResultFound
from letta.schemas.enums import JobStatus
from letta.schemas.job import Job

View File

@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING, List
from fastapi import APIRouter, Depends
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.server.rest_api.utils import get_letta_server

View File

@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from letta.schemas.organization import Organization, OrganizationCreate
from letta.server.rest_api.utils import get_letta_server

View File

@@ -1,7 +1,6 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, Query
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate
from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar

View File

@@ -11,7 +11,6 @@ from fastapi import (
Query,
UploadFile,
)
from letta.schemas.file import FileMetadata
from letta.schemas.job import Job
from letta.schemas.passage import Passage

View File

@@ -6,7 +6,6 @@ from composio.client.enums.base import EnumStringNotFound
from composio.exceptions import ApiKeyNotProvidedError, ComposioSDKError
from composio.tools.base.abs import InvalidClassDefinition
from fastapi import APIRouter, Body, Depends, Header, HTTPException
from letta.errors import LettaToolCreateError
from letta.orm.errors import UniqueConstraintViolationError
from letta.schemas.letta_message import ToolReturnMessage

Some files were not shown because too many files have changed in this diff Show More