fix: Enable importing LangChain tools with arguments (#1807)

Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
Matthew Zhou
2024-09-30 18:47:48 -07:00
committed by GitHub
parent 9f9e967c8b
commit 2164aedb46
7 changed files with 254 additions and 54 deletions

View File

@@ -32,7 +32,7 @@ jobs:
with:
python-version: "3.12"
poetry-version: "1.8.2"
install-args: "-E dev -E postgres -E milvus -E crewai-tools"
install-args: "-E dev -E postgres -E milvus -E crewai-tools -E tests"
- name: Initialize credentials
run: poetry run letta quickstart --backend openai

175
letta/functions/helpers.py Normal file
View File

@@ -0,0 +1,175 @@
from typing import Any, Optional, Union
from pydantic import BaseModel
def generate_langchain_tool_wrapper(tool: "LangChainBaseTool", additional_imports_module_attr_map: dict = None) -> tuple[str, str]:
tool_name = tool.__class__.__name__
import_statement = f"from langchain_community.tools import {tool_name}"
extra_module_imports = generate_import_code(additional_imports_module_attr_map)
# Safety check that user has passed in all required imports:
current_class_imports = {tool_name}
if additional_imports_module_attr_map:
current_class_imports.update(set(additional_imports_module_attr_map.values()))
required_class_imports = set(find_required_class_names_for_import(tool))
if not current_class_imports.issuperset(required_class_imports):
err_msg = f"[ERROR] You are missing module_attr pairs in `additional_imports_module_attr_map`. Currently, you have imports for {current_class_imports}, but the required classes for import are {required_class_imports}"
print(err_msg)
raise RuntimeError(err_msg)
tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}"
run_call = f"return tool._run(**kwargs)"
func_name = f"run_{tool_name.lower()}"
# Combine all parts into the wrapper function
wrapper_function_str = f"""
def {func_name}(**kwargs):
if 'self' in kwargs:
del kwargs['self']
import importlib
{import_statement}
{extra_module_imports}
{tool_instantiation}
{run_call}
"""
return func_name, wrapper_function_str
def generate_crewai_tool_wrapper(tool: "CrewAIBaseTool") -> tuple[str, str]:
tool_name = tool.__class__.__name__
import_statement = f"from crewai_tools import {tool_name}"
tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}"
run_call = f"return tool._run(**kwargs)"
func_name = f"run_{tool_name.lower()}"
# Combine all parts into the wrapper function
wrapper_function_str = f"""
def {func_name}(**kwargs):
if 'self' in kwargs:
del kwargs['self']
{import_statement}
{tool_instantiation}
{run_call}
"""
return func_name, wrapper_function_str
def find_required_class_names_for_import(obj: Union["LangChainBaseTool", "CrewAIBaseTool", BaseModel]) -> list[str]:
"""
Finds all the class names for required imports when instantiating the `obj`.
NOTE: This does not return the full import path, only the class name.
We accomplish this by running BFS and deep searching all the BaseModel objects in the obj parameters.
"""
class_names = {obj.__class__.__name__}
queue = [obj]
while queue:
# Get the current object we are inspecting
curr_obj = queue.pop()
# Collect all possible candidates for BaseModel objects
candidates = []
if is_base_model(curr_obj):
# If it is a base model, we get all the values of the object parameters
# i.e., if obj('b' = <class A>), we would want to inspect <class A>
fields = dict(curr_obj)
# Generate code for each field, skipping empty or None values
candidates = list(fields.values())
elif isinstance(curr_obj, dict):
# If it is a dictionary, we get all the values
# i.e., if obj = {'a': 3, 'b': <class A>}, we would want to inspect <class A>
candidates = list(curr_obj.values())
elif isinstance(curr_obj, list):
# If it is a list, we inspect all the items in the list
# i.e., if obj = ['a', 3, None, <class A>], we would want to inspect <class A>
candidates = curr_obj
# Filter out all candidates that are not BaseModels
# In the list example above, ['a', 3, None, <class A>], we want to filter out 'a', 3, and None
candidates = filter(lambda x: is_base_model(x), candidates)
# Classic BFS here
for c in candidates:
c_name = c.__class__.__name__
if c_name not in class_names:
class_names.add(c_name)
queue.append(c)
return list(class_names)
def generate_imported_tool_instantiation_call_str(obj: Any) -> Optional[str]:
if isinstance(obj, (int, float, str, bool, type(None))):
# This is the base case
# If it is a basic Python type, we trivially return the string version of that value
# Handle basic types
return repr(obj)
elif is_base_model(obj):
# Otherwise, if it is a BaseModel
# We want to pull out all the parameters, and reformat them into strings
# e.g. {arg}={value}
# The reason why this is recursive, is because the value can be another BaseModel that we need to stringify
model_name = obj.__class__.__name__
fields = dict(obj)
# Generate code for each field, skipping empty or None values
field_assignments = []
for arg, value in fields.items():
python_string = generate_imported_tool_instantiation_call_str(value)
if python_string:
field_assignments.append(f"{arg}={python_string}")
assignments = ", ".join(field_assignments)
return f"{model_name}({assignments})"
elif isinstance(obj, dict):
# Inspect each of the items in the dict and stringify them
# This is important because the dictionary may contain other BaseModels
dict_items = []
for k, v in obj.items():
python_string = generate_imported_tool_instantiation_call_str(v)
if python_string:
dict_items.append(f"{repr(k)}: {python_string}")
joined_items = ", ".join(dict_items)
return f"{{{joined_items}}}"
elif isinstance(obj, list):
# Inspect each of the items in the list and stringify them
# This is important because the list may contain other BaseModels
list_items = [generate_imported_tool_instantiation_call_str(v) for v in obj]
filtered_list_items = list(filter(None, list_items))
list_items = ", ".join(filtered_list_items)
return f"[{list_items}]"
else:
# Otherwise, if it is none of the above, that usually means it is a custom Python class that is NOT a BaseModel
# Thus, we cannot get enough information about it to stringify it
# This may cause issues, but we are making the assumption that any of these custom Python types are handled correctly by the parent library, such as LangChain or CrewAI
# An example would be that WikipediaAPIWrapper has an argument that is a wikipedia (pip install wikipedia) object
# We cannot stringify this easily, but WikipediaAPIWrapper handles the setting of this parameter internally
# This assumption seems fair to me, since usually they are external imports, and LangChain and CrewAI should be bundling those as module-level imports within the tool
# We throw a warning here anyway and provide the class name
print(
f"[WARNING] Skipping parsing unknown class {obj.__class__.__name__} (does not inherit from the Pydantic BaseModel and is not a basic Python type)"
)
return None
def is_base_model(obj: Any):
from crewai_tools.tools.base_tool import BaseModel as CrewAiBaseModel
from langchain_core.pydantic_v1 import BaseModel as LangChainBaseModel
return isinstance(obj, BaseModel) or isinstance(obj, LangChainBaseModel) or isinstance(obj, CrewAiBaseModel)
def generate_import_code(module_attr_map: Optional[dict]):
if not module_attr_map:
return ""
code_lines = []
for module, attr in module_attr_map.items():
module_name = module.split(".")[-1]
code_lines.append(f"# Load the module\n {module_name} = importlib.import_module('{module}')")
code_lines.append(f" # Access the {attr} from the module")
code_lines.append(f" {attr} = getattr({module_name}, '{attr}')")
return "\n".join(code_lines)

