diff --git a/alembic/versions/1c6b6a38b713_add_pip_requirements_to_tools.py b/alembic/versions/1c6b6a38b713_add_pip_requirements_to_tools.py new file mode 100644 index 00000000..52a48198 --- /dev/null +++ b/alembic/versions/1c6b6a38b713_add_pip_requirements_to_tools.py @@ -0,0 +1,31 @@ +"""Add pip requirements to tools + +Revision ID: 1c6b6a38b713 +Revises: c96263433aef +Create Date: 2025-06-12 18:06:54.838510 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "1c6b6a38b713" +down_revision: Union[str, None] = "c96263433aef" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("tools", sa.Column("pip_requirements", sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("tools", "pip_requirements") + # ### end Alembic commands ### diff --git a/letta/orm/tool.py b/letta/orm/tool.py index 7a7c3199..6dbce85c 100644 --- a/letta/orm/tool.py +++ b/letta/orm/tool.py @@ -44,6 +44,9 @@ class Tool(SqlalchemyBase, OrganizationMixin): source_code: Mapped[Optional[str]] = mapped_column(String, doc="The source code of the function.") json_schema: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="The OAI compatable JSON schema of the function.") args_json_schema: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="The JSON schema of the function arguments.") + pip_requirements: Mapped[Optional[List]] = mapped_column( + JSON, nullable=True, doc="Optional list of pip packages required by this tool." + ) metadata_: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="A dictionary of additional metadata for the tool.") # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin") diff --git a/letta/schemas/pip_requirement.py b/letta/schemas/pip_requirement.py new file mode 100644 index 00000000..acfbe88b --- /dev/null +++ b/letta/schemas/pip_requirement.py @@ -0,0 +1,25 @@ +import re +from typing import Optional + +from pydantic import BaseModel, Field, field_validator + + +class PipRequirement(BaseModel): + name: str = Field(..., min_length=1, description="Name of the pip package.") + version: Optional[str] = Field(None, description="Optional version of the package, following semantic versioning.") + + @field_validator("version") + @classmethod + def validate_version(cls, v: Optional[str]) -> Optional[str]: + if v is None: + return None + semver_pattern = re.compile(r"^\d+(\.\d+){0,2}(-[a-zA-Z0-9.]+)?$") + if not semver_pattern.match(v): + raise ValueError(f"Invalid version format: {v}. Must follow semantic versioning (e.g., 1.2.3, 2.0, 1.5.0-alpha).") + return v + + def __str__(self) -> str: + """Return a pip-installable string format.""" + if self.version: + return f"{self.name}=={self.version}" + return self.name diff --git a/letta/schemas/sandbox_config.py b/letta/schemas/sandbox_config.py index c265fbf8..f3b90d8e 100644 --- a/letta/schemas/sandbox_config.py +++ b/letta/schemas/sandbox_config.py @@ -1,6 +1,5 @@ import hashlib import json -import re from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union @@ -9,6 +8,7 @@ from pydantic import BaseModel, Field, model_validator from letta.constants import LETTA_TOOL_EXECUTION_DIR from letta.schemas.agent import AgentState from letta.schemas.letta_base import LettaBase, OrmMetadataBase +from letta.schemas.pip_requirement import PipRequirement from letta.settings import tool_settings @@ -27,24 +27,6 @@ class SandboxRunResult(BaseModel): sandbox_config_fingerprint: str = Field(None, description="The fingerprint of the config for the sandbox") -class PipRequirement(BaseModel): - name: str = Field(..., min_length=1, description="Name of the pip package.") - version: Optional[str] = Field(None, description="Optional version of the package, following semantic versioning.") - - @classmethod - def validate_version(cls, version: Optional[str]) -> Optional[str]: - if version is None: - return None - semver_pattern = re.compile(r"^\d+(\.\d+){0,2}(-[a-zA-Z0-9.]+)?$") - if not semver_pattern.match(version): - raise ValueError(f"Invalid version format: {version}. Must follow semantic versioning (e.g., 1.2.3, 2.0, 1.5.0-alpha).") - return version - - def __init__(self, **data): - super().__init__(**data) - self.version = self.validate_version(self.version) - - class LocalSandboxConfig(BaseModel): sandbox_dir: Optional[str] = Field(None, description="Directory for the sandbox environment.") use_venv: bool = Field(False, description="Whether or not to use the venv, or run directly in the same run loop.") diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index e7b23e86..47d9d6e8 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -24,6 +24,7 @@ from letta.functions.schema_generator import ( from letta.log import get_logger from letta.orm.enums import ToolType from letta.schemas.letta_base import LettaBase +from letta.schemas.pip_requirement import PipRequirement logger = get_logger(__name__) @@ -60,6 +61,7 @@ class Tool(BaseTool): # tool configuration return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.") + pip_requirements: Optional[List[PipRequirement]] = Field(None, description="Optional list of pip packages required by this tool.") # metadata fields created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.") @@ -145,6 +147,7 @@ class ToolCreate(LettaBase): ) args_json_schema: Optional[Dict] = Field(None, description="The args JSON schema of the function.") return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.") + pip_requirements: Optional[List[PipRequirement]] = Field(None, description="Optional list of pip packages required by this tool.") # TODO should we put the HTTP / API fetch inside from_mcp? # async def from_mcp(cls, mcp_server: str, mcp_tool_name: str) -> "ToolCreate": @@ -253,6 +256,7 @@ class ToolUpdate(LettaBase): ) args_json_schema: Optional[Dict] = Field(None, description="The args JSON schema of the function.") return_char_limit: Optional[int] = Field(None, description="The maximum number of characters in the response.") + pip_requirements: Optional[List[PipRequirement]] = Field(None, description="Optional list of pip packages required by this tool.") class Config: extra = "ignore" # Allows extra fields without validation errors @@ -269,3 +273,4 @@ class ToolRunFromSource(LettaBase): json_schema: Optional[Dict] = Field( None, description="The JSON schema of the function (auto-generated from source_code if not provided)" ) + pip_requirements: Optional[List[PipRequirement]] = Field(None, description="Optional list of pip packages required by this tool.") diff --git a/letta/services/helpers/tool_execution_helper.py b/letta/services/helpers/tool_execution_helper.py index 34ea17af..1fef1e0f 100644 --- a/letta/services/helpers/tool_execution_helper.py +++ b/letta/services/helpers/tool_execution_helper.py @@ -2,7 +2,7 @@ import os import platform import subprocess import venv -from typing import Dict, Optional +from typing import TYPE_CHECKING, Dict, Optional from datamodel_code_generator import DataModelType, PythonVersion from datamodel_code_generator.model import get_data_model_types @@ -11,6 +11,9 @@ from datamodel_code_generator.parser.jsonschema import JsonSchemaParser from letta.log import get_logger from letta.schemas.sandbox_config import LocalSandboxConfig +if TYPE_CHECKING: + from letta.schemas.tool import Tool + logger = get_logger(__name__) @@ -85,14 +88,12 @@ def install_pip_requirements_for_sandbox( upgrade: bool = True, user_install_if_no_venv: bool = False, env: Optional[Dict[str, str]] = None, + tool: Optional["Tool"] = None, ): """ Installs the specified pip requirements inside the correct environment (venv or system). + Installs both sandbox-level and tool-specific pip requirements. """ - if not local_configs.pip_requirements: - logger.debug("No pip requirements specified; skipping installation.") - return - sandbox_dir = os.path.expanduser(local_configs.sandbox_dir) # Expand tilde local_configs.sandbox_dir = sandbox_dir # Update the object to store the absolute path @@ -102,19 +103,48 @@ def install_pip_requirements_for_sandbox( if local_configs.use_venv: ensure_pip_is_up_to_date(python_exec, env=env) - # Construct package list - packages = [f"{req.name}=={req.version}" if req.version else req.name for req in local_configs.pip_requirements] + # Collect all pip requirements + all_packages = [] + + # Add sandbox-level pip requirements + if local_configs.pip_requirements: + packages = [f"{req.name}=={req.version}" if req.version else req.name for req in local_configs.pip_requirements] + all_packages.extend(packages) + logger.debug(f"Added sandbox pip requirements: {packages}") + + # Add tool-specific pip requirements + if tool and tool.pip_requirements: + tool_packages = [str(req) for req in tool.pip_requirements] + all_packages.extend(tool_packages) + logger.debug(f"Added tool pip requirements for {tool.name}: {tool_packages}") + + if not all_packages: + logger.debug("No pip requirements specified; skipping installation.") + return # Construct pip install command pip_cmd = [python_exec, "-m", "pip", "install"] if upgrade: pip_cmd.append("--upgrade") - pip_cmd += packages + pip_cmd += all_packages if user_install_if_no_venv and not local_configs.use_venv: pip_cmd.append("--user") - run_subprocess(pip_cmd, env=env, fail_msg=f"Failed to install packages: {', '.join(packages)}") + # Enhanced error message for better debugging + sandbox_packages = [f"{req.name}=={req.version}" if req.version else req.name for req in (local_configs.pip_requirements or [])] + tool_packages = [str(req) for req in (tool.pip_requirements if tool and tool.pip_requirements else [])] + + error_details = [] + if sandbox_packages: + error_details.append(f"sandbox requirements: {', '.join(sandbox_packages)}") + if tool_packages: + error_details.append(f"tool requirements: {', '.join(tool_packages)}") + + context = f" ({'; '.join(error_details)})" if error_details else "" + fail_msg = f"Failed to install pip packages{context}. This may be due to package version incompatibility. Consider updating package versions or removing version constraints." + + run_subprocess(pip_cmd, env=env, fail_msg=fail_msg) def create_venv_for_local_sandbox(sandbox_dir_path: str, venv_path: str, env: Dict[str, str], force_recreate: bool): diff --git a/letta/services/tool_sandbox/e2b_sandbox.py b/letta/services/tool_sandbox/e2b_sandbox.py index fb44fc94..1e232168 100644 --- a/letta/services/tool_sandbox/e2b_sandbox.py +++ b/letta/services/tool_sandbox/e2b_sandbox.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, Any, Dict, Optional +from e2b.sandbox.commands.command_handle import CommandExitException from e2b_code_interpreter import AsyncSandbox from letta.log import get_logger @@ -189,14 +190,63 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase): "package": package, }, ) - await sbx.commands.run(f"pip install {package}") + try: + await sbx.commands.run(f"pip install {package}") + log_event( + "e2b_pip_install_finished", + { + "sandbox_id": sbx.sandbox_id, + "package": package, + }, + ) + except CommandExitException as e: + error_msg = f"Failed to install sandbox pip requirement '{package}' in E2B sandbox. This may be due to package version incompatibility with the E2B environment. Error: {e}" + logger.error(error_msg) + log_event( + "e2b_pip_install_failed", + { + "sandbox_id": sbx.sandbox_id, + "package": package, + "error": str(e), + }, + ) + raise RuntimeError(error_msg) from e + + # Install tool-specific pip requirements + if self.tool and self.tool.pip_requirements: + for pip_requirement in self.tool.pip_requirements: + package_str = str(pip_requirement) log_event( - "e2b_pip_install_finished", + "tool_pip_install_started", { "sandbox_id": sbx.sandbox_id, - "package": package, + "package": package_str, + "tool_name": self.tool.name, }, ) + try: + await sbx.commands.run(f"pip install {package_str}") + log_event( + "tool_pip_install_finished", + { + "sandbox_id": sbx.sandbox_id, + "package": package_str, + "tool_name": self.tool.name, + }, + ) + except CommandExitException as e: + error_msg = f"Failed to install tool pip requirement '{package_str}' for tool '{self.tool.name}' in E2B sandbox. This may be due to package version incompatibility with the E2B environment. Consider updating the package version or removing the version constraint. Error: {e}" + logger.error(error_msg) + log_event( + "tool_pip_install_failed", + { + "sandbox_id": sbx.sandbox_id, + "package": package_str, + "tool_name": self.tool.name, + "error": str(e), + }, + ) + raise RuntimeError(error_msg) from e return sbx diff --git a/letta/services/tool_sandbox/local_sandbox.py b/letta/services/tool_sandbox/local_sandbox.py index 3d716137..901231ea 100644 --- a/letta/services/tool_sandbox/local_sandbox.py +++ b/letta/services/tool_sandbox/local_sandbox.py @@ -175,7 +175,9 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase): log_event(name="finish create_venv_for_local_sandbox") log_event(name="start install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()}) - await asyncio.to_thread(install_pip_requirements_for_sandbox, local_configs, upgrade=True, user_install_if_no_venv=False, env=env) + await asyncio.to_thread( + install_pip_requirements_for_sandbox, local_configs, upgrade=True, user_install_if_no_venv=False, env=env, tool=self.tool + ) log_event(name="finish install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()}) @trace_method diff --git a/tests/integration_test_async_tool_sandbox.py b/tests/integration_test_async_tool_sandbox.py index e6d54207..4390fdac 100644 --- a/tests/integration_test_async_tool_sandbox.py +++ b/tests/integration_test_async_tool_sandbox.py @@ -15,7 +15,8 @@ from letta.schemas.agent import AgentState, CreateAgent from letta.schemas.block import CreateBlock from letta.schemas.environment_variables import AgentEnvironmentVariable, SandboxEnvironmentVariableCreate from letta.schemas.organization import Organization -from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, PipRequirement, SandboxConfigCreate +from letta.schemas.pip_requirement import PipRequirement +from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate from letta.schemas.user import User from letta.server.server import SyncServer from letta.services.organization_manager import OrganizationManager @@ -263,6 +264,64 @@ def custom_test_sandbox_config(test_user): # Tool-specific fixtures +@pytest.fixture +def tool_with_pip_requirements(test_user): + def use_requests_and_numpy() -> str: + """ + Function that uses requests and numpy packages to test tool-specific pip requirements. + + Returns: + str: Success message if packages are available. + """ + try: + import numpy as np + import requests + + # Simple usage to verify packages work + response = requests.get("https://httpbin.org/json", timeout=5) + arr = np.array([1, 2, 3]) + return f"Success! Status: {response.status_code}, Array sum: {np.sum(arr)}" + except ImportError as e: + return f"Import error: {e}" + except Exception as e: + return f"Other error: {e}" + + tool = create_tool_from_func(use_requests_and_numpy) + # Add pip requirements to the tool - using more recent versions for E2B compatibility + tool.pip_requirements = [ + PipRequirement(name="requests", version="2.31.0"), + PipRequirement(name="numpy", version="1.26.0"), + ] + tool = ToolManager().create_or_update_tool(tool, test_user) + yield tool + + +@pytest.fixture +def tool_with_broken_pip_requirements(test_user): + def use_broken_package() -> str: + """ + Function that requires a package with known compatibility issues. + + Returns: + str: Should not reach here due to pip install failure. + """ + try: + import some_nonexistent_package # This will fail during pip install + + return "This should not execute" + except ImportError as e: + return f"Import error: {e}" + + tool = create_tool_from_func(use_broken_package) + # Add pip requirements that will fail in E2B environment + tool.pip_requirements = [ + PipRequirement(name="numpy", version="1.24.0"), # Known to have compatibility issues + PipRequirement(name="nonexistent-package-12345"), # This package doesn't exist + ] + tool = ToolManager().create_or_update_tool(tool, test_user) + yield tool + + @pytest.fixture def core_memory_tools(test_user): """Create all base tools for testing.""" @@ -418,6 +477,50 @@ async def test_local_sandbox_with_venv_pip_installs_basic(disable_e2b_api_key, c assert long_random_string in result.stdout[0] +@pytest.mark.asyncio +@pytest.mark.local_sandbox +async def test_local_sandbox_with_tool_pip_requirements(disable_e2b_api_key, tool_with_pip_requirements, test_user, event_loop): + """Test that local sandbox installs tool-specific pip requirements.""" + manager = SandboxConfigManager() + sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox") + config_create = SandboxConfigCreate(config=LocalSandboxConfig(sandbox_dir=sandbox_dir, use_venv=True).model_dump()) + manager.create_or_update_sandbox_config(config_create, test_user) + + sandbox = AsyncToolSandboxLocal( + tool_with_pip_requirements.name, {}, user=test_user, tool_object=tool_with_pip_requirements, force_recreate_venv=True + ) + result = await sandbox.run() + + # Should succeed since tool pip requirements were installed + assert "Success!" in result.func_return + assert "Status: 200" in result.func_return + assert "Array sum: 6" in result.func_return + + +@pytest.mark.asyncio +@pytest.mark.local_sandbox +async def test_local_sandbox_with_mixed_pip_requirements(disable_e2b_api_key, tool_with_pip_requirements, test_user, event_loop): + """Test that local sandbox installs both sandbox and tool pip requirements.""" + manager = SandboxConfigManager() + sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox") + + # Add sandbox-level pip requirement + config_create = SandboxConfigCreate( + config=LocalSandboxConfig(sandbox_dir=sandbox_dir, use_venv=True, pip_requirements=[PipRequirement(name="cowsay")]).model_dump() + ) + manager.create_or_update_sandbox_config(config_create, test_user) + + sandbox = AsyncToolSandboxLocal( + tool_with_pip_requirements.name, {}, user=test_user, tool_object=tool_with_pip_requirements, force_recreate_venv=True + ) + result = await sandbox.run() + + # Should succeed since both sandbox and tool pip requirements were installed + assert "Success!" in result.func_return + assert "Status: 200" in result.func_return + assert "Array sum: 6" in result.func_return + + @pytest.mark.asyncio @pytest.mark.e2b_sandbox async def test_local_sandbox_with_venv_pip_installs_with_update(disable_e2b_api_key, cowsay_tool, test_user, event_loop): @@ -550,3 +653,69 @@ async def test_e2b_sandbox_with_list_rv(check_e2b_key_is_set, list_tool, test_us sandbox = AsyncToolSandboxE2B(list_tool.name, {}, user=test_user) result = await sandbox.run() assert len(result.func_return) == 5 + + +@pytest.mark.asyncio +@pytest.mark.e2b_sandbox +async def test_e2b_sandbox_with_tool_pip_requirements(check_e2b_key_is_set, tool_with_pip_requirements, test_user, event_loop): + """Test that E2B sandbox installs tool-specific pip requirements.""" + manager = SandboxConfigManager() + config_create = SandboxConfigCreate(config=E2BSandboxConfig().model_dump()) + manager.create_or_update_sandbox_config(config_create, test_user) + + sandbox = AsyncToolSandboxE2B(tool_with_pip_requirements.name, {}, user=test_user, tool_object=tool_with_pip_requirements) + result = await sandbox.run() + + # Should succeed since tool pip requirements were installed + assert "Success!" in result.func_return + assert "Status: 200" in result.func_return + assert "Array sum: 6" in result.func_return + + +@pytest.mark.asyncio +@pytest.mark.e2b_sandbox +async def test_e2b_sandbox_with_mixed_pip_requirements(check_e2b_key_is_set, tool_with_pip_requirements, test_user, event_loop): + """Test that E2B sandbox installs both sandbox and tool pip requirements.""" + manager = SandboxConfigManager() + + # Add sandbox-level pip requirement + config_create = SandboxConfigCreate(config=E2BSandboxConfig(pip_requirements=["cowsay"]).model_dump()) + manager.create_or_update_sandbox_config(config_create, test_user) + + sandbox = AsyncToolSandboxE2B(tool_with_pip_requirements.name, {}, user=test_user, tool_object=tool_with_pip_requirements) + result = await sandbox.run() + + # Should succeed since both sandbox and tool pip requirements were installed + assert "Success!" in result.func_return + assert "Status: 200" in result.func_return + assert "Array sum: 6" in result.func_return + + +@pytest.mark.asyncio +@pytest.mark.e2b_sandbox +async def test_e2b_sandbox_with_broken_tool_pip_requirements_error_handling( + check_e2b_key_is_set, tool_with_broken_pip_requirements, test_user, event_loop +): + """Test that E2B sandbox provides informative error messages for broken tool pip requirements.""" + manager = SandboxConfigManager() + config_create = SandboxConfigCreate(config=E2BSandboxConfig().model_dump()) + manager.create_or_update_sandbox_config(config_create, test_user) + + sandbox = AsyncToolSandboxE2B(tool_with_broken_pip_requirements.name, {}, user=test_user, tool_object=tool_with_broken_pip_requirements) + + # Should raise a RuntimeError with informative message + with pytest.raises(RuntimeError) as exc_info: + await sandbox.run() + + error_message = str(exc_info.value) + print(error_message) + + # Verify the error message contains helpful information + assert "Failed to install tool pip requirement" in error_message + assert "use_broken_package" in error_message # Tool name + assert "E2B sandbox" in error_message + assert "package version incompatibility" in error_message + assert "Consider updating the package version or removing the version constraint" in error_message + + # Should mention one of the problematic packages + assert "numpy==1.24.0" in error_message or "nonexistent-package-12345" in error_message diff --git a/tests/integration_test_tool_execution_sandbox.py b/tests/integration_test_tool_execution_sandbox.py index 1a9bd763..2cdee949 100644 --- a/tests/integration_test_tool_execution_sandbox.py +++ b/tests/integration_test_tool_execution_sandbox.py @@ -14,7 +14,8 @@ from letta.schemas.agent import AgentState, CreateAgent from letta.schemas.block import CreateBlock from letta.schemas.environment_variables import AgentEnvironmentVariable, SandboxEnvironmentVariableCreate from letta.schemas.organization import Organization -from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, PipRequirement, SandboxConfigCreate, SandboxConfigUpdate +from letta.schemas.pip_requirement import PipRequirement +from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate from letta.schemas.user import User from letta.server.server import SyncServer from letta.services.organization_manager import OrganizationManager diff --git a/tests/test_managers.py b/tests/test_managers.py index 500dcd33..147aa402 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -73,6 +73,7 @@ from letta.schemas.organization import Organization from letta.schemas.organization import Organization as PydanticOrganization from letta.schemas.organization import OrganizationUpdate from letta.schemas.passage import Passage as PydanticPassage +from letta.schemas.pip_requirement import PipRequirement from letta.schemas.run import Run as PydanticRun from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate, SandboxType from letta.schemas.source import Source as PydanticSource @@ -3048,6 +3049,175 @@ async def test_upsert_base_tools_with_empty_type_filter(server: SyncServer, defa assert tools == [] +@pytest.mark.asyncio +async def test_create_tool_with_pip_requirements(server: SyncServer, default_user, default_organization): + def test_tool_with_deps(): + """ + A test tool with pip dependencies. + + Returns: + str: Hello message. + """ + return "hello" + + # Create pip requirements + pip_reqs = [ + PipRequirement(name="requests", version="2.28.0"), + PipRequirement(name="numpy"), # No version specified + ] + + # Set up tool details + source_code = parse_source_code(test_tool_with_deps) + source_type = "python" + description = "A test tool with pip dependencies" + tags = ["test"] + metadata = {"test": "pip_requirements"} + + tool = PydanticTool( + description=description, tags=tags, source_code=source_code, source_type=source_type, metadata_=metadata, pip_requirements=pip_reqs + ) + derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name) + derived_name = derived_json_schema["name"] + tool.json_schema = derived_json_schema + tool.name = derived_name + + created_tool = await server.tool_manager.create_or_update_tool_async(tool, actor=default_user) + + # Assertions + assert created_tool.pip_requirements is not None + assert len(created_tool.pip_requirements) == 2 + assert created_tool.pip_requirements[0].name == "requests" + assert created_tool.pip_requirements[0].version == "2.28.0" + assert created_tool.pip_requirements[1].name == "numpy" + assert created_tool.pip_requirements[1].version is None + + +@pytest.mark.asyncio +async def test_create_tool_without_pip_requirements(server: SyncServer, print_tool): + # Verify that tools without pip_requirements have the field as None + assert print_tool.pip_requirements is None + + +@pytest.mark.asyncio +async def test_update_tool_pip_requirements(server: SyncServer, print_tool, default_user): + # Add pip requirements to existing tool + pip_reqs = [ + PipRequirement(name="pandas", version="1.5.0"), + PipRequirement(name="matplotlib"), + ] + + tool_update = ToolUpdate(pip_requirements=pip_reqs) + await server.tool_manager.update_tool_by_id_async(print_tool.id, tool_update, actor=default_user) + + # Fetch the updated tool + updated_tool = await server.tool_manager.get_tool_by_id_async(print_tool.id, actor=default_user) + + # Assertions + assert updated_tool.pip_requirements is not None + assert len(updated_tool.pip_requirements) == 2 + assert updated_tool.pip_requirements[0].name == "pandas" + assert updated_tool.pip_requirements[0].version == "1.5.0" + assert updated_tool.pip_requirements[1].name == "matplotlib" + assert updated_tool.pip_requirements[1].version is None + + +@pytest.mark.asyncio +async def test_update_tool_clear_pip_requirements(server: SyncServer, default_user, default_organization): + def test_tool_clear_deps(): + """ + A test tool to clear dependencies. + + Returns: + str: Hello message. + """ + return "hello" + + # Create a tool with pip requirements + pip_reqs = [PipRequirement(name="requests")] + + # Set up tool details + source_code = parse_source_code(test_tool_clear_deps) + source_type = "python" + description = "A test tool to clear dependencies" + tags = ["test"] + metadata = {"test": "clear_deps"} + + tool = PydanticTool( + description=description, tags=tags, source_code=source_code, source_type=source_type, metadata_=metadata, pip_requirements=pip_reqs + ) + derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name) + derived_name = derived_json_schema["name"] + tool.json_schema = derived_json_schema + tool.name = derived_name + + created_tool = await server.tool_manager.create_or_update_tool_async(tool, actor=default_user) + + # Verify it has requirements + assert created_tool.pip_requirements is not None + assert len(created_tool.pip_requirements) == 1 + + # Clear the requirements + tool_update = ToolUpdate(pip_requirements=[]) + await server.tool_manager.update_tool_by_id_async(created_tool.id, tool_update, actor=default_user) + + # Fetch the updated tool + updated_tool = await server.tool_manager.get_tool_by_id_async(created_tool.id, actor=default_user) + + # Assertions + assert updated_tool.pip_requirements == [] + + +@pytest.mark.asyncio +async def test_pip_requirements_roundtrip(server: SyncServer, default_user, default_organization): + def roundtrip_test_tool(): + """ + Test pip requirements roundtrip. + + Returns: + str: Test message. + """ + return "test" + + # Create pip requirements with various version formats + pip_reqs = [ + PipRequirement(name="requests", version="2.28.0"), + PipRequirement(name="flask", version="2.0"), + PipRequirement(name="django", version="4.1.0-beta"), + PipRequirement(name="numpy"), # No version + ] + + # Set up tool details + source_code = parse_source_code(roundtrip_test_tool) + source_type = "python" + description = "Test pip requirements roundtrip" + tags = ["test"] + metadata = {"test": "roundtrip"} + + tool = PydanticTool( + description=description, tags=tags, source_code=source_code, source_type=source_type, metadata_=metadata, pip_requirements=pip_reqs + ) + derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name) + derived_name = derived_json_schema["name"] + tool.json_schema = derived_json_schema + tool.name = derived_name + + created_tool = await server.tool_manager.create_or_update_tool_async(tool, actor=default_user) + + # Fetch by ID + fetched_tool = await server.tool_manager.get_tool_by_id_async(created_tool.id, actor=default_user) + + # Verify all requirements match exactly + assert fetched_tool.pip_requirements is not None + assert len(fetched_tool.pip_requirements) == 4 + + # Check each requirement + reqs_dict = {req.name: req.version for req in fetched_tool.pip_requirements} + assert reqs_dict["requests"] == "2.28.0" + assert reqs_dict["flask"] == "2.0" + assert reqs_dict["django"] == "4.1.0-beta" + assert reqs_dict["numpy"] is None + + # ====================================================================================================================== # Message Manager Tests # ======================================================================================================================