fix: Enable importing LangChain tools with arguments (#1807)
Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
with:
|
||||
python-version: "3.12"
|
||||
poetry-version: "1.8.2"
|
||||
install-args: "-E dev -E postgres -E milvus -E crewai-tools"
|
||||
install-args: "-E dev -E postgres -E milvus -E crewai-tools -E tests"
|
||||
|
||||
- name: Initialize credentials
|
||||
run: poetry run letta quickstart --backend openai
|
||||
|
||||
175
letta/functions/helpers.py
Normal file
175
letta/functions/helpers.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def generate_langchain_tool_wrapper(tool: "LangChainBaseTool", additional_imports_module_attr_map: dict = None) -> tuple[str, str]:
|
||||
tool_name = tool.__class__.__name__
|
||||
import_statement = f"from langchain_community.tools import {tool_name}"
|
||||
extra_module_imports = generate_import_code(additional_imports_module_attr_map)
|
||||
|
||||
# Safety check that user has passed in all required imports:
|
||||
current_class_imports = {tool_name}
|
||||
if additional_imports_module_attr_map:
|
||||
current_class_imports.update(set(additional_imports_module_attr_map.values()))
|
||||
required_class_imports = set(find_required_class_names_for_import(tool))
|
||||
|
||||
if not current_class_imports.issuperset(required_class_imports):
|
||||
err_msg = f"[ERROR] You are missing module_attr pairs in `additional_imports_module_attr_map`. Currently, you have imports for {current_class_imports}, but the required classes for import are {required_class_imports}"
|
||||
print(err_msg)
|
||||
raise RuntimeError(err_msg)
|
||||
|
||||
tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}"
|
||||
run_call = f"return tool._run(**kwargs)"
|
||||
func_name = f"run_{tool_name.lower()}"
|
||||
|
||||
# Combine all parts into the wrapper function
|
||||
wrapper_function_str = f"""
|
||||
def {func_name}(**kwargs):
|
||||
if 'self' in kwargs:
|
||||
del kwargs['self']
|
||||
import importlib
|
||||
{import_statement}
|
||||
{extra_module_imports}
|
||||
{tool_instantiation}
|
||||
{run_call}
|
||||
"""
|
||||
return func_name, wrapper_function_str
|
||||
|
||||
|
||||
def generate_crewai_tool_wrapper(tool: "CrewAIBaseTool") -> tuple[str, str]:
|
||||
tool_name = tool.__class__.__name__
|
||||
import_statement = f"from crewai_tools import {tool_name}"
|
||||
tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}"
|
||||
run_call = f"return tool._run(**kwargs)"
|
||||
func_name = f"run_{tool_name.lower()}"
|
||||
|
||||
# Combine all parts into the wrapper function
|
||||
wrapper_function_str = f"""
|
||||
def {func_name}(**kwargs):
|
||||
if 'self' in kwargs:
|
||||
del kwargs['self']
|
||||
{import_statement}
|
||||
{tool_instantiation}
|
||||
{run_call}
|
||||
"""
|
||||
return func_name, wrapper_function_str
|
||||
|
||||
|
||||
def find_required_class_names_for_import(obj: Union["LangChainBaseTool", "CrewAIBaseTool", BaseModel]) -> list[str]:
|
||||
"""
|
||||
Finds all the class names for required imports when instantiating the `obj`.
|
||||
NOTE: This does not return the full import path, only the class name.
|
||||
|
||||
We accomplish this by running BFS and deep searching all the BaseModel objects in the obj parameters.
|
||||
"""
|
||||
class_names = {obj.__class__.__name__}
|
||||
queue = [obj]
|
||||
|
||||
while queue:
|
||||
# Get the current object we are inspecting
|
||||
curr_obj = queue.pop()
|
||||
|
||||
# Collect all possible candidates for BaseModel objects
|
||||
candidates = []
|
||||
if is_base_model(curr_obj):
|
||||
# If it is a base model, we get all the values of the object parameters
|
||||
# i.e., if obj('b' = <class A>), we would want to inspect <class A>
|
||||
fields = dict(curr_obj)
|
||||
# Generate code for each field, skipping empty or None values
|
||||
candidates = list(fields.values())
|
||||
elif isinstance(curr_obj, dict):
|
||||
# If it is a dictionary, we get all the values
|
||||
# i.e., if obj = {'a': 3, 'b': <class A>}, we would want to inspect <class A>
|
||||
candidates = list(curr_obj.values())
|
||||
elif isinstance(curr_obj, list):
|
||||
# If it is a list, we inspect all the items in the list
|
||||
# i.e., if obj = ['a', 3, None, <class A>], we would want to inspect <class A>
|
||||
candidates = curr_obj
|
||||
|
||||
# Filter out all candidates that are not BaseModels
|
||||
# In the list example above, ['a', 3, None, <class A>], we want to filter out 'a', 3, and None
|
||||
candidates = filter(lambda x: is_base_model(x), candidates)
|
||||
|
||||
# Classic BFS here
|
||||
for c in candidates:
|
||||
c_name = c.__class__.__name__
|
||||
if c_name not in class_names:
|
||||
class_names.add(c_name)
|
||||
queue.append(c)
|
||||
|
||||
return list(class_names)
|
||||
|
||||
|
||||
def generate_imported_tool_instantiation_call_str(obj: Any) -> Optional[str]:
|
||||
if isinstance(obj, (int, float, str, bool, type(None))):
|
||||
# This is the base case
|
||||
# If it is a basic Python type, we trivially return the string version of that value
|
||||
# Handle basic types
|
||||
return repr(obj)
|
||||
elif is_base_model(obj):
|
||||
# Otherwise, if it is a BaseModel
|
||||
# We want to pull out all the parameters, and reformat them into strings
|
||||
# e.g. {arg}={value}
|
||||
# The reason why this is recursive, is because the value can be another BaseModel that we need to stringify
|
||||
model_name = obj.__class__.__name__
|
||||
fields = dict(obj)
|
||||
# Generate code for each field, skipping empty or None values
|
||||
field_assignments = []
|
||||
for arg, value in fields.items():
|
||||
python_string = generate_imported_tool_instantiation_call_str(value)
|
||||
if python_string:
|
||||
field_assignments.append(f"{arg}={python_string}")
|
||||
|
||||
assignments = ", ".join(field_assignments)
|
||||
return f"{model_name}({assignments})"
|
||||
elif isinstance(obj, dict):
|
||||
# Inspect each of the items in the dict and stringify them
|
||||
# This is important because the dictionary may contain other BaseModels
|
||||
dict_items = []
|
||||
for k, v in obj.items():
|
||||
python_string = generate_imported_tool_instantiation_call_str(v)
|
||||
if python_string:
|
||||
dict_items.append(f"{repr(k)}: {python_string}")
|
||||
|
||||
joined_items = ", ".join(dict_items)
|
||||
return f"{{{joined_items}}}"
|
||||
elif isinstance(obj, list):
|
||||
# Inspect each of the items in the list and stringify them
|
||||
# This is important because the list may contain other BaseModels
|
||||
list_items = [generate_imported_tool_instantiation_call_str(v) for v in obj]
|
||||
filtered_list_items = list(filter(None, list_items))
|
||||
list_items = ", ".join(filtered_list_items)
|
||||
return f"[{list_items}]"
|
||||
else:
|
||||
# Otherwise, if it is none of the above, that usually means it is a custom Python class that is NOT a BaseModel
|
||||
# Thus, we cannot get enough information about it to stringify it
|
||||
# This may cause issues, but we are making the assumption that any of these custom Python types are handled correctly by the parent library, such as LangChain or CrewAI
|
||||
# An example would be that WikipediaAPIWrapper has an argument that is a wikipedia (pip install wikipedia) object
|
||||
# We cannot stringify this easily, but WikipediaAPIWrapper handles the setting of this parameter internally
|
||||
# This assumption seems fair to me, since usually they are external imports, and LangChain and CrewAI should be bundling those as module-level imports within the tool
|
||||
# We throw a warning here anyway and provide the class name
|
||||
print(
|
||||
f"[WARNING] Skipping parsing unknown class {obj.__class__.__name__} (does not inherit from the Pydantic BaseModel and is not a basic Python type)"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def is_base_model(obj: Any):
|
||||
from crewai_tools.tools.base_tool import BaseModel as CrewAiBaseModel
|
||||
from langchain_core.pydantic_v1 import BaseModel as LangChainBaseModel
|
||||
|
||||
return isinstance(obj, BaseModel) or isinstance(obj, LangChainBaseModel) or isinstance(obj, CrewAiBaseModel)
|
||||
|
||||
|
||||
def generate_import_code(module_attr_map: Optional[dict]):
|
||||
if not module_attr_map:
|
||||
return ""
|
||||
|
||||
code_lines = []
|
||||
for module, attr in module_attr_map.items():
|
||||
module_name = module.split(".")[-1]
|
||||
code_lines.append(f"# Load the module\n {module_name} = importlib.import_module('{module}')")
|
||||
code_lines.append(f" # Access the {attr} from the module")
|
||||
code_lines.append(f" {attr} = getattr({module_name}, '{attr}')")
|
||||
return "\n".join(code_lines)
|
||||
@@ -1,6 +1,5 @@
|
||||
import inspect
|
||||
import typing
|
||||
from typing import Any, Dict, Optional, Type, get_args, get_origin
|
||||
from typing import Any, Dict, Optional, Type, Union, get_args, get_origin
|
||||
|
||||
from docstring_parser import parse
|
||||
from pydantic import BaseModel
|
||||
@@ -8,7 +7,7 @@ from pydantic import BaseModel
|
||||
|
||||
def is_optional(annotation):
|
||||
# Check if the annotation is a Union
|
||||
if getattr(annotation, "__origin__", None) is typing.Union:
|
||||
if getattr(annotation, "__origin__", None) is Union:
|
||||
# Check if None is one of the options in the Union
|
||||
return type(None) in annotation.__args__
|
||||
return False
|
||||
@@ -164,42 +163,3 @@ def generate_schema_from_args_schema(
|
||||
}
|
||||
|
||||
return function_call_json
|
||||
|
||||
|
||||
def generate_langchain_tool_wrapper(tool_name: str) -> str:
|
||||
import_statement = f"from langchain_community.tools import {tool_name}"
|
||||
|
||||
# NOTE: this will fail for tools like 'wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())' since it needs to pass an argument to the tool instantiation
|
||||
# https://python.langchain.com/v0.1/docs/integrations/tools/wikipedia/
|
||||
tool_instantiation = f"tool = {tool_name}()"
|
||||
run_call = f"return tool._run(**kwargs)"
|
||||
func_name = f"run_{tool_name.lower()}"
|
||||
|
||||
# Combine all parts into the wrapper function
|
||||
wrapper_function_str = f"""
|
||||
def {func_name}(**kwargs):
|
||||
if 'self' in kwargs:
|
||||
del kwargs['self']
|
||||
{import_statement}
|
||||
{tool_instantiation}
|
||||
{run_call}
|
||||
"""
|
||||
return func_name, wrapper_function_str
|
||||
|
||||
|
||||
def generate_crewai_tool_wrapper(tool_name: str) -> str:
|
||||
import_statement = f"from crewai_tools import {tool_name}"
|
||||
tool_instantiation = f"tool = {tool_name}()"
|
||||
run_call = f"return tool._run(**kwargs)"
|
||||
func_name = f"run_{tool_name.lower()}"
|
||||
|
||||
# Combine all parts into the wrapper function
|
||||
wrapper_function_str = f"""
|
||||
def {func_name}(**kwargs):
|
||||
if 'self' in kwargs:
|
||||
del kwargs['self']
|
||||
{import_statement}
|
||||
{tool_instantiation}
|
||||
{run_call}
|
||||
"""
|
||||
return func_name, wrapper_function_str
|
||||
|
||||
@@ -2,11 +2,11 @@ from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.functions.schema_generator import (
|
||||
from letta.functions.helpers import (
|
||||
generate_crewai_tool_wrapper,
|
||||
generate_langchain_tool_wrapper,
|
||||
generate_schema_from_args_schema,
|
||||
)
|
||||
from letta.functions.schema_generator import generate_schema_from_args_schema
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.openai.chat_completions import ToolCall
|
||||
|
||||
@@ -58,12 +58,13 @@ class Tool(BaseTool):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_langchain(cls, langchain_tool) -> "Tool":
|
||||
def from_langchain(cls, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> "Tool":
|
||||
"""
|
||||
Class method to create an instance of Tool from a Langchain tool (must be from langchain_community.tools).
|
||||
|
||||
Args:
|
||||
langchain_tool (LangchainTool): An instance of a crewAI BaseTool (BaseTool from crewai)
|
||||
langchain_tool (LangChainBaseTool): An instance of a crewAI BaseTool (BaseTool from crewai)
|
||||
additional_imports_module_attr_map (dict[str, str]): A mapping of module names to attribute name. This is used internally to import all the required classes for the langchain tool. For example, you would pass in `{"langchain_community.utilities": "WikipediaAPIWrapper"}` for `from langchain_community.tools import WikipediaQueryRun`. NOTE: You do NOT need to specify the tool import here, that is done automatically for you.
|
||||
|
||||
Returns:
|
||||
Tool: A Letta Tool initialized with attributes derived from the provided crewAI BaseTool object.
|
||||
@@ -72,7 +73,7 @@ class Tool(BaseTool):
|
||||
source_type = "python"
|
||||
tags = ["langchain"]
|
||||
# NOTE: langchain tools may come from different packages
|
||||
wrapper_func_name, wrapper_function_str = generate_langchain_tool_wrapper(langchain_tool.__class__.__name__)
|
||||
wrapper_func_name, wrapper_function_str = generate_langchain_tool_wrapper(langchain_tool, additional_imports_module_attr_map)
|
||||
json_schema = generate_schema_from_args_schema(langchain_tool.args_schema, name=wrapper_func_name, description=description)
|
||||
|
||||
# append heartbeat (necessary for triggering another reasoning step after this tool call)
|
||||
@@ -92,7 +93,7 @@ class Tool(BaseTool):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_crewai(cls, crewai_tool) -> "Tool":
|
||||
def from_crewai(cls, crewai_tool: "CrewAIBaseTool") -> "Tool":
|
||||
"""
|
||||
Class method to create an instance of Tool from a crewAI BaseTool object.
|
||||
|
||||
@@ -102,11 +103,10 @@ class Tool(BaseTool):
|
||||
Returns:
|
||||
Tool: A Letta Tool initialized with attributes derived from the provided crewAI BaseTool object.
|
||||
"""
|
||||
crewai_tool.name
|
||||
description = crewai_tool.description
|
||||
source_type = "python"
|
||||
tags = ["crew-ai"]
|
||||
wrapper_func_name, wrapper_function_str = generate_crewai_tool_wrapper(crewai_tool.__class__.__name__)
|
||||
wrapper_func_name, wrapper_function_str = generate_crewai_tool_wrapper(crewai_tool)
|
||||
json_schema = generate_schema_from_args_schema(crewai_tool.args_schema, name=wrapper_func_name, description=description)
|
||||
|
||||
# append heartbeat (necessary for triggering another reasoning step after this tool call)
|
||||
|
||||
17
poetry.lock
generated
17
poetry.lock
generated
@@ -7388,6 +7388,20 @@ MarkupSafe = ">=2.1.1"
|
||||
[package.extras]
|
||||
watchdog = ["watchdog (>=2.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "wikipedia"
|
||||
version = "1.4.0"
|
||||
description = "Wikipedia API for Python"
|
||||
optional = true
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "wikipedia-1.4.0.tar.gz", hash = "sha256:db0fad1829fdd441b1852306e9856398204dc0786d2996dd2e0c8bb8e26133b2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
beautifulsoup4 = "*"
|
||||
requests = ">=2.0.0,<3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "wrapt"
|
||||
version = "1.16.0"
|
||||
@@ -7815,8 +7829,9 @@ ollama = ["llama-index-embeddings-ollama"]
|
||||
postgres = ["pg8000", "pgvector"]
|
||||
qdrant = ["qdrant-client"]
|
||||
server = ["fastapi", "uvicorn", "websockets"]
|
||||
tests = ["wikipedia"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.13,>=3.10"
|
||||
content-hash = "e13bb1fe8b39e5cc233b9a5fc1d62e671e9475a4e5c7961548dfc2c958f19f2f"
|
||||
content-hash = "0c2547e076e664c564e571135b9a18c4a3e1bf4d164c7312fbb291cfa9e7b780"
|
||||
|
||||
@@ -71,6 +71,7 @@ llama-index = "^0.11.9"
|
||||
llama-index-embeddings-openai = "^0.2.5"
|
||||
llama-index-embeddings-ollama = "^0.3.1"
|
||||
#llama-index-embeddings-huggingface = {version = "^0.2.0", optional = true}
|
||||
wikipedia = {version = "^1.4.0", optional = true}
|
||||
|
||||
[tool.poetry.extras]
|
||||
#local = ["llama-index-embeddings-huggingface"]
|
||||
@@ -82,6 +83,7 @@ autogen = ["pyautogen"]
|
||||
qdrant = ["qdrant-client"]
|
||||
ollama = ["llama-index-embeddings-ollama"]
|
||||
crewai-tools = ["crewai", "docker", "crewai-tools"]
|
||||
tests = ["wikipedia"]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^24.4.2"
|
||||
|
||||
@@ -284,8 +284,6 @@ def test_tools_from_crewai(client):
|
||||
retrieved_tool = client.get_tool(tool_id)
|
||||
source_code = retrieved_tool.source_code
|
||||
|
||||
print(source_code)
|
||||
|
||||
# Parse the function and attempt to use it
|
||||
local_scope = {}
|
||||
exec(source_code, {}, local_scope)
|
||||
@@ -299,6 +297,56 @@ def test_tools_from_crewai(client):
|
||||
assert expected_content in func(website_url=simple_webpage_url)
|
||||
|
||||
|
||||
def test_tools_from_langchain(client):
|
||||
# create langchain tool
|
||||
from langchain_community.tools import WikipediaQueryRun
|
||||
from langchain_community.utilities import WikipediaAPIWrapper
|
||||
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100)
|
||||
langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
|
||||
|
||||
# Translate to memGPT Tool
|
||||
tool = Tool.from_langchain(langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"})
|
||||
|
||||
# Add the tool
|
||||
client.add_tool(tool)
|
||||
|
||||
# list tools
|
||||
tools = client.list_tools()
|
||||
assert tool.name in [t.name for t in tools]
|
||||
|
||||
# get tool
|
||||
tool_id = client.get_tool_id(name=tool.name)
|
||||
retrieved_tool = client.get_tool(tool_id)
|
||||
source_code = retrieved_tool.source_code
|
||||
|
||||
# Parse the function and attempt to use it
|
||||
local_scope = {}
|
||||
exec(source_code, {}, local_scope)
|
||||
func = local_scope[tool.name]
|
||||
|
||||
expected_content = "Albert Einstein ( EYEN-styne; German:"
|
||||
assert expected_content in func(query="Albert Einstein")
|
||||
|
||||
|
||||
def test_tool_creation_langchain_missing_imports(client):
|
||||
# create langchain tool
|
||||
from langchain_community.tools import WikipediaQueryRun
|
||||
from langchain_community.utilities import WikipediaAPIWrapper
|
||||
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100)
|
||||
langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
|
||||
|
||||
# Translate to memGPT Tool
|
||||
# Intentionally missing {"langchain_community.utilities": "WikipediaAPIWrapper"}
|
||||
with pytest.raises(RuntimeError):
|
||||
Tool.from_langchain(langchain_tool)
|
||||
|
||||
|
||||
def test_sources(client, agent):
|
||||
# list sources (empty)
|
||||
sources = client.list_sources()
|
||||
|
||||
Reference in New Issue
Block a user