feat: Add endpoints to list Composio apps and actions (#2140)

This commit is contained in:
Matthew Zhou
2024-12-02 15:36:10 -08:00
committed by GitHub
parent bf3b42a2b6
commit 1e5d74b4a7
10 changed files with 434 additions and 90 deletions

View File

@@ -30,6 +30,7 @@ jobs:
- "test_agent_tool_graph.py"
- "test_utils.py"
- "test_tool_schema_parsing.py"
- "test_v1_routes.py"
services:
qdrant:
image: qdrant/qdrant
@@ -132,4 +133,4 @@ jobs:
LETTA_SERVER_PASS: test_server_token
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
run: |
poetry run pytest -s -vv -k "not test_model_letta_perfomance.py and not test_utils.py and not test_client.py and not integration_test_tool_execution_sandbox.py and not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_performance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client_legacy.py" tests
poetry run pytest -s -vv -k "not test_v1_routes.py and not test_model_letta_perfomance.py and not test_utils.py and not test_client.py and not integration_test_tool_execution_sandbox.py and not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_performance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client_legacy.py" tests

View File

@@ -1,11 +1,8 @@
import importlib
import inspect
import os
from textwrap import dedent # remove indentation
from types import ModuleType
from typing import Dict, List, Optional
from letta.constants import CLI_WARNING_PREFIX
from letta.errors import LettaToolCreateError
from letta.functions.schema_generator import generate_schema
@@ -90,46 +87,3 @@ def load_function_set(module: ModuleType) -> dict:
if len(function_dict) == 0:
raise ValueError(f"No functions found in module {module}")
return function_dict
def validate_function(module_name, module_full_path):
try:
file = os.path.basename(module_full_path)
spec = importlib.util.spec_from_file_location(module_name, module_full_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
except ModuleNotFoundError as e:
# Handle missing module imports
missing_package = str(e).split("'")[1] # Extract the name of the missing package
print(f"{CLI_WARNING_PREFIX}skipped loading python file '{module_full_path}'!")
return (
False,
f"'{file}' imports '{missing_package}', but '{missing_package}' is not installed locally - install python package '{missing_package}' to link functions from '{file}' to Letta.",
)
except SyntaxError as e:
# Handle syntax errors in the module
return False, f"{CLI_WARNING_PREFIX}skipped loading python file '{file}' due to a syntax error: {e}"
except Exception as e:
# Handle other general exceptions
return False, f"{CLI_WARNING_PREFIX}skipped loading python file '{file}': {e}"
return True, None
def load_function_file(filepath: str) -> dict:
file = os.path.basename(filepath)
module_name = file[:-3] # Remove '.py' from filename
try:
spec = importlib.util.spec_from_file_location(module_name, filepath)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
except ModuleNotFoundError as e:
# Handle missing module imports
missing_package = str(e).split("'")[1] # Extract the name of the missing package
print(f"{CLI_WARNING_PREFIX}skipped loading python file '{filepath}'!")
print(
f"'{file}' imports '{missing_package}', but '{missing_package}' is not installed locally - install python package '{missing_package}' to link functions from '{file}' to Letta."
)
# load all functions in the module
function_dict = load_function_set(module)
return function_dict

View File

@@ -1,5 +1,6 @@
from typing import List, Optional
from composio.client.collections import ActionModel, AppModel
from fastapi import APIRouter, Body, Depends, Header, HTTPException
from letta.errors import LettaToolCreateError
@@ -156,3 +157,39 @@ def add_base_tools(
"""
actor = server.get_user_or_default(user_id=user_id)
return server.tool_manager.add_base_tools(actor=actor)
# Specific routes for Composio
@router.get("/composio/apps", response_model=List[AppModel], operation_id="list_composio_apps")
def list_composio_apps(server: SyncServer = Depends(get_letta_server)):
"""
Get a list of all Composio apps
"""
return server.get_composio_apps()
@router.get("/composio/apps/{composio_app_name}/actions", response_model=List[ActionModel], operation_id="list_composio_actions_by_app")
def list_composio_actions_by_app(
composio_app_name: str,
server: SyncServer = Depends(get_letta_server),
):
"""
Get a list of all Composio actions for a specific app
"""
return server.get_composio_actions_from_app_name(composio_app_name=composio_app_name)
@router.post("/composio/{composio_action_name}", response_model=Tool, operation_id="add_composio_tool")
def add_composio_tool(
composio_action_name: str,
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Add a new Composio tool by action name (Composio refers to each tool as an `Action`)
"""
actor = server.get_user_or_default(user_id=user_id)
tool_create = ToolCreate.from_composio(action=composio_action_name)
return server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=actor)

View File

@@ -7,6 +7,8 @@ from asyncio import Lock
from datetime import datetime
from typing import Callable, Dict, List, Optional, Tuple, Union
from composio.client import Composio
from composio.client.collections import ActionModel, AppModel
from fastapi import HTTPException
import letta.constants as constants
@@ -227,6 +229,11 @@ class SyncServer(Server):
# Locks
self.send_message_lock = Lock()
# Composio
self.composio_client = None
if tool_settings.composio_api_key:
self.composio_client = Composio(api_key=tool_settings.composio_api_key)
# Initialize the metadata store
config = LettaConfig.load()
if settings.letta_pg_uri_no_default:
@@ -1750,3 +1757,18 @@ class SyncServer(Server):
if block.label == label:
return block
return None
# Composio wrappers
def get_composio_apps(self) -> List["AppModel"]:
"""Get a list of all Composio apps with actions"""
apps = self.composio_client.apps.get()
apps_with_actions = []
for app in apps:
if app.meta["actionsCount"] > 0:
apps_with_actions.append(app)
return apps_with_actions
def get_composio_actions_from_app_name(self, composio_app_name: str) -> List["ActionModel"]:
actions = self.composio_client.actions.get(apps=[composio_app_name])
return actions

52
poetry.lock generated
View File

@@ -889,7 +889,7 @@ test = ["pytest"]
name = "composio-core"
version = "0.5.44"
description = "Core package to act as a bridge between composio platform and other services."
optional = true
optional = false
python-versions = "<4,>=3.9"
files = [
{file = "composio_core-0.5.44-py3-none-any.whl", hash = "sha256:bb125794035a3c3c98dab1e72b45024068019c5eb3f29b9cc4eafc845320774b"},
@@ -925,7 +925,7 @@ tools = ["diskcache", "flake8", "networkx", "pathspec", "pygments", "ruff", "tra
name = "composio-langchain"
version = "0.5.44"
description = "Use Composio to get an array of tools with your LangChain agent."
optional = true
optional = false
python-versions = "<4,>=3.9"
files = [
{file = "composio_langchain-0.5.44-py3-none-any.whl", hash = "sha256:4cb05d5b92faea32bc02c04e49b5dfae5858abe2f6469a81c673d7f754402375"},
@@ -958,7 +958,7 @@ yaml = ["PyYAML"]
name = "cryptography"
version = "43.0.3"
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
optional = true
optional = false
python-versions = ">=3.7"
files = [
{file = "cryptography-43.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:bf7a1932ac4176486eab36a19ed4c0492da5d97123f1406cf15e41b05e787d2e"},
@@ -2304,7 +2304,7 @@ type = ["pytest-mypy"]
name = "inflection"
version = "0.5.1"
description = "A port of Ruby on Rails inflector to Python"
optional = true
optional = false
python-versions = ">=3.5"
files = [
{file = "inflection-0.5.1-py2.py3-none-any.whl", hash = "sha256:f38b2b640938a4f35ade69ac3d053042959b62a0f1076a5bbaa1b9526605a8a2"},
@@ -2565,7 +2565,7 @@ files = [
name = "jsonpatch"
version = "1.33"
description = "Apply JSON-Patches (RFC 6902)"
optional = true
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
files = [
{file = "jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade"},
@@ -2579,7 +2579,7 @@ jsonpointer = ">=1.9"
name = "jsonpointer"
version = "3.0.0"
description = "Identify specific nodes in a JSON document (RFC 6901)"
optional = true
optional = false
python-versions = ">=3.7"
files = [
{file = "jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942"},
@@ -2590,7 +2590,7 @@ files = [
name = "jsonref"
version = "1.1.0"
description = "jsonref is a library for automatic dereferencing of JSON Reference objects for Python."
optional = true
optional = false
python-versions = ">=3.7"
files = [
{file = "jsonref-1.1.0-py3-none-any.whl", hash = "sha256:590dc7773df6c21cbf948b5dac07a72a251db28b0238ceecce0a2abfa8ec30a9"},
@@ -2601,7 +2601,7 @@ files = [
name = "jsonschema"
version = "4.23.0"
description = "An implementation of JSON Schema validation for Python"
optional = true
optional = false
python-versions = ">=3.8"
files = [
{file = "jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566"},
@@ -2622,7 +2622,7 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-
name = "jsonschema-specifications"
version = "2024.10.1"
description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry"
optional = true
optional = false
python-versions = ">=3.9"
files = [
{file = "jsonschema_specifications-2024.10.1-py3-none-any.whl", hash = "sha256:a09a0680616357d9a0ecf05c12ad234479f549239d0f5b55f3deea67475da9bf"},
@@ -2705,7 +2705,7 @@ adal = ["adal (>=1.0.2)"]
name = "langchain"
version = "0.3.7"
description = "Building applications with LLMs through composability"
optional = true
optional = false
python-versions = "<4.0,>=3.9"
files = [
{file = "langchain-0.3.7-py3-none-any.whl", hash = "sha256:cf4af1d5751dacdc278df3de1ff3cbbd8ca7eb55d39deadccdd7fb3d3ee02ac0"},
@@ -2760,7 +2760,7 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10"
name = "langchain-core"
version = "0.3.19"
description = "Building applications with LLMs through composability"
optional = true
optional = false
python-versions = "<4.0,>=3.9"
files = [
{file = "langchain_core-0.3.19-py3-none-any.whl", hash = "sha256:562b7cc3c15dfaa9270cb1496990c1f3b3e0b660c4d6a3236d7f693346f2a96c"},
@@ -2783,7 +2783,7 @@ typing-extensions = ">=4.7"
name = "langchain-openai"
version = "0.2.9"
description = "An integration package connecting OpenAI and LangChain"
optional = true
optional = false
python-versions = "<4.0,>=3.9"
files = [
{file = "langchain_openai-0.2.9-py3-none-any.whl", hash = "sha256:2723015e56879f9e5edfcb175fdbec6c296c1b3bf65caad28579ce9c4d1bd652"},
@@ -2799,7 +2799,7 @@ tiktoken = ">=0.7,<1"
name = "langchain-text-splitters"
version = "0.3.2"
description = "LangChain text splitting utilities"
optional = true
optional = false
python-versions = "<4.0,>=3.9"
files = [
{file = "langchain_text_splitters-0.3.2-py3-none-any.whl", hash = "sha256:0db28c53f41d1bc024cdb3b1646741f6d46d5371e90f31e7e7c9fbe75d01c726"},
@@ -2813,7 +2813,7 @@ langchain-core = ">=0.3.15,<0.4.0"
name = "langchainhub"
version = "0.1.21"
description = "The LangChain Hub API client"
optional = true
optional = false
python-versions = "<4.0,>=3.8.1"
files = [
{file = "langchainhub-0.1.21-py3-none-any.whl", hash = "sha256:1cc002dc31e0d132a776afd044361e2b698743df5202618cf2bad399246b895f"},
@@ -2829,7 +2829,7 @@ types-requests = ">=2.31.0.2,<3.0.0.0"
name = "langsmith"
version = "0.1.144"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = true
optional = false
python-versions = "<4.0,>=3.8.1"
files = [
{file = "langsmith-0.1.144-py3-none-any.whl", hash = "sha256:08ffb975bff2e82fc6f5428837c64c074ea25102d08a25e256361a80812c6100"},
@@ -4243,7 +4243,7 @@ xml = ["lxml (>=4.9.2)"]
name = "paramiko"
version = "3.5.0"
description = "SSH2 protocol library"
optional = true
optional = false
python-versions = ">=3.6"
files = [
{file = "paramiko-3.5.0-py3-none-any.whl", hash = "sha256:1fedf06b085359051cd7d0d270cebe19e755a8a921cc2ddbfa647fb0cd7d68f9"},
@@ -5214,7 +5214,7 @@ model = ["milvus-model (>=0.1.0)"]
name = "pynacl"
version = "1.5.0"
description = "Python binding to the Networking and Cryptography (NaCl) library"
optional = true
optional = false
python-versions = ">=3.6"
files = [
{file = "PyNaCl-1.5.0-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:401002a4aaa07c9414132aaed7f6836ff98f59277a234704ff66878c2ee4a0d1"},
@@ -5262,7 +5262,7 @@ image = ["Pillow (>=8.0.0)"]
name = "pyperclip"
version = "1.9.0"
description = "A cross-platform clipboard module for Python. (Only handles plain text for now.)"
optional = true
optional = false
python-versions = "*"
files = [
{file = "pyperclip-1.9.0.tar.gz", hash = "sha256:b7de0142ddc81bfc5c7507eea19da920b92252b548b96186caf94a5e2527d310"},
@@ -5327,7 +5327,7 @@ nodejs = ["nodejs-wheel-binaries"]
name = "pysher"
version = "1.0.8"
description = "Pusher websocket client for python, based on Erik Kulyk's PythonPusherClient"
optional = true
optional = false
python-versions = "*"
files = [
{file = "Pysher-1.0.8.tar.gz", hash = "sha256:7849c56032b208e49df67d7bd8d49029a69042ab0bb45b2ed59fa08f11ac5988"},
@@ -5734,7 +5734,7 @@ prompt_toolkit = ">=2.0,<=3.0.36"
name = "referencing"
version = "0.35.1"
description = "JSON Referencing + Python"
optional = true
optional = false
python-versions = ">=3.8"
files = [
{file = "referencing-0.35.1-py3-none-any.whl", hash = "sha256:eda6d3234d62814d1c64e305c1331c9a3a6132da475ab6382eaa997b21ee75de"},
@@ -5891,7 +5891,7 @@ rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
name = "requests-toolbelt"
version = "1.0.0"
description = "A utility belt for advanced users of python-requests"
optional = true
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
files = [
{file = "requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6"},
@@ -5924,7 +5924,7 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"]
name = "rpds-py"
version = "0.21.0"
description = "Python bindings to Rust's persistent data structures (rpds)"
optional = true
optional = false
python-versions = ">=3.9"
files = [
{file = "rpds_py-0.21.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:a017f813f24b9df929674d0332a374d40d7f0162b326562daae8066b502d0590"},
@@ -6051,7 +6051,7 @@ asn1crypto = ">=1.5.1"
name = "semver"
version = "3.0.2"
description = "Python helper for Semantic Versioning (https://semver.org)"
optional = true
optional = false
python-versions = ">=3.7"
files = [
{file = "semver-3.0.2-py3-none-any.whl", hash = "sha256:b1ea4686fe70b981f85359eda33199d60c53964284e0cfb4977d243e37cf4bf4"},
@@ -6062,7 +6062,7 @@ files = [
name = "sentry-sdk"
version = "2.19.0"
description = "Python client for Sentry (https://sentry.io)"
optional = true
optional = false
python-versions = ">=3.6"
files = [
{file = "sentry_sdk-2.19.0-py2.py3-none-any.whl", hash = "sha256:7b0b3b709dee051337244a09a30dbf6e95afe0d34a1f8b430d45e0982a7c125b"},
@@ -6699,7 +6699,7 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.
name = "types-requests"
version = "2.32.0.20241016"
description = "Typing stubs for requests"
optional = true
optional = false
python-versions = ">=3.8"
files = [
{file = "types-requests-2.32.0.20241016.tar.gz", hash = "sha256:0d9cad2f27515d0e3e3da7134a1b6f28fb97129d86b867f24d9c726452634d95"},
@@ -7595,4 +7595,4 @@ tests = ["wikipedia"]
[metadata]
lock-version = "2.0"
python-versions = "<3.13,>=3.10"
content-hash = "28cd26c6573ca0a07173262bc0e819e19b661157fa757efca0590262f9b9f35c"
content-hash = "9e0f7eb7ed1007cfeb0227d0f1bf20c5601c23e1d363ee23195ce6f8e134f14e"

View File

@@ -66,8 +66,8 @@ llama-index = "^0.11.9"
llama-index-embeddings-openai = "^0.2.5"
llama-index-embeddings-ollama = "^0.3.1"
wikipedia = {version = "^1.4.0", optional = true}
composio-langchain = {version = "^0.5.28", optional = true}
composio-core = {version = "^0.5.34", optional = true}
composio-langchain = "^0.5.28"
composio-core = "^0.5.34"
alembic = "^1.13.3"
pyhumps = "^3.8.0"
psycopg2 = "^2.9.10"

View File

@@ -1,6 +1,9 @@
from typing import Union
from letta import LocalClient, RESTClient
from letta.functions.functions import parse_source_code
from letta.functions.schema_generator import generate_schema
from letta.schemas.tool import Tool
def cleanup(client: Union[LocalClient, RESTClient], agent_uuid: str):
@@ -9,3 +12,15 @@ def cleanup(client: Union[LocalClient, RESTClient], agent_uuid: str):
if agent_state.name == agent_uuid:
client.delete_agent(agent_id=agent_state.id)
print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}")
# Utility functions
def create_tool_from_func(func: callable):
return Tool(
name=func.__name__,
description="",
source_type="python",
tags=[],
source_code=parse_source_code(func),
json_schema=generate_schema(func, None),
)

View File

@@ -10,8 +10,6 @@ from sqlalchemy import delete
from letta import create_client
from letta.functions.function_sets.base import core_memory_replace
from letta.functions.functions import parse_source_code
from letta.functions.schema_generator import generate_schema
from letta.orm import SandboxConfig, SandboxEnvironmentVariable
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
@@ -34,6 +32,7 @@ from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.settings import tool_settings
from tests.helpers.utils import create_tool_from_func
# Constants
namespace = uuid.NAMESPACE_DNS
@@ -214,18 +213,6 @@ def agent_state():
yield agent_state
# Utility functions
def create_tool_from_func(func: callable):
return Tool(
name=func.__name__,
description="",
source_type="python",
tags=[],
source_code=parse_source_code(func),
json_schema=generate_schema(func, None),
)
# Local sandbox tests
@pytest.mark.local_sandbox
def test_local_sandbox_default(mock_e2b_api_key_none, add_integers_tool, test_user):

View File

@@ -26,7 +26,6 @@ from letta.schemas.agent import CreateAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.memory import ChatMemory
from letta.schemas.source import Source
from letta.server.server import SyncServer
@@ -540,3 +539,15 @@ def _test_get_messages_letta_format(
def test_get_messages_letta_format(server, user_id, agent_id):
for reverse in [False, True]:
_test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse)
def test_composio_client_simple(server):
apps = server.get_composio_apps()
# Assert there's some amount of apps returned
assert len(apps) > 0
app = apps[0]
actions = server.get_composio_actions_from_app_name(composio_app_name=app.name)
# Assert there's some amount of actions
assert len(actions) > 0

317
tests/test_v1_routes.py Normal file
View File

@@ -0,0 +1,317 @@
from unittest.mock import MagicMock, Mock, patch
import pytest
from composio.client.collections import (
ActionModel,
ActionParametersModel,
ActionResponseModel,
AppModel,
)
from fastapi.testclient import TestClient
from letta.schemas.tool import ToolCreate, ToolUpdate
from letta.server.rest_api.app import app
from letta.server.rest_api.utils import get_letta_server
from tests.helpers.utils import create_tool_from_func
@pytest.fixture
def client():
return TestClient(app)
@pytest.fixture
def mock_sync_server():
mock_server = Mock()
app.dependency_overrides[get_letta_server] = lambda: mock_server
return mock_server
@pytest.fixture
def add_integers_tool():
def add(x: int, y: int) -> int:
"""
Simple function that adds two integers.
Parameters:
x (int): The first integer to add.
y (int): The second integer to add.
Returns:
int: The result of adding x and y.
"""
return x + y
tool = create_tool_from_func(add)
yield tool
@pytest.fixture
def create_integers_tool(add_integers_tool):
tool_create = ToolCreate(
name=add_integers_tool.name,
description=add_integers_tool.description,
tags=add_integers_tool.tags,
module=add_integers_tool.module,
source_code=add_integers_tool.source_code,
source_type=add_integers_tool.source_type,
json_schema=add_integers_tool.json_schema,
)
yield tool_create
@pytest.fixture
def update_integers_tool(add_integers_tool):
tool_update = ToolUpdate(
name=add_integers_tool.name,
description=add_integers_tool.description,
tags=add_integers_tool.tags,
module=add_integers_tool.module,
source_code=add_integers_tool.source_code,
source_type=add_integers_tool.source_type,
json_schema=add_integers_tool.json_schema,
)
yield tool_update
@pytest.fixture
def composio_apps():
affinity_app = AppModel(
name="affinity",
key="affinity",
appId="3a7d2dc7-c58c-4491-be84-f64b1ff498a8",
description="Affinity helps private capital investors to find, manage, and close more deals",
categories=["CRM"],
meta={
"is_custom_app": False,
"triggersCount": 0,
"actionsCount": 20,
"documentation_doc_text": None,
"configuration_docs_text": None,
},
logo="https://cdn.jsdelivr.net/gh/ComposioHQ/open-logos@master/affinity.jpeg",
docs=None,
group=None,
status=None,
enabled=False,
no_auth=False,
auth_schemes=None,
testConnectors=None,
documentation_doc_text=None,
configuration_docs_text=None,
)
yield [affinity_app]
@pytest.fixture
def composio_actions():
yield [
ActionModel(
name="AFFINITY_GET_ALL_COMPANIES",
display_name="Get all companies",
parameters=ActionParametersModel(
properties={
"cursor": {"default": None, "description": "Cursor for the next or previous page", "title": "Cursor", "type": "string"},
"limit": {"default": 100, "description": "Number of items to include in the page", "title": "Limit", "type": "integer"},
"ids": {"default": None, "description": "Company IDs", "items": {"type": "integer"}, "title": "Ids", "type": "array"},
"fieldIds": {
"default": None,
"description": "Field IDs for which to return field data",
"items": {"type": "string"},
"title": "Fieldids",
"type": "array",
},
"fieldTypes": {
"default": None,
"description": "Field Types for which to return field data",
"items": {"enum": ["enriched", "global", "relationship-intelligence"], "title": "FieldtypesEnm", "type": "string"},
"title": "Fieldtypes",
"type": "array",
},
},
title="GetAllCompaniesRequest",
type="object",
required=None,
),
response=ActionResponseModel(
properties={
"data": {"title": "Data", "type": "object"},
"successful": {
"description": "Whether or not the action execution was successful or not",
"title": "Successful",
"type": "boolean",
},
"error": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"default": None,
"description": "Error if any occurred during the execution of the action",
"title": "Error",
},
},
title="GetAllCompaniesResponse",
type="object",
required=["data", "successful"],
),
appName="affinity",
appId="affinity",
tags=["companies", "important"],
enabled=False,
logo="https://cdn.jsdelivr.net/gh/ComposioHQ/open-logos@master/affinity.jpeg",
description="Affinity Api Allows Paginated Access To Company Info And Custom Fields. Use `Field Ids` Or `Field Types` To Specify Data In A Request. Retrieve Field I Ds/Types Via Get `/V2/Companies/Fields`. Export Permission Needed.",
)
]
# ======================================================================================================================
# Tools Routes Tests
# ======================================================================================================================
def test_delete_tool(client, mock_sync_server, add_integers_tool):
mock_sync_server.tool_manager.delete_tool_by_id = MagicMock()
response = client.delete(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"})
assert response.status_code == 200
mock_sync_server.tool_manager.delete_tool_by_id.assert_called_once_with(
tool_id=add_integers_tool.id, actor=mock_sync_server.get_user_or_default.return_value
)
def test_get_tool(client, mock_sync_server, add_integers_tool):
mock_sync_server.tool_manager.get_tool_by_id.return_value = add_integers_tool
response = client.get(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"})
assert response.status_code == 200
assert response.json()["id"] == add_integers_tool.id
assert response.json()["source_code"] == add_integers_tool.source_code
mock_sync_server.tool_manager.get_tool_by_id.assert_called_once_with(
tool_id=add_integers_tool.id, actor=mock_sync_server.get_user_or_default.return_value
)
def test_get_tool_404(client, mock_sync_server, add_integers_tool):
mock_sync_server.tool_manager.get_tool_by_id.return_value = None
response = client.get(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"})
assert response.status_code == 404
assert response.json()["detail"] == f"Tool with id {add_integers_tool.id} not found."
def test_get_tool_id(client, mock_sync_server, add_integers_tool):
mock_sync_server.tool_manager.get_tool_by_name.return_value = add_integers_tool
response = client.get(f"/v1/tools/name/{add_integers_tool.name}", headers={"user_id": "test_user"})
assert response.status_code == 200
assert response.json() == add_integers_tool.id
mock_sync_server.tool_manager.get_tool_by_name.assert_called_once_with(
tool_name=add_integers_tool.name, actor=mock_sync_server.get_user_or_default.return_value
)
def test_get_tool_id_404(client, mock_sync_server):
mock_sync_server.tool_manager.get_tool_by_name.return_value = None
response = client.get("/v1/tools/name/UnknownTool", headers={"user_id": "test_user"})
assert response.status_code == 404
assert "Tool with name UnknownTool" in response.json()["detail"]
def test_list_tools(client, mock_sync_server, add_integers_tool):
mock_sync_server.tool_manager.list_tools.return_value = [add_integers_tool]
response = client.get("/v1/tools", headers={"user_id": "test_user"})
assert response.status_code == 200
assert len(response.json()) == 1
assert response.json()[0]["id"] == add_integers_tool.id
mock_sync_server.tool_manager.list_tools.assert_called_once()
def test_create_tool(client, mock_sync_server, create_integers_tool, add_integers_tool):
mock_sync_server.tool_manager.create_tool.return_value = add_integers_tool
response = client.post("/v1/tools", json=create_integers_tool.model_dump(), headers={"user_id": "test_user"})
assert response.status_code == 200
assert response.json()["id"] == add_integers_tool.id
mock_sync_server.tool_manager.create_tool.assert_called_once()
def test_upsert_tool(client, mock_sync_server, create_integers_tool, add_integers_tool):
mock_sync_server.tool_manager.create_or_update_tool.return_value = add_integers_tool
response = client.put("/v1/tools", json=create_integers_tool.model_dump(), headers={"user_id": "test_user"})
assert response.status_code == 200
assert response.json()["id"] == add_integers_tool.id
mock_sync_server.tool_manager.create_or_update_tool.assert_called_once()
def test_update_tool(client, mock_sync_server, update_integers_tool, add_integers_tool):
mock_sync_server.tool_manager.update_tool_by_id.return_value = add_integers_tool
response = client.patch(f"/v1/tools/{add_integers_tool.id}", json=update_integers_tool.model_dump(), headers={"user_id": "test_user"})
assert response.status_code == 200
assert response.json()["id"] == add_integers_tool.id
mock_sync_server.tool_manager.update_tool_by_id.assert_called_once_with(
tool_id=add_integers_tool.id, tool_update=update_integers_tool, actor=mock_sync_server.get_user_or_default.return_value
)
def test_add_base_tools(client, mock_sync_server, add_integers_tool):
mock_sync_server.tool_manager.add_base_tools.return_value = [add_integers_tool]
response = client.post("/v1/tools/add-base-tools", headers={"user_id": "test_user"})
assert response.status_code == 200
assert len(response.json()) == 1
assert response.json()[0]["id"] == add_integers_tool.id
mock_sync_server.tool_manager.add_base_tools.assert_called_once_with(actor=mock_sync_server.get_user_or_default.return_value)
def test_list_composio_apps(client, mock_sync_server, composio_apps):
mock_sync_server.get_composio_apps.return_value = composio_apps
response = client.get("/v1/tools/composio/apps")
assert response.status_code == 200
assert len(response.json()) == 1
mock_sync_server.get_composio_apps.assert_called_once()
def test_list_composio_actions_by_app(client, mock_sync_server, composio_actions):
mock_sync_server.get_composio_actions_from_app_name.return_value = composio_actions
response = client.get("/v1/tools/composio/apps/App1/actions")
assert response.status_code == 200
assert len(response.json()) == 1
mock_sync_server.get_composio_actions_from_app_name.assert_called_once_with(composio_app_name="App1")
def test_add_composio_tool(client, mock_sync_server, add_integers_tool):
# Mock ToolCreate.from_composio to return the expected ToolCreate object
with patch("letta.schemas.tool.ToolCreate.from_composio") as mock_from_composio:
mock_from_composio.return_value = ToolCreate(
name=add_integers_tool.name,
source_code=add_integers_tool.source_code,
json_schema=add_integers_tool.json_schema,
)
# Mock server behavior
mock_sync_server.tool_manager.create_or_update_tool.return_value = add_integers_tool
# Perform the request
response = client.post(f"/v1/tools/composio/{add_integers_tool.name}", headers={"user_id": "test_user"})
# Assertions
assert response.status_code == 200
assert response.json()["id"] == add_integers_tool.id
mock_sync_server.tool_manager.create_or_update_tool.assert_called_once()
# Verify the mocked from_composio method was called
mock_from_composio.assert_called_once_with(action=add_integers_tool.name)