feat: Add optional pip requirements to tool object (#2793)
This commit is contained in:
@@ -0,0 +1,31 @@
|
||||
"""Add pip requirements to tools
|
||||
|
||||
Revision ID: 1c6b6a38b713
|
||||
Revises: c96263433aef
|
||||
Create Date: 2025-06-12 18:06:54.838510
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "1c6b6a38b713"
|
||||
down_revision: Union[str, None] = "c96263433aef"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("tools", sa.Column("pip_requirements", sa.JSON(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("tools", "pip_requirements")
|
||||
# ### end Alembic commands ###
|
||||
@@ -44,6 +44,9 @@ class Tool(SqlalchemyBase, OrganizationMixin):
|
||||
source_code: Mapped[Optional[str]] = mapped_column(String, doc="The source code of the function.")
|
||||
json_schema: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="The OAI compatable JSON schema of the function.")
|
||||
args_json_schema: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="The JSON schema of the function arguments.")
|
||||
pip_requirements: Mapped[Optional[List]] = mapped_column(
|
||||
JSON, nullable=True, doc="Optional list of pip packages required by this tool."
|
||||
)
|
||||
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="A dictionary of additional metadata for the tool.")
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")
|
||||
|
||||
25
letta/schemas/pip_requirement.py
Normal file
25
letta/schemas/pip_requirement.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class PipRequirement(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="Name of the pip package.")
|
||||
version: Optional[str] = Field(None, description="Optional version of the package, following semantic versioning.")
|
||||
|
||||
@field_validator("version")
|
||||
@classmethod
|
||||
def validate_version(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return None
|
||||
semver_pattern = re.compile(r"^\d+(\.\d+){0,2}(-[a-zA-Z0-9.]+)?$")
|
||||
if not semver_pattern.match(v):
|
||||
raise ValueError(f"Invalid version format: {v}. Must follow semantic versioning (e.g., 1.2.3, 2.0, 1.5.0-alpha).")
|
||||
return v
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a pip-installable string format."""
|
||||
if self.version:
|
||||
return f"{self.name}=={self.version}"
|
||||
return self.name
|
||||
@@ -1,6 +1,5 @@
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
@@ -9,6 +8,7 @@ from pydantic import BaseModel, Field, model_validator
|
||||
from letta.constants import LETTA_TOOL_EXECUTION_DIR
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.letta_base import LettaBase, OrmMetadataBase
|
||||
from letta.schemas.pip_requirement import PipRequirement
|
||||
from letta.settings import tool_settings
|
||||
|
||||
|
||||
@@ -27,24 +27,6 @@ class SandboxRunResult(BaseModel):
|
||||
sandbox_config_fingerprint: str = Field(None, description="The fingerprint of the config for the sandbox")
|
||||
|
||||
|
||||
class PipRequirement(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="Name of the pip package.")
|
||||
version: Optional[str] = Field(None, description="Optional version of the package, following semantic versioning.")
|
||||
|
||||
@classmethod
|
||||
def validate_version(cls, version: Optional[str]) -> Optional[str]:
|
||||
if version is None:
|
||||
return None
|
||||
semver_pattern = re.compile(r"^\d+(\.\d+){0,2}(-[a-zA-Z0-9.]+)?$")
|
||||
if not semver_pattern.match(version):
|
||||
raise ValueError(f"Invalid version format: {version}. Must follow semantic versioning (e.g., 1.2.3, 2.0, 1.5.0-alpha).")
|
||||
return version
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.version = self.validate_version(self.version)
|
||||
|
||||
|
||||
class LocalSandboxConfig(BaseModel):
|
||||
sandbox_dir: Optional[str] = Field(None, description="Directory for the sandbox environment.")
|
||||
use_venv: bool = Field(False, description="Whether or not to use the venv, or run directly in the same run loop.")
|
||||
|
||||
@@ -24,6 +24,7 @@ from letta.functions.schema_generator import (
|
||||
from letta.log import get_logger
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.pip_requirement import PipRequirement
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -60,6 +61,7 @@ class Tool(BaseTool):
|
||||
|
||||
# tool configuration
|
||||
return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.")
|
||||
pip_requirements: Optional[List[PipRequirement]] = Field(None, description="Optional list of pip packages required by this tool.")
|
||||
|
||||
# metadata fields
|
||||
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
@@ -145,6 +147,7 @@ class ToolCreate(LettaBase):
|
||||
)
|
||||
args_json_schema: Optional[Dict] = Field(None, description="The args JSON schema of the function.")
|
||||
return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.")
|
||||
pip_requirements: Optional[List[PipRequirement]] = Field(None, description="Optional list of pip packages required by this tool.")
|
||||
|
||||
# TODO should we put the HTTP / API fetch inside from_mcp?
|
||||
# async def from_mcp(cls, mcp_server: str, mcp_tool_name: str) -> "ToolCreate":
|
||||
@@ -253,6 +256,7 @@ class ToolUpdate(LettaBase):
|
||||
)
|
||||
args_json_schema: Optional[Dict] = Field(None, description="The args JSON schema of the function.")
|
||||
return_char_limit: Optional[int] = Field(None, description="The maximum number of characters in the response.")
|
||||
pip_requirements: Optional[List[PipRequirement]] = Field(None, description="Optional list of pip packages required by this tool.")
|
||||
|
||||
class Config:
|
||||
extra = "ignore" # Allows extra fields without validation errors
|
||||
@@ -269,3 +273,4 @@ class ToolRunFromSource(LettaBase):
|
||||
json_schema: Optional[Dict] = Field(
|
||||
None, description="The JSON schema of the function (auto-generated from source_code if not provided)"
|
||||
)
|
||||
pip_requirements: Optional[List[PipRequirement]] = Field(None, description="Optional list of pip packages required by this tool.")
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import platform
|
||||
import subprocess
|
||||
import venv
|
||||
from typing import Dict, Optional
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
from datamodel_code_generator import DataModelType, PythonVersion
|
||||
from datamodel_code_generator.model import get_data_model_types
|
||||
@@ -11,6 +11,9 @@ from datamodel_code_generator.parser.jsonschema import JsonSchemaParser
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.sandbox_config import LocalSandboxConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -85,14 +88,12 @@ def install_pip_requirements_for_sandbox(
|
||||
upgrade: bool = True,
|
||||
user_install_if_no_venv: bool = False,
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
tool: Optional["Tool"] = None,
|
||||
):
|
||||
"""
|
||||
Installs the specified pip requirements inside the correct environment (venv or system).
|
||||
Installs both sandbox-level and tool-specific pip requirements.
|
||||
"""
|
||||
if not local_configs.pip_requirements:
|
||||
logger.debug("No pip requirements specified; skipping installation.")
|
||||
return
|
||||
|
||||
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir) # Expand tilde
|
||||
local_configs.sandbox_dir = sandbox_dir # Update the object to store the absolute path
|
||||
|
||||
@@ -102,19 +103,48 @@ def install_pip_requirements_for_sandbox(
|
||||
if local_configs.use_venv:
|
||||
ensure_pip_is_up_to_date(python_exec, env=env)
|
||||
|
||||
# Construct package list
|
||||
packages = [f"{req.name}=={req.version}" if req.version else req.name for req in local_configs.pip_requirements]
|
||||
# Collect all pip requirements
|
||||
all_packages = []
|
||||
|
||||
# Add sandbox-level pip requirements
|
||||
if local_configs.pip_requirements:
|
||||
packages = [f"{req.name}=={req.version}" if req.version else req.name for req in local_configs.pip_requirements]
|
||||
all_packages.extend(packages)
|
||||
logger.debug(f"Added sandbox pip requirements: {packages}")
|
||||
|
||||
# Add tool-specific pip requirements
|
||||
if tool and tool.pip_requirements:
|
||||
tool_packages = [str(req) for req in tool.pip_requirements]
|
||||
all_packages.extend(tool_packages)
|
||||
logger.debug(f"Added tool pip requirements for {tool.name}: {tool_packages}")
|
||||
|
||||
if not all_packages:
|
||||
logger.debug("No pip requirements specified; skipping installation.")
|
||||
return
|
||||
|
||||
# Construct pip install command
|
||||
pip_cmd = [python_exec, "-m", "pip", "install"]
|
||||
if upgrade:
|
||||
pip_cmd.append("--upgrade")
|
||||
pip_cmd += packages
|
||||
pip_cmd += all_packages
|
||||
|
||||
if user_install_if_no_venv and not local_configs.use_venv:
|
||||
pip_cmd.append("--user")
|
||||
|
||||
run_subprocess(pip_cmd, env=env, fail_msg=f"Failed to install packages: {', '.join(packages)}")
|
||||
# Enhanced error message for better debugging
|
||||
sandbox_packages = [f"{req.name}=={req.version}" if req.version else req.name for req in (local_configs.pip_requirements or [])]
|
||||
tool_packages = [str(req) for req in (tool.pip_requirements if tool and tool.pip_requirements else [])]
|
||||
|
||||
error_details = []
|
||||
if sandbox_packages:
|
||||
error_details.append(f"sandbox requirements: {', '.join(sandbox_packages)}")
|
||||
if tool_packages:
|
||||
error_details.append(f"tool requirements: {', '.join(tool_packages)}")
|
||||
|
||||
context = f" ({'; '.join(error_details)})" if error_details else ""
|
||||
fail_msg = f"Failed to install pip packages{context}. This may be due to package version incompatibility. Consider updating package versions or removing version constraints."
|
||||
|
||||
run_subprocess(pip_cmd, env=env, fail_msg=fail_msg)
|
||||
|
||||
|
||||
def create_venv_for_local_sandbox(sandbox_dir_path: str, venv_path: str, env: Dict[str, str], force_recreate: bool):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from e2b.sandbox.commands.command_handle import CommandExitException
|
||||
from e2b_code_interpreter import AsyncSandbox
|
||||
|
||||
from letta.log import get_logger
|
||||
@@ -189,14 +190,63 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
|
||||
"package": package,
|
||||
},
|
||||
)
|
||||
await sbx.commands.run(f"pip install {package}")
|
||||
try:
|
||||
await sbx.commands.run(f"pip install {package}")
|
||||
log_event(
|
||||
"e2b_pip_install_finished",
|
||||
{
|
||||
"sandbox_id": sbx.sandbox_id,
|
||||
"package": package,
|
||||
},
|
||||
)
|
||||
except CommandExitException as e:
|
||||
error_msg = f"Failed to install sandbox pip requirement '{package}' in E2B sandbox. This may be due to package version incompatibility with the E2B environment. Error: {e}"
|
||||
logger.error(error_msg)
|
||||
log_event(
|
||||
"e2b_pip_install_failed",
|
||||
{
|
||||
"sandbox_id": sbx.sandbox_id,
|
||||
"package": package,
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
# Install tool-specific pip requirements
|
||||
if self.tool and self.tool.pip_requirements:
|
||||
for pip_requirement in self.tool.pip_requirements:
|
||||
package_str = str(pip_requirement)
|
||||
log_event(
|
||||
"e2b_pip_install_finished",
|
||||
"tool_pip_install_started",
|
||||
{
|
||||
"sandbox_id": sbx.sandbox_id,
|
||||
"package": package,
|
||||
"package": package_str,
|
||||
"tool_name": self.tool.name,
|
||||
},
|
||||
)
|
||||
try:
|
||||
await sbx.commands.run(f"pip install {package_str}")
|
||||
log_event(
|
||||
"tool_pip_install_finished",
|
||||
{
|
||||
"sandbox_id": sbx.sandbox_id,
|
||||
"package": package_str,
|
||||
"tool_name": self.tool.name,
|
||||
},
|
||||
)
|
||||
except CommandExitException as e:
|
||||
error_msg = f"Failed to install tool pip requirement '{package_str}' for tool '{self.tool.name}' in E2B sandbox. This may be due to package version incompatibility with the E2B environment. Consider updating the package version or removing the version constraint. Error: {e}"
|
||||
logger.error(error_msg)
|
||||
log_event(
|
||||
"tool_pip_install_failed",
|
||||
{
|
||||
"sandbox_id": sbx.sandbox_id,
|
||||
"package": package_str,
|
||||
"tool_name": self.tool.name,
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
return sbx
|
||||
|
||||
|
||||
@@ -175,7 +175,9 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
log_event(name="finish create_venv_for_local_sandbox")
|
||||
|
||||
log_event(name="start install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()})
|
||||
await asyncio.to_thread(install_pip_requirements_for_sandbox, local_configs, upgrade=True, user_install_if_no_venv=False, env=env)
|
||||
await asyncio.to_thread(
|
||||
install_pip_requirements_for_sandbox, local_configs, upgrade=True, user_install_if_no_venv=False, env=env, tool=self.tool
|
||||
)
|
||||
log_event(name="finish install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()})
|
||||
|
||||
@trace_method
|
||||
|
||||
@@ -15,7 +15,8 @@ from letta.schemas.agent import AgentState, CreateAgent
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.environment_variables import AgentEnvironmentVariable, SandboxEnvironmentVariableCreate
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, PipRequirement, SandboxConfigCreate
|
||||
from letta.schemas.pip_requirement import PipRequirement
|
||||
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
@@ -263,6 +264,64 @@ def custom_test_sandbox_config(test_user):
|
||||
|
||||
|
||||
# Tool-specific fixtures
|
||||
@pytest.fixture
|
||||
def tool_with_pip_requirements(test_user):
|
||||
def use_requests_and_numpy() -> str:
|
||||
"""
|
||||
Function that uses requests and numpy packages to test tool-specific pip requirements.
|
||||
|
||||
Returns:
|
||||
str: Success message if packages are available.
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
# Simple usage to verify packages work
|
||||
response = requests.get("https://httpbin.org/json", timeout=5)
|
||||
arr = np.array([1, 2, 3])
|
||||
return f"Success! Status: {response.status_code}, Array sum: {np.sum(arr)}"
|
||||
except ImportError as e:
|
||||
return f"Import error: {e}"
|
||||
except Exception as e:
|
||||
return f"Other error: {e}"
|
||||
|
||||
tool = create_tool_from_func(use_requests_and_numpy)
|
||||
# Add pip requirements to the tool - using more recent versions for E2B compatibility
|
||||
tool.pip_requirements = [
|
||||
PipRequirement(name="requests", version="2.31.0"),
|
||||
PipRequirement(name="numpy", version="1.26.0"),
|
||||
]
|
||||
tool = ToolManager().create_or_update_tool(tool, test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_with_broken_pip_requirements(test_user):
|
||||
def use_broken_package() -> str:
|
||||
"""
|
||||
Function that requires a package with known compatibility issues.
|
||||
|
||||
Returns:
|
||||
str: Should not reach here due to pip install failure.
|
||||
"""
|
||||
try:
|
||||
import some_nonexistent_package # This will fail during pip install
|
||||
|
||||
return "This should not execute"
|
||||
except ImportError as e:
|
||||
return f"Import error: {e}"
|
||||
|
||||
tool = create_tool_from_func(use_broken_package)
|
||||
# Add pip requirements that will fail in E2B environment
|
||||
tool.pip_requirements = [
|
||||
PipRequirement(name="numpy", version="1.24.0"), # Known to have compatibility issues
|
||||
PipRequirement(name="nonexistent-package-12345"), # This package doesn't exist
|
||||
]
|
||||
tool = ToolManager().create_or_update_tool(tool, test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def core_memory_tools(test_user):
|
||||
"""Create all base tools for testing."""
|
||||
@@ -418,6 +477,50 @@ async def test_local_sandbox_with_venv_pip_installs_basic(disable_e2b_api_key, c
|
||||
assert long_random_string in result.stdout[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.local_sandbox
|
||||
async def test_local_sandbox_with_tool_pip_requirements(disable_e2b_api_key, tool_with_pip_requirements, test_user, event_loop):
|
||||
"""Test that local sandbox installs tool-specific pip requirements."""
|
||||
manager = SandboxConfigManager()
|
||||
sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox")
|
||||
config_create = SandboxConfigCreate(config=LocalSandboxConfig(sandbox_dir=sandbox_dir, use_venv=True).model_dump())
|
||||
manager.create_or_update_sandbox_config(config_create, test_user)
|
||||
|
||||
sandbox = AsyncToolSandboxLocal(
|
||||
tool_with_pip_requirements.name, {}, user=test_user, tool_object=tool_with_pip_requirements, force_recreate_venv=True
|
||||
)
|
||||
result = await sandbox.run()
|
||||
|
||||
# Should succeed since tool pip requirements were installed
|
||||
assert "Success!" in result.func_return
|
||||
assert "Status: 200" in result.func_return
|
||||
assert "Array sum: 6" in result.func_return
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.local_sandbox
|
||||
async def test_local_sandbox_with_mixed_pip_requirements(disable_e2b_api_key, tool_with_pip_requirements, test_user, event_loop):
|
||||
"""Test that local sandbox installs both sandbox and tool pip requirements."""
|
||||
manager = SandboxConfigManager()
|
||||
sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox")
|
||||
|
||||
# Add sandbox-level pip requirement
|
||||
config_create = SandboxConfigCreate(
|
||||
config=LocalSandboxConfig(sandbox_dir=sandbox_dir, use_venv=True, pip_requirements=[PipRequirement(name="cowsay")]).model_dump()
|
||||
)
|
||||
manager.create_or_update_sandbox_config(config_create, test_user)
|
||||
|
||||
sandbox = AsyncToolSandboxLocal(
|
||||
tool_with_pip_requirements.name, {}, user=test_user, tool_object=tool_with_pip_requirements, force_recreate_venv=True
|
||||
)
|
||||
result = await sandbox.run()
|
||||
|
||||
# Should succeed since both sandbox and tool pip requirements were installed
|
||||
assert "Success!" in result.func_return
|
||||
assert "Status: 200" in result.func_return
|
||||
assert "Array sum: 6" in result.func_return
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.e2b_sandbox
|
||||
async def test_local_sandbox_with_venv_pip_installs_with_update(disable_e2b_api_key, cowsay_tool, test_user, event_loop):
|
||||
@@ -550,3 +653,69 @@ async def test_e2b_sandbox_with_list_rv(check_e2b_key_is_set, list_tool, test_us
|
||||
sandbox = AsyncToolSandboxE2B(list_tool.name, {}, user=test_user)
|
||||
result = await sandbox.run()
|
||||
assert len(result.func_return) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.e2b_sandbox
|
||||
async def test_e2b_sandbox_with_tool_pip_requirements(check_e2b_key_is_set, tool_with_pip_requirements, test_user, event_loop):
|
||||
"""Test that E2B sandbox installs tool-specific pip requirements."""
|
||||
manager = SandboxConfigManager()
|
||||
config_create = SandboxConfigCreate(config=E2BSandboxConfig().model_dump())
|
||||
manager.create_or_update_sandbox_config(config_create, test_user)
|
||||
|
||||
sandbox = AsyncToolSandboxE2B(tool_with_pip_requirements.name, {}, user=test_user, tool_object=tool_with_pip_requirements)
|
||||
result = await sandbox.run()
|
||||
|
||||
# Should succeed since tool pip requirements were installed
|
||||
assert "Success!" in result.func_return
|
||||
assert "Status: 200" in result.func_return
|
||||
assert "Array sum: 6" in result.func_return
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.e2b_sandbox
|
||||
async def test_e2b_sandbox_with_mixed_pip_requirements(check_e2b_key_is_set, tool_with_pip_requirements, test_user, event_loop):
|
||||
"""Test that E2B sandbox installs both sandbox and tool pip requirements."""
|
||||
manager = SandboxConfigManager()
|
||||
|
||||
# Add sandbox-level pip requirement
|
||||
config_create = SandboxConfigCreate(config=E2BSandboxConfig(pip_requirements=["cowsay"]).model_dump())
|
||||
manager.create_or_update_sandbox_config(config_create, test_user)
|
||||
|
||||
sandbox = AsyncToolSandboxE2B(tool_with_pip_requirements.name, {}, user=test_user, tool_object=tool_with_pip_requirements)
|
||||
result = await sandbox.run()
|
||||
|
||||
# Should succeed since both sandbox and tool pip requirements were installed
|
||||
assert "Success!" in result.func_return
|
||||
assert "Status: 200" in result.func_return
|
||||
assert "Array sum: 6" in result.func_return
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.e2b_sandbox
|
||||
async def test_e2b_sandbox_with_broken_tool_pip_requirements_error_handling(
|
||||
check_e2b_key_is_set, tool_with_broken_pip_requirements, test_user, event_loop
|
||||
):
|
||||
"""Test that E2B sandbox provides informative error messages for broken tool pip requirements."""
|
||||
manager = SandboxConfigManager()
|
||||
config_create = SandboxConfigCreate(config=E2BSandboxConfig().model_dump())
|
||||
manager.create_or_update_sandbox_config(config_create, test_user)
|
||||
|
||||
sandbox = AsyncToolSandboxE2B(tool_with_broken_pip_requirements.name, {}, user=test_user, tool_object=tool_with_broken_pip_requirements)
|
||||
|
||||
# Should raise a RuntimeError with informative message
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await sandbox.run()
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
print(error_message)
|
||||
|
||||
# Verify the error message contains helpful information
|
||||
assert "Failed to install tool pip requirement" in error_message
|
||||
assert "use_broken_package" in error_message # Tool name
|
||||
assert "E2B sandbox" in error_message
|
||||
assert "package version incompatibility" in error_message
|
||||
assert "Consider updating the package version or removing the version constraint" in error_message
|
||||
|
||||
# Should mention one of the problematic packages
|
||||
assert "numpy==1.24.0" in error_message or "nonexistent-package-12345" in error_message
|
||||
|
||||
@@ -14,7 +14,8 @@ from letta.schemas.agent import AgentState, CreateAgent
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.environment_variables import AgentEnvironmentVariable, SandboxEnvironmentVariableCreate
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, PipRequirement, SandboxConfigCreate, SandboxConfigUpdate
|
||||
from letta.schemas.pip_requirement import PipRequirement
|
||||
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate
|
||||
from letta.schemas.user import User
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
|
||||
@@ -73,6 +73,7 @@ from letta.schemas.organization import Organization
|
||||
from letta.schemas.organization import Organization as PydanticOrganization
|
||||
from letta.schemas.organization import OrganizationUpdate
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.pip_requirement import PipRequirement
|
||||
from letta.schemas.run import Run as PydanticRun
|
||||
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate, SandboxType
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
@@ -3048,6 +3049,175 @@ async def test_upsert_base_tools_with_empty_type_filter(server: SyncServer, defa
|
||||
assert tools == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_tool_with_pip_requirements(server: SyncServer, default_user, default_organization):
|
||||
def test_tool_with_deps():
|
||||
"""
|
||||
A test tool with pip dependencies.
|
||||
|
||||
Returns:
|
||||
str: Hello message.
|
||||
"""
|
||||
return "hello"
|
||||
|
||||
# Create pip requirements
|
||||
pip_reqs = [
|
||||
PipRequirement(name="requests", version="2.28.0"),
|
||||
PipRequirement(name="numpy"), # No version specified
|
||||
]
|
||||
|
||||
# Set up tool details
|
||||
source_code = parse_source_code(test_tool_with_deps)
|
||||
source_type = "python"
|
||||
description = "A test tool with pip dependencies"
|
||||
tags = ["test"]
|
||||
metadata = {"test": "pip_requirements"}
|
||||
|
||||
tool = PydanticTool(
|
||||
description=description, tags=tags, source_code=source_code, source_type=source_type, metadata_=metadata, pip_requirements=pip_reqs
|
||||
)
|
||||
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
|
||||
derived_name = derived_json_schema["name"]
|
||||
tool.json_schema = derived_json_schema
|
||||
tool.name = derived_name
|
||||
|
||||
created_tool = await server.tool_manager.create_or_update_tool_async(tool, actor=default_user)
|
||||
|
||||
# Assertions
|
||||
assert created_tool.pip_requirements is not None
|
||||
assert len(created_tool.pip_requirements) == 2
|
||||
assert created_tool.pip_requirements[0].name == "requests"
|
||||
assert created_tool.pip_requirements[0].version == "2.28.0"
|
||||
assert created_tool.pip_requirements[1].name == "numpy"
|
||||
assert created_tool.pip_requirements[1].version is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_tool_without_pip_requirements(server: SyncServer, print_tool):
|
||||
# Verify that tools without pip_requirements have the field as None
|
||||
assert print_tool.pip_requirements is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_tool_pip_requirements(server: SyncServer, print_tool, default_user):
|
||||
# Add pip requirements to existing tool
|
||||
pip_reqs = [
|
||||
PipRequirement(name="pandas", version="1.5.0"),
|
||||
PipRequirement(name="matplotlib"),
|
||||
]
|
||||
|
||||
tool_update = ToolUpdate(pip_requirements=pip_reqs)
|
||||
await server.tool_manager.update_tool_by_id_async(print_tool.id, tool_update, actor=default_user)
|
||||
|
||||
# Fetch the updated tool
|
||||
updated_tool = await server.tool_manager.get_tool_by_id_async(print_tool.id, actor=default_user)
|
||||
|
||||
# Assertions
|
||||
assert updated_tool.pip_requirements is not None
|
||||
assert len(updated_tool.pip_requirements) == 2
|
||||
assert updated_tool.pip_requirements[0].name == "pandas"
|
||||
assert updated_tool.pip_requirements[0].version == "1.5.0"
|
||||
assert updated_tool.pip_requirements[1].name == "matplotlib"
|
||||
assert updated_tool.pip_requirements[1].version is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_tool_clear_pip_requirements(server: SyncServer, default_user, default_organization):
|
||||
def test_tool_clear_deps():
|
||||
"""
|
||||
A test tool to clear dependencies.
|
||||
|
||||
Returns:
|
||||
str: Hello message.
|
||||
"""
|
||||
return "hello"
|
||||
|
||||
# Create a tool with pip requirements
|
||||
pip_reqs = [PipRequirement(name="requests")]
|
||||
|
||||
# Set up tool details
|
||||
source_code = parse_source_code(test_tool_clear_deps)
|
||||
source_type = "python"
|
||||
description = "A test tool to clear dependencies"
|
||||
tags = ["test"]
|
||||
metadata = {"test": "clear_deps"}
|
||||
|
||||
tool = PydanticTool(
|
||||
description=description, tags=tags, source_code=source_code, source_type=source_type, metadata_=metadata, pip_requirements=pip_reqs
|
||||
)
|
||||
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
|
||||
derived_name = derived_json_schema["name"]
|
||||
tool.json_schema = derived_json_schema
|
||||
tool.name = derived_name
|
||||
|
||||
created_tool = await server.tool_manager.create_or_update_tool_async(tool, actor=default_user)
|
||||
|
||||
# Verify it has requirements
|
||||
assert created_tool.pip_requirements is not None
|
||||
assert len(created_tool.pip_requirements) == 1
|
||||
|
||||
# Clear the requirements
|
||||
tool_update = ToolUpdate(pip_requirements=[])
|
||||
await server.tool_manager.update_tool_by_id_async(created_tool.id, tool_update, actor=default_user)
|
||||
|
||||
# Fetch the updated tool
|
||||
updated_tool = await server.tool_manager.get_tool_by_id_async(created_tool.id, actor=default_user)
|
||||
|
||||
# Assertions
|
||||
assert updated_tool.pip_requirements == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pip_requirements_roundtrip(server: SyncServer, default_user, default_organization):
|
||||
def roundtrip_test_tool():
|
||||
"""
|
||||
Test pip requirements roundtrip.
|
||||
|
||||
Returns:
|
||||
str: Test message.
|
||||
"""
|
||||
return "test"
|
||||
|
||||
# Create pip requirements with various version formats
|
||||
pip_reqs = [
|
||||
PipRequirement(name="requests", version="2.28.0"),
|
||||
PipRequirement(name="flask", version="2.0"),
|
||||
PipRequirement(name="django", version="4.1.0-beta"),
|
||||
PipRequirement(name="numpy"), # No version
|
||||
]
|
||||
|
||||
# Set up tool details
|
||||
source_code = parse_source_code(roundtrip_test_tool)
|
||||
source_type = "python"
|
||||
description = "Test pip requirements roundtrip"
|
||||
tags = ["test"]
|
||||
metadata = {"test": "roundtrip"}
|
||||
|
||||
tool = PydanticTool(
|
||||
description=description, tags=tags, source_code=source_code, source_type=source_type, metadata_=metadata, pip_requirements=pip_reqs
|
||||
)
|
||||
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
|
||||
derived_name = derived_json_schema["name"]
|
||||
tool.json_schema = derived_json_schema
|
||||
tool.name = derived_name
|
||||
|
||||
created_tool = await server.tool_manager.create_or_update_tool_async(tool, actor=default_user)
|
||||
|
||||
# Fetch by ID
|
||||
fetched_tool = await server.tool_manager.get_tool_by_id_async(created_tool.id, actor=default_user)
|
||||
|
||||
# Verify all requirements match exactly
|
||||
assert fetched_tool.pip_requirements is not None
|
||||
assert len(fetched_tool.pip_requirements) == 4
|
||||
|
||||
# Check each requirement
|
||||
reqs_dict = {req.name: req.version for req in fetched_tool.pip_requirements}
|
||||
assert reqs_dict["requests"] == "2.28.0"
|
||||
assert reqs_dict["flask"] == "2.0"
|
||||
assert reqs_dict["django"] == "4.1.0-beta"
|
||||
assert reqs_dict["numpy"] is None
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Message Manager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user