feat: various fixes (#2320)
Co-authored-by: Shubham Naik <shub@memgpt.ai> Co-authored-by: Matt Zhou <mattzh1314@gmail.com> Co-authored-by: Shubham Naik <shubham.naik10@gmail.com> Co-authored-by: Caren Thomas <caren@letta.com> Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
29
.env.example
29
.env.example
@@ -2,43 +2,20 @@
|
|||||||
Example enviornment variable configurations for the Letta
|
Example enviornment variable configurations for the Letta
|
||||||
Docker container. Un-coment the sections you want to
|
Docker container. Un-coment the sections you want to
|
||||||
configure with.
|
configure with.
|
||||||
|
|
||||||
Hint: You don't need to have the same LLM and
|
|
||||||
Embedding model backends (can mix and match).
|
|
||||||
##########################################################
|
##########################################################
|
||||||
|
|
||||||
|
|
||||||
##########################################################
|
##########################################################
|
||||||
OpenAI configuration
|
OpenAI configuration
|
||||||
##########################################################
|
##########################################################
|
||||||
## LLM Model
|
# OPENAI_API_KEY=sk-...
|
||||||
#LETTA_LLM_ENDPOINT_TYPE=openai
|
|
||||||
#LETTA_LLM_MODEL=gpt-4o-mini
|
|
||||||
## Embeddings
|
|
||||||
#LETTA_EMBEDDING_ENDPOINT_TYPE=openai
|
|
||||||
#LETTA_EMBEDDING_MODEL=text-embedding-ada-002
|
|
||||||
|
|
||||||
|
|
||||||
##########################################################
|
##########################################################
|
||||||
Ollama configuration
|
Ollama configuration
|
||||||
##########################################################
|
##########################################################
|
||||||
## LLM Model
|
# OLLAMA_BASE_URL="http://host.docker.internal:11434"
|
||||||
#LETTA_LLM_ENDPOINT=http://host.docker.internal:11434
|
|
||||||
#LETTA_LLM_ENDPOINT_TYPE=ollama
|
|
||||||
#LETTA_LLM_MODEL=dolphin2.2-mistral:7b-q6_K
|
|
||||||
#LETTA_LLM_CONTEXT_WINDOW=8192
|
|
||||||
## Embeddings
|
|
||||||
#LETTA_EMBEDDING_ENDPOINT=http://host.docker.internal:11434
|
|
||||||
#LETTA_EMBEDDING_ENDPOINT_TYPE=ollama
|
|
||||||
#LETTA_EMBEDDING_MODEL=mxbai-embed-large
|
|
||||||
#LETTA_EMBEDDING_DIM=512
|
|
||||||
|
|
||||||
|
|
||||||
##########################################################
|
##########################################################
|
||||||
vLLM configuration
|
vLLM configuration
|
||||||
##########################################################
|
##########################################################
|
||||||
## LLM Model
|
# VLLM_API_BASE="http://host.docker.internal:8000"
|
||||||
#LETTA_LLM_ENDPOINT=http://host.docker.internal:8000
|
|
||||||
#LETTA_LLM_ENDPOINT_TYPE=vllm
|
|
||||||
#LETTA_LLM_MODEL=ehartford/dolphin-2.2.1-mistral-7b
|
|
||||||
#LETTA_LLM_CONTEXT_WINDOW=8192
|
|
||||||
|
|||||||
42
.github/workflows/letta-web-openapi-saftey.yml
vendored
42
.github/workflows/letta-web-openapi-saftey.yml
vendored
@@ -1,42 +0,0 @@
|
|||||||
name: "Letta Web OpenAPI Compatibility Checker"
|
|
||||||
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [ main ]
|
|
||||||
pull_request:
|
|
||||||
branches: [ main ]
|
|
||||||
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
validate-openapi:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
- name: "Setup Python, Poetry and Dependencies"
|
|
||||||
uses: packetcoders/action-setup-cache-python-poetry@main
|
|
||||||
with:
|
|
||||||
python-version: "3.12"
|
|
||||||
poetry-version: "1.8.2"
|
|
||||||
install-args: "-E dev"
|
|
||||||
- name: Checkout letta web
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
repository: letta-ai/letta-web
|
|
||||||
token: ${{ secrets.PULLER_TOKEN }}
|
|
||||||
path: letta-web
|
|
||||||
- name: Run OpenAPI schema generation
|
|
||||||
run: |
|
|
||||||
bash ./letta/server/generate_openapi_schema.sh
|
|
||||||
- name: Setup letta-web
|
|
||||||
working-directory: letta-web
|
|
||||||
run: npm ci
|
|
||||||
- name: Copy OpenAPI schema
|
|
||||||
working-directory: .
|
|
||||||
run: cp openapi_letta.json letta-web/libs/letta-agents-api/letta-agents-openapi.json
|
|
||||||
- name: Validate OpenAPI schema
|
|
||||||
working-directory: letta-web
|
|
||||||
run: |
|
|
||||||
npm run agents-api:generate
|
|
||||||
npm run type-check
|
|
||||||
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -6,6 +6,8 @@ env:
|
|||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||||
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
||||||
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
|
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
|
||||||
|
E2B_API_KEY: ${{ secrets.E2B_API_KEY }}
|
||||||
|
E2B_SANDBOX_TEMPLATE_ID: ${{ secrets.E2B_SANDBOX_TEMPLATE_ID }}
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@@ -61,7 +63,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
python-version: "3.12"
|
python-version: "3.12"
|
||||||
poetry-version: "1.8.2"
|
poetry-version: "1.8.2"
|
||||||
install-args: "-E dev -E postgres -E external-tools -E tests"
|
install-args: "-E dev -E postgres -E external-tools -E tests -E cloud-tool-sandbox"
|
||||||
- name: Migrate database
|
- name: Migrate database
|
||||||
env:
|
env:
|
||||||
LETTA_PG_PORT: 5432
|
LETTA_PG_PORT: 5432
|
||||||
|
|||||||
@@ -5,40 +5,45 @@ Revises: 3c683a662c82
|
|||||||
Create Date: 2024-12-05 16:46:51.258831
|
Create Date: 2024-12-05 16:46:51.258831
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = '08b2f8225812'
|
revision: str = "08b2f8225812"
|
||||||
down_revision: Union[str, None] = '3c683a662c82'
|
down_revision: Union[str, None] = "3c683a662c82"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.create_table('tools_agents',
|
op.create_table(
|
||||||
sa.Column('agent_id', sa.String(), nullable=False),
|
"tools_agents",
|
||||||
sa.Column('tool_id', sa.String(), nullable=False),
|
sa.Column("agent_id", sa.String(), nullable=False),
|
||||||
sa.Column('tool_name', sa.String(), nullable=False),
|
sa.Column("tool_id", sa.String(), nullable=False),
|
||||||
sa.Column('id', sa.String(), nullable=False),
|
sa.Column("tool_name", sa.String(), nullable=False),
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
sa.Column("id", sa.String(), nullable=False),
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||||
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False),
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||||
sa.Column('_created_by_id', sa.String(), nullable=True),
|
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||||
sa.Column('_last_updated_by_id', sa.String(), nullable=True),
|
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||||
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ),
|
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||||
sa.ForeignKeyConstraint(['tool_id'], ['tools.id'], name='fk_tool_id'),
|
sa.ForeignKeyConstraint(
|
||||||
sa.PrimaryKeyConstraint('agent_id', 'tool_id', 'tool_name', 'id'),
|
["agent_id"],
|
||||||
sa.UniqueConstraint('agent_id', 'tool_name', name='unique_tool_per_agent')
|
["agents.id"],
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["tool_id"], ["tools.id"], name="fk_tool_id"),
|
||||||
|
sa.PrimaryKeyConstraint("agent_id", "tool_id", "tool_name", "id"),
|
||||||
|
sa.UniqueConstraint("agent_id", "tool_name", name="unique_tool_per_agent"),
|
||||||
)
|
)
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.drop_table('tools_agents')
|
op.drop_table("tools_agents")
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|||||||
@@ -5,18 +5,19 @@ Revises: 4e88e702f85e
|
|||||||
Create Date: 2024-12-14 17:23:08.772554
|
Create Date: 2024-12-14 17:23:08.772554
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
from pgvector.sqlalchemy import Vector
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
from pgvector.sqlalchemy import Vector
|
||||||
from sqlalchemy.dialects import postgresql
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
from letta.orm.custom_columns import EmbeddingConfigColumn
|
from letta.orm.custom_columns import EmbeddingConfigColumn
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = '54dec07619c4'
|
revision: str = "54dec07619c4"
|
||||||
down_revision: Union[str, None] = '4e88e702f85e'
|
down_revision: Union[str, None] = "4e88e702f85e"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
@@ -24,82 +25,88 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.create_table(
|
op.create_table(
|
||||||
'agent_passages',
|
"agent_passages",
|
||||||
sa.Column('id', sa.String(), nullable=False),
|
sa.Column("id", sa.String(), nullable=False),
|
||||||
sa.Column('text', sa.String(), nullable=False),
|
sa.Column("text", sa.String(), nullable=False),
|
||||||
sa.Column('embedding_config', EmbeddingConfigColumn(), nullable=False),
|
sa.Column("embedding_config", EmbeddingConfigColumn(), nullable=False),
|
||||||
sa.Column('metadata_', sa.JSON(), nullable=False),
|
sa.Column("metadata_", sa.JSON(), nullable=False),
|
||||||
sa.Column('embedding', Vector(dim=4096), nullable=True),
|
sa.Column("embedding", Vector(dim=4096), nullable=True),
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||||
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False),
|
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||||
sa.Column('_created_by_id', sa.String(), nullable=True),
|
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||||
sa.Column('_last_updated_by_id', sa.String(), nullable=True),
|
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||||
sa.Column('organization_id', sa.String(), nullable=False),
|
sa.Column("organization_id", sa.String(), nullable=False),
|
||||||
sa.Column('agent_id', sa.String(), nullable=False),
|
sa.Column("agent_id", sa.String(), nullable=False),
|
||||||
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ondelete='CASCADE'),
|
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
|
||||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
sa.ForeignKeyConstraint(
|
||||||
sa.PrimaryKeyConstraint('id')
|
["organization_id"],
|
||||||
|
["organizations.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
)
|
)
|
||||||
op.create_index('agent_passages_org_idx', 'agent_passages', ['organization_id'], unique=False)
|
op.create_index("agent_passages_org_idx", "agent_passages", ["organization_id"], unique=False)
|
||||||
op.create_table(
|
op.create_table(
|
||||||
'source_passages',
|
"source_passages",
|
||||||
sa.Column('id', sa.String(), nullable=False),
|
sa.Column("id", sa.String(), nullable=False),
|
||||||
sa.Column('text', sa.String(), nullable=False),
|
sa.Column("text", sa.String(), nullable=False),
|
||||||
sa.Column('embedding_config', EmbeddingConfigColumn(), nullable=False),
|
sa.Column("embedding_config", EmbeddingConfigColumn(), nullable=False),
|
||||||
sa.Column('metadata_', sa.JSON(), nullable=False),
|
sa.Column("metadata_", sa.JSON(), nullable=False),
|
||||||
sa.Column('embedding', Vector(dim=4096), nullable=True),
|
sa.Column("embedding", Vector(dim=4096), nullable=True),
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||||
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False),
|
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||||
sa.Column('_created_by_id', sa.String(), nullable=True),
|
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||||
sa.Column('_last_updated_by_id', sa.String(), nullable=True),
|
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||||
sa.Column('organization_id', sa.String(), nullable=False),
|
sa.Column("organization_id", sa.String(), nullable=False),
|
||||||
sa.Column('file_id', sa.String(), nullable=True),
|
sa.Column("file_id", sa.String(), nullable=True),
|
||||||
sa.Column('source_id', sa.String(), nullable=False),
|
sa.Column("source_id", sa.String(), nullable=False),
|
||||||
sa.ForeignKeyConstraint(['file_id'], ['files.id'], ondelete='CASCADE'),
|
sa.ForeignKeyConstraint(["file_id"], ["files.id"], ondelete="CASCADE"),
|
||||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
sa.ForeignKeyConstraint(
|
||||||
sa.ForeignKeyConstraint(['source_id'], ['sources.id'], ondelete='CASCADE'),
|
["organization_id"],
|
||||||
sa.PrimaryKeyConstraint('id')
|
["organizations.id"],
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["source_id"], ["sources.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
)
|
)
|
||||||
op.create_index('source_passages_org_idx', 'source_passages', ['organization_id'], unique=False)
|
op.create_index("source_passages_org_idx", "source_passages", ["organization_id"], unique=False)
|
||||||
op.drop_table('passages')
|
op.drop_table("passages")
|
||||||
op.drop_constraint('files_source_id_fkey', 'files', type_='foreignkey')
|
op.drop_constraint("files_source_id_fkey", "files", type_="foreignkey")
|
||||||
op.create_foreign_key(None, 'files', 'sources', ['source_id'], ['id'], ondelete='CASCADE')
|
op.create_foreign_key(None, "files", "sources", ["source_id"], ["id"], ondelete="CASCADE")
|
||||||
op.drop_constraint('messages_agent_id_fkey', 'messages', type_='foreignkey')
|
op.drop_constraint("messages_agent_id_fkey", "messages", type_="foreignkey")
|
||||||
op.create_foreign_key(None, 'messages', 'agents', ['agent_id'], ['id'], ondelete='CASCADE')
|
op.create_foreign_key(None, "messages", "agents", ["agent_id"], ["id"], ondelete="CASCADE")
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.drop_constraint(None, 'messages', type_='foreignkey')
|
op.drop_constraint(None, "messages", type_="foreignkey")
|
||||||
op.create_foreign_key('messages_agent_id_fkey', 'messages', 'agents', ['agent_id'], ['id'])
|
op.create_foreign_key("messages_agent_id_fkey", "messages", "agents", ["agent_id"], ["id"])
|
||||||
op.drop_constraint(None, 'files', type_='foreignkey')
|
op.drop_constraint(None, "files", type_="foreignkey")
|
||||||
op.create_foreign_key('files_source_id_fkey', 'files', 'sources', ['source_id'], ['id'])
|
op.create_foreign_key("files_source_id_fkey", "files", "sources", ["source_id"], ["id"])
|
||||||
op.create_table(
|
op.create_table(
|
||||||
'passages',
|
"passages",
|
||||||
sa.Column('id', sa.VARCHAR(), autoincrement=False, nullable=False),
|
sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||||
sa.Column('text', sa.VARCHAR(), autoincrement=False, nullable=False),
|
sa.Column("text", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||||
sa.Column('file_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
sa.Column("file_id", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||||
sa.Column('agent_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
sa.Column("agent_id", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||||
sa.Column('source_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
sa.Column("source_id", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||||
sa.Column('embedding', Vector(dim=4096), autoincrement=False, nullable=True),
|
sa.Column("embedding", Vector(dim=4096), autoincrement=False, nullable=True),
|
||||||
sa.Column('embedding_config', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
sa.Column("embedding_config", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
||||||
sa.Column('metadata_', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
sa.Column("metadata_", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
||||||
sa.Column('created_at', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=False),
|
sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=False),
|
||||||
sa.Column('updated_at', postgresql.TIMESTAMP(timezone=True), server_default=sa.text('now()'), autoincrement=False, nullable=True),
|
sa.Column("updated_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True),
|
||||||
sa.Column('is_deleted', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False),
|
sa.Column("is_deleted", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False),
|
||||||
sa.Column('_created_by_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
sa.Column("_created_by_id", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||||
sa.Column('_last_updated_by_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
sa.Column("_last_updated_by_id", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||||
sa.Column('organization_id', sa.VARCHAR(), autoincrement=False, nullable=False),
|
sa.Column("organization_id", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||||
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], name='passages_agent_id_fkey'),
|
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], name="passages_agent_id_fkey"),
|
||||||
sa.ForeignKeyConstraint(['file_id'], ['files.id'], name='passages_file_id_fkey', ondelete='CASCADE'),
|
sa.ForeignKeyConstraint(["file_id"], ["files.id"], name="passages_file_id_fkey", ondelete="CASCADE"),
|
||||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], name='passages_organization_id_fkey'),
|
sa.ForeignKeyConstraint(["organization_id"], ["organizations.id"], name="passages_organization_id_fkey"),
|
||||||
sa.PrimaryKeyConstraint('id', name='passages_pkey')
|
sa.PrimaryKeyConstraint("id", name="passages_pkey"),
|
||||||
)
|
)
|
||||||
op.drop_index('source_passages_org_idx', table_name='source_passages')
|
op.drop_index("source_passages_org_idx", table_name="source_passages")
|
||||||
op.drop_table('source_passages')
|
op.drop_table("source_passages")
|
||||||
op.drop_index('agent_passages_org_idx', table_name='agent_passages')
|
op.drop_index("agent_passages_org_idx", table_name="agent_passages")
|
||||||
op.drop_table('agent_passages')
|
op.drop_table("agent_passages")
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|||||||
@@ -5,25 +5,27 @@ Revises: a91994b9752f
|
|||||||
Create Date: 2024-12-10 15:05:32.335519
|
Create Date: 2024-12-10 15:05:32.335519
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy.dialects import postgresql
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = 'c5d964280dff'
|
revision: str = "c5d964280dff"
|
||||||
down_revision: Union[str, None] = 'a91994b9752f'
|
down_revision: Union[str, None] = "a91994b9752f"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.add_column('passages', sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True))
|
op.add_column("passages", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
|
||||||
op.add_column('passages', sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False))
|
op.add_column("passages", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
|
||||||
op.add_column('passages', sa.Column('_created_by_id', sa.String(), nullable=True))
|
op.add_column("passages", sa.Column("_created_by_id", sa.String(), nullable=True))
|
||||||
op.add_column('passages', sa.Column('_last_updated_by_id', sa.String(), nullable=True))
|
op.add_column("passages", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
|
||||||
|
|
||||||
# Data migration step:
|
# Data migration step:
|
||||||
op.add_column("passages", sa.Column("organization_id", sa.String(), nullable=True))
|
op.add_column("passages", sa.Column("organization_id", sa.String(), nullable=True))
|
||||||
@@ -41,48 +43,32 @@ def upgrade() -> None:
|
|||||||
# Set `organization_id` as non-nullable after population
|
# Set `organization_id` as non-nullable after population
|
||||||
op.alter_column("passages", "organization_id", nullable=False)
|
op.alter_column("passages", "organization_id", nullable=False)
|
||||||
|
|
||||||
op.alter_column('passages', 'text',
|
op.alter_column("passages", "text", existing_type=sa.VARCHAR(), nullable=False)
|
||||||
existing_type=sa.VARCHAR(),
|
op.alter_column("passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
|
||||||
nullable=False)
|
op.alter_column("passages", "metadata_", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
|
||||||
op.alter_column('passages', 'embedding_config',
|
op.alter_column("passages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False)
|
||||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
op.drop_index("passage_idx_user", table_name="passages")
|
||||||
nullable=False)
|
op.create_foreign_key(None, "passages", "organizations", ["organization_id"], ["id"])
|
||||||
op.alter_column('passages', 'metadata_',
|
op.create_foreign_key(None, "passages", "agents", ["agent_id"], ["id"])
|
||||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
op.create_foreign_key(None, "passages", "files", ["file_id"], ["id"], ondelete="CASCADE")
|
||||||
nullable=False)
|
op.drop_column("passages", "user_id")
|
||||||
op.alter_column('passages', 'created_at',
|
|
||||||
existing_type=postgresql.TIMESTAMP(timezone=True),
|
|
||||||
nullable=False)
|
|
||||||
op.drop_index('passage_idx_user', table_name='passages')
|
|
||||||
op.create_foreign_key(None, 'passages', 'organizations', ['organization_id'], ['id'])
|
|
||||||
op.create_foreign_key(None, 'passages', 'agents', ['agent_id'], ['id'])
|
|
||||||
op.create_foreign_key(None, 'passages', 'files', ['file_id'], ['id'], ondelete='CASCADE')
|
|
||||||
op.drop_column('passages', 'user_id')
|
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.add_column('passages', sa.Column('user_id', sa.VARCHAR(), autoincrement=False, nullable=False))
|
op.add_column("passages", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False))
|
||||||
op.drop_constraint(None, 'passages', type_='foreignkey')
|
op.drop_constraint(None, "passages", type_="foreignkey")
|
||||||
op.drop_constraint(None, 'passages', type_='foreignkey')
|
op.drop_constraint(None, "passages", type_="foreignkey")
|
||||||
op.drop_constraint(None, 'passages', type_='foreignkey')
|
op.drop_constraint(None, "passages", type_="foreignkey")
|
||||||
op.create_index('passage_idx_user', 'passages', ['user_id', 'agent_id', 'file_id'], unique=False)
|
op.create_index("passage_idx_user", "passages", ["user_id", "agent_id", "file_id"], unique=False)
|
||||||
op.alter_column('passages', 'created_at',
|
op.alter_column("passages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=True)
|
||||||
existing_type=postgresql.TIMESTAMP(timezone=True),
|
op.alter_column("passages", "metadata_", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
|
||||||
nullable=True)
|
op.alter_column("passages", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
|
||||||
op.alter_column('passages', 'metadata_',
|
op.alter_column("passages", "text", existing_type=sa.VARCHAR(), nullable=True)
|
||||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
op.drop_column("passages", "organization_id")
|
||||||
nullable=True)
|
op.drop_column("passages", "_last_updated_by_id")
|
||||||
op.alter_column('passages', 'embedding_config',
|
op.drop_column("passages", "_created_by_id")
|
||||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
op.drop_column("passages", "is_deleted")
|
||||||
nullable=True)
|
op.drop_column("passages", "updated_at")
|
||||||
op.alter_column('passages', 'text',
|
|
||||||
existing_type=sa.VARCHAR(),
|
|
||||||
nullable=True)
|
|
||||||
op.drop_column('passages', 'organization_id')
|
|
||||||
op.drop_column('passages', '_last_updated_by_id')
|
|
||||||
op.drop_column('passages', '_created_by_id')
|
|
||||||
op.drop_column('passages', 'is_deleted')
|
|
||||||
op.drop_column('passages', 'updated_at')
|
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|||||||
@@ -4,10 +4,7 @@ import uuid
|
|||||||
from letta import create_client
|
from letta import create_client
|
||||||
from letta.schemas.letta_message import ToolCallMessage
|
from letta.schemas.letta_message import ToolCallMessage
|
||||||
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
||||||
from tests.helpers.endpoints_helper import (
|
from tests.helpers.endpoints_helper import assert_invoked_send_message_with_keyword, setup_agent
|
||||||
assert_invoked_send_message_with_keyword,
|
|
||||||
setup_agent,
|
|
||||||
)
|
|
||||||
from tests.helpers.utils import cleanup
|
from tests.helpers.utils import cleanup
|
||||||
from tests.test_model_letta_perfomance import llm_config_dir
|
from tests.test_model_letta_perfomance import llm_config_dir
|
||||||
|
|
||||||
|
|||||||
@@ -12,13 +12,7 @@ from letta.schemas.file import FileMetadata
|
|||||||
from letta.schemas.job import Job
|
from letta.schemas.job import Job
|
||||||
from letta.schemas.letta_message import LettaMessage
|
from letta.schemas.letta_message import LettaMessage
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.memory import (
|
from letta.schemas.memory import ArchivalMemorySummary, BasicBlockMemory, ChatMemory, Memory, RecallMemorySummary
|
||||||
ArchivalMemorySummary,
|
|
||||||
BasicBlockMemory,
|
|
||||||
ChatMemory,
|
|
||||||
Memory,
|
|
||||||
RecallMemorySummary,
|
|
||||||
)
|
|
||||||
from letta.schemas.message import Message
|
from letta.schemas.message import Message
|
||||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||||
from letta.schemas.organization import Organization
|
from letta.schemas.organization import Organization
|
||||||
|
|||||||
@@ -33,34 +33,21 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
|||||||
from letta.schemas.enums import MessageRole
|
from letta.schemas.enums import MessageRole
|
||||||
from letta.schemas.memory import ContextWindowOverview, Memory
|
from letta.schemas.memory import ContextWindowOverview, Memory
|
||||||
from letta.schemas.message import Message
|
from letta.schemas.message import Message
|
||||||
from letta.schemas.openai.chat_completion_request import (
|
from letta.schemas.openai.chat_completion_request import Tool as ChatCompletionRequestTool
|
||||||
Tool as ChatCompletionRequestTool,
|
|
||||||
)
|
|
||||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||||
from letta.schemas.openai.chat_completion_response import (
|
from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage
|
||||||
Message as ChatCompletionMessage,
|
|
||||||
)
|
|
||||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||||
from letta.schemas.tool import Tool
|
from letta.schemas.tool import Tool
|
||||||
from letta.schemas.tool_rule import TerminalToolRule
|
from letta.schemas.tool_rule import TerminalToolRule
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
from letta.schemas.usage import LettaUsageStatistics
|
||||||
from letta.services.agent_manager import AgentManager
|
from letta.services.agent_manager import AgentManager
|
||||||
from letta.services.block_manager import BlockManager
|
from letta.services.block_manager import BlockManager
|
||||||
from letta.services.helpers.agent_manager_helper import (
|
from letta.services.helpers.agent_manager_helper import check_supports_structured_output, compile_memory_metadata_block
|
||||||
check_supports_structured_output,
|
|
||||||
compile_memory_metadata_block,
|
|
||||||
)
|
|
||||||
from letta.services.message_manager import MessageManager
|
from letta.services.message_manager import MessageManager
|
||||||
from letta.services.passage_manager import PassageManager
|
from letta.services.passage_manager import PassageManager
|
||||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||||
from letta.streaming_interface import StreamingRefreshCLIInterface
|
from letta.streaming_interface import StreamingRefreshCLIInterface
|
||||||
from letta.system import (
|
from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message
|
||||||
get_heartbeat,
|
|
||||||
get_token_limit_warning,
|
|
||||||
package_function_response,
|
|
||||||
package_summarize_message,
|
|
||||||
package_user_message,
|
|
||||||
)
|
|
||||||
from letta.utils import (
|
from letta.utils import (
|
||||||
count_tokens,
|
count_tokens,
|
||||||
get_friendly_error_msg,
|
get_friendly_error_msg,
|
||||||
|
|||||||
@@ -10,12 +10,7 @@ import letta.utils as utils
|
|||||||
from letta import create_client
|
from letta import create_client
|
||||||
from letta.agent import Agent, save_agent
|
from letta.agent import Agent, save_agent
|
||||||
from letta.config import LettaConfig
|
from letta.config import LettaConfig
|
||||||
from letta.constants import (
|
from letta.constants import CLI_WARNING_PREFIX, CORE_MEMORY_BLOCK_CHAR_LIMIT, LETTA_DIR, MIN_CONTEXT_WINDOW
|
||||||
CLI_WARNING_PREFIX,
|
|
||||||
CORE_MEMORY_BLOCK_CHAR_LIMIT,
|
|
||||||
LETTA_DIR,
|
|
||||||
MIN_CONTEXT_WINDOW,
|
|
||||||
)
|
|
||||||
from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL
|
from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.schemas.enums import OptionState
|
from letta.schemas.enums import OptionState
|
||||||
@@ -23,9 +18,7 @@ from letta.schemas.memory import ChatMemory, Memory
|
|||||||
from letta.server.server import logger as server_logger
|
from letta.server.server import logger as server_logger
|
||||||
|
|
||||||
# from letta.interface import CLIInterface as interface # for printing to terminal
|
# from letta.interface import CLIInterface as interface # for printing to terminal
|
||||||
from letta.streaming_interface import (
|
from letta.streaming_interface import StreamingRefreshCLIInterface as interface # for printing to terminal
|
||||||
StreamingRefreshCLIInterface as interface, # for printing to terminal
|
|
||||||
)
|
|
||||||
from letta.utils import open_folder_in_explorer, printd
|
from letta.utils import open_folder_in_explorer, printd
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|||||||
@@ -5,14 +5,7 @@ from typing import Callable, Dict, Generator, List, Optional, Union
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
import letta.utils
|
import letta.utils
|
||||||
from letta.constants import (
|
from letta.constants import ADMIN_PREFIX, BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA, FUNCTION_RETURN_CHAR_LIMIT
|
||||||
ADMIN_PREFIX,
|
|
||||||
BASE_MEMORY_TOOLS,
|
|
||||||
BASE_TOOLS,
|
|
||||||
DEFAULT_HUMAN,
|
|
||||||
DEFAULT_PERSONA,
|
|
||||||
FUNCTION_RETURN_CHAR_LIMIT,
|
|
||||||
)
|
|
||||||
from letta.data_sources.connectors import DataConnector
|
from letta.data_sources.connectors import DataConnector
|
||||||
from letta.functions.functions import parse_source_code
|
from letta.functions.functions import parse_source_code
|
||||||
from letta.orm.errors import NoResultFound
|
from letta.orm.errors import NoResultFound
|
||||||
@@ -27,13 +20,7 @@ from letta.schemas.job import Job
|
|||||||
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
|
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
|
||||||
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
|
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.memory import (
|
from letta.schemas.memory import ArchivalMemorySummary, ChatMemory, CreateArchivalMemory, Memory, RecallMemorySummary
|
||||||
ArchivalMemorySummary,
|
|
||||||
ChatMemory,
|
|
||||||
CreateArchivalMemory,
|
|
||||||
Memory,
|
|
||||||
RecallMemorySummary,
|
|
||||||
)
|
|
||||||
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||||
from letta.schemas.openai.chat_completions import ToolCall
|
from letta.schemas.openai.chat_completions import ToolCall
|
||||||
from letta.schemas.organization import Organization
|
from letta.schemas.organization import Organization
|
||||||
|
|||||||
@@ -7,11 +7,7 @@ from httpx_sse import SSEError, connect_sse
|
|||||||
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
|
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
|
||||||
from letta.errors import LLMError
|
from letta.errors import LLMError
|
||||||
from letta.schemas.enums import MessageStreamStatus
|
from letta.schemas.enums import MessageStreamStatus
|
||||||
from letta.schemas.letta_message import (
|
from letta.schemas.letta_message import ReasoningMessage, ToolCallMessage, ToolReturnMessage
|
||||||
ToolCallMessage,
|
|
||||||
ToolReturnMessage,
|
|
||||||
ReasoningMessage,
|
|
||||||
)
|
|
||||||
from letta.schemas.letta_response import LettaStreamingResponse
|
from letta.schemas.letta_response import LettaStreamingResponse
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
from letta.schemas.usage import LettaUsageStatistics
|
||||||
|
|
||||||
|
|||||||
@@ -5,10 +5,7 @@ from typing import Optional
|
|||||||
from IPython.display import HTML, display
|
from IPython.display import HTML, display
|
||||||
from sqlalchemy.testing.plugin.plugin_base import warnings
|
from sqlalchemy.testing.plugin.plugin_base import warnings
|
||||||
|
|
||||||
from letta.local_llm.constants import (
|
from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL, INNER_THOUGHTS_CLI_SYMBOL
|
||||||
ASSISTANT_MESSAGE_CLI_SYMBOL,
|
|
||||||
INNER_THOUGHTS_CLI_SYMBOL,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pprint(messages):
|
def pprint(messages):
|
||||||
|
|||||||
@@ -2,11 +2,7 @@ from typing import Dict, Iterator, List, Tuple
|
|||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
from letta.data_sources.connectors_helper import (
|
from letta.data_sources.connectors_helper import assert_all_files_exist_locally, extract_metadata_from_files, get_filenames_in_dir
|
||||||
assert_all_files_exist_locally,
|
|
||||||
extract_metadata_from_files,
|
|
||||||
get_filenames_in_dir,
|
|
||||||
)
|
|
||||||
from letta.embeddings import embedding_model
|
from letta.embeddings import embedding_model
|
||||||
from letta.schemas.file import FileMetadata
|
from letta.schemas.file import FileMetadata
|
||||||
from letta.schemas.passage import Passage
|
from letta.schemas.passage import Passage
|
||||||
@@ -14,6 +10,7 @@ from letta.schemas.source import Source
|
|||||||
from letta.services.passage_manager import PassageManager
|
from letta.services.passage_manager import PassageManager
|
||||||
from letta.services.source_manager import SourceManager
|
from letta.services.source_manager import SourceManager
|
||||||
|
|
||||||
|
|
||||||
class DataConnector:
|
class DataConnector:
|
||||||
"""
|
"""
|
||||||
Base class for data connectors that can be extended to generate files and passages from a custom data source.
|
Base class for data connectors that can be extended to generate files and passages from a custom data source.
|
||||||
|
|||||||
@@ -4,11 +4,7 @@ from typing import Any, List, Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
from letta.constants import (
|
from letta.constants import EMBEDDING_TO_TOKENIZER_DEFAULT, EMBEDDING_TO_TOKENIZER_MAP, MAX_EMBEDDING_DIM
|
||||||
EMBEDDING_TO_TOKENIZER_DEFAULT,
|
|
||||||
EMBEDDING_TO_TOKENIZER_MAP,
|
|
||||||
MAX_EMBEDDING_DIM,
|
|
||||||
)
|
|
||||||
from letta.schemas.embedding_config import EmbeddingConfig
|
from letta.schemas.embedding_config import EmbeddingConfig
|
||||||
from letta.utils import is_valid_url, printd
|
from letta.utils import is_valid_url, printd
|
||||||
|
|
||||||
|
|||||||
@@ -52,12 +52,10 @@ class LettaConfigurationError(LettaError):
|
|||||||
|
|
||||||
class LettaAgentNotFoundError(LettaError):
|
class LettaAgentNotFoundError(LettaError):
|
||||||
"""Error raised when an agent is not found."""
|
"""Error raised when an agent is not found."""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class LettaUserNotFoundError(LettaError):
|
class LettaUserNotFoundError(LettaError):
|
||||||
"""Error raised when a user is not found."""
|
"""Error raised when a user is not found."""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class LLMError(LettaError):
|
class LLMError(LettaError):
|
||||||
|
|||||||
@@ -4,10 +4,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from letta.constants import (
|
from letta.constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE
|
||||||
MESSAGE_CHATGPT_FUNCTION_MODEL,
|
|
||||||
MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE,
|
|
||||||
)
|
|
||||||
from letta.llm_api.llm_api_tools import create
|
from letta.llm_api.llm_api_tools import create
|
||||||
from letta.schemas.message import Message
|
from letta.schemas.message import Message
|
||||||
from letta.utils import json_dumps, json_loads
|
from letta.utils import json_dumps, json_loads
|
||||||
|
|||||||
@@ -396,44 +396,6 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
|
|||||||
return schema
|
return schema
|
||||||
|
|
||||||
|
|
||||||
def generate_schema_from_args_schema_v1(
|
|
||||||
args_schema: Type[V1BaseModel], name: Optional[str] = None, description: Optional[str] = None, append_heartbeat: bool = True
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
properties = {}
|
|
||||||
required = []
|
|
||||||
for field_name, field in args_schema.__fields__.items():
|
|
||||||
if field.type_ == str:
|
|
||||||
field_type = "string"
|
|
||||||
elif field.type_ == int:
|
|
||||||
field_type = "integer"
|
|
||||||
elif field.type_ == bool:
|
|
||||||
field_type = "boolean"
|
|
||||||
else:
|
|
||||||
field_type = field.type_.__name__
|
|
||||||
|
|
||||||
properties[field_name] = {
|
|
||||||
"type": field_type,
|
|
||||||
"description": field.field_info.description,
|
|
||||||
}
|
|
||||||
if field.required:
|
|
||||||
required.append(field_name)
|
|
||||||
|
|
||||||
function_call_json = {
|
|
||||||
"name": name,
|
|
||||||
"description": description,
|
|
||||||
"parameters": {"type": "object", "properties": properties, "required": required},
|
|
||||||
}
|
|
||||||
|
|
||||||
if append_heartbeat:
|
|
||||||
function_call_json["parameters"]["properties"]["request_heartbeat"] = {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
|
||||||
}
|
|
||||||
function_call_json["parameters"]["required"].append("request_heartbeat")
|
|
||||||
|
|
||||||
return function_call_json
|
|
||||||
|
|
||||||
|
|
||||||
def generate_schema_from_args_schema_v2(
|
def generate_schema_from_args_schema_v2(
|
||||||
args_schema: Type[BaseModel], name: Optional[str] = None, description: Optional[str] = None, append_heartbeat: bool = True
|
args_schema: Type[BaseModel], name: Optional[str] = None, description: Optional[str] = None, append_heartbeat: bool = True
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@@ -441,19 +403,8 @@ def generate_schema_from_args_schema_v2(
|
|||||||
required = []
|
required = []
|
||||||
for field_name, field in args_schema.model_fields.items():
|
for field_name, field in args_schema.model_fields.items():
|
||||||
field_type_annotation = field.annotation
|
field_type_annotation = field.annotation
|
||||||
if field_type_annotation == str:
|
properties[field_name] = type_to_json_schema_type(field_type_annotation)
|
||||||
field_type = "string"
|
properties[field_name]["description"] = field.description
|
||||||
elif field_type_annotation == int:
|
|
||||||
field_type = "integer"
|
|
||||||
elif field_type_annotation == bool:
|
|
||||||
field_type = "boolean"
|
|
||||||
else:
|
|
||||||
field_type = field_type_annotation.__name__
|
|
||||||
|
|
||||||
properties[field_name] = {
|
|
||||||
"type": field_type,
|
|
||||||
"description": field.description,
|
|
||||||
}
|
|
||||||
if field.is_required():
|
if field.is_required():
|
||||||
required.append(field_name)
|
required.append(field_name)
|
||||||
|
|
||||||
|
|||||||
@@ -4,13 +4,7 @@ from typing import List, Optional, Union
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from letta.schemas.enums import ToolRuleType
|
from letta.schemas.enums import ToolRuleType
|
||||||
from letta.schemas.tool_rule import (
|
from letta.schemas.tool_rule import BaseToolRule, ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
|
||||||
BaseToolRule,
|
|
||||||
ChildToolRule,
|
|
||||||
ConditionalToolRule,
|
|
||||||
InitToolRule,
|
|
||||||
TerminalToolRule,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolRuleValidationError(Exception):
|
class ToolRuleValidationError(Exception):
|
||||||
@@ -50,7 +44,6 @@ class ToolRulesSolver(BaseModel):
|
|||||||
assert isinstance(rule, TerminalToolRule)
|
assert isinstance(rule, TerminalToolRule)
|
||||||
self.terminal_tool_rules.append(rule)
|
self.terminal_tool_rules.append(rule)
|
||||||
|
|
||||||
|
|
||||||
def update_tool_usage(self, tool_name: str):
|
def update_tool_usage(self, tool_name: str):
|
||||||
"""Update the internal state to track the last tool called."""
|
"""Update the internal state to track the last tool called."""
|
||||||
self.last_tool_name = tool_name
|
self.last_tool_name = tool_name
|
||||||
@@ -88,7 +81,7 @@ class ToolRulesSolver(BaseModel):
|
|||||||
return any(rule.tool_name == tool_name for rule in self.tool_rules)
|
return any(rule.tool_name == tool_name for rule in self.tool_rules)
|
||||||
|
|
||||||
def validate_conditional_tool(self, rule: ConditionalToolRule):
|
def validate_conditional_tool(self, rule: ConditionalToolRule):
|
||||||
'''
|
"""
|
||||||
Validate a conditional tool rule
|
Validate a conditional tool rule
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -96,13 +89,13 @@ class ToolRulesSolver(BaseModel):
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ToolRuleValidationError: If the rule is invalid
|
ToolRuleValidationError: If the rule is invalid
|
||||||
'''
|
"""
|
||||||
if len(rule.child_output_mapping) == 0:
|
if len(rule.child_output_mapping) == 0:
|
||||||
raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.")
|
raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str:
|
def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str:
|
||||||
'''
|
"""
|
||||||
Parse function response to determine which child tool to use based on the mapping
|
Parse function response to determine which child tool to use based on the mapping
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -111,7 +104,7 @@ class ToolRulesSolver(BaseModel):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The name of the child tool to use next
|
str: The name of the child tool to use next
|
||||||
'''
|
"""
|
||||||
json_response = json.loads(last_function_response)
|
json_response = json.loads(last_function_response)
|
||||||
function_output = json_response["message"]
|
function_output = json_response["message"]
|
||||||
|
|
||||||
|
|||||||
@@ -5,10 +5,7 @@ from typing import List, Optional
|
|||||||
from colorama import Fore, Style, init
|
from colorama import Fore, Style, init
|
||||||
|
|
||||||
from letta.constants import CLI_WARNING_PREFIX
|
from letta.constants import CLI_WARNING_PREFIX
|
||||||
from letta.local_llm.constants import (
|
from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL, INNER_THOUGHTS_CLI_SYMBOL
|
||||||
ASSISTANT_MESSAGE_CLI_SYMBOL,
|
|
||||||
INNER_THOUGHTS_CLI_SYMBOL,
|
|
||||||
)
|
|
||||||
from letta.schemas.message import Message
|
from letta.schemas.message import Message
|
||||||
from letta.utils import json_loads, printd
|
from letta.utils import json_loads, printd
|
||||||
|
|
||||||
|
|||||||
@@ -5,11 +5,7 @@ from typing import List, Optional, Union
|
|||||||
from letta.llm_api.helpers import make_post_request
|
from letta.llm_api.helpers import make_post_request
|
||||||
from letta.schemas.message import Message
|
from letta.schemas.message import Message
|
||||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
|
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
|
||||||
from letta.schemas.openai.chat_completion_response import (
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall
|
||||||
ChatCompletionResponse,
|
|
||||||
Choice,
|
|
||||||
FunctionCall,
|
|
||||||
)
|
|
||||||
from letta.schemas.openai.chat_completion_response import (
|
from letta.schemas.openai.chat_completion_response import (
|
||||||
Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype
|
Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype
|
||||||
)
|
)
|
||||||
@@ -102,13 +98,9 @@ def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]:
|
|||||||
formatted_tools = []
|
formatted_tools = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
formatted_tool = {
|
formatted_tool = {
|
||||||
"name" : tool.function.name,
|
"name": tool.function.name,
|
||||||
"description" : tool.function.description,
|
"description": tool.function.description,
|
||||||
"input_schema" : tool.function.parameters or {
|
"input_schema": tool.function.parameters or {"type": "object", "properties": {}, "required": []},
|
||||||
"type": "object",
|
|
||||||
"properties": {},
|
|
||||||
"required": []
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
formatted_tools.append(formatted_tool)
|
formatted_tools.append(formatted_tool)
|
||||||
|
|
||||||
@@ -346,7 +338,7 @@ def anthropic_chat_completions_request(
|
|||||||
data["tool_choice"] = {
|
data["tool_choice"] = {
|
||||||
"type": "tool", # Changed from "function" to "tool"
|
"type": "tool", # Changed from "function" to "tool"
|
||||||
"name": anthropic_tools[0]["name"], # Directly specify name without nested "function" object
|
"name": anthropic_tools[0]["name"], # Directly specify name without nested "function" object
|
||||||
"disable_parallel_tool_use": True # Force single tool use
|
"disable_parallel_tool_use": True, # Force single tool use
|
||||||
}
|
}
|
||||||
|
|
||||||
# Move 'system' to the top level
|
# Move 'system' to the top level
|
||||||
|
|||||||
@@ -7,11 +7,7 @@ import requests
|
|||||||
from letta.local_llm.utils import count_tokens
|
from letta.local_llm.utils import count_tokens
|
||||||
from letta.schemas.message import Message
|
from letta.schemas.message import Message
|
||||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
|
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
|
||||||
from letta.schemas.openai.chat_completion_response import (
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall
|
||||||
ChatCompletionResponse,
|
|
||||||
Choice,
|
|
||||||
FunctionCall,
|
|
||||||
)
|
|
||||||
from letta.schemas.openai.chat_completion_response import (
|
from letta.schemas.openai.chat_completion_response import (
|
||||||
Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype
|
Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype
|
||||||
)
|
)
|
||||||
@@ -276,10 +272,7 @@ def convert_tools_to_cohere_format(tools: List[Tool], inner_thoughts_in_kwargs:
|
|||||||
if inner_thoughts_in_kwargs:
|
if inner_thoughts_in_kwargs:
|
||||||
# NOTE: since Cohere doesn't allow "text" in the response when a tool call happens, if we want
|
# NOTE: since Cohere doesn't allow "text" in the response when a tool call happens, if we want
|
||||||
# a simultaneous CoT + tool call we need to put it inside a kwarg
|
# a simultaneous CoT + tool call we need to put it inside a kwarg
|
||||||
from letta.local_llm.constants import (
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||||
INNER_THOUGHTS_KWARG,
|
|
||||||
INNER_THOUGHTS_KWARG_DESCRIPTION,
|
|
||||||
)
|
|
||||||
|
|
||||||
for cohere_tool in tools_dict_list:
|
for cohere_tool in tools_dict_list:
|
||||||
cohere_tool["parameter_definitions"][INNER_THOUGHTS_KWARG] = {
|
cohere_tool["parameter_definitions"][INNER_THOUGHTS_KWARG] = {
|
||||||
|
|||||||
@@ -8,14 +8,7 @@ from letta.llm_api.helpers import make_post_request
|
|||||||
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
||||||
from letta.local_llm.utils import count_tokens
|
from letta.local_llm.utils import count_tokens
|
||||||
from letta.schemas.openai.chat_completion_request import Tool
|
from letta.schemas.openai.chat_completion_request import Tool
|
||||||
from letta.schemas.openai.chat_completion_response import (
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
|
||||||
ChatCompletionResponse,
|
|
||||||
Choice,
|
|
||||||
FunctionCall,
|
|
||||||
Message,
|
|
||||||
ToolCall,
|
|
||||||
UsageStatistics,
|
|
||||||
)
|
|
||||||
from letta.utils import get_tool_call_id, get_utc_time, json_dumps
|
from letta.utils import get_tool_call_id, get_utc_time, json_dumps
|
||||||
|
|
||||||
|
|
||||||
@@ -230,10 +223,7 @@ def convert_tools_to_google_ai_format(tools: List[Tool], inner_thoughts_in_kwarg
|
|||||||
param_fields["type"] = param_fields["type"].upper()
|
param_fields["type"] = param_fields["type"].upper()
|
||||||
# Add inner thoughts
|
# Add inner thoughts
|
||||||
if inner_thoughts_in_kwargs:
|
if inner_thoughts_in_kwargs:
|
||||||
from letta.local_llm.constants import (
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||||
INNER_THOUGHTS_KWARG,
|
|
||||||
INNER_THOUGHTS_KWARG_DESCRIPTION,
|
|
||||||
)
|
|
||||||
|
|
||||||
func["parameters"]["properties"][INNER_THOUGHTS_KWARG] = {
|
func["parameters"]["properties"][INNER_THOUGHTS_KWARG] = {
|
||||||
"type": "STRING",
|
"type": "STRING",
|
||||||
|
|||||||
@@ -8,38 +8,22 @@ from letta.constants import CLI_WARNING_PREFIX
|
|||||||
from letta.errors import LettaConfigurationError, RateLimitExceededError
|
from letta.errors import LettaConfigurationError, RateLimitExceededError
|
||||||
from letta.llm_api.anthropic import anthropic_chat_completions_request
|
from letta.llm_api.anthropic import anthropic_chat_completions_request
|
||||||
from letta.llm_api.azure_openai import azure_openai_chat_completions_request
|
from letta.llm_api.azure_openai import azure_openai_chat_completions_request
|
||||||
from letta.llm_api.google_ai import (
|
from letta.llm_api.google_ai import convert_tools_to_google_ai_format, google_ai_chat_completions_request
|
||||||
convert_tools_to_google_ai_format,
|
from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_inner_thoughts_from_kwargs
|
||||||
google_ai_chat_completions_request,
|
|
||||||
)
|
|
||||||
from letta.llm_api.helpers import (
|
|
||||||
add_inner_thoughts_to_functions,
|
|
||||||
unpack_all_inner_thoughts_from_kwargs,
|
|
||||||
)
|
|
||||||
from letta.llm_api.openai import (
|
from letta.llm_api.openai import (
|
||||||
build_openai_chat_completions_request,
|
build_openai_chat_completions_request,
|
||||||
openai_chat_completions_process_stream,
|
openai_chat_completions_process_stream,
|
||||||
openai_chat_completions_request,
|
openai_chat_completions_request,
|
||||||
)
|
)
|
||||||
from letta.local_llm.chat_completion_proxy import get_chat_completion
|
from letta.local_llm.chat_completion_proxy import get_chat_completion
|
||||||
from letta.local_llm.constants import (
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||||
INNER_THOUGHTS_KWARG,
|
|
||||||
INNER_THOUGHTS_KWARG_DESCRIPTION,
|
|
||||||
)
|
|
||||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.message import Message
|
from letta.schemas.message import Message
|
||||||
from letta.schemas.openai.chat_completion_request import (
|
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool, cast_message_to_subtype
|
||||||
ChatCompletionRequest,
|
|
||||||
Tool,
|
|
||||||
cast_message_to_subtype,
|
|
||||||
)
|
|
||||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||||
from letta.settings import ModelSettings
|
from letta.settings import ModelSettings
|
||||||
from letta.streaming_interface import (
|
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
|
||||||
AgentChunkStreamingInterface,
|
|
||||||
AgentRefreshStreamingInterface,
|
|
||||||
)
|
|
||||||
|
|
||||||
LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local", "groq"]
|
LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local", "groq"]
|
||||||
|
|
||||||
|
|||||||
@@ -9,28 +9,15 @@ from httpx_sse._exceptions import SSEError
|
|||||||
|
|
||||||
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
|
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
|
||||||
from letta.errors import LLMError
|
from letta.errors import LLMError
|
||||||
from letta.llm_api.helpers import (
|
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request
|
||||||
add_inner_thoughts_to_functions,
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||||
convert_to_structured_output,
|
|
||||||
make_post_request,
|
|
||||||
)
|
|
||||||
from letta.local_llm.constants import (
|
|
||||||
INNER_THOUGHTS_KWARG,
|
|
||||||
INNER_THOUGHTS_KWARG_DESCRIPTION,
|
|
||||||
)
|
|
||||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.message import Message as _Message
|
from letta.schemas.message import Message as _Message
|
||||||
from letta.schemas.message import MessageRole as _MessageRole
|
from letta.schemas.message import MessageRole as _MessageRole
|
||||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||||
from letta.schemas.openai.chat_completion_request import (
|
from letta.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall
|
||||||
FunctionCall as ToolFunctionChoiceFunctionCall,
|
from letta.schemas.openai.chat_completion_request import Tool, ToolFunctionChoice, cast_message_to_subtype
|
||||||
)
|
|
||||||
from letta.schemas.openai.chat_completion_request import (
|
|
||||||
Tool,
|
|
||||||
ToolFunctionChoice,
|
|
||||||
cast_message_to_subtype,
|
|
||||||
)
|
|
||||||
from letta.schemas.openai.chat_completion_response import (
|
from letta.schemas.openai.chat_completion_response import (
|
||||||
ChatCompletionChunkResponse,
|
ChatCompletionChunkResponse,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
@@ -41,10 +28,7 @@ from letta.schemas.openai.chat_completion_response import (
|
|||||||
UsageStatistics,
|
UsageStatistics,
|
||||||
)
|
)
|
||||||
from letta.schemas.openai.embedding_response import EmbeddingResponse
|
from letta.schemas.openai.embedding_response import EmbeddingResponse
|
||||||
from letta.streaming_interface import (
|
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
|
||||||
AgentChunkStreamingInterface,
|
|
||||||
AgentRefreshStreamingInterface,
|
|
||||||
)
|
|
||||||
from letta.utils import get_tool_call_id, smart_urljoin
|
from letta.utils import get_tool_call_id, smart_urljoin
|
||||||
|
|
||||||
OPENAI_SSE_DONE = "[DONE]"
|
OPENAI_SSE_DONE = "[DONE]"
|
||||||
|
|||||||
@@ -8,10 +8,7 @@ from letta.constants import CLI_WARNING_PREFIX
|
|||||||
from letta.errors import LocalLLMConnectionError, LocalLLMError
|
from letta.errors import LocalLLMConnectionError, LocalLLMError
|
||||||
from letta.local_llm.constants import DEFAULT_WRAPPER
|
from letta.local_llm.constants import DEFAULT_WRAPPER
|
||||||
from letta.local_llm.function_parser import patch_function
|
from letta.local_llm.function_parser import patch_function
|
||||||
from letta.local_llm.grammars.gbnf_grammar_generator import (
|
from letta.local_llm.grammars.gbnf_grammar_generator import create_dynamic_model_from_function, generate_gbnf_grammar_and_documentation
|
||||||
create_dynamic_model_from_function,
|
|
||||||
generate_gbnf_grammar_and_documentation,
|
|
||||||
)
|
|
||||||
from letta.local_llm.koboldcpp.api import get_koboldcpp_completion
|
from letta.local_llm.koboldcpp.api import get_koboldcpp_completion
|
||||||
from letta.local_llm.llamacpp.api import get_llamacpp_completion
|
from letta.local_llm.llamacpp.api import get_llamacpp_completion
|
||||||
from letta.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper
|
from letta.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper
|
||||||
@@ -20,17 +17,9 @@ from letta.local_llm.ollama.api import get_ollama_completion
|
|||||||
from letta.local_llm.utils import count_tokens, get_available_wrappers
|
from letta.local_llm.utils import count_tokens, get_available_wrappers
|
||||||
from letta.local_llm.vllm.api import get_vllm_completion
|
from letta.local_llm.vllm.api import get_vllm_completion
|
||||||
from letta.local_llm.webui.api import get_webui_completion
|
from letta.local_llm.webui.api import get_webui_completion
|
||||||
from letta.local_llm.webui.legacy_api import (
|
from letta.local_llm.webui.legacy_api import get_webui_completion as get_webui_completion_legacy
|
||||||
get_webui_completion as get_webui_completion_legacy,
|
|
||||||
)
|
|
||||||
from letta.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
|
from letta.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
|
||||||
from letta.schemas.openai.chat_completion_response import (
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, Message, ToolCall, UsageStatistics
|
||||||
ChatCompletionResponse,
|
|
||||||
Choice,
|
|
||||||
Message,
|
|
||||||
ToolCall,
|
|
||||||
UsageStatistics,
|
|
||||||
)
|
|
||||||
from letta.utils import get_tool_call_id, get_utc_time, json_dumps
|
from letta.utils import get_tool_call_id, get_utc_time, json_dumps
|
||||||
|
|
||||||
has_shown_warning = False
|
has_shown_warning = False
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
# import letta.local_llm.llm_chat_completion_wrappers.airoboros as airoboros
|
# import letta.local_llm.llm_chat_completion_wrappers.airoboros as airoboros
|
||||||
from letta.local_llm.llm_chat_completion_wrappers.chatml import (
|
from letta.local_llm.llm_chat_completion_wrappers.chatml import ChatMLInnerMonologueWrapper
|
||||||
ChatMLInnerMonologueWrapper,
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_ENDPOINTS = {
|
DEFAULT_ENDPOINTS = {
|
||||||
# Local
|
# Local
|
||||||
|
|||||||
@@ -5,18 +5,7 @@ from copy import copy
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from inspect import getdoc, isclass
|
from inspect import getdoc, isclass
|
||||||
from types import NoneType
|
from types import NoneType
|
||||||
from typing import (
|
from typing import Any, Callable, List, Optional, Tuple, Type, Union, _GenericAlias, get_args, get_origin
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
_GenericAlias,
|
|
||||||
get_args,
|
|
||||||
get_origin,
|
|
||||||
)
|
|
||||||
|
|
||||||
from docstring_parser import parse
|
from docstring_parser import parse
|
||||||
from pydantic import BaseModel, create_model
|
from pydantic import BaseModel, create_model
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
from letta.errors import LLMJSONParsingError
|
from letta.errors import LLMJSONParsingError
|
||||||
from letta.local_llm.json_parser import clean_json
|
from letta.local_llm.json_parser import clean_json
|
||||||
from letta.local_llm.llm_chat_completion_wrappers.wrapper_base import (
|
from letta.local_llm.llm_chat_completion_wrappers.wrapper_base import LLMChatCompletionWrapper
|
||||||
LLMChatCompletionWrapper,
|
|
||||||
)
|
|
||||||
from letta.schemas.enums import MessageRole
|
from letta.schemas.enums import MessageRole
|
||||||
from letta.utils import json_dumps, json_loads
|
from letta.utils import json_dumps, json_loads
|
||||||
|
|
||||||
@@ -75,10 +73,7 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
|
|||||||
func_str += f"\n description: {schema['description']}"
|
func_str += f"\n description: {schema['description']}"
|
||||||
func_str += f"\n params:"
|
func_str += f"\n params:"
|
||||||
if add_inner_thoughts:
|
if add_inner_thoughts:
|
||||||
from letta.local_llm.constants import (
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||||
INNER_THOUGHTS_KWARG,
|
|
||||||
INNER_THOUGHTS_KWARG_DESCRIPTION,
|
|
||||||
)
|
|
||||||
|
|
||||||
func_str += f"\n {INNER_THOUGHTS_KWARG}: {INNER_THOUGHTS_KWARG_DESCRIPTION}"
|
func_str += f"\n {INNER_THOUGHTS_KWARG}: {INNER_THOUGHTS_KWARG_DESCRIPTION}"
|
||||||
for param_k, param_v in schema["parameters"]["properties"].items():
|
for param_k, param_v in schema["parameters"]["properties"].items():
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
from letta.errors import LLMJSONParsingError
|
from letta.errors import LLMJSONParsingError
|
||||||
from letta.local_llm.json_parser import clean_json
|
from letta.local_llm.json_parser import clean_json
|
||||||
from letta.local_llm.llm_chat_completion_wrappers.wrapper_base import (
|
from letta.local_llm.llm_chat_completion_wrappers.wrapper_base import LLMChatCompletionWrapper
|
||||||
LLMChatCompletionWrapper,
|
|
||||||
)
|
|
||||||
from letta.utils import json_dumps, json_loads
|
from letta.utils import json_dumps, json_loads
|
||||||
|
|
||||||
PREFIX_HINT = """# Reminders:
|
PREFIX_HINT = """# Reminders:
|
||||||
@@ -74,10 +72,7 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper):
|
|||||||
func_str += f"\n description: {schema['description']}"
|
func_str += f"\n description: {schema['description']}"
|
||||||
func_str += "\n params:"
|
func_str += "\n params:"
|
||||||
if add_inner_thoughts:
|
if add_inner_thoughts:
|
||||||
from letta.local_llm.constants import (
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||||
INNER_THOUGHTS_KWARG,
|
|
||||||
INNER_THOUGHTS_KWARG_DESCRIPTION,
|
|
||||||
)
|
|
||||||
|
|
||||||
func_str += f"\n {INNER_THOUGHTS_KWARG}: {INNER_THOUGHTS_KWARG_DESCRIPTION}"
|
func_str += f"\n {INNER_THOUGHTS_KWARG}: {INNER_THOUGHTS_KWARG_DESCRIPTION}"
|
||||||
for param_k, param_v in schema["parameters"]["properties"].items():
|
for param_k, param_v in schema["parameters"]["properties"].items():
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from letta.constants import LETTA_DIR
|
from letta.constants import LETTA_DIR
|
||||||
from letta.local_llm.settings.deterministic_mirostat import (
|
from letta.local_llm.settings.deterministic_mirostat import settings as det_miro_settings
|
||||||
settings as det_miro_settings,
|
|
||||||
)
|
|
||||||
from letta.local_llm.settings.simple import settings as simple_settings
|
from letta.local_llm.settings.simple import settings as simple_settings
|
||||||
|
|
||||||
DEFAULT = "simple"
|
DEFAULT = "simple"
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from letta.orm.file import FileMetadata
|
|||||||
from letta.orm.job import Job
|
from letta.orm.job import Job
|
||||||
from letta.orm.message import Message
|
from letta.orm.message import Message
|
||||||
from letta.orm.organization import Organization
|
from letta.orm.organization import Organization
|
||||||
from letta.orm.passage import BasePassage, AgentPassage, SourcePassage
|
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
|
||||||
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
||||||
from letta.orm.source import Source
|
from letta.orm.source import Source
|
||||||
from letta.orm.sources_agents import SourcesAgents
|
from letta.orm.sources_agents import SourcesAgents
|
||||||
|
|||||||
@@ -5,11 +5,7 @@ from sqlalchemy import JSON, String, UniqueConstraint
|
|||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from letta.orm.block import Block
|
from letta.orm.block import Block
|
||||||
from letta.orm.custom_columns import (
|
from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ToolRulesColumn
|
||||||
EmbeddingConfigColumn,
|
|
||||||
LLMConfigColumn,
|
|
||||||
ToolRulesColumn,
|
|
||||||
)
|
|
||||||
from letta.orm.message import Message
|
from letta.orm.message import Message
|
||||||
from letta.orm.mixins import OrganizationMixin
|
from letta.orm.mixins import OrganizationMixin
|
||||||
from letta.orm.organization import Organization
|
from letta.orm.organization import Organization
|
||||||
|
|||||||
@@ -2,13 +2,7 @@ from datetime import datetime
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import Boolean, DateTime, String, func, text
|
from sqlalchemy import Boolean, DateTime, String, func, text
|
||||||
from sqlalchemy.orm import (
|
from sqlalchemy.orm import DeclarativeBase, Mapped, declarative_mixin, declared_attr, mapped_column
|
||||||
DeclarativeBase,
|
|
||||||
Mapped,
|
|
||||||
declarative_mixin,
|
|
||||||
declared_attr,
|
|
||||||
mapped_column,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from sqlalchemy import Integer, String
|
from sqlalchemy import Integer, String
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
@@ -9,8 +9,9 @@ from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from letta.orm.organization import Organization
|
from letta.orm.organization import Organization
|
||||||
from letta.orm.source import Source
|
|
||||||
from letta.orm.passage import SourcePassage
|
from letta.orm.passage import SourcePassage
|
||||||
|
from letta.orm.source import Source
|
||||||
|
|
||||||
|
|
||||||
class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
|
class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
|
||||||
"""Represents metadata for an uploaded file."""
|
"""Represents metadata for an uploaded file."""
|
||||||
@@ -28,4 +29,6 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
|
|||||||
# relationships
|
# relationships
|
||||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin")
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin")
|
||||||
source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin")
|
source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin")
|
||||||
source_passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan")
|
source_passages: Mapped[List["SourcePassage"]] = relationship(
|
||||||
|
"SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ class UserMixin(Base):
|
|||||||
|
|
||||||
user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"))
|
user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"))
|
||||||
|
|
||||||
|
|
||||||
class AgentMixin(Base):
|
class AgentMixin(Base):
|
||||||
"""Mixin for models that belong to an agent."""
|
"""Mixin for models that belong to an agent."""
|
||||||
|
|
||||||
@@ -38,6 +39,7 @@ class AgentMixin(Base):
|
|||||||
|
|
||||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"))
|
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"))
|
||||||
|
|
||||||
|
|
||||||
class FileMixin(Base):
|
class FileMixin(Base):
|
||||||
"""Mixin for models that belong to a file."""
|
"""Mixin for models that belong to a file."""
|
||||||
|
|
||||||
|
|||||||
@@ -38,19 +38,11 @@ class Organization(SqlalchemyBase):
|
|||||||
agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
|
agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
|
||||||
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
|
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
|
||||||
source_passages: Mapped[List["SourcePassage"]] = relationship(
|
source_passages: Mapped[List["SourcePassage"]] = relationship(
|
||||||
"SourcePassage",
|
"SourcePassage", back_populates="organization", cascade="all, delete-orphan"
|
||||||
back_populates="organization",
|
|
||||||
cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
agent_passages: Mapped[List["AgentPassage"]] = relationship(
|
|
||||||
"AgentPassage",
|
|
||||||
back_populates="organization",
|
|
||||||
cascade="all, delete-orphan"
|
|
||||||
)
|
)
|
||||||
|
agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="organization", cascade="all, delete-orphan")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]:
|
def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]:
|
||||||
"""Convenience property to get all passages"""
|
"""Convenience property to get all passages"""
|
||||||
return self.source_passages + self.agent_passages
|
return self.source_passages + self.agent_passages
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,9 +8,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|||||||
from letta.orm.mixins import OrganizationMixin, SandboxConfigMixin
|
from letta.orm.mixins import OrganizationMixin, SandboxConfigMixin
|
||||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||||
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
|
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
|
||||||
from letta.schemas.sandbox_config import (
|
from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticSandboxEnvironmentVariable
|
||||||
SandboxEnvironmentVariable as PydanticSandboxEnvironmentVariable,
|
|
||||||
)
|
|
||||||
from letta.schemas.sandbox_config import SandboxType
|
from letta.schemas.sandbox_config import SandboxType
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
@@ -9,12 +9,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
|
|||||||
|
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
||||||
from letta.orm.errors import (
|
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
|
||||||
DatabaseTimeoutError,
|
|
||||||
ForeignKeyConstraintViolationError,
|
|
||||||
NoResultFound,
|
|
||||||
UniqueConstraintViolationError,
|
|
||||||
)
|
|
||||||
from letta.orm.sqlite_functions import adapt_array
|
from letta.orm.sqlite_functions import adapt_array
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
|
import base64
|
||||||
|
import sqlite3
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import base64
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sqlalchemy import event
|
from sqlalchemy import event
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
import sqlite3
|
|
||||||
|
|
||||||
from letta.constants import MAX_EMBEDDING_DIM
|
from letta.constants import MAX_EMBEDDING_DIM
|
||||||
|
|
||||||
|
|
||||||
def adapt_array(arr):
|
def adapt_array(arr):
|
||||||
"""
|
"""
|
||||||
Converts numpy array to binary for SQLite storage
|
Converts numpy array to binary for SQLite storage
|
||||||
@@ -25,6 +26,7 @@ def adapt_array(arr):
|
|||||||
base64_data = base64.b64encode(bytes_data)
|
base64_data = base64.b64encode(bytes_data)
|
||||||
return sqlite3.Binary(base64_data)
|
return sqlite3.Binary(base64_data)
|
||||||
|
|
||||||
|
|
||||||
def convert_array(text):
|
def convert_array(text):
|
||||||
"""
|
"""
|
||||||
Converts binary back to numpy array
|
Converts binary back to numpy array
|
||||||
@@ -44,9 +46,10 @@ def convert_array(text):
|
|||||||
decoded_data = base64.b64decode(binary_data)
|
decoded_data = base64.b64decode(binary_data)
|
||||||
# Then convert to numpy array
|
# Then convert to numpy array
|
||||||
return np.frombuffer(decoded_data, dtype=np.float32)
|
return np.frombuffer(decoded_data, dtype=np.float32)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def verify_embedding_dimension(embedding: np.ndarray, expected_dim: int = MAX_EMBEDDING_DIM) -> bool:
|
def verify_embedding_dimension(embedding: np.ndarray, expected_dim: int = MAX_EMBEDDING_DIM) -> bool:
|
||||||
"""
|
"""
|
||||||
Verifies that an embedding has the expected dimension
|
Verifies that an embedding has the expected dimension
|
||||||
@@ -62,10 +65,9 @@ def verify_embedding_dimension(embedding: np.ndarray, expected_dim: int = MAX_EM
|
|||||||
return False
|
return False
|
||||||
return embedding.shape[0] == expected_dim
|
return embedding.shape[0] == expected_dim
|
||||||
|
|
||||||
|
|
||||||
def validate_and_transform_embedding(
|
def validate_and_transform_embedding(
|
||||||
embedding: Union[bytes, sqlite3.Binary, list, np.ndarray],
|
embedding: Union[bytes, sqlite3.Binary, list, np.ndarray], expected_dim: int = MAX_EMBEDDING_DIM, dtype: np.dtype = np.float32
|
||||||
expected_dim: int = MAX_EMBEDDING_DIM,
|
|
||||||
dtype: np.dtype = np.float32
|
|
||||||
) -> Optional[np.ndarray]:
|
) -> Optional[np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Validates and transforms embeddings to ensure correct dimensionality.
|
Validates and transforms embeddings to ensure correct dimensionality.
|
||||||
@@ -96,12 +98,11 @@ def validate_and_transform_embedding(
|
|||||||
|
|
||||||
# Validate dimension
|
# Validate dimension
|
||||||
if vec.shape[0] != expected_dim:
|
if vec.shape[0] != expected_dim:
|
||||||
raise ValueError(
|
raise ValueError(f"Invalid embedding dimension: got {vec.shape[0]}, expected {expected_dim}")
|
||||||
f"Invalid embedding dimension: got {vec.shape[0]}, expected {expected_dim}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return vec
|
return vec
|
||||||
|
|
||||||
|
|
||||||
def cosine_distance(embedding1, embedding2, expected_dim=MAX_EMBEDDING_DIM):
|
def cosine_distance(embedding1, embedding2, expected_dim=MAX_EMBEDDING_DIM):
|
||||||
"""
|
"""
|
||||||
Calculate cosine distance between two embeddings
|
Calculate cosine distance between two embeddings
|
||||||
@@ -121,7 +122,7 @@ def cosine_distance(embedding1, embedding2, expected_dim=MAX_EMBEDDING_DIM):
|
|||||||
try:
|
try:
|
||||||
vec1 = validate_and_transform_embedding(embedding1, expected_dim)
|
vec1 = validate_and_transform_embedding(embedding1, expected_dim)
|
||||||
vec2 = validate_and_transform_embedding(embedding2, expected_dim)
|
vec2 = validate_and_transform_embedding(embedding2, expected_dim)
|
||||||
except ValueError as e:
|
except ValueError:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
||||||
@@ -129,12 +130,14 @@ def cosine_distance(embedding1, embedding2, expected_dim=MAX_EMBEDDING_DIM):
|
|||||||
|
|
||||||
return distance
|
return distance
|
||||||
|
|
||||||
|
|
||||||
@event.listens_for(Engine, "connect")
|
@event.listens_for(Engine, "connect")
|
||||||
def register_functions(dbapi_connection, connection_record):
|
def register_functions(dbapi_connection, connection_record):
|
||||||
"""Register SQLite functions"""
|
"""Register SQLite functions"""
|
||||||
if isinstance(dbapi_connection, sqlite3.Connection):
|
if isinstance(dbapi_connection, sqlite3.Connection):
|
||||||
dbapi_connection.create_function("cosine_distance", 2, cosine_distance)
|
dbapi_connection.create_function("cosine_distance", 2, cosine_distance)
|
||||||
|
|
||||||
|
|
||||||
# Register adapters and converters for numpy arrays
|
# Register adapters and converters for numpy arrays
|
||||||
sqlite3.register_adapter(np.ndarray, adapt_array)
|
sqlite3.register_adapter(np.ndarray, adapt_array)
|
||||||
sqlite3.register_converter("ARRAY", convert_array)
|
sqlite3.register_converter("ARRAY", convert_array)
|
||||||
|
|||||||
@@ -3,10 +3,7 @@ from typing import List, Optional
|
|||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
from letta.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
|
from letta.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
|
||||||
from letta.llm_api.azure_openai import (
|
from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint
|
||||||
get_azure_chat_completions_endpoint,
|
|
||||||
get_azure_embeddings_endpoint,
|
|
||||||
)
|
|
||||||
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
||||||
from letta.schemas.embedding_config import EmbeddingConfig
|
from letta.schemas.embedding_config import EmbeddingConfig
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
@@ -32,7 +29,6 @@ class Provider(BaseModel):
|
|||||||
return f"{self.name}/{model_name}"
|
return f"{self.name}/{model_name}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LettaProvider(Provider):
|
class LettaProvider(Provider):
|
||||||
|
|
||||||
name: str = "letta"
|
name: str = "letta"
|
||||||
@@ -44,7 +40,7 @@ class LettaProvider(Provider):
|
|||||||
model_endpoint_type="openai",
|
model_endpoint_type="openai",
|
||||||
model_endpoint="https://inference.memgpt.ai",
|
model_endpoint="https://inference.memgpt.ai",
|
||||||
context_window=16384,
|
context_window=16384,
|
||||||
handle=self.get_handle("letta-free")
|
handle=self.get_handle("letta-free"),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -56,7 +52,7 @@ class LettaProvider(Provider):
|
|||||||
embedding_endpoint="https://embeddings.memgpt.ai",
|
embedding_endpoint="https://embeddings.memgpt.ai",
|
||||||
embedding_dim=1024,
|
embedding_dim=1024,
|
||||||
embedding_chunk_size=300,
|
embedding_chunk_size=300,
|
||||||
handle=self.get_handle("letta-free")
|
handle=self.get_handle("letta-free"),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -121,7 +117,13 @@ class OpenAIProvider(Provider):
|
|||||||
# continue
|
# continue
|
||||||
|
|
||||||
configs.append(
|
configs.append(
|
||||||
LLMConfig(model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size, handle=self.get_handle(model_name))
|
LLMConfig(
|
||||||
|
model=model_name,
|
||||||
|
model_endpoint_type="openai",
|
||||||
|
model_endpoint=self.base_url,
|
||||||
|
context_window=context_window_size,
|
||||||
|
handle=self.get_handle(model_name),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# for OpenAI, sort in reverse order
|
# for OpenAI, sort in reverse order
|
||||||
@@ -141,7 +143,7 @@ class OpenAIProvider(Provider):
|
|||||||
embedding_endpoint="https://api.openai.com/v1",
|
embedding_endpoint="https://api.openai.com/v1",
|
||||||
embedding_dim=1536,
|
embedding_dim=1536,
|
||||||
embedding_chunk_size=300,
|
embedding_chunk_size=300,
|
||||||
handle=self.get_handle("text-embedding-ada-002")
|
handle=self.get_handle("text-embedding-ada-002"),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -170,7 +172,7 @@ class AnthropicProvider(Provider):
|
|||||||
model_endpoint_type="anthropic",
|
model_endpoint_type="anthropic",
|
||||||
model_endpoint=self.base_url,
|
model_endpoint=self.base_url,
|
||||||
context_window=model["context_window"],
|
context_window=model["context_window"],
|
||||||
handle=self.get_handle(model["name"])
|
handle=self.get_handle(model["name"]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
@@ -203,7 +205,7 @@ class MistralProvider(Provider):
|
|||||||
model_endpoint_type="openai",
|
model_endpoint_type="openai",
|
||||||
model_endpoint=self.base_url,
|
model_endpoint=self.base_url,
|
||||||
context_window=model["max_context_length"],
|
context_window=model["max_context_length"],
|
||||||
handle=self.get_handle(model["id"])
|
handle=self.get_handle(model["id"]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -259,7 +261,7 @@ class OllamaProvider(OpenAIProvider):
|
|||||||
model_endpoint=self.base_url,
|
model_endpoint=self.base_url,
|
||||||
model_wrapper=self.default_prompt_formatter,
|
model_wrapper=self.default_prompt_formatter,
|
||||||
context_window=context_window,
|
context_window=context_window,
|
||||||
handle=self.get_handle(model["name"])
|
handle=self.get_handle(model["name"]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
@@ -335,7 +337,7 @@ class OllamaProvider(OpenAIProvider):
|
|||||||
embedding_endpoint=self.base_url,
|
embedding_endpoint=self.base_url,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
embedding_chunk_size=300,
|
embedding_chunk_size=300,
|
||||||
handle=self.get_handle(model["name"])
|
handle=self.get_handle(model["name"]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
@@ -356,7 +358,11 @@ class GroqProvider(OpenAIProvider):
|
|||||||
continue
|
continue
|
||||||
configs.append(
|
configs.append(
|
||||||
LLMConfig(
|
LLMConfig(
|
||||||
model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"], handle=self.get_handle(model["id"])
|
model=model["id"],
|
||||||
|
model_endpoint_type="groq",
|
||||||
|
model_endpoint=self.base_url,
|
||||||
|
context_window=model["context_window"],
|
||||||
|
handle=self.get_handle(model["id"]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
@@ -424,7 +430,7 @@ class TogetherProvider(OpenAIProvider):
|
|||||||
model_endpoint=self.base_url,
|
model_endpoint=self.base_url,
|
||||||
model_wrapper=self.default_prompt_formatter,
|
model_wrapper=self.default_prompt_formatter,
|
||||||
context_window=context_window_size,
|
context_window=context_window_size,
|
||||||
handle=self.get_handle(model_name)
|
handle=self.get_handle(model_name),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -505,7 +511,7 @@ class GoogleAIProvider(Provider):
|
|||||||
model_endpoint_type="google_ai",
|
model_endpoint_type="google_ai",
|
||||||
model_endpoint=self.base_url,
|
model_endpoint=self.base_url,
|
||||||
context_window=self.get_model_context_window(model),
|
context_window=self.get_model_context_window(model),
|
||||||
handle=self.get_handle(model)
|
handle=self.get_handle(model),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
@@ -529,7 +535,7 @@ class GoogleAIProvider(Provider):
|
|||||||
embedding_endpoint=self.base_url,
|
embedding_endpoint=self.base_url,
|
||||||
embedding_dim=768,
|
embedding_dim=768,
|
||||||
embedding_chunk_size=300, # NOTE: max is 2048
|
embedding_chunk_size=300, # NOTE: max is 2048
|
||||||
handle=self.get_handle(model)
|
handle=self.get_handle(model),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
@@ -559,9 +565,7 @@ class AzureProvider(Provider):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def list_llm_models(self) -> List[LLMConfig]:
|
def list_llm_models(self) -> List[LLMConfig]:
|
||||||
from letta.llm_api.azure_openai import (
|
from letta.llm_api.azure_openai import azure_openai_get_chat_completion_model_list
|
||||||
azure_openai_get_chat_completion_model_list,
|
|
||||||
)
|
|
||||||
|
|
||||||
model_options = azure_openai_get_chat_completion_model_list(self.base_url, api_key=self.api_key, api_version=self.api_version)
|
model_options = azure_openai_get_chat_completion_model_list(self.base_url, api_key=self.api_key, api_version=self.api_version)
|
||||||
configs = []
|
configs = []
|
||||||
@@ -570,7 +574,8 @@ class AzureProvider(Provider):
|
|||||||
context_window_size = self.get_model_context_window(model_name)
|
context_window_size = self.get_model_context_window(model_name)
|
||||||
model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version)
|
model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version)
|
||||||
configs.append(
|
configs.append(
|
||||||
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size), handle=self.get_handle(model_name)
|
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size),
|
||||||
|
handle=self.get_handle(model_name),
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
@@ -591,7 +596,7 @@ class AzureProvider(Provider):
|
|||||||
embedding_endpoint=model_endpoint,
|
embedding_endpoint=model_endpoint,
|
||||||
embedding_dim=768,
|
embedding_dim=768,
|
||||||
embedding_chunk_size=300, # NOTE: max is 2048
|
embedding_chunk_size=300, # NOTE: max is 2048
|
||||||
handle=self.get_handle(model_name)
|
handle=self.get_handle(model_name),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
@@ -625,7 +630,7 @@ class VLLMChatCompletionsProvider(Provider):
|
|||||||
model_endpoint_type="openai",
|
model_endpoint_type="openai",
|
||||||
model_endpoint=self.base_url,
|
model_endpoint=self.base_url,
|
||||||
context_window=model["max_model_len"],
|
context_window=model["max_model_len"],
|
||||||
handle=self.get_handle(model["id"])
|
handle=self.get_handle(model["id"]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
@@ -658,7 +663,7 @@ class VLLMCompletionsProvider(Provider):
|
|||||||
model_endpoint=self.base_url,
|
model_endpoint=self.base_url,
|
||||||
model_wrapper=self.default_prompt_formatter,
|
model_wrapper=self.default_prompt_formatter,
|
||||||
context_window=model["max_model_len"],
|
context_window=model["max_model_len"],
|
||||||
handle=self.get_handle(model["id"])
|
handle=self.get_handle(model["id"]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
|
|||||||
@@ -119,6 +119,7 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
|||||||
context_window_limit: Optional[int] = Field(None, description="The context window limit used by the agent.")
|
context_window_limit: Optional[int] = Field(None, description="The context window limit used by the agent.")
|
||||||
embedding_chunk_size: Optional[int] = Field(DEFAULT_EMBEDDING_CHUNK_SIZE, description="The embedding chunk size used by the agent.")
|
embedding_chunk_size: Optional[int] = Field(DEFAULT_EMBEDDING_CHUNK_SIZE, description="The embedding chunk size used by the agent.")
|
||||||
from_template: Optional[str] = Field(None, description="The template id used to configure the agent")
|
from_template: Optional[str] = Field(None, description="The template id used to configure the agent")
|
||||||
|
project_id: Optional[str] = Field(None, description="The project id that the agent will be associated with.")
|
||||||
|
|
||||||
@field_validator("name")
|
@field_validator("name")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -23,8 +23,26 @@ class LettaResponse(BaseModel):
|
|||||||
usage (LettaUsageStatistics): The usage statistics
|
usage (LettaUsageStatistics): The usage statistics
|
||||||
"""
|
"""
|
||||||
|
|
||||||
messages: List[LettaMessageUnion] = Field(..., description="The messages returned by the agent.")
|
messages: List[LettaMessageUnion] = Field(
|
||||||
usage: LettaUsageStatistics = Field(..., description="The usage statistics of the agent.")
|
...,
|
||||||
|
description="The messages returned by the agent.",
|
||||||
|
json_schema_extra={
|
||||||
|
"items": {
|
||||||
|
"oneOf": [
|
||||||
|
{"x-ref-name": "SystemMessage"},
|
||||||
|
{"x-ref-name": "UserMessage"},
|
||||||
|
{"x-ref-name": "ReasoningMessage"},
|
||||||
|
{"x-ref-name": "ToolCallMessage"},
|
||||||
|
{"x-ref-name": "ToolReturnMessage"},
|
||||||
|
{"x-ref-name": "AssistantMessage"},
|
||||||
|
],
|
||||||
|
"discriminator": {"propertyName": "message_type"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
usage: LettaUsageStatistics = Field(
|
||||||
|
..., description="The usage statistics of the agent.", json_schema_extra={"x-ref-name": "LettaUsageStatistics"}
|
||||||
|
)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return json_dumps(
|
return json_dumps(
|
||||||
|
|||||||
@@ -6,24 +6,13 @@ from typing import List, Literal, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from letta.constants import (
|
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TOOL_CALL_ID_MAX_LEN
|
||||||
DEFAULT_MESSAGE_TOOL,
|
|
||||||
DEFAULT_MESSAGE_TOOL_KWARG,
|
|
||||||
TOOL_CALL_ID_MAX_LEN,
|
|
||||||
)
|
|
||||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||||
from letta.schemas.enums import MessageRole
|
from letta.schemas.enums import MessageRole
|
||||||
from letta.schemas.letta_base import OrmMetadataBase
|
from letta.schemas.letta_base import OrmMetadataBase
|
||||||
from letta.schemas.letta_message import (
|
from letta.schemas.letta_message import AssistantMessage, LettaMessage, ReasoningMessage, SystemMessage
|
||||||
AssistantMessage,
|
from letta.schemas.letta_message import ToolCall as LettaToolCall
|
||||||
ToolCall as LettaToolCall,
|
from letta.schemas.letta_message import ToolCallMessage, ToolReturnMessage, UserMessage
|
||||||
ToolCallMessage,
|
|
||||||
ToolReturnMessage,
|
|
||||||
ReasoningMessage,
|
|
||||||
LettaMessage,
|
|
||||||
SystemMessage,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
||||||
from letta.utils import get_utc_time, is_utc_datetime, json_dumps
|
from letta.utils import get_utc_time, is_utc_datetime, json_dumps
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class OrganizationBase(LettaBase):
|
|||||||
|
|
||||||
class Organization(OrganizationBase):
|
class Organization(OrganizationBase):
|
||||||
id: str = OrganizationBase.generate_id_field()
|
id: str = OrganizationBase.generate_id_field()
|
||||||
name: str = Field(create_random_username(), description="The name of the organization.")
|
name: str = Field(create_random_username(), description="The name of the organization.", json_schema_extra={"default": "SincereYogurt"})
|
||||||
created_at: Optional[datetime] = Field(default_factory=get_utc_time, description="The creation date of the organization.")
|
created_at: Optional[datetime] = Field(default_factory=get_utc_time, description="The creation date of the organization.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,10 +4,7 @@ from pydantic import Field, model_validator
|
|||||||
|
|
||||||
from letta.constants import FUNCTION_RETURN_CHAR_LIMIT
|
from letta.constants import FUNCTION_RETURN_CHAR_LIMIT
|
||||||
from letta.functions.functions import derive_openai_json_schema
|
from letta.functions.functions import derive_openai_json_schema
|
||||||
from letta.functions.helpers import (
|
from letta.functions.helpers import generate_composio_tool_wrapper, generate_langchain_tool_wrapper
|
||||||
generate_composio_tool_wrapper,
|
|
||||||
generate_langchain_tool_wrapper,
|
|
||||||
)
|
|
||||||
from letta.functions.schema_generator import generate_schema_from_args_schema_v2
|
from letta.functions.schema_generator import generate_schema_from_args_schema_v2
|
||||||
from letta.schemas.letta_base import LettaBase
|
from letta.schemas.letta_base import LettaBase
|
||||||
from letta.schemas.openai.chat_completions import ToolCall
|
from letta.schemas.openai.chat_completions import ToolCall
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ class ConditionalToolRule(BaseToolRule):
|
|||||||
"""
|
"""
|
||||||
A ToolRule that conditionally maps to different child tools based on the output.
|
A ToolRule that conditionally maps to different child tools based on the output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: ToolRuleType = ToolRuleType.conditional
|
type: ToolRuleType = ToolRuleType.conditional
|
||||||
default_child: Optional[str] = Field(None, description="The default child tool to be called. If None, any tool can be called.")
|
default_child: Optional[str] = Field(None, description="The default child tool to be called. If None, any tool can be called.")
|
||||||
child_output_mapping: Dict[Any, str] = Field(..., description="The output case to check for mapping")
|
child_output_mapping: Dict[Any, str] = Field(..., description="The output case to check for mapping")
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
@@ -12,6 +13,7 @@ class LettaUsageStatistics(BaseModel):
|
|||||||
total_tokens (int): The total number of tokens processed by the agent.
|
total_tokens (int): The total number of tokens processed by the agent.
|
||||||
step_count (int): The number of steps taken by the agent.
|
step_count (int): The number of steps taken by the agent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
message_type: Literal["usage_statistics"] = "usage_statistics"
|
message_type: Literal["usage_statistics"] = "usage_statistics"
|
||||||
completion_tokens: int = Field(0, description="The number of tokens generated by the agent.")
|
completion_tokens: int = Field(0, description="The number of tokens generated by the agent.")
|
||||||
prompt_tokens: int = Field(0, description="The number of tokens in the prompt.")
|
prompt_tokens: int = Field(0, description="The number of tokens in the prompt.")
|
||||||
|
|||||||
@@ -15,35 +15,19 @@ from letta.__init__ import __version__
|
|||||||
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
|
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
|
||||||
from letta.errors import LettaAgentNotFoundError, LettaUserNotFoundError
|
from letta.errors import LettaAgentNotFoundError, LettaUserNotFoundError
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.orm.errors import (
|
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
|
||||||
DatabaseTimeoutError,
|
|
||||||
ForeignKeyConstraintViolationError,
|
|
||||||
NoResultFound,
|
|
||||||
UniqueConstraintViolationError,
|
|
||||||
)
|
|
||||||
from letta.schemas.letta_response import LettaResponse
|
|
||||||
from letta.server.constants import REST_DEFAULT_PORT
|
from letta.server.constants import REST_DEFAULT_PORT
|
||||||
|
|
||||||
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
|
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
|
||||||
from letta.server.rest_api.auth.index import (
|
from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right?
|
||||||
setup_auth_router, # TODO: probably remove right?
|
|
||||||
)
|
|
||||||
from letta.server.rest_api.interface import StreamingServerInterface
|
from letta.server.rest_api.interface import StreamingServerInterface
|
||||||
from letta.server.rest_api.routers.openai.assistants.assistants import (
|
from letta.server.rest_api.routers.openai.assistants.assistants import router as openai_assistants_router
|
||||||
router as openai_assistants_router,
|
from letta.server.rest_api.routers.openai.chat_completions.chat_completions import router as openai_chat_completions_router
|
||||||
)
|
|
||||||
from letta.server.rest_api.routers.openai.chat_completions.chat_completions import (
|
|
||||||
router as openai_chat_completions_router,
|
|
||||||
)
|
|
||||||
|
|
||||||
# from letta.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM
|
# from letta.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM
|
||||||
from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes
|
from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes
|
||||||
from letta.server.rest_api.routers.v1.organizations import (
|
from letta.server.rest_api.routers.v1.organizations import router as organizations_router
|
||||||
router as organizations_router,
|
from letta.server.rest_api.routers.v1.users import router as users_router # TODO: decide on admin
|
||||||
)
|
|
||||||
from letta.server.rest_api.routers.v1.users import (
|
|
||||||
router as users_router, # TODO: decide on admin
|
|
||||||
)
|
|
||||||
from letta.server.rest_api.static_files import mount_static_files
|
from letta.server.rest_api.static_files import mount_static_files
|
||||||
from letta.server.server import SyncServer
|
from letta.server.server import SyncServer
|
||||||
from letta.settings import settings
|
from letta.settings import settings
|
||||||
@@ -83,9 +67,6 @@ def generate_openapi_schema(app: FastAPI):
|
|||||||
openai_docs["info"]["title"] = "OpenAI Assistants API"
|
openai_docs["info"]["title"] = "OpenAI Assistants API"
|
||||||
letta_docs["paths"] = {k: v for k, v in letta_docs["paths"].items() if not k.startswith("/openai")}
|
letta_docs["paths"] = {k: v for k, v in letta_docs["paths"].items() if not k.startswith("/openai")}
|
||||||
letta_docs["info"]["title"] = "Letta API"
|
letta_docs["info"]["title"] = "Letta API"
|
||||||
letta_docs["components"]["schemas"]["LettaResponse"] = {
|
|
||||||
"properties": LettaResponse.model_json_schema(ref_template="#/components/schemas/LettaResponse/properties/{model}")["$defs"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Split the API docs into Letta API, and OpenAI Assistants compatible API
|
# Split the API docs into Letta API, and OpenAI Assistants compatible API
|
||||||
for name, docs in [
|
for name, docs in [
|
||||||
|
|||||||
@@ -12,22 +12,19 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
|||||||
from letta.schemas.enums import MessageStreamStatus
|
from letta.schemas.enums import MessageStreamStatus
|
||||||
from letta.schemas.letta_message import (
|
from letta.schemas.letta_message import (
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
|
LegacyFunctionCallMessage,
|
||||||
|
LegacyLettaMessage,
|
||||||
|
LettaMessage,
|
||||||
|
ReasoningMessage,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallMessage,
|
ToolCallMessage,
|
||||||
ToolReturnMessage,
|
ToolReturnMessage,
|
||||||
ReasoningMessage,
|
|
||||||
LegacyFunctionCallMessage,
|
|
||||||
LegacyLettaMessage,
|
|
||||||
LettaMessage,
|
|
||||||
)
|
)
|
||||||
from letta.schemas.message import Message
|
from letta.schemas.message import Message
|
||||||
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
|
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
|
||||||
from letta.streaming_interface import AgentChunkStreamingInterface
|
from letta.streaming_interface import AgentChunkStreamingInterface
|
||||||
from letta.streaming_utils import (
|
from letta.streaming_utils import FunctionArgumentsStreamHandler, JSONInnerThoughtsExtractor
|
||||||
FunctionArgumentsStreamHandler,
|
|
||||||
JSONInnerThoughtsExtractor,
|
|
||||||
)
|
|
||||||
from letta.utils import is_utc_datetime
|
from letta.utils import is_utc_datetime
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,13 +2,7 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from letta.schemas.openai.openai import (
|
from letta.schemas.openai.openai import MessageRoleType, OpenAIMessage, OpenAIThread, ToolCall, ToolCallOutput
|
||||||
MessageRoleType,
|
|
||||||
OpenAIMessage,
|
|
||||||
OpenAIThread,
|
|
||||||
ToolCall,
|
|
||||||
ToolCallOutput,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CreateAssistantRequest(BaseModel):
|
class CreateAssistantRequest(BaseModel):
|
||||||
|
|||||||
@@ -4,14 +4,9 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException
|
from fastapi import APIRouter, Body, Depends, Header, HTTPException
|
||||||
|
|
||||||
from letta.schemas.enums import MessageRole
|
from letta.schemas.enums import MessageRole
|
||||||
from letta.schemas.letta_message import ToolCall, LettaMessage
|
from letta.schemas.letta_message import LettaMessage, ToolCall
|
||||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||||
from letta.schemas.openai.chat_completion_response import (
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, Message, UsageStatistics
|
||||||
ChatCompletionResponse,
|
|
||||||
Choice,
|
|
||||||
Message,
|
|
||||||
UsageStatistics,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO this belongs in a controller!
|
# TODO this belongs in a controller!
|
||||||
from letta.server.rest_api.routers.v1.agents import send_message_to_agent
|
from letta.server.rest_api.routers.v1.agents import send_message_to_agent
|
||||||
|
|||||||
@@ -3,9 +3,7 @@ from letta.server.rest_api.routers.v1.blocks import router as blocks_router
|
|||||||
from letta.server.rest_api.routers.v1.health import router as health_router
|
from letta.server.rest_api.routers.v1.health import router as health_router
|
||||||
from letta.server.rest_api.routers.v1.jobs import router as jobs_router
|
from letta.server.rest_api.routers.v1.jobs import router as jobs_router
|
||||||
from letta.server.rest_api.routers.v1.llms import router as llm_router
|
from letta.server.rest_api.routers.v1.llms import router as llm_router
|
||||||
from letta.server.rest_api.routers.v1.sandbox_configs import (
|
from letta.server.rest_api.routers.v1.sandbox_configs import router as sandbox_configs_router
|
||||||
router as sandbox_configs_router,
|
|
||||||
)
|
|
||||||
from letta.server.rest_api.routers.v1.sources import router as sources_router
|
from letta.server.rest_api.routers.v1.sources import router as sources_router
|
||||||
from letta.server.rest_api.routers.v1.tools import router as tools_router
|
from letta.server.rest_api.routers.v1.tools import router as tools_router
|
||||||
|
|
||||||
|
|||||||
@@ -3,16 +3,7 @@ import warnings
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from fastapi import (
|
from fastapi import APIRouter, BackgroundTasks, Body, Depends, Header, HTTPException, Query, status
|
||||||
APIRouter,
|
|
||||||
BackgroundTasks,
|
|
||||||
Body,
|
|
||||||
Depends,
|
|
||||||
Header,
|
|
||||||
HTTPException,
|
|
||||||
Query,
|
|
||||||
status,
|
|
||||||
)
|
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
@@ -20,27 +11,13 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
|||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.orm.errors import NoResultFound
|
from letta.orm.errors import NoResultFound
|
||||||
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
|
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
|
||||||
from letta.schemas.block import ( # , BlockLabelUpdate, BlockLimitUpdate
|
from letta.schemas.block import Block, BlockUpdate, CreateBlock # , BlockLabelUpdate, BlockLimitUpdate
|
||||||
Block,
|
|
||||||
BlockUpdate,
|
|
||||||
CreateBlock,
|
|
||||||
)
|
|
||||||
from letta.schemas.enums import MessageStreamStatus
|
from letta.schemas.enums import MessageStreamStatus
|
||||||
from letta.schemas.job import Job, JobStatus, JobUpdate
|
from letta.schemas.job import Job, JobStatus, JobUpdate
|
||||||
from letta.schemas.letta_message import (
|
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, LettaMessageUnion
|
||||||
LegacyLettaMessage,
|
|
||||||
LettaMessage,
|
|
||||||
LettaMessageUnion,
|
|
||||||
)
|
|
||||||
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
|
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
|
||||||
from letta.schemas.letta_response import LettaResponse
|
from letta.schemas.letta_response import LettaResponse
|
||||||
from letta.schemas.memory import (
|
from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, CreateArchivalMemory, Memory, RecallMemorySummary
|
||||||
ArchivalMemorySummary,
|
|
||||||
ContextWindowOverview,
|
|
||||||
CreateArchivalMemory,
|
|
||||||
Memory,
|
|
||||||
RecallMemorySummary,
|
|
||||||
)
|
|
||||||
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||||
from letta.schemas.passage import Passage
|
from letta.schemas.passage import Passage
|
||||||
from letta.schemas.source import Source
|
from letta.schemas.source import Source
|
||||||
@@ -193,7 +170,7 @@ def get_agent_state(
|
|||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{agent_id}", response_model=AgentState, operation_id="delete_agent")
|
@router.delete("/{agent_id}", response_model=None, operation_id="delete_agent")
|
||||||
def delete_agent(
|
def delete_agent(
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
server: "SyncServer" = Depends(get_letta_server),
|
server: "SyncServer" = Depends(get_letta_server),
|
||||||
@@ -204,7 +181,8 @@ def delete_agent(
|
|||||||
"""
|
"""
|
||||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||||
try:
|
try:
|
||||||
return server.agent_manager.delete_agent(agent_id=agent_id, actor=actor)
|
server.agent_manager.delete_agent(agent_id=agent_id, actor=actor)
|
||||||
|
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Agent id={agent_id} successfully deleted"})
|
||||||
except NoResultFound:
|
except NoResultFound:
|
||||||
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found for user_id={actor.id}.")
|
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found for user_id={actor.id}.")
|
||||||
|
|
||||||
@@ -343,7 +321,12 @@ def update_agent_memory_block(
|
|||||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||||
|
|
||||||
block = server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
|
block = server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
|
||||||
return server.block_manager.update_block(block.id, block_update=block_update, actor=actor)
|
block = server.block_manager.update_block(block.id, block_update=block_update, actor=actor)
|
||||||
|
|
||||||
|
# This should also trigger a system prompt change in the agent
|
||||||
|
server.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor, force=True, update_timestamp=False)
|
||||||
|
|
||||||
|
return block
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{agent_id}/memory/recall", response_model=RecallMemorySummary, operation_id="get_agent_recall_memory_summary")
|
@router.get("/{agent_id}/memory/recall", response_model=RecallMemorySummary, operation_id="get_agent_recall_memory_summary")
|
||||||
|
|||||||
@@ -5,11 +5,7 @@ from fastapi import APIRouter, Depends, Query
|
|||||||
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
|
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
|
||||||
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate
|
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate
|
||||||
from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar
|
from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar
|
||||||
from letta.schemas.sandbox_config import (
|
from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate, SandboxType
|
||||||
SandboxEnvironmentVariableCreate,
|
|
||||||
SandboxEnvironmentVariableUpdate,
|
|
||||||
SandboxType,
|
|
||||||
)
|
|
||||||
from letta.server.rest_api.utils import get_letta_server, get_user_id
|
from letta.server.rest_api.utils import get_letta_server, get_user_id
|
||||||
from letta.server.server import SyncServer
|
from letta.server.server import SyncServer
|
||||||
|
|
||||||
|
|||||||
@@ -2,15 +2,7 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from fastapi import (
|
from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, Query, UploadFile
|
||||||
APIRouter,
|
|
||||||
BackgroundTasks,
|
|
||||||
Depends,
|
|
||||||
Header,
|
|
||||||
HTTPException,
|
|
||||||
Query,
|
|
||||||
UploadFile,
|
|
||||||
)
|
|
||||||
|
|
||||||
from letta.schemas.file import FileMetadata
|
from letta.schemas.file import FileMetadata
|
||||||
from letta.schemas.job import Job
|
from letta.schemas.job import Job
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ def get_user_id(user_id: Optional[str] = Header(None, alias="user_id")) -> Optio
|
|||||||
def get_current_interface() -> StreamingServerInterface:
|
def get_current_interface() -> StreamingServerInterface:
|
||||||
return StreamingServerInterface
|
return StreamingServerInterface
|
||||||
|
|
||||||
|
|
||||||
def log_error_to_sentry(e):
|
def log_error_to_sentry(e):
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|||||||
@@ -49,15 +49,11 @@ from letta.schemas.enums import JobStatus
|
|||||||
from letta.schemas.job import Job, JobUpdate
|
from letta.schemas.job import Job, JobUpdate
|
||||||
from letta.schemas.letta_message import LettaMessage, ToolReturnMessage
|
from letta.schemas.letta_message import LettaMessage, ToolReturnMessage
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.memory import (
|
from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, Memory, RecallMemorySummary
|
||||||
ArchivalMemorySummary,
|
|
||||||
ContextWindowOverview,
|
|
||||||
Memory,
|
|
||||||
RecallMemorySummary,
|
|
||||||
)
|
|
||||||
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate
|
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate
|
||||||
from letta.schemas.organization import Organization
|
from letta.schemas.organization import Organization
|
||||||
from letta.schemas.passage import Passage
|
from letta.schemas.passage import Passage
|
||||||
|
from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxType
|
||||||
from letta.schemas.source import Source
|
from letta.schemas.source import Source
|
||||||
from letta.schemas.tool import Tool
|
from letta.schemas.tool import Tool
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
from letta.schemas.usage import LettaUsageStatistics
|
||||||
@@ -303,6 +299,17 @@ class SyncServer(Server):
|
|||||||
self.block_manager.add_default_blocks(actor=self.default_user)
|
self.block_manager.add_default_blocks(actor=self.default_user)
|
||||||
self.tool_manager.upsert_base_tools(actor=self.default_user)
|
self.tool_manager.upsert_base_tools(actor=self.default_user)
|
||||||
|
|
||||||
|
# Add composio keys to the tool sandbox env vars of the org
|
||||||
|
if tool_settings.composio_api_key:
|
||||||
|
manager = SandboxConfigManager(tool_settings)
|
||||||
|
sandbox_config = manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.default_user)
|
||||||
|
|
||||||
|
manager.create_sandbox_env_var(
|
||||||
|
SandboxEnvironmentVariableCreate(key="COMPOSIO_API_KEY", value=tool_settings.composio_api_key),
|
||||||
|
sandbox_config_id=sandbox_config.id,
|
||||||
|
actor=self.default_user,
|
||||||
|
)
|
||||||
|
|
||||||
# collect providers (always has Letta as a default)
|
# collect providers (always has Letta as a default)
|
||||||
self._enabled_providers: List[Provider] = [LettaProvider()]
|
self._enabled_providers: List[Provider] = [LettaProvider()]
|
||||||
if model_settings.openai_api_key:
|
if model_settings.openai_api_key:
|
||||||
|
|||||||
@@ -279,7 +279,7 @@ class AgentManager:
|
|||||||
return agent.to_pydantic()
|
return agent.to_pydantic()
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
def delete_agent(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
def delete_agent(self, agent_id: str, actor: PydanticUser) -> None:
|
||||||
"""
|
"""
|
||||||
Deletes an agent and its associated relationships.
|
Deletes an agent and its associated relationships.
|
||||||
Ensures proper permission checks and cascades where applicable.
|
Ensures proper permission checks and cascades where applicable.
|
||||||
@@ -288,15 +288,13 @@ class AgentManager:
|
|||||||
agent_id: ID of the agent to be deleted.
|
agent_id: ID of the agent to be deleted.
|
||||||
actor: User performing the action.
|
actor: User performing the action.
|
||||||
|
|
||||||
Returns:
|
Raises:
|
||||||
PydanticAgentState: The deleted agent state
|
NoResultFound: If agent doesn't exist
|
||||||
"""
|
"""
|
||||||
with self.session_maker() as session:
|
with self.session_maker() as session:
|
||||||
# Retrieve the agent
|
# Retrieve the agent
|
||||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||||
agent_state = agent.to_pydantic()
|
|
||||||
agent.hard_delete(session)
|
agent.hard_delete(session)
|
||||||
return agent_state
|
|
||||||
|
|
||||||
# ======================================================================================================================
|
# ======================================================================================================================
|
||||||
# In Context Messages Management
|
# In Context Messages Management
|
||||||
|
|||||||
@@ -1,21 +1,15 @@
|
|||||||
from typing import List, Optional
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import numpy as np
|
from typing import List, Optional
|
||||||
|
|
||||||
from sqlalchemy import select, union_all, literal
|
|
||||||
|
|
||||||
from letta.constants import MAX_EMBEDDING_DIM
|
|
||||||
from letta.embeddings import embedding_model, parse_and_chunk_text
|
from letta.embeddings import embedding_model, parse_and_chunk_text
|
||||||
from letta.orm.errors import NoResultFound
|
from letta.orm.errors import NoResultFound
|
||||||
from letta.orm.passage import AgentPassage, SourcePassage
|
from letta.orm.passage import AgentPassage, SourcePassage
|
||||||
from letta.schemas.agent import AgentState
|
from letta.schemas.agent import AgentState
|
||||||
from letta.schemas.embedding_config import EmbeddingConfig
|
|
||||||
from letta.schemas.passage import Passage as PydanticPassage
|
from letta.schemas.passage import Passage as PydanticPassage
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
from letta.utils import enforce_types
|
from letta.utils import enforce_types
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PassageManager:
|
class PassageManager:
|
||||||
"""Manager class to handle business logic related to Passages."""
|
"""Manager class to handle business logic related to Passages."""
|
||||||
|
|
||||||
|
|||||||
@@ -9,11 +9,7 @@ from letta.schemas.sandbox_config import LocalSandboxConfig
|
|||||||
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
|
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
|
||||||
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate
|
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate
|
||||||
from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar
|
from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar
|
||||||
from letta.schemas.sandbox_config import (
|
from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate, SandboxType
|
||||||
SandboxEnvironmentVariableCreate,
|
|
||||||
SandboxEnvironmentVariableUpdate,
|
|
||||||
SandboxType,
|
|
||||||
)
|
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
from letta.utils import enforce_types, printd
|
from letta.utils import enforce_types, printd
|
||||||
|
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ class ToolExecutionSandbox:
|
|||||||
if local_configs.use_venv:
|
if local_configs.use_venv:
|
||||||
return self.run_local_dir_sandbox_venv(sbx_config, env, temp_file_path)
|
return self.run_local_dir_sandbox_venv(sbx_config, env, temp_file_path)
|
||||||
else:
|
else:
|
||||||
return self.run_local_dir_sandbox_runpy(sbx_config, env_vars, temp_file_path)
|
return self.run_local_dir_sandbox_runpy(sbx_config, env, temp_file_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}")
|
logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}")
|
||||||
logger.error(f"Logging out tool {self.tool_name} auto-generated code for debugging: \n\n{code}")
|
logger.error(f"Logging out tool {self.tool_name} auto-generated code for debugging: \n\n{code}")
|
||||||
@@ -200,7 +200,7 @@ class ToolExecutionSandbox:
|
|||||||
logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}")
|
logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def run_local_dir_sandbox_runpy(self, sbx_config: SandboxConfig, env_vars: Dict[str, str], temp_file_path: str) -> SandboxRunResult:
|
def run_local_dir_sandbox_runpy(self, sbx_config: SandboxConfig, env: Dict[str, str], temp_file_path: str) -> SandboxRunResult:
|
||||||
status = "success"
|
status = "success"
|
||||||
agent_state, stderr = None, None
|
agent_state, stderr = None, None
|
||||||
|
|
||||||
@@ -213,8 +213,8 @@ class ToolExecutionSandbox:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Execute the temp file
|
# Execute the temp file
|
||||||
with self.temporary_env_vars(env_vars):
|
with self.temporary_env_vars(env):
|
||||||
result = runpy.run_path(temp_file_path, init_globals=env_vars)
|
result = runpy.run_path(temp_file_path, init_globals=env)
|
||||||
|
|
||||||
# Fetch the result
|
# Fetch the result
|
||||||
func_result = result.get(self.LOCAL_SANDBOX_RESULT_VAR_NAME)
|
func_result = result.get(self.LOCAL_SANDBOX_RESULT_VAR_NAME)
|
||||||
@@ -277,6 +277,10 @@ class ToolExecutionSandbox:
|
|||||||
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user)
|
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user)
|
||||||
sbx = self.get_running_e2b_sandbox_with_same_state(sbx_config)
|
sbx = self.get_running_e2b_sandbox_with_same_state(sbx_config)
|
||||||
if not sbx or self.force_recreate:
|
if not sbx or self.force_recreate:
|
||||||
|
if not sbx:
|
||||||
|
logger.info(f"No running e2b sandbox found with the same state: {sbx_config}")
|
||||||
|
else:
|
||||||
|
logger.info(f"Force recreated e2b sandbox with state: {sbx_config}")
|
||||||
sbx = self.create_e2b_sandbox_with_metadata_hash(sandbox_config=sbx_config)
|
sbx = self.create_e2b_sandbox_with_metadata_hash(sandbox_config=sbx_config)
|
||||||
|
|
||||||
# Since this sandbox was used, we extend its lifecycle by the timeout
|
# Since this sandbox was used, we extend its lifecycle by the timeout
|
||||||
@@ -292,6 +296,8 @@ class ToolExecutionSandbox:
|
|||||||
func_return, agent_state = self.parse_best_effort(execution.results[0].text)
|
func_return, agent_state = self.parse_best_effort(execution.results[0].text)
|
||||||
elif execution.error:
|
elif execution.error:
|
||||||
logger.error(f"Executing tool {self.tool_name} failed with {execution.error}")
|
logger.error(f"Executing tool {self.tool_name} failed with {execution.error}")
|
||||||
|
logger.error(f"E2B Sandbox configurations: {sbx_config}")
|
||||||
|
logger.error(f"E2B Sandbox ID: {sbx.sandbox_id}")
|
||||||
func_return = get_friendly_error_msg(
|
func_return = get_friendly_error_msg(
|
||||||
function_name=self.tool_name, exception_name=execution.error.name, exception_message=execution.error.value
|
function_name=self.tool_name, exception_name=execution.error.name, exception_message=execution.error.value
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -60,7 +60,13 @@ class ModelSettings(BaseSettings):
|
|||||||
openllm_api_key: Optional[str] = None
|
openllm_api_key: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
cors_origins = ["http://letta.localhost", "http://localhost:8283", "http://localhost:8083", "http://localhost:3000"]
|
cors_origins = [
|
||||||
|
"http://letta.localhost",
|
||||||
|
"http://localhost:8283",
|
||||||
|
"http://localhost:8083",
|
||||||
|
"http://localhost:3000",
|
||||||
|
"http://localhost:4200",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
|
|||||||
@@ -9,15 +9,9 @@ from rich.live import Live
|
|||||||
from rich.markup import escape
|
from rich.markup import escape
|
||||||
|
|
||||||
from letta.interface import CLIInterface
|
from letta.interface import CLIInterface
|
||||||
from letta.local_llm.constants import (
|
from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL, INNER_THOUGHTS_CLI_SYMBOL
|
||||||
ASSISTANT_MESSAGE_CLI_SYMBOL,
|
|
||||||
INNER_THOUGHTS_CLI_SYMBOL,
|
|
||||||
)
|
|
||||||
from letta.schemas.message import Message
|
from letta.schemas.message import Message
|
||||||
from letta.schemas.openai.chat_completion_response import (
|
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse, ChatCompletionResponse
|
||||||
ChatCompletionChunkResponse,
|
|
||||||
ChatCompletionResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
# init(autoreset=True)
|
# init(autoreset=True)
|
||||||
|
|
||||||
|
|||||||
@@ -1120,6 +1120,7 @@ def sanitize_filename(filename: str) -> str:
|
|||||||
# Return the sanitized filename
|
# Return the sanitized filename
|
||||||
return sanitized_filename
|
return sanitized_filename
|
||||||
|
|
||||||
|
|
||||||
def get_friendly_error_msg(function_name: str, exception_name: str, exception_message: str):
|
def get_friendly_error_msg(function_name: str, exception_name: str, exception_message: str):
|
||||||
from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT
|
from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT
|
||||||
|
|
||||||
|
|||||||
14
poetry.lock
generated
14
poetry.lock
generated
@@ -726,13 +726,13 @@ test = ["pytest"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "composio-core"
|
name = "composio-core"
|
||||||
version = "0.6.3"
|
version = "0.6.7"
|
||||||
description = "Core package to act as a bridge between composio platform and other services."
|
description = "Core package to act as a bridge between composio platform and other services."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "<4,>=3.9"
|
python-versions = "<4,>=3.9"
|
||||||
files = [
|
files = [
|
||||||
{file = "composio_core-0.6.3-py3-none-any.whl", hash = "sha256:981a9856781b791242f947a9685a18974d8a012ac7fab2c09438e1b19610d6a2"},
|
{file = "composio_core-0.6.7-py3-none-any.whl", hash = "sha256:03cedeffe417b919d1021c1bc4751f54bd05829b52ff3285f7984e14bdf91efe"},
|
||||||
{file = "composio_core-0.6.3.tar.gz", hash = "sha256:13098b20d8832e74453ca194889305c935432156fc07be91dfddf76561ad591b"},
|
{file = "composio_core-0.6.7.tar.gz", hash = "sha256:b87f0b804d87945b4eae556468b9efc75f751d256bbf2c20fb8ae5b6a31a2818"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -762,13 +762,13 @@ tools = ["diskcache", "flake8", "networkx", "pathspec", "pygments", "ruff", "tra
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "composio-langchain"
|
name = "composio-langchain"
|
||||||
version = "0.6.3"
|
version = "0.6.7"
|
||||||
description = "Use Composio to get an array of tools with your LangChain agent."
|
description = "Use Composio to get an array of tools with your LangChain agent."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "<4,>=3.9"
|
python-versions = "<4,>=3.9"
|
||||||
files = [
|
files = [
|
||||||
{file = "composio_langchain-0.6.3-py3-none-any.whl", hash = "sha256:0e749a1603dc0562293412d0a6429f88b75152b01a313cca859732070d762a6b"},
|
{file = "composio_langchain-0.6.7-py3-none-any.whl", hash = "sha256:f8653b6a7e6b03a61b679a096e278744d3009ebaf3741d7e24e5120a364f212e"},
|
||||||
{file = "composio_langchain-0.6.3.tar.gz", hash = "sha256:2036f94bfe60974b31f2be0bfdb33dd75a1d43435f275141219b3376587bf49d"},
|
{file = "composio_langchain-0.6.7.tar.gz", hash = "sha256:adeab3a87b0e6eb7e96048cef6b988dbe699b6a493a82fac2d371ab940e7e54e"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -6246,4 +6246,4 @@ tests = ["wikipedia"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "<4.0,>=3.10"
|
python-versions = "<4.0,>=3.10"
|
||||||
content-hash = "4a7cf176579d5dc15648979542da152ec98290f1e9f39039cfe9baf73bc1076f"
|
content-hash = "1c52219049a4470dd54a45318b22495a4cafa29e93a1c5369a0d54da71990adb"
|
||||||
|
|||||||
82
project.json
Normal file
82
project.json
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
{
|
||||||
|
"name": "core",
|
||||||
|
"$schema": "../../node_modules/nx/schemas/project-schema.json",
|
||||||
|
"projectType": "application",
|
||||||
|
"sourceRoot": "apps/core",
|
||||||
|
"targets": {
|
||||||
|
"lock": {
|
||||||
|
"executor": "@nxlv/python:run-commands",
|
||||||
|
"options": {
|
||||||
|
"command": "poetry lock --no-update",
|
||||||
|
"cwd": "apps/core"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"add": {
|
||||||
|
"executor": "@nxlv/python:add",
|
||||||
|
"options": {}
|
||||||
|
},
|
||||||
|
"update": {
|
||||||
|
"executor": "@nxlv/python:update",
|
||||||
|
"options": {}
|
||||||
|
},
|
||||||
|
"remove": {
|
||||||
|
"executor": "@nxlv/python:remove",
|
||||||
|
"options": {}
|
||||||
|
},
|
||||||
|
"dev": {
|
||||||
|
"executor": "@nxlv/python:run-commands",
|
||||||
|
"options": {
|
||||||
|
"command": "poetry run letta server",
|
||||||
|
"cwd": "apps/core"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"build": {
|
||||||
|
"executor": "@nxlv/python:build",
|
||||||
|
"outputs": ["{projectRoot}/dist"],
|
||||||
|
"options": {
|
||||||
|
"outputPath": "apps/core/dist",
|
||||||
|
"publish": false,
|
||||||
|
"lockedVersions": true,
|
||||||
|
"bundleLocalDependencies": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"install": {
|
||||||
|
"executor": "@nxlv/python:run-commands",
|
||||||
|
"options": {
|
||||||
|
"command": "poetry install --all-extras",
|
||||||
|
"cwd": "apps/core"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"lint": {
|
||||||
|
"executor": "@nxlv/python:run-commands",
|
||||||
|
"options": {
|
||||||
|
"command": "poetry run isort --profile black . && poetry run black . && poetry run autoflake --remove-all-unused-imports --remove-unused-variables --in-place --recursive --ignore-init-module-imports .",
|
||||||
|
"cwd": "apps/core"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"database:migrate": {
|
||||||
|
"executor": "@nxlv/python:run-commands",
|
||||||
|
"options": {
|
||||||
|
"command": "poetry run alembic upgrade head",
|
||||||
|
"cwd": "apps/core"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"test": {
|
||||||
|
"executor": "@nxlv/python:run-commands",
|
||||||
|
"outputs": [
|
||||||
|
"{workspaceRoot}/reports/apps/core/unittests",
|
||||||
|
"{workspaceRoot}/coverage/apps/core"
|
||||||
|
],
|
||||||
|
"options": {
|
||||||
|
"command": "poetry run pytest tests/",
|
||||||
|
"cwd": "apps/core"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [],
|
||||||
|
"release": {
|
||||||
|
"version": {
|
||||||
|
"generator": "@nxlv/python:release-version"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -59,8 +59,8 @@ nltk = "^3.8.1"
|
|||||||
jinja2 = "^3.1.4"
|
jinja2 = "^3.1.4"
|
||||||
locust = {version = "^2.31.5", optional = true}
|
locust = {version = "^2.31.5", optional = true}
|
||||||
wikipedia = {version = "^1.4.0", optional = true}
|
wikipedia = {version = "^1.4.0", optional = true}
|
||||||
composio-langchain = "^0.6.3"
|
composio-langchain = "^0.6.7"
|
||||||
composio-core = "^0.6.3"
|
composio-core = "^0.6.7"
|
||||||
alembic = "^1.13.3"
|
alembic = "^1.13.3"
|
||||||
pyhumps = "^3.8.0"
|
pyhumps = "^3.8.0"
|
||||||
psycopg2 = {version = "^2.9.10", optional = true}
|
psycopg2 = {version = "^2.9.10", optional = true}
|
||||||
@@ -85,7 +85,7 @@ qdrant = ["qdrant-client"]
|
|||||||
cloud-tool-sandbox = ["e2b-code-interpreter"]
|
cloud-tool-sandbox = ["e2b-code-interpreter"]
|
||||||
external-tools = ["docker", "langchain", "wikipedia", "langchain-community"]
|
external-tools = ["docker", "langchain", "wikipedia", "langchain-community"]
|
||||||
tests = ["wikipedia"]
|
tests = ["wikipedia"]
|
||||||
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "llama-index-embeddings-ollama", "docker", "langchain", "wikipedia", "langchain-community", "locust"]
|
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust"]
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
black = "^24.4.2"
|
black = "^24.4.2"
|
||||||
@@ -100,3 +100,11 @@ extend-exclude = "examples/*"
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
profile = "black"
|
||||||
|
line_length = 140
|
||||||
|
multi_line_output = 3
|
||||||
|
include_trailing_comma = true
|
||||||
|
force_grid_wrap = 0
|
||||||
|
use_parentheses = true
|
||||||
|
|||||||
@@ -14,12 +14,7 @@ from letta.agent import Agent
|
|||||||
from letta.config import LettaConfig
|
from letta.config import LettaConfig
|
||||||
from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
|
from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
|
||||||
from letta.embeddings import embedding_model
|
from letta.embeddings import embedding_model
|
||||||
from letta.errors import (
|
from letta.errors import InvalidInnerMonologueError, InvalidToolCallError, MissingInnerMonologueError, MissingToolCallError
|
||||||
InvalidInnerMonologueError,
|
|
||||||
InvalidToolCallError,
|
|
||||||
MissingInnerMonologueError,
|
|
||||||
MissingToolCallError,
|
|
||||||
)
|
|
||||||
from letta.llm_api.llm_api_tools import create
|
from letta.llm_api.llm_api_tools import create
|
||||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||||
from letta.schemas.agent import AgentState
|
from letta.schemas.agent import AgentState
|
||||||
@@ -28,12 +23,7 @@ from letta.schemas.letta_message import LettaMessage, ReasoningMessage, ToolCall
|
|||||||
from letta.schemas.letta_response import LettaResponse
|
from letta.schemas.letta_response import LettaResponse
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.memory import ChatMemory
|
from letta.schemas.memory import ChatMemory
|
||||||
from letta.schemas.openai.chat_completion_response import (
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message
|
||||||
ChatCompletionResponse,
|
|
||||||
Choice,
|
|
||||||
FunctionCall,
|
|
||||||
Message,
|
|
||||||
)
|
|
||||||
from letta.utils import get_human_text, get_persona_text, json_dumps
|
from letta.utils import get_human_text, get_persona_text, json_dumps
|
||||||
from tests.helpers.utils import cleanup
|
from tests.helpers.utils import cleanup
|
||||||
|
|
||||||
|
|||||||
@@ -5,12 +5,7 @@ import pytest
|
|||||||
|
|
||||||
from letta import create_client
|
from letta import create_client
|
||||||
from letta.schemas.letta_message import ToolCallMessage
|
from letta.schemas.letta_message import ToolCallMessage
|
||||||
from letta.schemas.tool_rule import (
|
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
|
||||||
ChildToolRule,
|
|
||||||
ConditionalToolRule,
|
|
||||||
InitToolRule,
|
|
||||||
TerminalToolRule,
|
|
||||||
)
|
|
||||||
from tests.helpers.endpoints_helper import (
|
from tests.helpers.endpoints_helper import (
|
||||||
assert_invoked_function_call,
|
assert_invoked_function_call,
|
||||||
assert_invoked_send_message_with_keyword,
|
assert_invoked_send_message_with_keyword,
|
||||||
|
|||||||
28
tests/integration_test_composio.py
Normal file
28
tests/integration_test_composio.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from letta.server.rest_api.app import app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_composio_apps(client):
|
||||||
|
response = client.get("/v1/tools/composio/apps")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert isinstance(response.json(), list)
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_composio_actions_by_app(client):
|
||||||
|
response = client.get("/v1/tools/composio/apps/github/actions")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert isinstance(response.json(), list)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_composio_tool(client):
|
||||||
|
response = client.post("/v1/tools/composio/GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "id" in response.json()
|
||||||
|
assert "name" in response.json()
|
||||||
@@ -212,9 +212,7 @@ def clear_core_memory_tool(test_user):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def external_codebase_tool(test_user):
|
def external_codebase_tool(test_user):
|
||||||
from tests.test_tool_sandbox.restaurant_management_system.adjust_menu_prices import (
|
from tests.test_tool_sandbox.restaurant_management_system.adjust_menu_prices import adjust_menu_prices
|
||||||
adjust_menu_prices,
|
|
||||||
)
|
|
||||||
|
|
||||||
tool = create_tool_from_func(adjust_menu_prices)
|
tool = create_tool_from_func(adjust_menu_prices)
|
||||||
tool = ToolManager().create_or_update_tool(tool, test_user)
|
tool = ToolManager().create_or_update_tool(tool, test_user)
|
||||||
@@ -353,6 +351,14 @@ def test_local_sandbox_e2e_composio_star_github(mock_e2b_api_key_none, check_com
|
|||||||
assert result.func_return["details"] == "Action executed successfully"
|
assert result.func_return["details"] == "Action executed successfully"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.local_sandbox
|
||||||
|
def test_local_sandbox_e2e_composio_star_github_without_setting_db_env_vars(
|
||||||
|
mock_e2b_api_key_none, check_composio_key_set, composio_github_star_tool, test_user
|
||||||
|
):
|
||||||
|
result = ToolExecutionSandbox(composio_github_star_tool.name, {"owner": "letta-ai", "repo": "letta"}, user=test_user).run()
|
||||||
|
assert result.func_return["details"] == "Action executed successfully"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.local_sandbox
|
@pytest.mark.local_sandbox
|
||||||
def test_local_sandbox_external_codebase(mock_e2b_api_key_none, custom_test_sandbox_config, external_codebase_tool, test_user):
|
def test_local_sandbox_external_codebase(mock_e2b_api_key_none, custom_test_sandbox_config, external_codebase_tool, test_user):
|
||||||
# Set the args
|
# Set the args
|
||||||
@@ -458,7 +464,7 @@ def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, get_e
|
|||||||
config = manager.create_or_update_sandbox_config(config_create, test_user)
|
config = manager.create_or_update_sandbox_config(config_create, test_user)
|
||||||
|
|
||||||
# Run the custom sandbox once, assert nothing returns because missing env variable
|
# Run the custom sandbox once, assert nothing returns because missing env variable
|
||||||
sandbox = ToolExecutionSandbox(get_env_tool.name, {}, user=test_user, force_recreate=True)
|
sandbox = ToolExecutionSandbox(get_env_tool.name, {}, user=test_user)
|
||||||
result = sandbox.run()
|
result = sandbox.run()
|
||||||
# response should be None
|
# response should be None
|
||||||
assert result.func_return is None
|
assert result.func_return is None
|
||||||
|
|||||||
@@ -5,10 +5,7 @@ import sys
|
|||||||
import pexpect
|
import pexpect
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from letta.local_llm.constants import (
|
from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL, INNER_THOUGHTS_CLI_SYMBOL
|
||||||
ASSISTANT_MESSAGE_CLI_SYMBOL,
|
|
||||||
INNER_THOUGHTS_CLI_SYMBOL,
|
|
||||||
)
|
|
||||||
|
|
||||||
original_letta_path = os.path.expanduser("~/.letta")
|
original_letta_path = os.path.expanduser("~/.letta")
|
||||||
backup_letta_path = os.path.expanduser("~/.letta_backup")
|
backup_letta_path = os.path.expanduser("~/.letta_backup")
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ def run_server():
|
|||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
params=[{"server": False}, {"server": True}], # whether to use REST API server
|
params=[{"server": False}, {"server": True}], # whether to use REST API server
|
||||||
# params=[{"server": True}], # whether to use REST API server
|
# params=[{"server": False}], # whether to use REST API server
|
||||||
scope="module",
|
scope="module",
|
||||||
)
|
)
|
||||||
def client(request):
|
def client(request):
|
||||||
@@ -341,7 +341,9 @@ def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState):
|
|||||||
|
|
||||||
def test_send_system_message(client: Union[LocalClient, RESTClient], agent: AgentState):
|
def test_send_system_message(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||||
"""Important unit test since the Letta API exposes sending system messages, but some backends don't natively support it (eg Anthropic)"""
|
"""Important unit test since the Letta API exposes sending system messages, but some backends don't natively support it (eg Anthropic)"""
|
||||||
send_system_message_response = client.send_message(agent_id=agent.id, message="Event occurred: The user just logged off.", role="system")
|
send_system_message_response = client.send_message(
|
||||||
|
agent_id=agent.id, message="Event occurred: The user just logged off.", role="system"
|
||||||
|
)
|
||||||
assert send_system_message_response, "Sending message failed"
|
assert send_system_message_response, "Sending message failed"
|
||||||
|
|
||||||
|
|
||||||
@@ -390,7 +392,7 @@ def test_function_always_error(client: Union[LocalClient, RESTClient]):
|
|||||||
"""
|
"""
|
||||||
Always throw an error.
|
Always throw an error.
|
||||||
"""
|
"""
|
||||||
return 5/0
|
return 5 / 0
|
||||||
|
|
||||||
tool = client.create_or_update_tool(func=always_error)
|
tool = client.create_or_update_tool(func=always_error)
|
||||||
agent = client.create_agent(tool_ids=[tool.id])
|
agent = client.create_agent(tool_ids=[tool.id])
|
||||||
@@ -406,12 +408,13 @@ def test_function_always_error(client: Union[LocalClient, RESTClient]):
|
|||||||
|
|
||||||
assert response_message, "ToolReturnMessage message not found in response"
|
assert response_message, "ToolReturnMessage message not found in response"
|
||||||
assert response_message.status == "error"
|
assert response_message.status == "error"
|
||||||
|
|
||||||
if isinstance(client, RESTClient):
|
if isinstance(client, RESTClient):
|
||||||
assert response_message.tool_return == "Error executing function always_error: ZeroDivisionError: division by zero"
|
assert response_message.tool_return == "Error executing function always_error: ZeroDivisionError: division by zero"
|
||||||
else:
|
else:
|
||||||
response_json = json.loads(response_message.tool_return)
|
response_json = json.loads(response_message.tool_return)
|
||||||
assert response_json['status'] == "Failed"
|
assert response_json["status"] == "Failed"
|
||||||
assert response_json['message'] == "Error executing function always_error: ZeroDivisionError: division by zero"
|
assert response_json["message"] == "Error executing function always_error: ZeroDivisionError: division by zero"
|
||||||
|
|
||||||
client.delete_agent(agent_id=agent.id)
|
client.delete_agent(agent_id=agent.id)
|
||||||
|
|
||||||
|
|||||||
@@ -9,14 +9,7 @@ import letta.utils as utils
|
|||||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
||||||
from letta.schemas.block import CreateBlock
|
from letta.schemas.block import CreateBlock
|
||||||
from letta.schemas.enums import MessageRole
|
from letta.schemas.enums import MessageRole
|
||||||
from letta.schemas.letta_message import (
|
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
|
||||||
LettaMessage,
|
|
||||||
ReasoningMessage,
|
|
||||||
SystemMessage,
|
|
||||||
ToolCallMessage,
|
|
||||||
ToolReturnMessage,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
|
|
||||||
utils.DEBUG = True
|
utils.DEBUG = True
|
||||||
|
|||||||
@@ -2,12 +2,7 @@ import pytest
|
|||||||
|
|
||||||
from letta.helpers import ToolRulesSolver
|
from letta.helpers import ToolRulesSolver
|
||||||
from letta.helpers.tool_rule_solver import ToolRuleValidationError
|
from letta.helpers.tool_rule_solver import ToolRuleValidationError
|
||||||
from letta.schemas.tool_rule import (
|
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
|
||||||
ChildToolRule,
|
|
||||||
ConditionalToolRule,
|
|
||||||
InitToolRule,
|
|
||||||
TerminalToolRule
|
|
||||||
)
|
|
||||||
|
|
||||||
# Constants for tool names used in the tests
|
# Constants for tool names used in the tests
|
||||||
START_TOOL = "start_tool"
|
START_TOOL = "start_tool"
|
||||||
@@ -113,11 +108,7 @@ def test_conditional_tool_rule():
|
|||||||
# Setup: Define a conditional tool rule
|
# Setup: Define a conditional tool rule
|
||||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||||
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
|
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
|
||||||
rule = ConditionalToolRule(
|
rule = ConditionalToolRule(tool_name=START_TOOL, default_child=None, child_output_mapping={True: END_TOOL, False: START_TOOL})
|
||||||
tool_name=START_TOOL,
|
|
||||||
default_child=None,
|
|
||||||
child_output_mapping={True: END_TOOL, False: START_TOOL}
|
|
||||||
)
|
|
||||||
solver = ToolRulesSolver(tool_rules=[init_rule, rule, terminal_rule])
|
solver = ToolRulesSolver(tool_rules=[init_rule, rule, terminal_rule])
|
||||||
|
|
||||||
# Action & Assert: Verify the rule properties
|
# Action & Assert: Verify the rule properties
|
||||||
@@ -126,8 +117,12 @@ def test_conditional_tool_rule():
|
|||||||
|
|
||||||
# Step 2: After using 'start_tool'
|
# Step 2: After using 'start_tool'
|
||||||
solver.update_tool_usage(START_TOOL)
|
solver.update_tool_usage(START_TOOL)
|
||||||
assert solver.get_allowed_tool_names(last_function_response='{"message": "true"}') == [END_TOOL], "After 'start_tool' returns true, should allow 'end_tool'"
|
assert solver.get_allowed_tool_names(last_function_response='{"message": "true"}') == [
|
||||||
assert solver.get_allowed_tool_names(last_function_response='{"message": "false"}') == [START_TOOL], "After 'start_tool' returns false, should allow 'start_tool'"
|
END_TOOL
|
||||||
|
], "After 'start_tool' returns true, should allow 'end_tool'"
|
||||||
|
assert solver.get_allowed_tool_names(last_function_response='{"message": "false"}') == [
|
||||||
|
START_TOOL
|
||||||
|
], "After 'start_tool' returns false, should allow 'start_tool'"
|
||||||
|
|
||||||
# Step 3: After using 'end_tool'
|
# Step 3: After using 'end_tool'
|
||||||
assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as terminal"
|
assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as terminal"
|
||||||
@@ -137,11 +132,7 @@ def test_invalid_conditional_tool_rule():
|
|||||||
# Setup: Define an invalid conditional tool rule
|
# Setup: Define an invalid conditional tool rule
|
||||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||||
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
|
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
|
||||||
invalid_rule_1 = ConditionalToolRule(
|
invalid_rule_1 = ConditionalToolRule(tool_name=START_TOOL, default_child=END_TOOL, child_output_mapping={})
|
||||||
tool_name=START_TOOL,
|
|
||||||
default_child=END_TOOL,
|
|
||||||
child_output_mapping={}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test 1: Missing child output mapping
|
# Test 1: Missing child output mapping
|
||||||
with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have at least one child tool."):
|
with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have at least one child tool."):
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ def adjust_menu_prices(percentage: float) -> str:
|
|||||||
str: A formatted string summarizing the price adjustments.
|
str: A formatted string summarizing the price adjustments.
|
||||||
"""
|
"""
|
||||||
import cowsay
|
import cowsay
|
||||||
|
|
||||||
from core.menu import Menu, MenuItem # Import a class from the codebase
|
from core.menu import Menu, MenuItem # Import a class from the codebase
|
||||||
from core.utils import format_currency # Use a utility function to test imports
|
from core.utils import format_currency # Use a utility function to test imports
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import pytest
|
|||||||
|
|
||||||
from letta.functions.functions import derive_openai_json_schema
|
from letta.functions.functions import derive_openai_json_schema
|
||||||
from letta.llm_api.helpers import convert_to_structured_output, make_post_request
|
from letta.llm_api.helpers import convert_to_structured_output, make_post_request
|
||||||
|
from letta.schemas.tool import ToolCreate
|
||||||
|
|
||||||
|
|
||||||
def _clean_diff(d1, d2):
|
def _clean_diff(d1, d2):
|
||||||
@@ -176,3 +177,38 @@ def test_valid_schemas_via_openai(openai_model: str, structured_output: bool):
|
|||||||
_openai_payload(openai_model, schema, structured_output)
|
_openai_payload(openai_model, schema, structured_output)
|
||||||
else:
|
else:
|
||||||
_openai_payload(openai_model, schema, structured_output)
|
_openai_payload(openai_model, schema, structured_output)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("openai_model", ["gpt-4o-mini"])
|
||||||
|
@pytest.mark.parametrize("structured_output", [True])
|
||||||
|
def test_composio_tool_schema_generation(openai_model: str, structured_output: bool):
|
||||||
|
"""Test that we can generate the schemas for some Composio tools."""
|
||||||
|
|
||||||
|
if not os.getenv("COMPOSIO_API_KEY"):
|
||||||
|
pytest.skip("COMPOSIO_API_KEY not set")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import composio
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Composio not installed")
|
||||||
|
|
||||||
|
for action_name in [
|
||||||
|
"CAL_GET_AVAILABLE_SLOTS_INFO", # has an array arg, needs to be converted properly
|
||||||
|
]:
|
||||||
|
try:
|
||||||
|
tool_create = ToolCreate.from_composio(action_name=action_name)
|
||||||
|
except composio.exceptions.ComposioSDKError:
|
||||||
|
# e.g. "composio.exceptions.ComposioSDKError: No connected account found for app `CAL`; Run `composio add cal` to fix this"
|
||||||
|
pytest.skip(f"Composio account not configured to use action_name {action_name}")
|
||||||
|
|
||||||
|
print(tool_create)
|
||||||
|
|
||||||
|
assert tool_create.json_schema
|
||||||
|
schema = tool_create.json_schema
|
||||||
|
|
||||||
|
try:
|
||||||
|
_openai_payload(openai_model, schema, structured_output)
|
||||||
|
print(f"Successfully called OpenAI using schema {schema} generated from {action_name}")
|
||||||
|
except:
|
||||||
|
print(f"Failed to call OpenAI using schema {schema} generated from {action_name}")
|
||||||
|
raise
|
||||||
|
|||||||
@@ -1,12 +1,7 @@
|
|||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from composio.client.collections import (
|
from composio.client.collections import ActionModel, ActionParametersModel, ActionResponseModel, AppModel
|
||||||
ActionModel,
|
|
||||||
ActionParametersModel,
|
|
||||||
ActionResponseModel,
|
|
||||||
AppModel,
|
|
||||||
)
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from letta.schemas.tool import ToolCreate, ToolUpdate
|
from letta.schemas.tool import ToolCreate, ToolUpdate
|
||||||
|
|||||||
Reference in New Issue
Block a user