View File

@@ -1,6 +1,5 @@
import inspect
import typing
from typing import Any, Dict, Optional, Type, get_args, get_origin
from typing import Any, Dict, Optional, Type, Union, get_args, get_origin
from docstring_parser import parse
from pydantic import BaseModel
@@ -8,7 +7,7 @@ from pydantic import BaseModel
def is_optional(annotation):
# Check if the annotation is a Union
if getattr(annotation, "__origin__", None) is typing.Union:
if getattr(annotation, "__origin__", None) is Union:
# Check if None is one of the options in the Union
return type(None) in annotation.__args__
return False
@@ -164,42 +163,3 @@ def generate_schema_from_args_schema(
}
return function_call_json
def generate_langchain_tool_wrapper(tool_name: str) -> str:
import_statement = f"from langchain_community.tools import {tool_name}"
# NOTE: this will fail for tools like 'wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())' since it needs to pass an argument to the tool instantiation
# https://python.langchain.com/v0.1/docs/integrations/tools/wikipedia/
tool_instantiation = f"tool = {tool_name}()"
run_call = f"return tool._run(**kwargs)"
func_name = f"run_{tool_name.lower()}"
# Combine all parts into the wrapper function
wrapper_function_str = f"""
def {func_name}(**kwargs):
if 'self' in kwargs:
del kwargs['self']
{import_statement}
{tool_instantiation}
{run_call}
"""
return func_name, wrapper_function_str
def generate_crewai_tool_wrapper(tool_name: str) -> str:
import_statement = f"from crewai_tools import {tool_name}"
tool_instantiation = f"tool = {tool_name}()"
run_call = f"return tool._run(**kwargs)"
func_name = f"run_{tool_name.lower()}"
# Combine all parts into the wrapper function
wrapper_function_str = f"""
def {func_name}(**kwargs):
if 'self' in kwargs:
del kwargs['self']
{import_statement}
{tool_instantiation}
{run_call}
"""
return func_name, wrapper_function_str

