feat: Add optional pip requirements to tool object (#2793)

This commit is contained in:
Matthew Zhou
2025-06-13 13:20:36 -07:00
committed by GitHub
parent dd5581a72a
commit 2e77ea6e76
11 changed files with 502 additions and 34 deletions

View File

@@ -0,0 +1,31 @@
"""Add pip requirements to tools
Revision ID: 1c6b6a38b713
Revises: c96263433aef
Create Date: 2025-06-12 18:06:54.838510
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "1c6b6a38b713"
down_revision: Union[str, None] = "c96263433aef"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("tools", sa.Column("pip_requirements", sa.JSON(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("tools", "pip_requirements")
# ### end Alembic commands ###

View File

@@ -44,6 +44,9 @@ class Tool(SqlalchemyBase, OrganizationMixin):
source_code: Mapped[Optional[str]] = mapped_column(String, doc="The source code of the function.")
json_schema: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="The OAI compatable JSON schema of the function.")
args_json_schema: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="The JSON schema of the function arguments.")
pip_requirements: Mapped[Optional[List]] = mapped_column(
JSON, nullable=True, doc="Optional list of pip packages required by this tool."
)
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="A dictionary of additional metadata for the tool.")
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")

View File

@@ -0,0 +1,25 @@
import re
from typing import Optional
from pydantic import BaseModel, Field, field_validator
class PipRequirement(BaseModel):
name: str = Field(..., min_length=1, description="Name of the pip package.")
version: Optional[str] = Field(None, description="Optional version of the package, following semantic versioning.")
@field_validator("version")
@classmethod
def validate_version(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return None
semver_pattern = re.compile(r"^\d+(\.\d+){0,2}(-[a-zA-Z0-9.]+)?$")
if not semver_pattern.match(v):
raise ValueError(f"Invalid version format: {v}. Must follow semantic versioning (e.g., 1.2.3, 2.0, 1.5.0-alpha).")
return v
def __str__(self) -> str:
"""Return a pip-installable string format."""
if self.version:
return f"{self.name}=={self.version}"
return self.name

View File

@@ -1,6 +1,5 @@
import hashlib
import json
import re
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
@@ -9,6 +8,7 @@ from pydantic import BaseModel, Field, model_validator
from letta.constants import LETTA_TOOL_EXECUTION_DIR
from letta.schemas.agent import AgentState
from letta.schemas.letta_base import LettaBase, OrmMetadataBase
from letta.schemas.pip_requirement import PipRequirement
from letta.settings import tool_settings
@@ -27,24 +27,6 @@ class SandboxRunResult(BaseModel):
sandbox_config_fingerprint: str = Field(None, description="The fingerprint of the config for the sandbox")
class PipRequirement(BaseModel):
name: str = Field(..., min_length=1, description="Name of the pip package.")
version: Optional[str] = Field(None, description="Optional version of the package, following semantic versioning.")
@classmethod
def validate_version(cls, version: Optional[str]) -> Optional[str]:
if version is None:
return None
semver_pattern = re.compile(r"^\d+(\.\d+){0,2}(-[a-zA-Z0-9.]+)?$")
if not semver_pattern.match(version):
raise ValueError(f"Invalid version format: {version}. Must follow semantic versioning (e.g., 1.2.3, 2.0, 1.5.0-alpha).")
return version
def __init__(self, **data):
super().__init__(**data)
self.version = self.validate_version(self.version)
class LocalSandboxConfig(BaseModel):
sandbox_dir: Optional[str] = Field(None, description="Directory for the sandbox environment.")
use_venv: bool = Field(False, description="Whether or not to use the venv, or run directly in the same run loop.")

View File

@@ -24,6 +24,7 @@ from letta.functions.schema_generator import (
from letta.log import get_logger
from letta.orm.enums import ToolType
from letta.schemas.letta_base import LettaBase
from letta.schemas.pip_requirement import PipRequirement
logger = get_logger(__name__)
@@ -60,6 +61,7 @@ class Tool(BaseTool):
# tool configuration
return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.")
pip_requirements: Optional[List[PipRequirement]] = Field(None, description="Optional list of pip packages required by this tool.")
# metadata fields
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
@@ -145,6 +147,7 @@ class ToolCreate(LettaBase):
)
args_json_schema: Optional[Dict] = Field(None, description="The args JSON schema of the function.")
return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.")
pip_requirements: Optional[List[PipRequirement]] = Field(None, description="Optional list of pip packages required by this tool.")
# TODO should we put the HTTP / API fetch inside from_mcp?
# async def from_mcp(cls, mcp_server: str, mcp_tool_name: str) -> "ToolCreate":
@@ -253,6 +256,7 @@ class ToolUpdate(LettaBase):
)
args_json_schema: Optional[Dict] = Field(None, description="The args JSON schema of the function.")
return_char_limit: Optional[int] = Field(None, description="The maximum number of characters in the response.")
pip_requirements: Optional[List[PipRequirement]] = Field(None, description="Optional list of pip packages required by this tool.")
class Config:
extra = "ignore" # Allows extra fields without validation errors
@@ -269,3 +273,4 @@ class ToolRunFromSource(LettaBase):
json_schema: Optional[Dict] = Field(
None, description="The JSON schema of the function (auto-generated from source_code if not provided)"
)
pip_requirements: Optional[List[PipRequirement]] = Field(None, description="Optional list of pip packages required by this tool.")

View File

@@ -2,7 +2,7 @@ import os
import platform
import subprocess
import venv
from typing import Dict, Optional
from typing import TYPE_CHECKING, Dict, Optional
from datamodel_code_generator import DataModelType, PythonVersion
from datamodel_code_generator.model import get_data_model_types
@@ -11,6 +11,9 @@ from datamodel_code_generator.parser.jsonschema import JsonSchemaParser
from letta.log import get_logger
from letta.schemas.sandbox_config import LocalSandboxConfig
if TYPE_CHECKING:
from letta.schemas.tool import Tool
logger = get_logger(__name__)
@@ -85,14 +88,12 @@ def install_pip_requirements_for_sandbox(
upgrade: bool = True,
user_install_if_no_venv: bool = False,
env: Optional[Dict[str, str]] = None,
tool: Optional["Tool"] = None,
):
"""
Installs the specified pip requirements inside the correct environment (venv or system).
Installs both sandbox-level and tool-specific pip requirements.
"""
if not local_configs.pip_requirements:
logger.debug("No pip requirements specified; skipping installation.")
return
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir) # Expand tilde
local_configs.sandbox_dir = sandbox_dir # Update the object to store the absolute path
@@ -102,19 +103,48 @@ def install_pip_requirements_for_sandbox(
if local_configs.use_venv:
ensure_pip_is_up_to_date(python_exec, env=env)
# Construct package list
packages = [f"{req.name}=={req.version}" if req.version else req.name for req in local_configs.pip_requirements]
# Collect all pip requirements
all_packages = []
# Add sandbox-level pip requirements
if local_configs.pip_requirements:
packages = [f"{req.name}=={req.version}" if req.version else req.name for req in local_configs.pip_requirements]
all_packages.extend(packages)
logger.debug(f"Added sandbox pip requirements: {packages}")
# Add tool-specific pip requirements
if tool and tool.pip_requirements:
tool_packages = [str(req) for req in tool.pip_requirements]
all_packages.extend(tool_packages)
logger.debug(f"Added tool pip requirements for {tool.name}: {tool_packages}")
if not all_packages:
logger.debug("No pip requirements specified; skipping installation.")
return
# Construct pip install command
pip_cmd = [python_exec, "-m", "pip", "install"]
if upgrade:
pip_cmd.append("--upgrade")
pip_cmd += packages
pip_cmd += all_packages
if user_install_if_no_venv and not local_configs.use_venv:
pip_cmd.append("--user")
run_subprocess(pip_cmd, env=env, fail_msg=f"Failed to install packages: {', '.join(packages)}")
# Enhanced error message for better debugging
sandbox_packages = [f"{req.name}=={req.version}" if req.version else req.name for req in (local_configs.pip_requirements or [])]
tool_packages = [str(req) for req in (tool.pip_requirements if tool and tool.pip_requirements else [])]
error_details = []
if sandbox_packages:
error_details.append(f"sandbox requirements: {', '.join(sandbox_packages)}")
if tool_packages:
error_details.append(f"tool requirements: {', '.join(tool_packages)}")
context = f" ({'; '.join(error_details)})" if error_details else ""
fail_msg = f"Failed to install pip packages{context}. This may be due to package version incompatibility. Consider updating package versions or removing version constraints."
run_subprocess(pip_cmd, env=env, fail_msg=fail_msg)
def create_venv_for_local_sandbox(sandbox_dir_path: str, venv_path: str, env: Dict[str, str], force_recreate: bool):

View File

@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
from e2b.sandbox.commands.command_handle import CommandExitException
from e2b_code_interpreter import AsyncSandbox
from letta.log import get_logger
@@ -189,14 +190,63 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
"package": package,
},
)
await sbx.commands.run(f"pip install {package}")
try:
await sbx.commands.run(f"pip install {package}")
log_event(
"e2b_pip_install_finished",
{
"sandbox_id": sbx.sandbox_id,
"package": package,
},
)
except CommandExitException as e:
error_msg = f"Failed to install sandbox pip requirement '{package}' in E2B sandbox. This may be due to package version incompatibility with the E2B environment. Error: {e}"
logger.error(error_msg)
log_event(
"e2b_pip_install_failed",
{
"sandbox_id": sbx.sandbox_id,
"package": package,
"error": str(e),
},
)
raise RuntimeError(error_msg) from e
# Install tool-specific pip requirements
if self.tool and self.tool.pip_requirements:
for pip_requirement in self.tool.pip_requirements:
package_str = str(pip_requirement)
log_event(
"e2b_pip_install_finished",
"tool_pip_install_started",
{
"sandbox_id": sbx.sandbox_id,
"package": package,
"package": package_str,
"tool_name": self.tool.name,
},
)
try:
await sbx.commands.run(f"pip install {package_str}")
log_event(
"tool_pip_install_finished",
{
"sandbox_id": sbx.sandbox_id,
"package": package_str,
"tool_name": self.tool.name,
},
)
except CommandExitException as e:
error_msg = f"Failed to install tool pip requirement '{package_str}' for tool '{self.tool.name}' in E2B sandbox. This may be due to package version incompatibility with the E2B environment. Consider updating the package version or removing the version constraint. Error: {e}"
logger.error(error_msg)
log_event(
"tool_pip_install_failed",
{
"sandbox_id": sbx.sandbox_id,
"package": package_str,
"tool_name": self.tool.name,
"error": str(e),
},
)
raise RuntimeError(error_msg) from e
return sbx

View File

@@ -175,7 +175,9 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
log_event(name="finish create_venv_for_local_sandbox")
log_event(name="start install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()})
await asyncio.to_thread(install_pip_requirements_for_sandbox, local_configs, upgrade=True, user_install_if_no_venv=False, env=env)
await asyncio.to_thread(
install_pip_requirements_for_sandbox, local_configs, upgrade=True, user_install_if_no_venv=False, env=env, tool=self.tool
)
log_event(name="finish install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()})
@trace_method

View File

@@ -15,7 +15,8 @@ from letta.schemas.agent import AgentState, CreateAgent
from letta.schemas.block import CreateBlock
from letta.schemas.environment_variables import AgentEnvironmentVariable, SandboxEnvironmentVariableCreate
from letta.schemas.organization import Organization
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, PipRequirement, SandboxConfigCreate
from letta.schemas.pip_requirement import PipRequirement
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate
from letta.schemas.user import User
from letta.server.server import SyncServer
from letta.services.organization_manager import OrganizationManager
@@ -263,6 +264,64 @@ def custom_test_sandbox_config(test_user):
# Tool-specific fixtures
@pytest.fixture
def tool_with_pip_requirements(test_user):
def use_requests_and_numpy() -> str:
"""
Function that uses requests and numpy packages to test tool-specific pip requirements.
Returns:
str: Success message if packages are available.
"""
try:
import numpy as np
import requests
# Simple usage to verify packages work
response = requests.get("https://httpbin.org/json", timeout=5)
arr = np.array([1, 2, 3])
return f"Success! Status: {response.status_code}, Array sum: {np.sum(arr)}"
except ImportError as e:
return f"Import error: {e}"
except Exception as e:
return f"Other error: {e}"
tool = create_tool_from_func(use_requests_and_numpy)
# Add pip requirements to the tool - using more recent versions for E2B compatibility
tool.pip_requirements = [
PipRequirement(name="requests", version="2.31.0"),
PipRequirement(name="numpy", version="1.26.0"),
]
tool = ToolManager().create_or_update_tool(tool, test_user)
yield tool
@pytest.fixture
def tool_with_broken_pip_requirements(test_user):
def use_broken_package() -> str:
"""
Function that requires a package with known compatibility issues.
Returns:
str: Should not reach here due to pip install failure.
"""
try:
import some_nonexistent_package # This will fail during pip install
return "This should not execute"
except ImportError as e:
return f"Import error: {e}"
tool = create_tool_from_func(use_broken_package)
# Add pip requirements that will fail in E2B environment
tool.pip_requirements = [
PipRequirement(name="numpy", version="1.24.0"), # Known to have compatibility issues
PipRequirement(name="nonexistent-package-12345"), # This package doesn't exist
]
tool = ToolManager().create_or_update_tool(tool, test_user)
yield tool
@pytest.fixture
def core_memory_tools(test_user):
"""Create all base tools for testing."""
@@ -418,6 +477,50 @@ async def test_local_sandbox_with_venv_pip_installs_basic(disable_e2b_api_key, c
assert long_random_string in result.stdout[0]
@pytest.mark.asyncio
@pytest.mark.local_sandbox
async def test_local_sandbox_with_tool_pip_requirements(disable_e2b_api_key, tool_with_pip_requirements, test_user, event_loop):
"""Test that local sandbox installs tool-specific pip requirements."""
manager = SandboxConfigManager()
sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox")
config_create = SandboxConfigCreate(config=LocalSandboxConfig(sandbox_dir=sandbox_dir, use_venv=True).model_dump())
manager.create_or_update_sandbox_config(config_create, test_user)
sandbox = AsyncToolSandboxLocal(
tool_with_pip_requirements.name, {}, user=test_user, tool_object=tool_with_pip_requirements, force_recreate_venv=True
)
result = await sandbox.run()
# Should succeed since tool pip requirements were installed
assert "Success!" in result.func_return
assert "Status: 200" in result.func_return
assert "Array sum: 6" in result.func_return
@pytest.mark.asyncio
@pytest.mark.local_sandbox
async def test_local_sandbox_with_mixed_pip_requirements(disable_e2b_api_key, tool_with_pip_requirements, test_user, event_loop):
"""Test that local sandbox installs both sandbox and tool pip requirements."""
manager = SandboxConfigManager()
sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox")
# Add sandbox-level pip requirement
config_create = SandboxConfigCreate(
config=LocalSandboxConfig(sandbox_dir=sandbox_dir, use_venv=True, pip_requirements=[PipRequirement(name="cowsay")]).model_dump()
)
manager.create_or_update_sandbox_config(config_create, test_user)
sandbox = AsyncToolSandboxLocal(
tool_with_pip_requirements.name, {}, user=test_user, tool_object=tool_with_pip_requirements, force_recreate_venv=True
)
result = await sandbox.run()
# Should succeed since both sandbox and tool pip requirements were installed
assert "Success!" in result.func_return
assert "Status: 200" in result.func_return
assert "Array sum: 6" in result.func_return
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_local_sandbox_with_venv_pip_installs_with_update(disable_e2b_api_key, cowsay_tool, test_user, event_loop):
@@ -550,3 +653,69 @@ async def test_e2b_sandbox_with_list_rv(check_e2b_key_is_set, list_tool, test_us
sandbox = AsyncToolSandboxE2B(list_tool.name, {}, user=test_user)
result = await sandbox.run()
assert len(result.func_return) == 5
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_e2b_sandbox_with_tool_pip_requirements(check_e2b_key_is_set, tool_with_pip_requirements, test_user, event_loop):
"""Test that E2B sandbox installs tool-specific pip requirements."""
manager = SandboxConfigManager()
config_create = SandboxConfigCreate(config=E2BSandboxConfig().model_dump())
manager.create_or_update_sandbox_config(config_create, test_user)
sandbox = AsyncToolSandboxE2B(tool_with_pip_requirements.name, {}, user=test_user, tool_object=tool_with_pip_requirements)
result = await sandbox.run()
# Should succeed since tool pip requirements were installed
assert "Success!" in result.func_return
assert "Status: 200" in result.func_return
assert "Array sum: 6" in result.func_return
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_e2b_sandbox_with_mixed_pip_requirements(check_e2b_key_is_set, tool_with_pip_requirements, test_user, event_loop):
"""Test that E2B sandbox installs both sandbox and tool pip requirements."""
manager = SandboxConfigManager()
# Add sandbox-level pip requirement
config_create = SandboxConfigCreate(config=E2BSandboxConfig(pip_requirements=["cowsay"]).model_dump())
manager.create_or_update_sandbox_config(config_create, test_user)
sandbox = AsyncToolSandboxE2B(tool_with_pip_requirements.name, {}, user=test_user, tool_object=tool_with_pip_requirements)
result = await sandbox.run()
# Should succeed since both sandbox and tool pip requirements were installed
assert "Success!" in result.func_return
assert "Status: 200" in result.func_return
assert "Array sum: 6" in result.func_return
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_e2b_sandbox_with_broken_tool_pip_requirements_error_handling(
check_e2b_key_is_set, tool_with_broken_pip_requirements, test_user, event_loop
):
"""Test that E2B sandbox provides informative error messages for broken tool pip requirements."""
manager = SandboxConfigManager()
config_create = SandboxConfigCreate(config=E2BSandboxConfig().model_dump())
manager.create_or_update_sandbox_config(config_create, test_user)
sandbox = AsyncToolSandboxE2B(tool_with_broken_pip_requirements.name, {}, user=test_user, tool_object=tool_with_broken_pip_requirements)
# Should raise a RuntimeError with informative message
with pytest.raises(RuntimeError) as exc_info:
await sandbox.run()
error_message = str(exc_info.value)
print(error_message)
# Verify the error message contains helpful information
assert "Failed to install tool pip requirement" in error_message
assert "use_broken_package" in error_message # Tool name
assert "E2B sandbox" in error_message
assert "package version incompatibility" in error_message
assert "Consider updating the package version or removing the version constraint" in error_message
# Should mention one of the problematic packages
assert "numpy==1.24.0" in error_message or "nonexistent-package-12345" in error_message

View File

@@ -14,7 +14,8 @@ from letta.schemas.agent import AgentState, CreateAgent
from letta.schemas.block import CreateBlock
from letta.schemas.environment_variables import AgentEnvironmentVariable, SandboxEnvironmentVariableCreate
from letta.schemas.organization import Organization
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, PipRequirement, SandboxConfigCreate, SandboxConfigUpdate
from letta.schemas.pip_requirement import PipRequirement
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate
from letta.schemas.user import User
from letta.server.server import SyncServer
from letta.services.organization_manager import OrganizationManager

View File

@@ -73,6 +73,7 @@ from letta.schemas.organization import Organization
from letta.schemas.organization import Organization as PydanticOrganization
from letta.schemas.organization import OrganizationUpdate
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.pip_requirement import PipRequirement
from letta.schemas.run import Run as PydanticRun
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate, SandboxType
from letta.schemas.source import Source as PydanticSource
@@ -3048,6 +3049,175 @@ async def test_upsert_base_tools_with_empty_type_filter(server: SyncServer, defa
assert tools == []
@pytest.mark.asyncio
async def test_create_tool_with_pip_requirements(server: SyncServer, default_user, default_organization):
def test_tool_with_deps():
"""
A test tool with pip dependencies.
Returns:
str: Hello message.
"""
return "hello"
# Create pip requirements
pip_reqs = [
PipRequirement(name="requests", version="2.28.0"),
PipRequirement(name="numpy"), # No version specified
]
# Set up tool details
source_code = parse_source_code(test_tool_with_deps)
source_type = "python"
description = "A test tool with pip dependencies"
tags = ["test"]
metadata = {"test": "pip_requirements"}
tool = PydanticTool(
description=description, tags=tags, source_code=source_code, source_type=source_type, metadata_=metadata, pip_requirements=pip_reqs
)
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
derived_name = derived_json_schema["name"]
tool.json_schema = derived_json_schema
tool.name = derived_name
created_tool = await server.tool_manager.create_or_update_tool_async(tool, actor=default_user)
# Assertions
assert created_tool.pip_requirements is not None
assert len(created_tool.pip_requirements) == 2
assert created_tool.pip_requirements[0].name == "requests"
assert created_tool.pip_requirements[0].version == "2.28.0"
assert created_tool.pip_requirements[1].name == "numpy"
assert created_tool.pip_requirements[1].version is None
@pytest.mark.asyncio
async def test_create_tool_without_pip_requirements(server: SyncServer, print_tool):
# Verify that tools without pip_requirements have the field as None
assert print_tool.pip_requirements is None
@pytest.mark.asyncio
async def test_update_tool_pip_requirements(server: SyncServer, print_tool, default_user):
# Add pip requirements to existing tool
pip_reqs = [
PipRequirement(name="pandas", version="1.5.0"),
PipRequirement(name="matplotlib"),
]
tool_update = ToolUpdate(pip_requirements=pip_reqs)
await server.tool_manager.update_tool_by_id_async(print_tool.id, tool_update, actor=default_user)
# Fetch the updated tool
updated_tool = await server.tool_manager.get_tool_by_id_async(print_tool.id, actor=default_user)
# Assertions
assert updated_tool.pip_requirements is not None
assert len(updated_tool.pip_requirements) == 2
assert updated_tool.pip_requirements[0].name == "pandas"
assert updated_tool.pip_requirements[0].version == "1.5.0"
assert updated_tool.pip_requirements[1].name == "matplotlib"
assert updated_tool.pip_requirements[1].version is None
@pytest.mark.asyncio
async def test_update_tool_clear_pip_requirements(server: SyncServer, default_user, default_organization):
def test_tool_clear_deps():
"""
A test tool to clear dependencies.
Returns:
str: Hello message.
"""
return "hello"
# Create a tool with pip requirements
pip_reqs = [PipRequirement(name="requests")]
# Set up tool details
source_code = parse_source_code(test_tool_clear_deps)
source_type = "python"
description = "A test tool to clear dependencies"
tags = ["test"]
metadata = {"test": "clear_deps"}
tool = PydanticTool(
description=description, tags=tags, source_code=source_code, source_type=source_type, metadata_=metadata, pip_requirements=pip_reqs
)
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
derived_name = derived_json_schema["name"]
tool.json_schema = derived_json_schema
tool.name = derived_name
created_tool = await server.tool_manager.create_or_update_tool_async(tool, actor=default_user)
# Verify it has requirements
assert created_tool.pip_requirements is not None
assert len(created_tool.pip_requirements) == 1
# Clear the requirements
tool_update = ToolUpdate(pip_requirements=[])
await server.tool_manager.update_tool_by_id_async(created_tool.id, tool_update, actor=default_user)
# Fetch the updated tool
updated_tool = await server.tool_manager.get_tool_by_id_async(created_tool.id, actor=default_user)
# Assertions
assert updated_tool.pip_requirements == []
@pytest.mark.asyncio
async def test_pip_requirements_roundtrip(server: SyncServer, default_user, default_organization):
def roundtrip_test_tool():
"""
Test pip requirements roundtrip.
Returns:
str: Test message.
"""
return "test"
# Create pip requirements with various version formats
pip_reqs = [
PipRequirement(name="requests", version="2.28.0"),
PipRequirement(name="flask", version="2.0"),
PipRequirement(name="django", version="4.1.0-beta"),
PipRequirement(name="numpy"), # No version
]
# Set up tool details
source_code = parse_source_code(roundtrip_test_tool)
source_type = "python"
description = "Test pip requirements roundtrip"
tags = ["test"]
metadata = {"test": "roundtrip"}
tool = PydanticTool(
description=description, tags=tags, source_code=source_code, source_type=source_type, metadata_=metadata, pip_requirements=pip_reqs
)
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
derived_name = derived_json_schema["name"]
tool.json_schema = derived_json_schema
tool.name = derived_name
created_tool = await server.tool_manager.create_or_update_tool_async(tool, actor=default_user)
# Fetch by ID
fetched_tool = await server.tool_manager.get_tool_by_id_async(created_tool.id, actor=default_user)
# Verify all requirements match exactly
assert fetched_tool.pip_requirements is not None
assert len(fetched_tool.pip_requirements) == 4
# Check each requirement
reqs_dict = {req.name: req.version for req in fetched_tool.pip_requirements}
assert reqs_dict["requests"] == "2.28.0"
assert reqs_dict["flask"] == "2.0"
assert reqs_dict["django"] == "4.1.0-beta"
assert reqs_dict["numpy"] is None
# ======================================================================================================================
# Message Manager Tests
# ======================================================================================================================