feat: various fixes (#2320)

Co-authored-by: Shubham Naik <shub@memgpt.ai>
Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
Co-authored-by: Shubham Naik <shubham.naik10@gmail.com>
Co-authored-by: Caren Thomas <caren@letta.com>
Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
Sarah Wooders
2024-12-31 10:53:33 +04:00
committed by GitHub
parent 644fff77c3
commit ece8dab05d
79 changed files with 565 additions and 783 deletions

View File

@@ -2,43 +2,20 @@
Example enviornment variable configurations for the Letta Example enviornment variable configurations for the Letta
Docker container. Un-coment the sections you want to Docker container. Un-coment the sections you want to
configure with. configure with.
Hint: You don't need to have the same LLM and
Embedding model backends (can mix and match).
########################################################## ##########################################################
########################################################## ##########################################################
OpenAI configuration OpenAI configuration
########################################################## ##########################################################
## LLM Model # OPENAI_API_KEY=sk-...
#LETTA_LLM_ENDPOINT_TYPE=openai
#LETTA_LLM_MODEL=gpt-4o-mini
## Embeddings
#LETTA_EMBEDDING_ENDPOINT_TYPE=openai
#LETTA_EMBEDDING_MODEL=text-embedding-ada-002
########################################################## ##########################################################
Ollama configuration Ollama configuration
########################################################## ##########################################################
## LLM Model # OLLAMA_BASE_URL="http://host.docker.internal:11434"
#LETTA_LLM_ENDPOINT=http://host.docker.internal:11434
#LETTA_LLM_ENDPOINT_TYPE=ollama
#LETTA_LLM_MODEL=dolphin2.2-mistral:7b-q6_K
#LETTA_LLM_CONTEXT_WINDOW=8192
## Embeddings
#LETTA_EMBEDDING_ENDPOINT=http://host.docker.internal:11434
#LETTA_EMBEDDING_ENDPOINT_TYPE=ollama
#LETTA_EMBEDDING_MODEL=mxbai-embed-large
#LETTA_EMBEDDING_DIM=512
########################################################## ##########################################################
vLLM configuration vLLM configuration
########################################################## ##########################################################
## LLM Model # VLLM_API_BASE="http://host.docker.internal:8000"
#LETTA_LLM_ENDPOINT=http://host.docker.internal:8000
#LETTA_LLM_ENDPOINT_TYPE=vllm
#LETTA_LLM_MODEL=ehartford/dolphin-2.2.1-mistral-7b
#LETTA_LLM_CONTEXT_WINDOW=8192

View File

@@ -1,42 +0,0 @@
name: "Letta Web OpenAPI Compatibility Checker"
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
validate-openapi:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: "Setup Python, Poetry and Dependencies"
uses: packetcoders/action-setup-cache-python-poetry@main
with:
python-version: "3.12"
poetry-version: "1.8.2"
install-args: "-E dev"
- name: Checkout letta web
uses: actions/checkout@v4
with:
repository: letta-ai/letta-web
token: ${{ secrets.PULLER_TOKEN }}
path: letta-web
- name: Run OpenAPI schema generation
run: |
bash ./letta/server/generate_openapi_schema.sh
- name: Setup letta-web
working-directory: letta-web
run: npm ci
- name: Copy OpenAPI schema
working-directory: .
run: cp openapi_letta.json letta-web/libs/letta-agents-api/letta-agents-openapi.json
- name: Validate OpenAPI schema
working-directory: letta-web
run: |
npm run agents-api:generate
npm run type-check

View File

@@ -6,6 +6,8 @@ env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
E2B_API_KEY: ${{ secrets.E2B_API_KEY }}
E2B_SANDBOX_TEMPLATE_ID: ${{ secrets.E2B_SANDBOX_TEMPLATE_ID }}
on: on:
push: push:
@@ -61,7 +63,7 @@ jobs:
with: with:
python-version: "3.12" python-version: "3.12"
poetry-version: "1.8.2" poetry-version: "1.8.2"
install-args: "-E dev -E postgres -E external-tools -E tests" install-args: "-E dev -E postgres -E external-tools -E tests -E cloud-tool-sandbox"
- name: Migrate database - name: Migrate database
env: env:
LETTA_PG_PORT: 5432 LETTA_PG_PORT: 5432

View File

@@ -5,40 +5,45 @@ Revises: 3c683a662c82
Create Date: 2024-12-05 16:46:51.258831 Create Date: 2024-12-05 16:46:51.258831
""" """
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '08b2f8225812' revision: str = "08b2f8225812"
down_revision: Union[str, None] = '3c683a662c82' down_revision: Union[str, None] = "3c683a662c82"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table('tools_agents', op.create_table(
sa.Column('agent_id', sa.String(), nullable=False), "tools_agents",
sa.Column('tool_id', sa.String(), nullable=False), sa.Column("agent_id", sa.String(), nullable=False),
sa.Column('tool_name', sa.String(), nullable=False), sa.Column("tool_id", sa.String(), nullable=False),
sa.Column('id', sa.String(), nullable=False), sa.Column("tool_name", sa.String(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), sa.Column("id", sa.String(), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column('_created_by_id', sa.String(), nullable=True), sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column('_last_updated_by_id', sa.String(), nullable=True), sa.Column("_created_by_id", sa.String(), nullable=True),
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ), sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.ForeignKeyConstraint(['tool_id'], ['tools.id'], name='fk_tool_id'), sa.ForeignKeyConstraint(
sa.PrimaryKeyConstraint('agent_id', 'tool_id', 'tool_name', 'id'), ["agent_id"],
sa.UniqueConstraint('agent_id', 'tool_name', name='unique_tool_per_agent') ["agents.id"],
),
sa.ForeignKeyConstraint(["tool_id"], ["tools.id"], name="fk_tool_id"),
sa.PrimaryKeyConstraint("agent_id", "tool_id", "tool_name", "id"),
sa.UniqueConstraint("agent_id", "tool_name", name="unique_tool_per_agent"),
) )
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade() -> None: def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_table('tools_agents') op.drop_table("tools_agents")
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@@ -5,18 +5,19 @@ Revises: 4e88e702f85e
Create Date: 2024-12-14 17:23:08.772554 Create Date: 2024-12-14 17:23:08.772554
""" """
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op
from pgvector.sqlalchemy import Vector
import sqlalchemy as sa import sqlalchemy as sa
from pgvector.sqlalchemy import Vector
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from alembic import op
from letta.orm.custom_columns import EmbeddingConfigColumn from letta.orm.custom_columns import EmbeddingConfigColumn
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '54dec07619c4' revision: str = "54dec07619c4"
down_revision: Union[str, None] = '4e88e702f85e' down_revision: Union[str, None] = "4e88e702f85e"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@@ -24,82 +25,88 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table( op.create_table(
'agent_passages', "agent_passages",
sa.Column('id', sa.String(), nullable=False), sa.Column("id", sa.String(), nullable=False),
sa.Column('text', sa.String(), nullable=False), sa.Column("text", sa.String(), nullable=False),
sa.Column('embedding_config', EmbeddingConfigColumn(), nullable=False), sa.Column("embedding_config", EmbeddingConfigColumn(), nullable=False),
sa.Column('metadata_', sa.JSON(), nullable=False), sa.Column("metadata_", sa.JSON(), nullable=False),
sa.Column('embedding', Vector(dim=4096), nullable=True), sa.Column("embedding", Vector(dim=4096), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column('_created_by_id', sa.String(), nullable=True), sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column('_last_updated_by_id', sa.String(), nullable=True), sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column('organization_id', sa.String(), nullable=False), sa.Column("organization_id", sa.String(), nullable=False),
sa.Column('agent_id', sa.String(), nullable=False), sa.Column("agent_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ondelete='CASCADE'), sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ), sa.ForeignKeyConstraint(
sa.PrimaryKeyConstraint('id') ["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("id"),
) )
op.create_index('agent_passages_org_idx', 'agent_passages', ['organization_id'], unique=False) op.create_index("agent_passages_org_idx", "agent_passages", ["organization_id"], unique=False)
op.create_table( op.create_table(
'source_passages', "source_passages",
sa.Column('id', sa.String(), nullable=False), sa.Column("id", sa.String(), nullable=False),
sa.Column('text', sa.String(), nullable=False), sa.Column("text", sa.String(), nullable=False),
sa.Column('embedding_config', EmbeddingConfigColumn(), nullable=False), sa.Column("embedding_config", EmbeddingConfigColumn(), nullable=False),
sa.Column('metadata_', sa.JSON(), nullable=False), sa.Column("metadata_", sa.JSON(), nullable=False),
sa.Column('embedding', Vector(dim=4096), nullable=True), sa.Column("embedding", Vector(dim=4096), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column('_created_by_id', sa.String(), nullable=True), sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column('_last_updated_by_id', sa.String(), nullable=True), sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column('organization_id', sa.String(), nullable=False), sa.Column("organization_id", sa.String(), nullable=False),
sa.Column('file_id', sa.String(), nullable=True), sa.Column("file_id", sa.String(), nullable=True),
sa.Column('source_id', sa.String(), nullable=False), sa.Column("source_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(['file_id'], ['files.id'], ondelete='CASCADE'), sa.ForeignKeyConstraint(["file_id"], ["files.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ), sa.ForeignKeyConstraint(
sa.ForeignKeyConstraint(['source_id'], ['sources.id'], ondelete='CASCADE'), ["organization_id"],
sa.PrimaryKeyConstraint('id') ["organizations.id"],
),
sa.ForeignKeyConstraint(["source_id"], ["sources.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
) )
op.create_index('source_passages_org_idx', 'source_passages', ['organization_id'], unique=False) op.create_index("source_passages_org_idx", "source_passages", ["organization_id"], unique=False)
op.drop_table('passages') op.drop_table("passages")
op.drop_constraint('files_source_id_fkey', 'files', type_='foreignkey') op.drop_constraint("files_source_id_fkey", "files", type_="foreignkey")
op.create_foreign_key(None, 'files', 'sources', ['source_id'], ['id'], ondelete='CASCADE') op.create_foreign_key(None, "files", "sources", ["source_id"], ["id"], ondelete="CASCADE")
op.drop_constraint('messages_agent_id_fkey', 'messages', type_='foreignkey') op.drop_constraint("messages_agent_id_fkey", "messages", type_="foreignkey")
op.create_foreign_key(None, 'messages', 'agents', ['agent_id'], ['id'], ondelete='CASCADE') op.create_foreign_key(None, "messages", "agents", ["agent_id"], ["id"], ondelete="CASCADE")
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade() -> None: def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, 'messages', type_='foreignkey') op.drop_constraint(None, "messages", type_="foreignkey")
op.create_foreign_key('messages_agent_id_fkey', 'messages', 'agents', ['agent_id'], ['id']) op.create_foreign_key("messages_agent_id_fkey", "messages", "agents", ["agent_id"], ["id"])
op.drop_constraint(None, 'files', type_='foreignkey') op.drop_constraint(None, "files", type_="foreignkey")
op.create_foreign_key('files_source_id_fkey', 'files', 'sources', ['source_id'], ['id']) op.create_foreign_key("files_source_id_fkey", "files", "sources", ["source_id"], ["id"])
op.create_table( op.create_table(
'passages', "passages",
sa.Column('id', sa.VARCHAR(), autoincrement=False, nullable=False), sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column('text', sa.VARCHAR(), autoincrement=False, nullable=False), sa.Column("text", sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column('file_id', sa.VARCHAR(), autoincrement=False, nullable=True), sa.Column("file_id", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column('agent_id', sa.VARCHAR(), autoincrement=False, nullable=True), sa.Column("agent_id", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column('source_id', sa.VARCHAR(), autoincrement=False, nullable=True), sa.Column("source_id", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column('embedding', Vector(dim=4096), autoincrement=False, nullable=True), sa.Column("embedding", Vector(dim=4096), autoincrement=False, nullable=True),
sa.Column('embedding_config', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False), sa.Column("embedding_config", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
sa.Column('metadata_', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False), sa.Column("metadata_", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
sa.Column('created_at', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=False), sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=False),
sa.Column('updated_at', postgresql.TIMESTAMP(timezone=True), server_default=sa.text('now()'), autoincrement=False, nullable=True), sa.Column("updated_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True),
sa.Column('is_deleted', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False), sa.Column("is_deleted", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False),
sa.Column('_created_by_id', sa.VARCHAR(), autoincrement=False, nullable=True), sa.Column("_created_by_id", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column('_last_updated_by_id', sa.VARCHAR(), autoincrement=False, nullable=True), sa.Column("_last_updated_by_id", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column('organization_id', sa.VARCHAR(), autoincrement=False, nullable=False), sa.Column("organization_id", sa.VARCHAR(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], name='passages_agent_id_fkey'), sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], name="passages_agent_id_fkey"),
sa.ForeignKeyConstraint(['file_id'], ['files.id'], name='passages_file_id_fkey', ondelete='CASCADE'), sa.ForeignKeyConstraint(["file_id"], ["files.id"], name="passages_file_id_fkey", ondelete="CASCADE"),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], name='passages_organization_id_fkey'), sa.ForeignKeyConstraint(["organization_id"], ["organizations.id"], name="passages_organization_id_fkey"),
sa.PrimaryKeyConstraint('id', name='passages_pkey') sa.PrimaryKeyConstraint("id", name="passages_pkey"),
) )
op.drop_index('source_passages_org_idx', table_name='source_passages') op.drop_index("source_passages_org_idx", table_name="source_passages")
op.drop_table('source_passages') op.drop_table("source_passages")
op.drop_index('agent_passages_org_idx', table_name='agent_passages') op.drop_index("agent_passages_org_idx", table_name="agent_passages")
op.drop_table('agent_passages') op.drop_table("agent_passages")
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@@ -5,25 +5,27 @@ Revises: a91994b9752f
Create Date: 2024-12-10 15:05:32.335519 Create Date: 2024-12-10 15:05:32.335519
""" """
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = 'c5d964280dff' revision: str = "c5d964280dff"
down_revision: Union[str, None] = 'a91994b9752f' down_revision: Union[str, None] = "a91994b9752f"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.add_column('passages', sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True)) op.add_column("passages", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
op.add_column('passages', sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False)) op.add_column("passages", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
op.add_column('passages', sa.Column('_created_by_id', sa.String(), nullable=True)) op.add_column("passages", sa.Column("_created_by_id", sa.String(), nullable=True))
op.add_column('passages', sa.Column('_last_updated_by_id', sa.String(), nullable=True)) op.add_column("passages", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
# Data migration step: # Data migration step:
op.add_column("passages", sa.Column("organization_id", sa.String(), nullable=True)) op.add_column("passages", sa.Column("organization_id", sa.String(), nullable=True))
@@ -41,48 +43,32 @@ def upgrade() -> None:
# Set `organization_id` as non-nullable after population # Set `organization_id` as non-nullable after population
op.alter_column("passages", "organization_id", nullable=False) op.alter_column("passages", "organization_id", nullable=False)
op.alter_column('passages', 'text', op.alter_column("passages", "text", existing_type=sa.VARCHAR(), nullable=False)
existing_type=sa.VARCHAR(), op.alter_column("passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
nullable=False) op.alter_column("passages", "metadata_", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
op.alter_column('passages', 'embedding_config', op.alter_column("passages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False)
existing_type=postgresql.JSON(astext_type=sa.Text()), op.drop_index("passage_idx_user", table_name="passages")
nullable=False) op.create_foreign_key(None, "passages", "organizations", ["organization_id"], ["id"])
op.alter_column('passages', 'metadata_', op.create_foreign_key(None, "passages", "agents", ["agent_id"], ["id"])
existing_type=postgresql.JSON(astext_type=sa.Text()), op.create_foreign_key(None, "passages", "files", ["file_id"], ["id"], ondelete="CASCADE")
nullable=False) op.drop_column("passages", "user_id")
op.alter_column('passages', 'created_at',
existing_type=postgresql.TIMESTAMP(timezone=True),
nullable=False)
op.drop_index('passage_idx_user', table_name='passages')
op.create_foreign_key(None, 'passages', 'organizations', ['organization_id'], ['id'])
op.create_foreign_key(None, 'passages', 'agents', ['agent_id'], ['id'])
op.create_foreign_key(None, 'passages', 'files', ['file_id'], ['id'], ondelete='CASCADE')
op.drop_column('passages', 'user_id')
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade() -> None: def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.add_column('passages', sa.Column('user_id', sa.VARCHAR(), autoincrement=False, nullable=False)) op.add_column("passages", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False))
op.drop_constraint(None, 'passages', type_='foreignkey') op.drop_constraint(None, "passages", type_="foreignkey")
op.drop_constraint(None, 'passages', type_='foreignkey') op.drop_constraint(None, "passages", type_="foreignkey")
op.drop_constraint(None, 'passages', type_='foreignkey') op.drop_constraint(None, "passages", type_="foreignkey")
op.create_index('passage_idx_user', 'passages', ['user_id', 'agent_id', 'file_id'], unique=False) op.create_index("passage_idx_user", "passages", ["user_id", "agent_id", "file_id"], unique=False)
op.alter_column('passages', 'created_at', op.alter_column("passages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=True)
existing_type=postgresql.TIMESTAMP(timezone=True), op.alter_column("passages", "metadata_", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
nullable=True) op.alter_column("passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
op.alter_column('passages', 'metadata_', op.alter_column("passages", "text", existing_type=sa.VARCHAR(), nullable=True)
existing_type=postgresql.JSON(astext_type=sa.Text()), op.drop_column("passages", "organization_id")
nullable=True) op.drop_column("passages", "_last_updated_by_id")
op.alter_column('passages', 'embedding_config', op.drop_column("passages", "_created_by_id")
existing_type=postgresql.JSON(astext_type=sa.Text()), op.drop_column("passages", "is_deleted")
nullable=True) op.drop_column("passages", "updated_at")
op.alter_column('passages', 'text',
existing_type=sa.VARCHAR(),
nullable=True)
op.drop_column('passages', 'organization_id')
op.drop_column('passages', '_last_updated_by_id')
op.drop_column('passages', '_created_by_id')
op.drop_column('passages', 'is_deleted')
op.drop_column('passages', 'updated_at')
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@@ -4,10 +4,7 @@ import uuid
from letta import create_client from letta import create_client
from letta.schemas.letta_message import ToolCallMessage from letta.schemas.letta_message import ToolCallMessage
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from tests.helpers.endpoints_helper import ( from tests.helpers.endpoints_helper import assert_invoked_send_message_with_keyword, setup_agent
assert_invoked_send_message_with_keyword,
setup_agent,
)
from tests.helpers.utils import cleanup from tests.helpers.utils import cleanup
from tests.test_model_letta_perfomance import llm_config_dir from tests.test_model_letta_perfomance import llm_config_dir

View File

@@ -12,13 +12,7 @@ from letta.schemas.file import FileMetadata
from letta.schemas.job import Job from letta.schemas.job import Job
from letta.schemas.letta_message import LettaMessage from letta.schemas.letta_message import LettaMessage
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ( from letta.schemas.memory import ArchivalMemorySummary, BasicBlockMemory, ChatMemory, Memory, RecallMemorySummary
ArchivalMemorySummary,
BasicBlockMemory,
ChatMemory,
Memory,
RecallMemorySummary,
)
from letta.schemas.message import Message from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.organization import Organization from letta.schemas.organization import Organization

View File

@@ -33,34 +33,21 @@ from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole from letta.schemas.enums import MessageRole
from letta.schemas.memory import ContextWindowOverview, Memory from letta.schemas.memory import ContextWindowOverview, Memory
from letta.schemas.message import Message from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import ( from letta.schemas.openai.chat_completion_request import Tool as ChatCompletionRequestTool
Tool as ChatCompletionRequestTool,
)
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.openai.chat_completion_response import ( from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage
Message as ChatCompletionMessage,
)
from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.tool import Tool from letta.schemas.tool import Tool
from letta.schemas.tool_rule import TerminalToolRule from letta.schemas.tool_rule import TerminalToolRule
from letta.schemas.usage import LettaUsageStatistics from letta.schemas.usage import LettaUsageStatistics
from letta.services.agent_manager import AgentManager from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import ( from letta.services.helpers.agent_manager_helper import check_supports_structured_output, compile_memory_metadata_block
check_supports_structured_output,
compile_memory_metadata_block,
)
from letta.services.message_manager import MessageManager from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager from letta.services.passage_manager import PassageManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.streaming_interface import StreamingRefreshCLIInterface from letta.streaming_interface import StreamingRefreshCLIInterface
from letta.system import ( from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message
get_heartbeat,
get_token_limit_warning,
package_function_response,
package_summarize_message,
package_user_message,
)
from letta.utils import ( from letta.utils import (
count_tokens, count_tokens,
get_friendly_error_msg, get_friendly_error_msg,

View File

@@ -10,12 +10,7 @@ import letta.utils as utils
from letta import create_client from letta import create_client
from letta.agent import Agent, save_agent from letta.agent import Agent, save_agent
from letta.config import LettaConfig from letta.config import LettaConfig
from letta.constants import ( from letta.constants import CLI_WARNING_PREFIX, CORE_MEMORY_BLOCK_CHAR_LIMIT, LETTA_DIR, MIN_CONTEXT_WINDOW
CLI_WARNING_PREFIX,
CORE_MEMORY_BLOCK_CHAR_LIMIT,
LETTA_DIR,
MIN_CONTEXT_WINDOW,
)
from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL
from letta.log import get_logger from letta.log import get_logger
from letta.schemas.enums import OptionState from letta.schemas.enums import OptionState
@@ -23,9 +18,7 @@ from letta.schemas.memory import ChatMemory, Memory
from letta.server.server import logger as server_logger from letta.server.server import logger as server_logger
# from letta.interface import CLIInterface as interface # for printing to terminal # from letta.interface import CLIInterface as interface # for printing to terminal
from letta.streaming_interface import ( from letta.streaming_interface import StreamingRefreshCLIInterface as interface # for printing to terminal
StreamingRefreshCLIInterface as interface, # for printing to terminal
)
from letta.utils import open_folder_in_explorer, printd from letta.utils import open_folder_in_explorer, printd
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -5,14 +5,7 @@ from typing import Callable, Dict, Generator, List, Optional, Union
import requests import requests
import letta.utils import letta.utils
from letta.constants import ( from letta.constants import ADMIN_PREFIX, BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA, FUNCTION_RETURN_CHAR_LIMIT
ADMIN_PREFIX,
BASE_MEMORY_TOOLS,
BASE_TOOLS,
DEFAULT_HUMAN,
DEFAULT_PERSONA,
FUNCTION_RETURN_CHAR_LIMIT,
)
from letta.data_sources.connectors import DataConnector from letta.data_sources.connectors import DataConnector
from letta.functions.functions import parse_source_code from letta.functions.functions import parse_source_code
from letta.orm.errors import NoResultFound from letta.orm.errors import NoResultFound
@@ -27,13 +20,7 @@ from letta.schemas.job import Job
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ( from letta.schemas.memory import ArchivalMemorySummary, ChatMemory, CreateArchivalMemory, Memory, RecallMemorySummary
ArchivalMemorySummary,
ChatMemory,
CreateArchivalMemory,
Memory,
RecallMemorySummary,
)
from letta.schemas.message import Message, MessageCreate, MessageUpdate from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.openai.chat_completions import ToolCall from letta.schemas.openai.chat_completions import ToolCall
from letta.schemas.organization import Organization from letta.schemas.organization import Organization

View File

@@ -7,11 +7,7 @@ from httpx_sse import SSEError, connect_sse
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from letta.errors import LLMError from letta.errors import LLMError
from letta.schemas.enums import MessageStreamStatus from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import ( from letta.schemas.letta_message import ReasoningMessage, ToolCallMessage, ToolReturnMessage
ToolCallMessage,
ToolReturnMessage,
ReasoningMessage,
)
from letta.schemas.letta_response import LettaStreamingResponse from letta.schemas.letta_response import LettaStreamingResponse
from letta.schemas.usage import LettaUsageStatistics from letta.schemas.usage import LettaUsageStatistics

View File

@@ -5,10 +5,7 @@ from typing import Optional
from IPython.display import HTML, display from IPython.display import HTML, display
from sqlalchemy.testing.plugin.plugin_base import warnings from sqlalchemy.testing.plugin.plugin_base import warnings
from letta.local_llm.constants import ( from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL, INNER_THOUGHTS_CLI_SYMBOL
ASSISTANT_MESSAGE_CLI_SYMBOL,
INNER_THOUGHTS_CLI_SYMBOL,
)
def pprint(messages): def pprint(messages):

View File

@@ -2,11 +2,7 @@ from typing import Dict, Iterator, List, Tuple
import typer import typer
from letta.data_sources.connectors_helper import ( from letta.data_sources.connectors_helper import assert_all_files_exist_locally, extract_metadata_from_files, get_filenames_in_dir
assert_all_files_exist_locally,
extract_metadata_from_files,
get_filenames_in_dir,
)
from letta.embeddings import embedding_model from letta.embeddings import embedding_model
from letta.schemas.file import FileMetadata from letta.schemas.file import FileMetadata
from letta.schemas.passage import Passage from letta.schemas.passage import Passage
@@ -14,6 +10,7 @@ from letta.schemas.source import Source
from letta.services.passage_manager import PassageManager from letta.services.passage_manager import PassageManager
from letta.services.source_manager import SourceManager from letta.services.source_manager import SourceManager
class DataConnector: class DataConnector:
""" """
Base class for data connectors that can be extended to generate files and passages from a custom data source. Base class for data connectors that can be extended to generate files and passages from a custom data source.

View File

@@ -4,11 +4,7 @@ from typing import Any, List, Optional
import numpy as np import numpy as np
import tiktoken import tiktoken
from letta.constants import ( from letta.constants import EMBEDDING_TO_TOKENIZER_DEFAULT, EMBEDDING_TO_TOKENIZER_MAP, MAX_EMBEDDING_DIM
EMBEDDING_TO_TOKENIZER_DEFAULT,
EMBEDDING_TO_TOKENIZER_MAP,
MAX_EMBEDDING_DIM,
)
from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config import EmbeddingConfig
from letta.utils import is_valid_url, printd from letta.utils import is_valid_url, printd

View File

@@ -52,12 +52,10 @@ class LettaConfigurationError(LettaError):
class LettaAgentNotFoundError(LettaError): class LettaAgentNotFoundError(LettaError):
"""Error raised when an agent is not found.""" """Error raised when an agent is not found."""
pass
class LettaUserNotFoundError(LettaError): class LettaUserNotFoundError(LettaError):
"""Error raised when a user is not found.""" """Error raised when a user is not found."""
pass
class LLMError(LettaError): class LLMError(LettaError):

View File

@@ -4,10 +4,7 @@ from typing import Optional
import requests import requests
from letta.constants import ( from letta.constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE
MESSAGE_CHATGPT_FUNCTION_MODEL,
MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE,
)
from letta.llm_api.llm_api_tools import create from letta.llm_api.llm_api_tools import create
from letta.schemas.message import Message from letta.schemas.message import Message
from letta.utils import json_dumps, json_loads from letta.utils import json_dumps, json_loads

View File

@@ -396,44 +396,6 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
return schema return schema
def generate_schema_from_args_schema_v1(
args_schema: Type[V1BaseModel], name: Optional[str] = None, description: Optional[str] = None, append_heartbeat: bool = True
) -> Dict[str, Any]:
properties = {}
required = []
for field_name, field in args_schema.__fields__.items():
if field.type_ == str:
field_type = "string"
elif field.type_ == int:
field_type = "integer"
elif field.type_ == bool:
field_type = "boolean"
else:
field_type = field.type_.__name__
properties[field_name] = {
"type": field_type,
"description": field.field_info.description,
}
if field.required:
required.append(field_name)
function_call_json = {
"name": name,
"description": description,
"parameters": {"type": "object", "properties": properties, "required": required},
}
if append_heartbeat:
function_call_json["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
}
function_call_json["parameters"]["required"].append("request_heartbeat")
return function_call_json
def generate_schema_from_args_schema_v2( def generate_schema_from_args_schema_v2(
args_schema: Type[BaseModel], name: Optional[str] = None, description: Optional[str] = None, append_heartbeat: bool = True args_schema: Type[BaseModel], name: Optional[str] = None, description: Optional[str] = None, append_heartbeat: bool = True
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@@ -441,19 +403,8 @@ def generate_schema_from_args_schema_v2(
required = [] required = []
for field_name, field in args_schema.model_fields.items(): for field_name, field in args_schema.model_fields.items():
field_type_annotation = field.annotation field_type_annotation = field.annotation
if field_type_annotation == str: properties[field_name] = type_to_json_schema_type(field_type_annotation)
field_type = "string" properties[field_name]["description"] = field.description
elif field_type_annotation == int:
field_type = "integer"
elif field_type_annotation == bool:
field_type = "boolean"
else:
field_type = field_type_annotation.__name__
properties[field_name] = {
"type": field_type,
"description": field.description,
}
if field.is_required(): if field.is_required():
required.append(field_name) required.append(field_name)

View File

@@ -4,13 +4,7 @@ from typing import List, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from letta.schemas.enums import ToolRuleType from letta.schemas.enums import ToolRuleType
from letta.schemas.tool_rule import ( from letta.schemas.tool_rule import BaseToolRule, ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
BaseToolRule,
ChildToolRule,
ConditionalToolRule,
InitToolRule,
TerminalToolRule,
)
class ToolRuleValidationError(Exception): class ToolRuleValidationError(Exception):
@@ -50,7 +44,6 @@ class ToolRulesSolver(BaseModel):
assert isinstance(rule, TerminalToolRule) assert isinstance(rule, TerminalToolRule)
self.terminal_tool_rules.append(rule) self.terminal_tool_rules.append(rule)
def update_tool_usage(self, tool_name: str): def update_tool_usage(self, tool_name: str):
"""Update the internal state to track the last tool called.""" """Update the internal state to track the last tool called."""
self.last_tool_name = tool_name self.last_tool_name = tool_name
@@ -88,7 +81,7 @@ class ToolRulesSolver(BaseModel):
return any(rule.tool_name == tool_name for rule in self.tool_rules) return any(rule.tool_name == tool_name for rule in self.tool_rules)
def validate_conditional_tool(self, rule: ConditionalToolRule): def validate_conditional_tool(self, rule: ConditionalToolRule):
''' """
Validate a conditional tool rule Validate a conditional tool rule
Args: Args:
@@ -96,13 +89,13 @@ class ToolRulesSolver(BaseModel):
Raises: Raises:
ToolRuleValidationError: If the rule is invalid ToolRuleValidationError: If the rule is invalid
''' """
if len(rule.child_output_mapping) == 0: if len(rule.child_output_mapping) == 0:
raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.") raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.")
return True return True
def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str: def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str:
''' """
Parse function response to determine which child tool to use based on the mapping Parse function response to determine which child tool to use based on the mapping
Args: Args:
@@ -111,7 +104,7 @@ class ToolRulesSolver(BaseModel):
Returns: Returns:
str: The name of the child tool to use next str: The name of the child tool to use next
''' """
json_response = json.loads(last_function_response) json_response = json.loads(last_function_response)
function_output = json_response["message"] function_output = json_response["message"]

View File

@@ -5,10 +5,7 @@ from typing import List, Optional
from colorama import Fore, Style, init from colorama import Fore, Style, init
from letta.constants import CLI_WARNING_PREFIX from letta.constants import CLI_WARNING_PREFIX
from letta.local_llm.constants import ( from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL, INNER_THOUGHTS_CLI_SYMBOL
ASSISTANT_MESSAGE_CLI_SYMBOL,
INNER_THOUGHTS_CLI_SYMBOL,
)
from letta.schemas.message import Message from letta.schemas.message import Message
from letta.utils import json_loads, printd from letta.utils import json_loads, printd

View File

@@ -5,11 +5,7 @@ from typing import List, Optional, Union
from letta.llm_api.helpers import make_post_request from letta.llm_api.helpers import make_post_request
from letta.schemas.message import Message from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
from letta.schemas.openai.chat_completion_response import ( from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall
ChatCompletionResponse,
Choice,
FunctionCall,
)
from letta.schemas.openai.chat_completion_response import ( from letta.schemas.openai.chat_completion_response import (
Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype
) )
@@ -102,13 +98,9 @@ def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]:
formatted_tools = [] formatted_tools = []
for tool in tools: for tool in tools:
formatted_tool = { formatted_tool = {
"name" : tool.function.name, "name": tool.function.name,
"description" : tool.function.description, "description": tool.function.description,
"input_schema" : tool.function.parameters or { "input_schema": tool.function.parameters or {"type": "object", "properties": {}, "required": []},
"type": "object",
"properties": {},
"required": []
}
} }
formatted_tools.append(formatted_tool) formatted_tools.append(formatted_tool)
@@ -346,7 +338,7 @@ def anthropic_chat_completions_request(
data["tool_choice"] = { data["tool_choice"] = {
"type": "tool", # Changed from "function" to "tool" "type": "tool", # Changed from "function" to "tool"
"name": anthropic_tools[0]["name"], # Directly specify name without nested "function" object "name": anthropic_tools[0]["name"], # Directly specify name without nested "function" object
"disable_parallel_tool_use": True # Force single tool use "disable_parallel_tool_use": True, # Force single tool use
} }
# Move 'system' to the top level # Move 'system' to the top level

View File

@@ -7,11 +7,7 @@ import requests
from letta.local_llm.utils import count_tokens from letta.local_llm.utils import count_tokens
from letta.schemas.message import Message from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
from letta.schemas.openai.chat_completion_response import ( from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall
ChatCompletionResponse,
Choice,
FunctionCall,
)
from letta.schemas.openai.chat_completion_response import ( from letta.schemas.openai.chat_completion_response import (
Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype
) )
@@ -276,10 +272,7 @@ def convert_tools_to_cohere_format(tools: List[Tool], inner_thoughts_in_kwargs:
if inner_thoughts_in_kwargs: if inner_thoughts_in_kwargs:
# NOTE: since Cohere doesn't allow "text" in the response when a tool call happens, if we want # NOTE: since Cohere doesn't allow "text" in the response when a tool call happens, if we want
# a simultaneous CoT + tool call we need to put it inside a kwarg # a simultaneous CoT + tool call we need to put it inside a kwarg
from letta.local_llm.constants import ( from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
INNER_THOUGHTS_KWARG,
INNER_THOUGHTS_KWARG_DESCRIPTION,
)
for cohere_tool in tools_dict_list: for cohere_tool in tools_dict_list:
cohere_tool["parameter_definitions"][INNER_THOUGHTS_KWARG] = { cohere_tool["parameter_definitions"][INNER_THOUGHTS_KWARG] = {

View File

@@ -8,14 +8,7 @@ from letta.llm_api.helpers import make_post_request
from letta.local_llm.json_parser import clean_json_string_extra_backslash from letta.local_llm.json_parser import clean_json_string_extra_backslash
from letta.local_llm.utils import count_tokens from letta.local_llm.utils import count_tokens
from letta.schemas.openai.chat_completion_request import Tool from letta.schemas.openai.chat_completion_request import Tool
from letta.schemas.openai.chat_completion_response import ( from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
ChatCompletionResponse,
Choice,
FunctionCall,
Message,
ToolCall,
UsageStatistics,
)
from letta.utils import get_tool_call_id, get_utc_time, json_dumps from letta.utils import get_tool_call_id, get_utc_time, json_dumps
@@ -230,10 +223,7 @@ def convert_tools_to_google_ai_format(tools: List[Tool], inner_thoughts_in_kwarg
param_fields["type"] = param_fields["type"].upper() param_fields["type"] = param_fields["type"].upper()
# Add inner thoughts # Add inner thoughts
if inner_thoughts_in_kwargs: if inner_thoughts_in_kwargs:
from letta.local_llm.constants import ( from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
INNER_THOUGHTS_KWARG,
INNER_THOUGHTS_KWARG_DESCRIPTION,
)
func["parameters"]["properties"][INNER_THOUGHTS_KWARG] = { func["parameters"]["properties"][INNER_THOUGHTS_KWARG] = {
"type": "STRING", "type": "STRING",

View File

@@ -8,38 +8,22 @@ from letta.constants import CLI_WARNING_PREFIX
from letta.errors import LettaConfigurationError, RateLimitExceededError from letta.errors import LettaConfigurationError, RateLimitExceededError
from letta.llm_api.anthropic import anthropic_chat_completions_request from letta.llm_api.anthropic import anthropic_chat_completions_request
from letta.llm_api.azure_openai import azure_openai_chat_completions_request from letta.llm_api.azure_openai import azure_openai_chat_completions_request
from letta.llm_api.google_ai import ( from letta.llm_api.google_ai import convert_tools_to_google_ai_format, google_ai_chat_completions_request
convert_tools_to_google_ai_format, from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_inner_thoughts_from_kwargs
google_ai_chat_completions_request,
)
from letta.llm_api.helpers import (
add_inner_thoughts_to_functions,
unpack_all_inner_thoughts_from_kwargs,
)
from letta.llm_api.openai import ( from letta.llm_api.openai import (
build_openai_chat_completions_request, build_openai_chat_completions_request,
openai_chat_completions_process_stream, openai_chat_completions_process_stream,
openai_chat_completions_request, openai_chat_completions_request,
) )
from letta.local_llm.chat_completion_proxy import get_chat_completion from letta.local_llm.chat_completion_proxy import get_chat_completion
from letta.local_llm.constants import ( from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
INNER_THOUGHTS_KWARG,
INNER_THOUGHTS_KWARG_DESCRIPTION,
)
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import ( from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool, cast_message_to_subtype
ChatCompletionRequest,
Tool,
cast_message_to_subtype,
)
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.settings import ModelSettings from letta.settings import ModelSettings
from letta.streaming_interface import ( from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
AgentChunkStreamingInterface,
AgentRefreshStreamingInterface,
)
LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local", "groq"] LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local", "groq"]

View File

@@ -9,28 +9,15 @@ from httpx_sse._exceptions import SSEError
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from letta.errors import LLMError from letta.errors import LLMError
from letta.llm_api.helpers import ( from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request
add_inner_thoughts_to_functions, from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
convert_to_structured_output,
make_post_request,
)
from letta.local_llm.constants import (
INNER_THOUGHTS_KWARG,
INNER_THOUGHTS_KWARG_DESCRIPTION,
)
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as _Message from letta.schemas.message import Message as _Message
from letta.schemas.message import MessageRole as _MessageRole from letta.schemas.message import MessageRole as _MessageRole
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
from letta.schemas.openai.chat_completion_request import ( from letta.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall
FunctionCall as ToolFunctionChoiceFunctionCall, from letta.schemas.openai.chat_completion_request import Tool, ToolFunctionChoice, cast_message_to_subtype
)
from letta.schemas.openai.chat_completion_request import (
Tool,
ToolFunctionChoice,
cast_message_to_subtype,
)
from letta.schemas.openai.chat_completion_response import ( from letta.schemas.openai.chat_completion_response import (
ChatCompletionChunkResponse, ChatCompletionChunkResponse,
ChatCompletionResponse, ChatCompletionResponse,
@@ -41,10 +28,7 @@ from letta.schemas.openai.chat_completion_response import (
UsageStatistics, UsageStatistics,
) )
from letta.schemas.openai.embedding_response import EmbeddingResponse from letta.schemas.openai.embedding_response import EmbeddingResponse
from letta.streaming_interface import ( from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
AgentChunkStreamingInterface,
AgentRefreshStreamingInterface,
)
from letta.utils import get_tool_call_id, smart_urljoin from letta.utils import get_tool_call_id, smart_urljoin
OPENAI_SSE_DONE = "[DONE]" OPENAI_SSE_DONE = "[DONE]"

View File

@@ -8,10 +8,7 @@ from letta.constants import CLI_WARNING_PREFIX
from letta.errors import LocalLLMConnectionError, LocalLLMError from letta.errors import LocalLLMConnectionError, LocalLLMError
from letta.local_llm.constants import DEFAULT_WRAPPER from letta.local_llm.constants import DEFAULT_WRAPPER
from letta.local_llm.function_parser import patch_function from letta.local_llm.function_parser import patch_function
from letta.local_llm.grammars.gbnf_grammar_generator import ( from letta.local_llm.grammars.gbnf_grammar_generator import create_dynamic_model_from_function, generate_gbnf_grammar_and_documentation
create_dynamic_model_from_function,
generate_gbnf_grammar_and_documentation,
)
from letta.local_llm.koboldcpp.api import get_koboldcpp_completion from letta.local_llm.koboldcpp.api import get_koboldcpp_completion
from letta.local_llm.llamacpp.api import get_llamacpp_completion from letta.local_llm.llamacpp.api import get_llamacpp_completion
from letta.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper from letta.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper
@@ -20,17 +17,9 @@ from letta.local_llm.ollama.api import get_ollama_completion
from letta.local_llm.utils import count_tokens, get_available_wrappers from letta.local_llm.utils import count_tokens, get_available_wrappers
from letta.local_llm.vllm.api import get_vllm_completion from letta.local_llm.vllm.api import get_vllm_completion
from letta.local_llm.webui.api import get_webui_completion from letta.local_llm.webui.api import get_webui_completion
from letta.local_llm.webui.legacy_api import ( from letta.local_llm.webui.legacy_api import get_webui_completion as get_webui_completion_legacy
get_webui_completion as get_webui_completion_legacy,
)
from letta.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE from letta.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
from letta.schemas.openai.chat_completion_response import ( from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, Message, ToolCall, UsageStatistics
ChatCompletionResponse,
Choice,
Message,
ToolCall,
UsageStatistics,
)
from letta.utils import get_tool_call_id, get_utc_time, json_dumps from letta.utils import get_tool_call_id, get_utc_time, json_dumps
has_shown_warning = False has_shown_warning = False

View File

@@ -1,7 +1,5 @@
# import letta.local_llm.llm_chat_completion_wrappers.airoboros as airoboros # import letta.local_llm.llm_chat_completion_wrappers.airoboros as airoboros
from letta.local_llm.llm_chat_completion_wrappers.chatml import ( from letta.local_llm.llm_chat_completion_wrappers.chatml import ChatMLInnerMonologueWrapper
ChatMLInnerMonologueWrapper,
)
DEFAULT_ENDPOINTS = { DEFAULT_ENDPOINTS = {
# Local # Local

View File

@@ -5,18 +5,7 @@ from copy import copy
from enum import Enum from enum import Enum
from inspect import getdoc, isclass from inspect import getdoc, isclass
from types import NoneType from types import NoneType
from typing import ( from typing import Any, Callable, List, Optional, Tuple, Type, Union, _GenericAlias, get_args, get_origin
Any,
Callable,
List,
Optional,
Tuple,
Type,
Union,
_GenericAlias,
get_args,
get_origin,
)
from docstring_parser import parse from docstring_parser import parse
from pydantic import BaseModel, create_model from pydantic import BaseModel, create_model

View File

@@ -1,8 +1,6 @@
from letta.errors import LLMJSONParsingError from letta.errors import LLMJSONParsingError
from letta.local_llm.json_parser import clean_json from letta.local_llm.json_parser import clean_json
from letta.local_llm.llm_chat_completion_wrappers.wrapper_base import ( from letta.local_llm.llm_chat_completion_wrappers.wrapper_base import LLMChatCompletionWrapper
LLMChatCompletionWrapper,
)
from letta.schemas.enums import MessageRole from letta.schemas.enums import MessageRole
from letta.utils import json_dumps, json_loads from letta.utils import json_dumps, json_loads
@@ -75,10 +73,7 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
func_str += f"\n description: {schema['description']}" func_str += f"\n description: {schema['description']}"
func_str += f"\n params:" func_str += f"\n params:"
if add_inner_thoughts: if add_inner_thoughts:
from letta.local_llm.constants import ( from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
INNER_THOUGHTS_KWARG,
INNER_THOUGHTS_KWARG_DESCRIPTION,
)
func_str += f"\n {INNER_THOUGHTS_KWARG}: {INNER_THOUGHTS_KWARG_DESCRIPTION}" func_str += f"\n {INNER_THOUGHTS_KWARG}: {INNER_THOUGHTS_KWARG_DESCRIPTION}"
for param_k, param_v in schema["parameters"]["properties"].items(): for param_k, param_v in schema["parameters"]["properties"].items():

View File

@@ -1,8 +1,6 @@
from letta.errors import LLMJSONParsingError from letta.errors import LLMJSONParsingError
from letta.local_llm.json_parser import clean_json from letta.local_llm.json_parser import clean_json
from letta.local_llm.llm_chat_completion_wrappers.wrapper_base import ( from letta.local_llm.llm_chat_completion_wrappers.wrapper_base import LLMChatCompletionWrapper
LLMChatCompletionWrapper,
)
from letta.utils import json_dumps, json_loads from letta.utils import json_dumps, json_loads
PREFIX_HINT = """# Reminders: PREFIX_HINT = """# Reminders:
@@ -74,10 +72,7 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper):
func_str += f"\n description: {schema['description']}" func_str += f"\n description: {schema['description']}"
func_str += "\n params:" func_str += "\n params:"
if add_inner_thoughts: if add_inner_thoughts:
from letta.local_llm.constants import ( from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
INNER_THOUGHTS_KWARG,
INNER_THOUGHTS_KWARG_DESCRIPTION,
)
func_str += f"\n {INNER_THOUGHTS_KWARG}: {INNER_THOUGHTS_KWARG_DESCRIPTION}" func_str += f"\n {INNER_THOUGHTS_KWARG}: {INNER_THOUGHTS_KWARG_DESCRIPTION}"
for param_k, param_v in schema["parameters"]["properties"].items(): for param_k, param_v in schema["parameters"]["properties"].items():

View File

@@ -2,9 +2,7 @@ import json
import os import os
from letta.constants import LETTA_DIR from letta.constants import LETTA_DIR
from letta.local_llm.settings.deterministic_mirostat import ( from letta.local_llm.settings.deterministic_mirostat import settings as det_miro_settings
settings as det_miro_settings,
)
from letta.local_llm.settings.simple import settings as simple_settings from letta.local_llm.settings.simple import settings as simple_settings
DEFAULT = "simple" DEFAULT = "simple"

View File

@@ -7,7 +7,7 @@ from letta.orm.file import FileMetadata
from letta.orm.job import Job from letta.orm.job import Job
from letta.orm.message import Message from letta.orm.message import Message
from letta.orm.organization import Organization from letta.orm.organization import Organization
from letta.orm.passage import BasePassage, AgentPassage, SourcePassage from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
from letta.orm.source import Source from letta.orm.source import Source
from letta.orm.sources_agents import SourcesAgents from letta.orm.sources_agents import SourcesAgents

View File

@@ -5,11 +5,7 @@ from sqlalchemy import JSON, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.block import Block from letta.orm.block import Block
from letta.orm.custom_columns import ( from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ToolRulesColumn
EmbeddingConfigColumn,
LLMConfigColumn,
ToolRulesColumn,
)
from letta.orm.message import Message from letta.orm.message import Message
from letta.orm.mixins import OrganizationMixin from letta.orm.mixins import OrganizationMixin
from letta.orm.organization import Organization from letta.orm.organization import Organization

View File

@@ -2,13 +2,7 @@ from datetime import datetime
from typing import Optional from typing import Optional
from sqlalchemy import Boolean, DateTime, String, func, text from sqlalchemy import Boolean, DateTime, String, func, text
from sqlalchemy.orm import ( from sqlalchemy.orm import DeclarativeBase, Mapped, declarative_mixin, declared_attr, mapped_column
DeclarativeBase,
Mapped,
declarative_mixin,
declared_attr,
mapped_column,
)
class Base(DeclarativeBase): class Base(DeclarativeBase):

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, List, Optional
from sqlalchemy import Integer, String from sqlalchemy import Integer, String
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
@@ -9,8 +9,9 @@ from letta.schemas.file import FileMetadata as PydanticFileMetadata
if TYPE_CHECKING: if TYPE_CHECKING:
from letta.orm.organization import Organization from letta.orm.organization import Organization
from letta.orm.source import Source
from letta.orm.passage import SourcePassage from letta.orm.passage import SourcePassage
from letta.orm.source import Source
class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin): class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
"""Represents metadata for an uploaded file.""" """Represents metadata for an uploaded file."""
@@ -28,4 +29,6 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
# relationships # relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin") organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin")
source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin") source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin")
source_passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan") source_passages: Mapped[List["SourcePassage"]] = relationship(
"SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan"
)

View File

@@ -31,6 +31,7 @@ class UserMixin(Base):
user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id")) user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"))
class AgentMixin(Base): class AgentMixin(Base):
"""Mixin for models that belong to an agent.""" """Mixin for models that belong to an agent."""
@@ -38,6 +39,7 @@ class AgentMixin(Base):
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE")) agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"))
class FileMixin(Base): class FileMixin(Base):
"""Mixin for models that belong to a file.""" """Mixin for models that belong to a file."""

View File

@@ -38,19 +38,11 @@ class Organization(SqlalchemyBase):
agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan") agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan") messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
source_passages: Mapped[List["SourcePassage"]] = relationship( source_passages: Mapped[List["SourcePassage"]] = relationship(
"SourcePassage", "SourcePassage", back_populates="organization", cascade="all, delete-orphan"
back_populates="organization",
cascade="all, delete-orphan"
)
agent_passages: Mapped[List["AgentPassage"]] = relationship(
"AgentPassage",
back_populates="organization",
cascade="all, delete-orphan"
) )
agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="organization", cascade="all, delete-orphan")
@property @property
def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]: def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]:
"""Convenience property to get all passages""" """Convenience property to get all passages"""
return self.source_passages + self.agent_passages return self.source_passages + self.agent_passages

View File

@@ -8,9 +8,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin, SandboxConfigMixin from letta.orm.mixins import OrganizationMixin, SandboxConfigMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
from letta.schemas.sandbox_config import ( from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticSandboxEnvironmentVariable
SandboxEnvironmentVariable as PydanticSandboxEnvironmentVariable,
)
from letta.schemas.sandbox_config import SandboxType from letta.schemas.sandbox_config import SandboxType
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@@ -9,12 +9,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
from letta.log import get_logger from letta.log import get_logger
from letta.orm.base import Base, CommonSqlalchemyMetaMixins from letta.orm.base import Base, CommonSqlalchemyMetaMixins
from letta.orm.errors import ( from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
DatabaseTimeoutError,
ForeignKeyConstraintViolationError,
NoResultFound,
UniqueConstraintViolationError,
)
from letta.orm.sqlite_functions import adapt_array from letta.orm.sqlite_functions import adapt_array
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@@ -1,13 +1,14 @@
import base64
import sqlite3
from typing import Optional, Union from typing import Optional, Union
import base64
import numpy as np import numpy as np
from sqlalchemy import event from sqlalchemy import event
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
import sqlite3
from letta.constants import MAX_EMBEDDING_DIM from letta.constants import MAX_EMBEDDING_DIM
def adapt_array(arr): def adapt_array(arr):
""" """
Converts numpy array to binary for SQLite storage Converts numpy array to binary for SQLite storage
@@ -19,12 +20,13 @@ def adapt_array(arr):
arr = np.array(arr, dtype=np.float32) arr = np.array(arr, dtype=np.float32)
elif not isinstance(arr, np.ndarray): elif not isinstance(arr, np.ndarray):
raise ValueError(f"Unsupported type: {type(arr)}") raise ValueError(f"Unsupported type: {type(arr)}")
# Convert to bytes and then base64 encode # Convert to bytes and then base64 encode
bytes_data = arr.tobytes() bytes_data = arr.tobytes()
base64_data = base64.b64encode(bytes_data) base64_data = base64.b64encode(bytes_data)
return sqlite3.Binary(base64_data) return sqlite3.Binary(base64_data)
def convert_array(text): def convert_array(text):
""" """
Converts binary back to numpy array Converts binary back to numpy array
@@ -38,23 +40,24 @@ def convert_array(text):
# Handle both bytes and sqlite3.Binary # Handle both bytes and sqlite3.Binary
binary_data = bytes(text) if isinstance(text, sqlite3.Binary) else text binary_data = bytes(text) if isinstance(text, sqlite3.Binary) else text
try: try:
# First decode base64 # First decode base64
decoded_data = base64.b64decode(binary_data) decoded_data = base64.b64decode(binary_data)
# Then convert to numpy array # Then convert to numpy array
return np.frombuffer(decoded_data, dtype=np.float32) return np.frombuffer(decoded_data, dtype=np.float32)
except Exception as e: except Exception:
return None return None
def verify_embedding_dimension(embedding: np.ndarray, expected_dim: int = MAX_EMBEDDING_DIM) -> bool: def verify_embedding_dimension(embedding: np.ndarray, expected_dim: int = MAX_EMBEDDING_DIM) -> bool:
""" """
Verifies that an embedding has the expected dimension Verifies that an embedding has the expected dimension
Args: Args:
embedding: Input embedding array embedding: Input embedding array
expected_dim: Expected embedding dimension (default: 4096) expected_dim: Expected embedding dimension (default: 4096)
Returns: Returns:
bool: True if dimension matches, False otherwise bool: True if dimension matches, False otherwise
""" """
@@ -62,28 +65,27 @@ def verify_embedding_dimension(embedding: np.ndarray, expected_dim: int = MAX_EM
return False return False
return embedding.shape[0] == expected_dim return embedding.shape[0] == expected_dim
def validate_and_transform_embedding( def validate_and_transform_embedding(
embedding: Union[bytes, sqlite3.Binary, list, np.ndarray], embedding: Union[bytes, sqlite3.Binary, list, np.ndarray], expected_dim: int = MAX_EMBEDDING_DIM, dtype: np.dtype = np.float32
expected_dim: int = MAX_EMBEDDING_DIM,
dtype: np.dtype = np.float32
) -> Optional[np.ndarray]: ) -> Optional[np.ndarray]:
""" """
Validates and transforms embeddings to ensure correct dimensionality. Validates and transforms embeddings to ensure correct dimensionality.
Args: Args:
embedding: Input embedding in various possible formats embedding: Input embedding in various possible formats
expected_dim: Expected embedding dimension (default 4096) expected_dim: Expected embedding dimension (default 4096)
dtype: NumPy dtype for the embedding (default float32) dtype: NumPy dtype for the embedding (default float32)
Returns: Returns:
np.ndarray: Validated and transformed embedding np.ndarray: Validated and transformed embedding
Raises: Raises:
ValueError: If embedding dimension doesn't match expected dimension ValueError: If embedding dimension doesn't match expected dimension
""" """
if embedding is None: if embedding is None:
return None return None
# Convert to numpy array based on input type # Convert to numpy array based on input type
if isinstance(embedding, (bytes, sqlite3.Binary)): if isinstance(embedding, (bytes, sqlite3.Binary)):
vec = convert_array(embedding) vec = convert_array(embedding)
@@ -93,48 +95,49 @@ def validate_and_transform_embedding(
vec = embedding.astype(dtype) vec = embedding.astype(dtype)
else: else:
raise ValueError(f"Unsupported embedding type: {type(embedding)}") raise ValueError(f"Unsupported embedding type: {type(embedding)}")
# Validate dimension # Validate dimension
if vec.shape[0] != expected_dim: if vec.shape[0] != expected_dim:
raise ValueError( raise ValueError(f"Invalid embedding dimension: got {vec.shape[0]}, expected {expected_dim}")
f"Invalid embedding dimension: got {vec.shape[0]}, expected {expected_dim}"
)
return vec return vec
def cosine_distance(embedding1, embedding2, expected_dim=MAX_EMBEDDING_DIM): def cosine_distance(embedding1, embedding2, expected_dim=MAX_EMBEDDING_DIM):
""" """
Calculate cosine distance between two embeddings Calculate cosine distance between two embeddings
Args: Args:
embedding1: First embedding embedding1: First embedding
embedding2: Second embedding embedding2: Second embedding
expected_dim: Expected embedding dimension (default 4096) expected_dim: Expected embedding dimension (default 4096)
Returns: Returns:
float: Cosine distance float: Cosine distance
""" """
if embedding1 is None or embedding2 is None: if embedding1 is None or embedding2 is None:
return 0.0 # Maximum distance if either embedding is None return 0.0 # Maximum distance if either embedding is None
try: try:
vec1 = validate_and_transform_embedding(embedding1, expected_dim) vec1 = validate_and_transform_embedding(embedding1, expected_dim)
vec2 = validate_and_transform_embedding(embedding2, expected_dim) vec2 = validate_and_transform_embedding(embedding2, expected_dim)
except ValueError as e: except ValueError:
return 0.0 return 0.0
similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
distance = float(1.0 - similarity) distance = float(1.0 - similarity)
return distance return distance
@event.listens_for(Engine, "connect") @event.listens_for(Engine, "connect")
def register_functions(dbapi_connection, connection_record): def register_functions(dbapi_connection, connection_record):
"""Register SQLite functions""" """Register SQLite functions"""
if isinstance(dbapi_connection, sqlite3.Connection): if isinstance(dbapi_connection, sqlite3.Connection):
dbapi_connection.create_function("cosine_distance", 2, cosine_distance) dbapi_connection.create_function("cosine_distance", 2, cosine_distance)
# Register adapters and converters for numpy arrays # Register adapters and converters for numpy arrays
sqlite3.register_adapter(np.ndarray, adapt_array) sqlite3.register_adapter(np.ndarray, adapt_array)
sqlite3.register_converter("ARRAY", convert_array) sqlite3.register_converter("ARRAY", convert_array)

View File

@@ -3,10 +3,7 @@ from typing import List, Optional
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from letta.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW from letta.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
from letta.llm_api.azure_openai import ( from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint
get_azure_chat_completions_endpoint,
get_azure_embeddings_endpoint,
)
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
@@ -27,12 +24,11 @@ class Provider(BaseModel):
def provider_tag(self) -> str: def provider_tag(self) -> str:
"""String representation of the provider for display purposes""" """String representation of the provider for display purposes"""
raise NotImplementedError raise NotImplementedError
def get_handle(self, model_name: str) -> str: def get_handle(self, model_name: str) -> str:
return f"{self.name}/{model_name}" return f"{self.name}/{model_name}"
class LettaProvider(Provider): class LettaProvider(Provider):
name: str = "letta" name: str = "letta"
@@ -44,7 +40,7 @@ class LettaProvider(Provider):
model_endpoint_type="openai", model_endpoint_type="openai",
model_endpoint="https://inference.memgpt.ai", model_endpoint="https://inference.memgpt.ai",
context_window=16384, context_window=16384,
handle=self.get_handle("letta-free") handle=self.get_handle("letta-free"),
) )
] ]
@@ -56,7 +52,7 @@ class LettaProvider(Provider):
embedding_endpoint="https://embeddings.memgpt.ai", embedding_endpoint="https://embeddings.memgpt.ai",
embedding_dim=1024, embedding_dim=1024,
embedding_chunk_size=300, embedding_chunk_size=300,
handle=self.get_handle("letta-free") handle=self.get_handle("letta-free"),
) )
] ]
@@ -121,7 +117,13 @@ class OpenAIProvider(Provider):
# continue # continue
configs.append( configs.append(
LLMConfig(model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size, handle=self.get_handle(model_name)) LLMConfig(
model=model_name,
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=context_window_size,
handle=self.get_handle(model_name),
)
) )
# for OpenAI, sort in reverse order # for OpenAI, sort in reverse order
@@ -141,7 +143,7 @@ class OpenAIProvider(Provider):
embedding_endpoint="https://api.openai.com/v1", embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536, embedding_dim=1536,
embedding_chunk_size=300, embedding_chunk_size=300,
handle=self.get_handle("text-embedding-ada-002") handle=self.get_handle("text-embedding-ada-002"),
) )
] ]
@@ -170,7 +172,7 @@ class AnthropicProvider(Provider):
model_endpoint_type="anthropic", model_endpoint_type="anthropic",
model_endpoint=self.base_url, model_endpoint=self.base_url,
context_window=model["context_window"], context_window=model["context_window"],
handle=self.get_handle(model["name"]) handle=self.get_handle(model["name"]),
) )
) )
return configs return configs
@@ -203,7 +205,7 @@ class MistralProvider(Provider):
model_endpoint_type="openai", model_endpoint_type="openai",
model_endpoint=self.base_url, model_endpoint=self.base_url,
context_window=model["max_context_length"], context_window=model["max_context_length"],
handle=self.get_handle(model["id"]) handle=self.get_handle(model["id"]),
) )
) )
@@ -259,7 +261,7 @@ class OllamaProvider(OpenAIProvider):
model_endpoint=self.base_url, model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter, model_wrapper=self.default_prompt_formatter,
context_window=context_window, context_window=context_window,
handle=self.get_handle(model["name"]) handle=self.get_handle(model["name"]),
) )
) )
return configs return configs
@@ -335,7 +337,7 @@ class OllamaProvider(OpenAIProvider):
embedding_endpoint=self.base_url, embedding_endpoint=self.base_url,
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
embedding_chunk_size=300, embedding_chunk_size=300,
handle=self.get_handle(model["name"]) handle=self.get_handle(model["name"]),
) )
) )
return configs return configs
@@ -356,7 +358,11 @@ class GroqProvider(OpenAIProvider):
continue continue
configs.append( configs.append(
LLMConfig( LLMConfig(
model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"], handle=self.get_handle(model["id"]) model=model["id"],
model_endpoint_type="groq",
model_endpoint=self.base_url,
context_window=model["context_window"],
handle=self.get_handle(model["id"]),
) )
) )
return configs return configs
@@ -424,7 +430,7 @@ class TogetherProvider(OpenAIProvider):
model_endpoint=self.base_url, model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter, model_wrapper=self.default_prompt_formatter,
context_window=context_window_size, context_window=context_window_size,
handle=self.get_handle(model_name) handle=self.get_handle(model_name),
) )
) )
@@ -505,7 +511,7 @@ class GoogleAIProvider(Provider):
model_endpoint_type="google_ai", model_endpoint_type="google_ai",
model_endpoint=self.base_url, model_endpoint=self.base_url,
context_window=self.get_model_context_window(model), context_window=self.get_model_context_window(model),
handle=self.get_handle(model) handle=self.get_handle(model),
) )
) )
return configs return configs
@@ -529,7 +535,7 @@ class GoogleAIProvider(Provider):
embedding_endpoint=self.base_url, embedding_endpoint=self.base_url,
embedding_dim=768, embedding_dim=768,
embedding_chunk_size=300, # NOTE: max is 2048 embedding_chunk_size=300, # NOTE: max is 2048
handle=self.get_handle(model) handle=self.get_handle(model),
) )
) )
return configs return configs
@@ -559,9 +565,7 @@ class AzureProvider(Provider):
return values return values
def list_llm_models(self) -> List[LLMConfig]: def list_llm_models(self) -> List[LLMConfig]:
from letta.llm_api.azure_openai import ( from letta.llm_api.azure_openai import azure_openai_get_chat_completion_model_list
azure_openai_get_chat_completion_model_list,
)
model_options = azure_openai_get_chat_completion_model_list(self.base_url, api_key=self.api_key, api_version=self.api_version) model_options = azure_openai_get_chat_completion_model_list(self.base_url, api_key=self.api_key, api_version=self.api_version)
configs = [] configs = []
@@ -570,7 +574,8 @@ class AzureProvider(Provider):
context_window_size = self.get_model_context_window(model_name) context_window_size = self.get_model_context_window(model_name)
model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version) model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version)
configs.append( configs.append(
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size), handle=self.get_handle(model_name) LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size),
handle=self.get_handle(model_name),
) )
return configs return configs
@@ -591,7 +596,7 @@ class AzureProvider(Provider):
embedding_endpoint=model_endpoint, embedding_endpoint=model_endpoint,
embedding_dim=768, embedding_dim=768,
embedding_chunk_size=300, # NOTE: max is 2048 embedding_chunk_size=300, # NOTE: max is 2048
handle=self.get_handle(model_name) handle=self.get_handle(model_name),
) )
) )
return configs return configs
@@ -625,7 +630,7 @@ class VLLMChatCompletionsProvider(Provider):
model_endpoint_type="openai", model_endpoint_type="openai",
model_endpoint=self.base_url, model_endpoint=self.base_url,
context_window=model["max_model_len"], context_window=model["max_model_len"],
handle=self.get_handle(model["id"]) handle=self.get_handle(model["id"]),
) )
) )
return configs return configs
@@ -658,7 +663,7 @@ class VLLMCompletionsProvider(Provider):
model_endpoint=self.base_url, model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter, model_wrapper=self.default_prompt_formatter,
context_window=model["max_model_len"], context_window=model["max_model_len"],
handle=self.get_handle(model["id"]) handle=self.get_handle(model["id"]),
) )
) )
return configs return configs

View File

@@ -119,6 +119,7 @@ class CreateAgent(BaseModel, validate_assignment=True): #
context_window_limit: Optional[int] = Field(None, description="The context window limit used by the agent.") context_window_limit: Optional[int] = Field(None, description="The context window limit used by the agent.")
embedding_chunk_size: Optional[int] = Field(DEFAULT_EMBEDDING_CHUNK_SIZE, description="The embedding chunk size used by the agent.") embedding_chunk_size: Optional[int] = Field(DEFAULT_EMBEDDING_CHUNK_SIZE, description="The embedding chunk size used by the agent.")
from_template: Optional[str] = Field(None, description="The template id used to configure the agent") from_template: Optional[str] = Field(None, description="The template id used to configure the agent")
project_id: Optional[str] = Field(None, description="The project id that the agent will be associated with.")
@field_validator("name") @field_validator("name")
@classmethod @classmethod

View File

@@ -23,8 +23,26 @@ class LettaResponse(BaseModel):
usage (LettaUsageStatistics): The usage statistics usage (LettaUsageStatistics): The usage statistics
""" """
messages: List[LettaMessageUnion] = Field(..., description="The messages returned by the agent.") messages: List[LettaMessageUnion] = Field(
usage: LettaUsageStatistics = Field(..., description="The usage statistics of the agent.") ...,
description="The messages returned by the agent.",
json_schema_extra={
"items": {
"oneOf": [
{"x-ref-name": "SystemMessage"},
{"x-ref-name": "UserMessage"},
{"x-ref-name": "ReasoningMessage"},
{"x-ref-name": "ToolCallMessage"},
{"x-ref-name": "ToolReturnMessage"},
{"x-ref-name": "AssistantMessage"},
],
"discriminator": {"propertyName": "message_type"},
}
},
)
usage: LettaUsageStatistics = Field(
..., description="The usage statistics of the agent.", json_schema_extra={"x-ref-name": "LettaUsageStatistics"}
)
def __str__(self): def __str__(self):
return json_dumps( return json_dumps(

View File

@@ -6,24 +6,13 @@ from typing import List, Literal, Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from letta.constants import ( from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TOOL_CALL_ID_MAX_LEN
DEFAULT_MESSAGE_TOOL,
DEFAULT_MESSAGE_TOOL_KWARG,
TOOL_CALL_ID_MAX_LEN,
)
from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.schemas.enums import MessageRole from letta.schemas.enums import MessageRole
from letta.schemas.letta_base import OrmMetadataBase from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.letta_message import ( from letta.schemas.letta_message import AssistantMessage, LettaMessage, ReasoningMessage, SystemMessage
AssistantMessage, from letta.schemas.letta_message import ToolCall as LettaToolCall
ToolCall as LettaToolCall, from letta.schemas.letta_message import ToolCallMessage, ToolReturnMessage, UserMessage
ToolCallMessage,
ToolReturnMessage,
ReasoningMessage,
LettaMessage,
SystemMessage,
UserMessage,
)
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from letta.utils import get_utc_time, is_utc_datetime, json_dumps from letta.utils import get_utc_time, is_utc_datetime, json_dumps

View File

@@ -13,7 +13,7 @@ class OrganizationBase(LettaBase):
class Organization(OrganizationBase): class Organization(OrganizationBase):
id: str = OrganizationBase.generate_id_field() id: str = OrganizationBase.generate_id_field()
name: str = Field(create_random_username(), description="The name of the organization.") name: str = Field(create_random_username(), description="The name of the organization.", json_schema_extra={"default": "SincereYogurt"})
created_at: Optional[datetime] = Field(default_factory=get_utc_time, description="The creation date of the organization.") created_at: Optional[datetime] = Field(default_factory=get_utc_time, description="The creation date of the organization.")

View File

@@ -4,10 +4,7 @@ from pydantic import Field, model_validator
from letta.constants import FUNCTION_RETURN_CHAR_LIMIT from letta.constants import FUNCTION_RETURN_CHAR_LIMIT
from letta.functions.functions import derive_openai_json_schema from letta.functions.functions import derive_openai_json_schema
from letta.functions.helpers import ( from letta.functions.helpers import generate_composio_tool_wrapper, generate_langchain_tool_wrapper
generate_composio_tool_wrapper,
generate_langchain_tool_wrapper,
)
from letta.functions.schema_generator import generate_schema_from_args_schema_v2 from letta.functions.schema_generator import generate_schema_from_args_schema_v2
from letta.schemas.letta_base import LettaBase from letta.schemas.letta_base import LettaBase
from letta.schemas.openai.chat_completions import ToolCall from letta.schemas.openai.chat_completions import ToolCall

View File

@@ -25,6 +25,7 @@ class ConditionalToolRule(BaseToolRule):
""" """
A ToolRule that conditionally maps to different child tools based on the output. A ToolRule that conditionally maps to different child tools based on the output.
""" """
type: ToolRuleType = ToolRuleType.conditional type: ToolRuleType = ToolRuleType.conditional
default_child: Optional[str] = Field(None, description="The default child tool to be called. If None, any tool can be called.") default_child: Optional[str] = Field(None, description="The default child tool to be called. If None, any tool can be called.")
child_output_mapping: Dict[Any, str] = Field(..., description="The output case to check for mapping") child_output_mapping: Dict[Any, str] = Field(..., description="The output case to check for mapping")

View File

@@ -1,4 +1,5 @@
from typing import Literal from typing import Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -12,6 +13,7 @@ class LettaUsageStatistics(BaseModel):
total_tokens (int): The total number of tokens processed by the agent. total_tokens (int): The total number of tokens processed by the agent.
step_count (int): The number of steps taken by the agent. step_count (int): The number of steps taken by the agent.
""" """
message_type: Literal["usage_statistics"] = "usage_statistics" message_type: Literal["usage_statistics"] = "usage_statistics"
completion_tokens: int = Field(0, description="The number of tokens generated by the agent.") completion_tokens: int = Field(0, description="The number of tokens generated by the agent.")
prompt_tokens: int = Field(0, description="The number of tokens in the prompt.") prompt_tokens: int = Field(0, description="The number of tokens in the prompt.")

View File

@@ -15,35 +15,19 @@ from letta.__init__ import __version__
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
from letta.errors import LettaAgentNotFoundError, LettaUserNotFoundError from letta.errors import LettaAgentNotFoundError, LettaUserNotFoundError
from letta.log import get_logger from letta.log import get_logger
from letta.orm.errors import ( from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
DatabaseTimeoutError,
ForeignKeyConstraintViolationError,
NoResultFound,
UniqueConstraintViolationError,
)
from letta.schemas.letta_response import LettaResponse
from letta.server.constants import REST_DEFAULT_PORT from letta.server.constants import REST_DEFAULT_PORT
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests # NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
from letta.server.rest_api.auth.index import ( from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right?
setup_auth_router, # TODO: probably remove right?
)
from letta.server.rest_api.interface import StreamingServerInterface from letta.server.rest_api.interface import StreamingServerInterface
from letta.server.rest_api.routers.openai.assistants.assistants import ( from letta.server.rest_api.routers.openai.assistants.assistants import router as openai_assistants_router
router as openai_assistants_router, from letta.server.rest_api.routers.openai.chat_completions.chat_completions import router as openai_chat_completions_router
)
from letta.server.rest_api.routers.openai.chat_completions.chat_completions import (
router as openai_chat_completions_router,
)
# from letta.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM # from letta.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM
from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes
from letta.server.rest_api.routers.v1.organizations import ( from letta.server.rest_api.routers.v1.organizations import router as organizations_router
router as organizations_router, from letta.server.rest_api.routers.v1.users import router as users_router # TODO: decide on admin
)
from letta.server.rest_api.routers.v1.users import (
router as users_router, # TODO: decide on admin
)
from letta.server.rest_api.static_files import mount_static_files from letta.server.rest_api.static_files import mount_static_files
from letta.server.server import SyncServer from letta.server.server import SyncServer
from letta.settings import settings from letta.settings import settings
@@ -83,9 +67,6 @@ def generate_openapi_schema(app: FastAPI):
openai_docs["info"]["title"] = "OpenAI Assistants API" openai_docs["info"]["title"] = "OpenAI Assistants API"
letta_docs["paths"] = {k: v for k, v in letta_docs["paths"].items() if not k.startswith("/openai")} letta_docs["paths"] = {k: v for k, v in letta_docs["paths"].items() if not k.startswith("/openai")}
letta_docs["info"]["title"] = "Letta API" letta_docs["info"]["title"] = "Letta API"
letta_docs["components"]["schemas"]["LettaResponse"] = {
"properties": LettaResponse.model_json_schema(ref_template="#/components/schemas/LettaResponse/properties/{model}")["$defs"]
}
# Split the API docs into Letta API, and OpenAI Assistants compatible API # Split the API docs into Letta API, and OpenAI Assistants compatible API
for name, docs in [ for name, docs in [

View File

@@ -12,22 +12,19 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.schemas.enums import MessageStreamStatus from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import ( from letta.schemas.letta_message import (
AssistantMessage, AssistantMessage,
LegacyFunctionCallMessage,
LegacyLettaMessage,
LettaMessage,
ReasoningMessage,
ToolCall, ToolCall,
ToolCallDelta, ToolCallDelta,
ToolCallMessage, ToolCallMessage,
ToolReturnMessage, ToolReturnMessage,
ReasoningMessage,
LegacyFunctionCallMessage,
LegacyLettaMessage,
LettaMessage,
) )
from letta.schemas.message import Message from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
from letta.streaming_interface import AgentChunkStreamingInterface from letta.streaming_interface import AgentChunkStreamingInterface
from letta.streaming_utils import ( from letta.streaming_utils import FunctionArgumentsStreamHandler, JSONInnerThoughtsExtractor
FunctionArgumentsStreamHandler,
JSONInnerThoughtsExtractor,
)
from letta.utils import is_utc_datetime from letta.utils import is_utc_datetime

View File

@@ -2,13 +2,7 @@ from typing import List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from letta.schemas.openai.openai import ( from letta.schemas.openai.openai import MessageRoleType, OpenAIMessage, OpenAIThread, ToolCall, ToolCallOutput
MessageRoleType,
OpenAIMessage,
OpenAIThread,
ToolCall,
ToolCallOutput,
)
class CreateAssistantRequest(BaseModel): class CreateAssistantRequest(BaseModel):

View File

@@ -4,14 +4,9 @@ from typing import TYPE_CHECKING, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException from fastapi import APIRouter, Body, Depends, Header, HTTPException
from letta.schemas.enums import MessageRole from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import ToolCall, LettaMessage from letta.schemas.letta_message import LettaMessage, ToolCall
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
from letta.schemas.openai.chat_completion_response import ( from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, Message, UsageStatistics
ChatCompletionResponse,
Choice,
Message,
UsageStatistics,
)
# TODO this belongs in a controller! # TODO this belongs in a controller!
from letta.server.rest_api.routers.v1.agents import send_message_to_agent from letta.server.rest_api.routers.v1.agents import send_message_to_agent

View File

@@ -3,9 +3,7 @@ from letta.server.rest_api.routers.v1.blocks import router as blocks_router
from letta.server.rest_api.routers.v1.health import router as health_router from letta.server.rest_api.routers.v1.health import router as health_router
from letta.server.rest_api.routers.v1.jobs import router as jobs_router from letta.server.rest_api.routers.v1.jobs import router as jobs_router
from letta.server.rest_api.routers.v1.llms import router as llm_router from letta.server.rest_api.routers.v1.llms import router as llm_router
from letta.server.rest_api.routers.v1.sandbox_configs import ( from letta.server.rest_api.routers.v1.sandbox_configs import router as sandbox_configs_router
router as sandbox_configs_router,
)
from letta.server.rest_api.routers.v1.sources import router as sources_router from letta.server.rest_api.routers.v1.sources import router as sources_router
from letta.server.rest_api.routers.v1.tools import router as tools_router from letta.server.rest_api.routers.v1.tools import router as tools_router

View File

@@ -3,16 +3,7 @@ import warnings
from datetime import datetime from datetime import datetime
from typing import List, Optional, Union from typing import List, Optional, Union
from fastapi import ( from fastapi import APIRouter, BackgroundTasks, Body, Depends, Header, HTTPException, Query, status
APIRouter,
BackgroundTasks,
Body,
Depends,
Header,
HTTPException,
Query,
status,
)
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import Field from pydantic import Field
@@ -20,27 +11,13 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.log import get_logger from letta.log import get_logger
from letta.orm.errors import NoResultFound from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
from letta.schemas.block import ( # , BlockLabelUpdate, BlockLimitUpdate from letta.schemas.block import Block, BlockUpdate, CreateBlock # , BlockLabelUpdate, BlockLimitUpdate
Block,
BlockUpdate,
CreateBlock,
)
from letta.schemas.enums import MessageStreamStatus from letta.schemas.enums import MessageStreamStatus
from letta.schemas.job import Job, JobStatus, JobUpdate from letta.schemas.job import Job, JobStatus, JobUpdate
from letta.schemas.letta_message import ( from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, LettaMessageUnion
LegacyLettaMessage,
LettaMessage,
LettaMessageUnion,
)
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_response import LettaResponse
from letta.schemas.memory import ( from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, CreateArchivalMemory, Memory, RecallMemorySummary
ArchivalMemorySummary,
ContextWindowOverview,
CreateArchivalMemory,
Memory,
RecallMemorySummary,
)
from letta.schemas.message import Message, MessageCreate, MessageUpdate from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.passage import Passage from letta.schemas.passage import Passage
from letta.schemas.source import Source from letta.schemas.source import Source
@@ -193,7 +170,7 @@ def get_agent_state(
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@router.delete("/{agent_id}", response_model=AgentState, operation_id="delete_agent") @router.delete("/{agent_id}", response_model=None, operation_id="delete_agent")
def delete_agent( def delete_agent(
agent_id: str, agent_id: str,
server: "SyncServer" = Depends(get_letta_server), server: "SyncServer" = Depends(get_letta_server),
@@ -204,7 +181,8 @@ def delete_agent(
""" """
actor = server.user_manager.get_user_or_default(user_id=user_id) actor = server.user_manager.get_user_or_default(user_id=user_id)
try: try:
return server.agent_manager.delete_agent(agent_id=agent_id, actor=actor) server.agent_manager.delete_agent(agent_id=agent_id, actor=actor)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Agent id={agent_id} successfully deleted"})
except NoResultFound: except NoResultFound:
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found for user_id={actor.id}.") raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found for user_id={actor.id}.")
@@ -343,7 +321,12 @@ def update_agent_memory_block(
actor = server.user_manager.get_user_or_default(user_id=user_id) actor = server.user_manager.get_user_or_default(user_id=user_id)
block = server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor) block = server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
return server.block_manager.update_block(block.id, block_update=block_update, actor=actor) block = server.block_manager.update_block(block.id, block_update=block_update, actor=actor)
# This should also trigger a system prompt change in the agent
server.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor, force=True, update_timestamp=False)
return block
@router.get("/{agent_id}/memory/recall", response_model=RecallMemorySummary, operation_id="get_agent_recall_memory_summary") @router.get("/{agent_id}/memory/recall", response_model=RecallMemorySummary, operation_id="get_agent_recall_memory_summary")

View File

@@ -5,11 +5,7 @@ from fastapi import APIRouter, Depends, Query
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate
from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar
from letta.schemas.sandbox_config import ( from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate, SandboxType
SandboxEnvironmentVariableCreate,
SandboxEnvironmentVariableUpdate,
SandboxType,
)
from letta.server.rest_api.utils import get_letta_server, get_user_id from letta.server.rest_api.utils import get_letta_server, get_user_id
from letta.server.server import SyncServer from letta.server.server import SyncServer

View File

@@ -2,15 +2,7 @@ import os
import tempfile import tempfile
from typing import List, Optional from typing import List, Optional
from fastapi import ( from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, Query, UploadFile
APIRouter,
BackgroundTasks,
Depends,
Header,
HTTPException,
Query,
UploadFile,
)
from letta.schemas.file import FileMetadata from letta.schemas.file import FileMetadata
from letta.schemas.job import Job from letta.schemas.job import Job

View File

@@ -102,6 +102,7 @@ def get_user_id(user_id: Optional[str] = Header(None, alias="user_id")) -> Optio
def get_current_interface() -> StreamingServerInterface: def get_current_interface() -> StreamingServerInterface:
return StreamingServerInterface return StreamingServerInterface
def log_error_to_sentry(e): def log_error_to_sentry(e):
import traceback import traceback

View File

@@ -49,15 +49,11 @@ from letta.schemas.enums import JobStatus
from letta.schemas.job import Job, JobUpdate from letta.schemas.job import Job, JobUpdate
from letta.schemas.letta_message import LettaMessage, ToolReturnMessage from letta.schemas.letta_message import LettaMessage, ToolReturnMessage
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ( from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, Memory, RecallMemorySummary
ArchivalMemorySummary,
ContextWindowOverview,
Memory,
RecallMemorySummary,
)
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate
from letta.schemas.organization import Organization from letta.schemas.organization import Organization
from letta.schemas.passage import Passage from letta.schemas.passage import Passage
from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxType
from letta.schemas.source import Source from letta.schemas.source import Source
from letta.schemas.tool import Tool from letta.schemas.tool import Tool
from letta.schemas.usage import LettaUsageStatistics from letta.schemas.usage import LettaUsageStatistics
@@ -303,6 +299,17 @@ class SyncServer(Server):
self.block_manager.add_default_blocks(actor=self.default_user) self.block_manager.add_default_blocks(actor=self.default_user)
self.tool_manager.upsert_base_tools(actor=self.default_user) self.tool_manager.upsert_base_tools(actor=self.default_user)
# Add composio keys to the tool sandbox env vars of the org
if tool_settings.composio_api_key:
manager = SandboxConfigManager(tool_settings)
sandbox_config = manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.default_user)
manager.create_sandbox_env_var(
SandboxEnvironmentVariableCreate(key="COMPOSIO_API_KEY", value=tool_settings.composio_api_key),
sandbox_config_id=sandbox_config.id,
actor=self.default_user,
)
# collect providers (always has Letta as a default) # collect providers (always has Letta as a default)
self._enabled_providers: List[Provider] = [LettaProvider()] self._enabled_providers: List[Provider] = [LettaProvider()]
if model_settings.openai_api_key: if model_settings.openai_api_key:

View File

@@ -279,7 +279,7 @@ class AgentManager:
return agent.to_pydantic() return agent.to_pydantic()
@enforce_types @enforce_types
def delete_agent(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState: def delete_agent(self, agent_id: str, actor: PydanticUser) -> None:
""" """
Deletes an agent and its associated relationships. Deletes an agent and its associated relationships.
Ensures proper permission checks and cascades where applicable. Ensures proper permission checks and cascades where applicable.
@@ -288,15 +288,13 @@ class AgentManager:
agent_id: ID of the agent to be deleted. agent_id: ID of the agent to be deleted.
actor: User performing the action. actor: User performing the action.
Returns: Raises:
PydanticAgentState: The deleted agent state NoResultFound: If agent doesn't exist
""" """
with self.session_maker() as session: with self.session_maker() as session:
# Retrieve the agent # Retrieve the agent
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
agent_state = agent.to_pydantic()
agent.hard_delete(session) agent.hard_delete(session)
return agent_state
# ====================================================================================================================== # ======================================================================================================================
# In Context Messages Management # In Context Messages Management

View File

@@ -1,21 +1,15 @@
from typing import List, Optional
from datetime import datetime from datetime import datetime
import numpy as np from typing import List, Optional
from sqlalchemy import select, union_all, literal
from letta.constants import MAX_EMBEDDING_DIM
from letta.embeddings import embedding_model, parse_and_chunk_text from letta.embeddings import embedding_model, parse_and_chunk_text
from letta.orm.errors import NoResultFound from letta.orm.errors import NoResultFound
from letta.orm.passage import AgentPassage, SourcePassage from letta.orm.passage import AgentPassage, SourcePassage
from letta.schemas.agent import AgentState from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.user import User as PydanticUser from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types from letta.utils import enforce_types
class PassageManager: class PassageManager:
"""Manager class to handle business logic related to Passages.""" """Manager class to handle business logic related to Passages."""

View File

@@ -9,11 +9,7 @@ from letta.schemas.sandbox_config import LocalSandboxConfig
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate
from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar
from letta.schemas.sandbox_config import ( from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate, SandboxType
SandboxEnvironmentVariableCreate,
SandboxEnvironmentVariableUpdate,
SandboxType,
)
from letta.schemas.user import User as PydanticUser from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types, printd from letta.utils import enforce_types, printd

View File

@@ -127,7 +127,7 @@ class ToolExecutionSandbox:
if local_configs.use_venv: if local_configs.use_venv:
return self.run_local_dir_sandbox_venv(sbx_config, env, temp_file_path) return self.run_local_dir_sandbox_venv(sbx_config, env, temp_file_path)
else: else:
return self.run_local_dir_sandbox_runpy(sbx_config, env_vars, temp_file_path) return self.run_local_dir_sandbox_runpy(sbx_config, env, temp_file_path)
except Exception as e: except Exception as e:
logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}") logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}")
logger.error(f"Logging out tool {self.tool_name} auto-generated code for debugging: \n\n{code}") logger.error(f"Logging out tool {self.tool_name} auto-generated code for debugging: \n\n{code}")
@@ -200,7 +200,7 @@ class ToolExecutionSandbox:
logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}") logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}")
raise e raise e
def run_local_dir_sandbox_runpy(self, sbx_config: SandboxConfig, env_vars: Dict[str, str], temp_file_path: str) -> SandboxRunResult: def run_local_dir_sandbox_runpy(self, sbx_config: SandboxConfig, env: Dict[str, str], temp_file_path: str) -> SandboxRunResult:
status = "success" status = "success"
agent_state, stderr = None, None agent_state, stderr = None, None
@@ -213,8 +213,8 @@ class ToolExecutionSandbox:
try: try:
# Execute the temp file # Execute the temp file
with self.temporary_env_vars(env_vars): with self.temporary_env_vars(env):
result = runpy.run_path(temp_file_path, init_globals=env_vars) result = runpy.run_path(temp_file_path, init_globals=env)
# Fetch the result # Fetch the result
func_result = result.get(self.LOCAL_SANDBOX_RESULT_VAR_NAME) func_result = result.get(self.LOCAL_SANDBOX_RESULT_VAR_NAME)
@@ -277,6 +277,10 @@ class ToolExecutionSandbox:
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user) sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user)
sbx = self.get_running_e2b_sandbox_with_same_state(sbx_config) sbx = self.get_running_e2b_sandbox_with_same_state(sbx_config)
if not sbx or self.force_recreate: if not sbx or self.force_recreate:
if not sbx:
logger.info(f"No running e2b sandbox found with the same state: {sbx_config}")
else:
logger.info(f"Force recreated e2b sandbox with state: {sbx_config}")
sbx = self.create_e2b_sandbox_with_metadata_hash(sandbox_config=sbx_config) sbx = self.create_e2b_sandbox_with_metadata_hash(sandbox_config=sbx_config)
# Since this sandbox was used, we extend its lifecycle by the timeout # Since this sandbox was used, we extend its lifecycle by the timeout
@@ -292,6 +296,8 @@ class ToolExecutionSandbox:
func_return, agent_state = self.parse_best_effort(execution.results[0].text) func_return, agent_state = self.parse_best_effort(execution.results[0].text)
elif execution.error: elif execution.error:
logger.error(f"Executing tool {self.tool_name} failed with {execution.error}") logger.error(f"Executing tool {self.tool_name} failed with {execution.error}")
logger.error(f"E2B Sandbox configurations: {sbx_config}")
logger.error(f"E2B Sandbox ID: {sbx.sandbox_id}")
func_return = get_friendly_error_msg( func_return = get_friendly_error_msg(
function_name=self.tool_name, exception_name=execution.error.name, exception_message=execution.error.value function_name=self.tool_name, exception_name=execution.error.name, exception_message=execution.error.value
) )

View File

@@ -60,7 +60,13 @@ class ModelSettings(BaseSettings):
openllm_api_key: Optional[str] = None openllm_api_key: Optional[str] = None
cors_origins = ["http://letta.localhost", "http://localhost:8283", "http://localhost:8083", "http://localhost:3000"] cors_origins = [
"http://letta.localhost",
"http://localhost:8283",
"http://localhost:8083",
"http://localhost:3000",
"http://localhost:4200",
]
class Settings(BaseSettings): class Settings(BaseSettings):

View File

@@ -9,15 +9,9 @@ from rich.live import Live
from rich.markup import escape from rich.markup import escape
from letta.interface import CLIInterface from letta.interface import CLIInterface
from letta.local_llm.constants import ( from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL, INNER_THOUGHTS_CLI_SYMBOL
ASSISTANT_MESSAGE_CLI_SYMBOL,
INNER_THOUGHTS_CLI_SYMBOL,
)
from letta.schemas.message import Message from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ( from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse, ChatCompletionResponse
ChatCompletionChunkResponse,
ChatCompletionResponse,
)
# init(autoreset=True) # init(autoreset=True)

View File

@@ -1120,6 +1120,7 @@ def sanitize_filename(filename: str) -> str:
# Return the sanitized filename # Return the sanitized filename
return sanitized_filename return sanitized_filename
def get_friendly_error_msg(function_name: str, exception_name: str, exception_message: str): def get_friendly_error_msg(function_name: str, exception_name: str, exception_message: str):
from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT

14
poetry.lock generated
View File

@@ -726,13 +726,13 @@ test = ["pytest"]
[[package]] [[package]]
name = "composio-core" name = "composio-core"
version = "0.6.3" version = "0.6.7"
description = "Core package to act as a bridge between composio platform and other services." description = "Core package to act as a bridge between composio platform and other services."
optional = false optional = false
python-versions = "<4,>=3.9" python-versions = "<4,>=3.9"
files = [ files = [
{file = "composio_core-0.6.3-py3-none-any.whl", hash = "sha256:981a9856781b791242f947a9685a18974d8a012ac7fab2c09438e1b19610d6a2"}, {file = "composio_core-0.6.7-py3-none-any.whl", hash = "sha256:03cedeffe417b919d1021c1bc4751f54bd05829b52ff3285f7984e14bdf91efe"},
{file = "composio_core-0.6.3.tar.gz", hash = "sha256:13098b20d8832e74453ca194889305c935432156fc07be91dfddf76561ad591b"}, {file = "composio_core-0.6.7.tar.gz", hash = "sha256:b87f0b804d87945b4eae556468b9efc75f751d256bbf2c20fb8ae5b6a31a2818"},
] ]
[package.dependencies] [package.dependencies]
@@ -762,13 +762,13 @@ tools = ["diskcache", "flake8", "networkx", "pathspec", "pygments", "ruff", "tra
[[package]] [[package]]
name = "composio-langchain" name = "composio-langchain"
version = "0.6.3" version = "0.6.7"
description = "Use Composio to get an array of tools with your LangChain agent." description = "Use Composio to get an array of tools with your LangChain agent."
optional = false optional = false
python-versions = "<4,>=3.9" python-versions = "<4,>=3.9"
files = [ files = [
{file = "composio_langchain-0.6.3-py3-none-any.whl", hash = "sha256:0e749a1603dc0562293412d0a6429f88b75152b01a313cca859732070d762a6b"}, {file = "composio_langchain-0.6.7-py3-none-any.whl", hash = "sha256:f8653b6a7e6b03a61b679a096e278744d3009ebaf3741d7e24e5120a364f212e"},
{file = "composio_langchain-0.6.3.tar.gz", hash = "sha256:2036f94bfe60974b31f2be0bfdb33dd75a1d43435f275141219b3376587bf49d"}, {file = "composio_langchain-0.6.7.tar.gz", hash = "sha256:adeab3a87b0e6eb7e96048cef6b988dbe699b6a493a82fac2d371ab940e7e54e"},
] ]
[package.dependencies] [package.dependencies]
@@ -6246,4 +6246,4 @@ tests = ["wikipedia"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "<4.0,>=3.10" python-versions = "<4.0,>=3.10"
content-hash = "4a7cf176579d5dc15648979542da152ec98290f1e9f39039cfe9baf73bc1076f" content-hash = "1c52219049a4470dd54a45318b22495a4cafa29e93a1c5369a0d54da71990adb"

82
project.json Normal file
View File

@@ -0,0 +1,82 @@
{
"name": "core",
"$schema": "../../node_modules/nx/schemas/project-schema.json",
"projectType": "application",
"sourceRoot": "apps/core",
"targets": {
"lock": {
"executor": "@nxlv/python:run-commands",
"options": {
"command": "poetry lock --no-update",
"cwd": "apps/core"
}
},
"add": {
"executor": "@nxlv/python:add",
"options": {}
},
"update": {
"executor": "@nxlv/python:update",
"options": {}
},
"remove": {
"executor": "@nxlv/python:remove",
"options": {}
},
"dev": {
"executor": "@nxlv/python:run-commands",
"options": {
"command": "poetry run letta server",
"cwd": "apps/core"
}
},
"build": {
"executor": "@nxlv/python:build",
"outputs": ["{projectRoot}/dist"],
"options": {
"outputPath": "apps/core/dist",
"publish": false,
"lockedVersions": true,
"bundleLocalDependencies": true
}
},
"install": {
"executor": "@nxlv/python:run-commands",
"options": {
"command": "poetry install --all-extras",
"cwd": "apps/core"
}
},
"lint": {
"executor": "@nxlv/python:run-commands",
"options": {
"command": "poetry run isort --profile black . && poetry run black . && poetry run autoflake --remove-all-unused-imports --remove-unused-variables --in-place --recursive --ignore-init-module-imports .",
"cwd": "apps/core"
}
},
"database:migrate": {
"executor": "@nxlv/python:run-commands",
"options": {
"command": "poetry run alembic upgrade head",
"cwd": "apps/core"
}
},
"test": {
"executor": "@nxlv/python:run-commands",
"outputs": [
"{workspaceRoot}/reports/apps/core/unittests",
"{workspaceRoot}/coverage/apps/core"
],
"options": {
"command": "poetry run pytest tests/",
"cwd": "apps/core"
}
}
},
"tags": [],
"release": {
"version": {
"generator": "@nxlv/python:release-version"
}
}
}

View File

@@ -59,8 +59,8 @@ nltk = "^3.8.1"
jinja2 = "^3.1.4" jinja2 = "^3.1.4"
locust = {version = "^2.31.5", optional = true} locust = {version = "^2.31.5", optional = true}
wikipedia = {version = "^1.4.0", optional = true} wikipedia = {version = "^1.4.0", optional = true}
composio-langchain = "^0.6.3" composio-langchain = "^0.6.7"
composio-core = "^0.6.3" composio-core = "^0.6.7"
alembic = "^1.13.3" alembic = "^1.13.3"
pyhumps = "^3.8.0" pyhumps = "^3.8.0"
psycopg2 = {version = "^2.9.10", optional = true} psycopg2 = {version = "^2.9.10", optional = true}
@@ -85,7 +85,7 @@ qdrant = ["qdrant-client"]
cloud-tool-sandbox = ["e2b-code-interpreter"] cloud-tool-sandbox = ["e2b-code-interpreter"]
external-tools = ["docker", "langchain", "wikipedia", "langchain-community"] external-tools = ["docker", "langchain", "wikipedia", "langchain-community"]
tests = ["wikipedia"] tests = ["wikipedia"]
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "llama-index-embeddings-ollama", "docker", "langchain", "wikipedia", "langchain-community", "locust"] all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust"]
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
black = "^24.4.2" black = "^24.4.2"
@@ -100,3 +100,11 @@ extend-exclude = "examples/*"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.isort]
profile = "black"
line_length = 140
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true

View File

@@ -14,12 +14,7 @@ from letta.agent import Agent
from letta.config import LettaConfig from letta.config import LettaConfig
from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
from letta.embeddings import embedding_model from letta.embeddings import embedding_model
from letta.errors import ( from letta.errors import InvalidInnerMonologueError, InvalidToolCallError, MissingInnerMonologueError, MissingToolCallError
InvalidInnerMonologueError,
InvalidToolCallError,
MissingInnerMonologueError,
MissingToolCallError,
)
from letta.llm_api.llm_api_tools import create from letta.llm_api.llm_api_tools import create
from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.schemas.agent import AgentState from letta.schemas.agent import AgentState
@@ -28,12 +23,7 @@ from letta.schemas.letta_message import LettaMessage, ReasoningMessage, ToolCall
from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_response import LettaResponse
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory from letta.schemas.memory import ChatMemory
from letta.schemas.openai.chat_completion_response import ( from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message
ChatCompletionResponse,
Choice,
FunctionCall,
Message,
)
from letta.utils import get_human_text, get_persona_text, json_dumps from letta.utils import get_human_text, get_persona_text, json_dumps
from tests.helpers.utils import cleanup from tests.helpers.utils import cleanup

View File

@@ -5,12 +5,7 @@ import pytest
from letta import create_client from letta import create_client
from letta.schemas.letta_message import ToolCallMessage from letta.schemas.letta_message import ToolCallMessage
from letta.schemas.tool_rule import ( from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
ChildToolRule,
ConditionalToolRule,
InitToolRule,
TerminalToolRule,
)
from tests.helpers.endpoints_helper import ( from tests.helpers.endpoints_helper import (
assert_invoked_function_call, assert_invoked_function_call,
assert_invoked_send_message_with_keyword, assert_invoked_send_message_with_keyword,

View File

@@ -0,0 +1,28 @@
import pytest
from fastapi.testclient import TestClient
from letta.server.rest_api.app import app
@pytest.fixture
def client():
return TestClient(app)
def test_list_composio_apps(client):
response = client.get("/v1/tools/composio/apps")
assert response.status_code == 200
assert isinstance(response.json(), list)
def test_list_composio_actions_by_app(client):
response = client.get("/v1/tools/composio/apps/github/actions")
assert response.status_code == 200
assert isinstance(response.json(), list)
def test_add_composio_tool(client):
response = client.post("/v1/tools/composio/GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER")
assert response.status_code == 200
assert "id" in response.json()
assert "name" in response.json()

View File

@@ -212,9 +212,7 @@ def clear_core_memory_tool(test_user):
@pytest.fixture @pytest.fixture
def external_codebase_tool(test_user): def external_codebase_tool(test_user):
from tests.test_tool_sandbox.restaurant_management_system.adjust_menu_prices import ( from tests.test_tool_sandbox.restaurant_management_system.adjust_menu_prices import adjust_menu_prices
adjust_menu_prices,
)
tool = create_tool_from_func(adjust_menu_prices) tool = create_tool_from_func(adjust_menu_prices)
tool = ToolManager().create_or_update_tool(tool, test_user) tool = ToolManager().create_or_update_tool(tool, test_user)
@@ -353,6 +351,14 @@ def test_local_sandbox_e2e_composio_star_github(mock_e2b_api_key_none, check_com
assert result.func_return["details"] == "Action executed successfully" assert result.func_return["details"] == "Action executed successfully"
@pytest.mark.local_sandbox
def test_local_sandbox_e2e_composio_star_github_without_setting_db_env_vars(
mock_e2b_api_key_none, check_composio_key_set, composio_github_star_tool, test_user
):
result = ToolExecutionSandbox(composio_github_star_tool.name, {"owner": "letta-ai", "repo": "letta"}, user=test_user).run()
assert result.func_return["details"] == "Action executed successfully"
@pytest.mark.local_sandbox @pytest.mark.local_sandbox
def test_local_sandbox_external_codebase(mock_e2b_api_key_none, custom_test_sandbox_config, external_codebase_tool, test_user): def test_local_sandbox_external_codebase(mock_e2b_api_key_none, custom_test_sandbox_config, external_codebase_tool, test_user):
# Set the args # Set the args
@@ -458,7 +464,7 @@ def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, get_e
config = manager.create_or_update_sandbox_config(config_create, test_user) config = manager.create_or_update_sandbox_config(config_create, test_user)
# Run the custom sandbox once, assert nothing returns because missing env variable # Run the custom sandbox once, assert nothing returns because missing env variable
sandbox = ToolExecutionSandbox(get_env_tool.name, {}, user=test_user, force_recreate=True) sandbox = ToolExecutionSandbox(get_env_tool.name, {}, user=test_user)
result = sandbox.run() result = sandbox.run()
# response should be None # response should be None
assert result.func_return is None assert result.func_return is None

View File

@@ -5,10 +5,7 @@ import sys
import pexpect import pexpect
import pytest import pytest
from letta.local_llm.constants import ( from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL, INNER_THOUGHTS_CLI_SYMBOL
ASSISTANT_MESSAGE_CLI_SYMBOL,
INNER_THOUGHTS_CLI_SYMBOL,
)
original_letta_path = os.path.expanduser("~/.letta") original_letta_path = os.path.expanduser("~/.letta")
backup_letta_path = os.path.expanduser("~/.letta_backup") backup_letta_path = os.path.expanduser("~/.letta_backup")

View File

@@ -43,7 +43,7 @@ def run_server():
@pytest.fixture( @pytest.fixture(
params=[{"server": False}, {"server": True}], # whether to use REST API server params=[{"server": False}, {"server": True}], # whether to use REST API server
# params=[{"server": True}], # whether to use REST API server # params=[{"server": False}], # whether to use REST API server
scope="module", scope="module",
) )
def client(request): def client(request):
@@ -341,7 +341,9 @@ def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState):
def test_send_system_message(client: Union[LocalClient, RESTClient], agent: AgentState): def test_send_system_message(client: Union[LocalClient, RESTClient], agent: AgentState):
"""Important unit test since the Letta API exposes sending system messages, but some backends don't natively support it (eg Anthropic)""" """Important unit test since the Letta API exposes sending system messages, but some backends don't natively support it (eg Anthropic)"""
send_system_message_response = client.send_message(agent_id=agent.id, message="Event occurred: The user just logged off.", role="system") send_system_message_response = client.send_message(
agent_id=agent.id, message="Event occurred: The user just logged off.", role="system"
)
assert send_system_message_response, "Sending message failed" assert send_system_message_response, "Sending message failed"
@@ -390,7 +392,7 @@ def test_function_always_error(client: Union[LocalClient, RESTClient]):
""" """
Always throw an error. Always throw an error.
""" """
return 5/0 return 5 / 0
tool = client.create_or_update_tool(func=always_error) tool = client.create_or_update_tool(func=always_error)
agent = client.create_agent(tool_ids=[tool.id]) agent = client.create_agent(tool_ids=[tool.id])
@@ -406,12 +408,13 @@ def test_function_always_error(client: Union[LocalClient, RESTClient]):
assert response_message, "ToolReturnMessage message not found in response" assert response_message, "ToolReturnMessage message not found in response"
assert response_message.status == "error" assert response_message.status == "error"
if isinstance(client, RESTClient): if isinstance(client, RESTClient):
assert response_message.tool_return == "Error executing function always_error: ZeroDivisionError: division by zero" assert response_message.tool_return == "Error executing function always_error: ZeroDivisionError: division by zero"
else: else:
response_json = json.loads(response_message.tool_return) response_json = json.loads(response_message.tool_return)
assert response_json['status'] == "Failed" assert response_json["status"] == "Failed"
assert response_json['message'] == "Error executing function always_error: ZeroDivisionError: division by zero" assert response_json["message"] == "Error executing function always_error: ZeroDivisionError: division by zero"
client.delete_agent(agent_id=agent.id) client.delete_agent(agent_id=agent.id)

View File

@@ -9,14 +9,7 @@ import letta.utils as utils
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
from letta.schemas.block import CreateBlock from letta.schemas.block import CreateBlock
from letta.schemas.enums import MessageRole from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import ( from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
LettaMessage,
ReasoningMessage,
SystemMessage,
ToolCallMessage,
ToolReturnMessage,
UserMessage,
)
from letta.schemas.user import User from letta.schemas.user import User
utils.DEBUG = True utils.DEBUG = True

View File

@@ -2,12 +2,7 @@ import pytest
from letta.helpers import ToolRulesSolver from letta.helpers import ToolRulesSolver
from letta.helpers.tool_rule_solver import ToolRuleValidationError from letta.helpers.tool_rule_solver import ToolRuleValidationError
from letta.schemas.tool_rule import ( from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
ChildToolRule,
ConditionalToolRule,
InitToolRule,
TerminalToolRule
)
# Constants for tool names used in the tests # Constants for tool names used in the tests
START_TOOL = "start_tool" START_TOOL = "start_tool"
@@ -113,11 +108,7 @@ def test_conditional_tool_rule():
# Setup: Define a conditional tool rule # Setup: Define a conditional tool rule
init_rule = InitToolRule(tool_name=START_TOOL) init_rule = InitToolRule(tool_name=START_TOOL)
terminal_rule = TerminalToolRule(tool_name=END_TOOL) terminal_rule = TerminalToolRule(tool_name=END_TOOL)
rule = ConditionalToolRule( rule = ConditionalToolRule(tool_name=START_TOOL, default_child=None, child_output_mapping={True: END_TOOL, False: START_TOOL})
tool_name=START_TOOL,
default_child=None,
child_output_mapping={True: END_TOOL, False: START_TOOL}
)
solver = ToolRulesSolver(tool_rules=[init_rule, rule, terminal_rule]) solver = ToolRulesSolver(tool_rules=[init_rule, rule, terminal_rule])
# Action & Assert: Verify the rule properties # Action & Assert: Verify the rule properties
@@ -126,8 +117,12 @@ def test_conditional_tool_rule():
# Step 2: After using 'start_tool' # Step 2: After using 'start_tool'
solver.update_tool_usage(START_TOOL) solver.update_tool_usage(START_TOOL)
assert solver.get_allowed_tool_names(last_function_response='{"message": "true"}') == [END_TOOL], "After 'start_tool' returns true, should allow 'end_tool'" assert solver.get_allowed_tool_names(last_function_response='{"message": "true"}') == [
assert solver.get_allowed_tool_names(last_function_response='{"message": "false"}') == [START_TOOL], "After 'start_tool' returns false, should allow 'start_tool'" END_TOOL
], "After 'start_tool' returns true, should allow 'end_tool'"
assert solver.get_allowed_tool_names(last_function_response='{"message": "false"}') == [
START_TOOL
], "After 'start_tool' returns false, should allow 'start_tool'"
# Step 3: After using 'end_tool' # Step 3: After using 'end_tool'
assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as terminal" assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as terminal"
@@ -137,11 +132,7 @@ def test_invalid_conditional_tool_rule():
# Setup: Define an invalid conditional tool rule # Setup: Define an invalid conditional tool rule
init_rule = InitToolRule(tool_name=START_TOOL) init_rule = InitToolRule(tool_name=START_TOOL)
terminal_rule = TerminalToolRule(tool_name=END_TOOL) terminal_rule = TerminalToolRule(tool_name=END_TOOL)
invalid_rule_1 = ConditionalToolRule( invalid_rule_1 = ConditionalToolRule(tool_name=START_TOOL, default_child=END_TOOL, child_output_mapping={})
tool_name=START_TOOL,
default_child=END_TOOL,
child_output_mapping={}
)
# Test 1: Missing child output mapping # Test 1: Missing child output mapping
with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have at least one child tool."): with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have at least one child tool."):

View File

@@ -8,6 +8,7 @@ def adjust_menu_prices(percentage: float) -> str:
str: A formatted string summarizing the price adjustments. str: A formatted string summarizing the price adjustments.
""" """
import cowsay import cowsay
from core.menu import Menu, MenuItem # Import a class from the codebase from core.menu import Menu, MenuItem # Import a class from the codebase
from core.utils import format_currency # Use a utility function to test imports from core.utils import format_currency # Use a utility function to test imports

View File

@@ -5,6 +5,7 @@ import pytest
from letta.functions.functions import derive_openai_json_schema from letta.functions.functions import derive_openai_json_schema
from letta.llm_api.helpers import convert_to_structured_output, make_post_request from letta.llm_api.helpers import convert_to_structured_output, make_post_request
from letta.schemas.tool import ToolCreate
def _clean_diff(d1, d2): def _clean_diff(d1, d2):
@@ -176,3 +177,38 @@ def test_valid_schemas_via_openai(openai_model: str, structured_output: bool):
_openai_payload(openai_model, schema, structured_output) _openai_payload(openai_model, schema, structured_output)
else: else:
_openai_payload(openai_model, schema, structured_output) _openai_payload(openai_model, schema, structured_output)
@pytest.mark.parametrize("openai_model", ["gpt-4o-mini"])
@pytest.mark.parametrize("structured_output", [True])
def test_composio_tool_schema_generation(openai_model: str, structured_output: bool):
"""Test that we can generate the schemas for some Composio tools."""
if not os.getenv("COMPOSIO_API_KEY"):
pytest.skip("COMPOSIO_API_KEY not set")
try:
import composio
except ImportError:
pytest.skip("Composio not installed")
for action_name in [
"CAL_GET_AVAILABLE_SLOTS_INFO", # has an array arg, needs to be converted properly
]:
try:
tool_create = ToolCreate.from_composio(action_name=action_name)
except composio.exceptions.ComposioSDKError:
# e.g. "composio.exceptions.ComposioSDKError: No connected account found for app `CAL`; Run `composio add cal` to fix this"
pytest.skip(f"Composio account not configured to use action_name {action_name}")
print(tool_create)
assert tool_create.json_schema
schema = tool_create.json_schema
try:
_openai_payload(openai_model, schema, structured_output)
print(f"Successfully called OpenAI using schema {schema} generated from {action_name}")
except:
print(f"Failed to call OpenAI using schema {schema} generated from {action_name}")
raise

View File

@@ -1,12 +1,7 @@
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
from composio.client.collections import ( from composio.client.collections import ActionModel, ActionParametersModel, ActionResponseModel, AppModel
ActionModel,
ActionParametersModel,
ActionResponseModel,
AppModel,
)
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from letta.schemas.tool import ToolCreate, ToolUpdate from letta.schemas.tool import ToolCreate, ToolUpdate