View File

@@ -2,11 +2,11 @@ from typing import Dict, List, Optional
from pydantic import Field
from letta.functions.schema_generator import (
from letta.functions.helpers import (
generate_crewai_tool_wrapper,
generate_langchain_tool_wrapper,
generate_schema_from_args_schema,
)
from letta.functions.schema_generator import generate_schema_from_args_schema
from letta.schemas.letta_base import LettaBase
from letta.schemas.openai.chat_completions import ToolCall
@@ -58,12 +58,13 @@ class Tool(BaseTool):
)
@classmethod
def from_langchain(cls, langchain_tool) -> "Tool":
def from_langchain(cls, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> "Tool":
"""
Class method to create an instance of Tool from a Langchain tool (must be from langchain_community.tools).
Args:
langchain_tool (LangchainTool): An instance of a crewAI BaseTool (BaseTool from crewai)
langchain_tool (LangChainBaseTool): An instance of a crewAI BaseTool (BaseTool from crewai)
additional_imports_module_attr_map (dict[str, str]): A mapping of module names to attribute name. This is used internally to import all the required classes for the langchain tool. For example, you would pass in `{"langchain_community.utilities": "WikipediaAPIWrapper"}` for `from langchain_community.tools import WikipediaQueryRun`. NOTE: You do NOT need to specify the tool import here, that is done automatically for you.
Returns:
Tool: A Letta Tool initialized with attributes derived from the provided crewAI BaseTool object.
@@ -72,7 +73,7 @@ class Tool(BaseTool):
source_type = "python"
tags = ["langchain"]
# NOTE: langchain tools may come from different packages
wrapper_func_name, wrapper_function_str = generate_langchain_tool_wrapper(langchain_tool.__class__.__name__)
wrapper_func_name, wrapper_function_str = generate_langchain_tool_wrapper(langchain_tool, additional_imports_module_attr_map)
json_schema = generate_schema_from_args_schema(langchain_tool.args_schema, name=wrapper_func_name, description=description)
# append heartbeat (necessary for triggering another reasoning step after this tool call)
@@ -92,7 +93,7 @@ class Tool(BaseTool):
)
@classmethod
def from_crewai(cls, crewai_tool) -> "Tool":
def from_crewai(cls, crewai_tool: "CrewAIBaseTool") -> "Tool":
"""
Class method to create an instance of Tool from a crewAI BaseTool object.
@@ -102,11 +103,10 @@ class Tool(BaseTool):
Returns:
Tool: A Letta Tool initialized with attributes derived from the provided crewAI BaseTool object.
"""
crewai_tool.name
description = crewai_tool.description
source_type = "python"
tags = ["crew-ai"]
wrapper_func_name, wrapper_function_str = generate_crewai_tool_wrapper(crewai_tool.__class__.__name__)
wrapper_func_name, wrapper_function_str = generate_crewai_tool_wrapper(crewai_tool)
json_schema = generate_schema_from_args_schema(crewai_tool.args_schema, name=wrapper_func_name, description=description)
# append heartbeat (necessary for triggering another reasoning step after this tool call)

17
poetry.lock generated
View File

@@ -7388,6 +7388,20 @@ MarkupSafe = ">=2.1.1"
[package.extras]
watchdog = ["watchdog (>=2.3)"]
[[package]]
name = "wikipedia"
version = "1.4.0"
description = "Wikipedia API for Python"
optional = true
python-versions = "*"
files = [
{file = "wikipedia-1.4.0.tar.gz", hash = "sha256:db0fad1829fdd441b1852306e9856398204dc0786d2996dd2e0c8bb8e26133b2"},
]
[package.dependencies]
beautifulsoup4 = "*"
requests = ">=2.0.0,<3.0.0"
[[package]]
name = "wrapt"
version = "1.16.0"
@@ -7815,8 +7829,9 @@ ollama = ["llama-index-embeddings-ollama"]
postgres = ["pg8000", "pgvector"]
qdrant = ["qdrant-client"]
server = ["fastapi", "uvicorn", "websockets"]
tests = ["wikipedia"]
[metadata]
lock-version = "2.0"
python-versions = "<3.13,>=3.10"
content-hash = "e13bb1fe8b39e5cc233b9a5fc1d62e671e9475a4e5c7961548dfc2c958f19f2f"
content-hash = "0c2547e076e664c564e571135b9a18c4a3e1bf4d164c7312fbb291cfa9e7b780"

View File

@@ -71,6 +71,7 @@ llama-index = "^0.11.9"
llama-index-embeddings-openai = "^0.2.5"
llama-index-embeddings-ollama = "^0.3.1"
#llama-index-embeddings-huggingface = {version = "^0.2.0", optional = true}
wikipedia = {version = "^1.4.0", optional = true}
[tool.poetry.extras]
#local = ["llama-index-embeddings-huggingface"]
@@ -82,6 +83,7 @@ autogen = ["pyautogen"]
qdrant = ["qdrant-client"]
ollama = ["llama-index-embeddings-ollama"]
crewai-tools = ["crewai", "docker", "crewai-tools"]
tests = ["wikipedia"]
[tool.poetry.group.dev.dependencies]
black = "^24.4.2"

View File

@@ -284,8 +284,6 @@ def test_tools_from_crewai(client):
retrieved_tool = client.get_tool(tool_id)
source_code = retrieved_tool.source_code
print(source_code)
# Parse the function and attempt to use it
local_scope = {}
exec(source_code, {}, local_scope)
@@ -299,6 +297,56 @@ def test_tools_from_crewai(client):
assert expected_content in func(website_url=simple_webpage_url)
def test_tools_from_langchain(client):
# create langchain tool
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper
from letta.schemas.tool import Tool
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100)
langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
# Translate to memGPT Tool
tool = Tool.from_langchain(langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"})
# Add the tool
client.add_tool(tool)
# list tools
tools = client.list_tools()
assert tool.name in [t.name for t in tools]
# get tool
tool_id = client.get_tool_id(name=tool.name)
retrieved_tool = client.get_tool(tool_id)
source_code = retrieved_tool.source_code
# Parse the function and attempt to use it
local_scope = {}
exec(source_code, {}, local_scope)
func = local_scope[tool.name]
expected_content = "Albert Einstein ( EYEN-styne; German:"
assert expected_content in func(query="Albert Einstein")
def test_tool_creation_langchain_missing_imports(client):
# create langchain tool
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper
from letta.schemas.tool import Tool
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100)
langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
# Translate to memGPT Tool
# Intentionally missing {"langchain_community.utilities": "WikipediaAPIWrapper"}
with pytest.raises(RuntimeError):
Tool.from_langchain(langchain_tool)
def test_sources(client, agent):
# list sources (empty)
sources = client.list_sources()