diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 99da8521..53aa7a75 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -32,7 +32,7 @@ jobs: with: python-version: "3.12" poetry-version: "1.8.2" - install-args: "-E dev -E postgres -E milvus -E crewai-tools" + install-args: "-E dev -E postgres -E milvus -E crewai-tools -E tests" - name: Initialize credentials run: poetry run letta quickstart --backend openai diff --git a/letta/functions/helpers.py b/letta/functions/helpers.py new file mode 100644 index 00000000..5b7cd33a --- /dev/null +++ b/letta/functions/helpers.py @@ -0,0 +1,175 @@ +from typing import Any, Optional, Union + +from pydantic import BaseModel + + +def generate_langchain_tool_wrapper(tool: "LangChainBaseTool", additional_imports_module_attr_map: dict = None) -> tuple[str, str]: + tool_name = tool.__class__.__name__ + import_statement = f"from langchain_community.tools import {tool_name}" + extra_module_imports = generate_import_code(additional_imports_module_attr_map) + + # Safety check that user has passed in all required imports: + current_class_imports = {tool_name} + if additional_imports_module_attr_map: + current_class_imports.update(set(additional_imports_module_attr_map.values())) + required_class_imports = set(find_required_class_names_for_import(tool)) + + if not current_class_imports.issuperset(required_class_imports): + err_msg = f"[ERROR] You are missing module_attr pairs in `additional_imports_module_attr_map`. Currently, you have imports for {current_class_imports}, but the required classes for import are {required_class_imports}" + print(err_msg) + raise RuntimeError(err_msg) + + tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}" + run_call = f"return tool._run(**kwargs)" + func_name = f"run_{tool_name.lower()}" + + # Combine all parts into the wrapper function + wrapper_function_str = f""" +def {func_name}(**kwargs): + if 'self' in kwargs: + del kwargs['self'] + import importlib + {import_statement} + {extra_module_imports} + {tool_instantiation} + {run_call} +""" + return func_name, wrapper_function_str + + +def generate_crewai_tool_wrapper(tool: "CrewAIBaseTool") -> tuple[str, str]: + tool_name = tool.__class__.__name__ + import_statement = f"from crewai_tools import {tool_name}" + tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}" + run_call = f"return tool._run(**kwargs)" + func_name = f"run_{tool_name.lower()}" + + # Combine all parts into the wrapper function + wrapper_function_str = f""" +def {func_name}(**kwargs): + if 'self' in kwargs: + del kwargs['self'] + {import_statement} + {tool_instantiation} + {run_call} +""" + return func_name, wrapper_function_str + + +def find_required_class_names_for_import(obj: Union["LangChainBaseTool", "CrewAIBaseTool", BaseModel]) -> list[str]: + """ + Finds all the class names for required imports when instantiating the `obj`. + NOTE: This does not return the full import path, only the class name. + + We accomplish this by running BFS and deep searching all the BaseModel objects in the obj parameters. + """ + class_names = {obj.__class__.__name__} + queue = [obj] + + while queue: + # Get the current object we are inspecting + curr_obj = queue.pop() + + # Collect all possible candidates for BaseModel objects + candidates = [] + if is_base_model(curr_obj): + # If it is a base model, we get all the values of the object parameters + # i.e., if obj('b' = ), we would want to inspect + fields = dict(curr_obj) + # Generate code for each field, skipping empty or None values + candidates = list(fields.values()) + elif isinstance(curr_obj, dict): + # If it is a dictionary, we get all the values + # i.e., if obj = {'a': 3, 'b': }, we would want to inspect + candidates = list(curr_obj.values()) + elif isinstance(curr_obj, list): + # If it is a list, we inspect all the items in the list + # i.e., if obj = ['a', 3, None, ], we would want to inspect + candidates = curr_obj + + # Filter out all candidates that are not BaseModels + # In the list example above, ['a', 3, None, ], we want to filter out 'a', 3, and None + candidates = filter(lambda x: is_base_model(x), candidates) + + # Classic BFS here + for c in candidates: + c_name = c.__class__.__name__ + if c_name not in class_names: + class_names.add(c_name) + queue.append(c) + + return list(class_names) + + +def generate_imported_tool_instantiation_call_str(obj: Any) -> Optional[str]: + if isinstance(obj, (int, float, str, bool, type(None))): + # This is the base case + # If it is a basic Python type, we trivially return the string version of that value + # Handle basic types + return repr(obj) + elif is_base_model(obj): + # Otherwise, if it is a BaseModel + # We want to pull out all the parameters, and reformat them into strings + # e.g. {arg}={value} + # The reason why this is recursive, is because the value can be another BaseModel that we need to stringify + model_name = obj.__class__.__name__ + fields = dict(obj) + # Generate code for each field, skipping empty or None values + field_assignments = [] + for arg, value in fields.items(): + python_string = generate_imported_tool_instantiation_call_str(value) + if python_string: + field_assignments.append(f"{arg}={python_string}") + + assignments = ", ".join(field_assignments) + return f"{model_name}({assignments})" + elif isinstance(obj, dict): + # Inspect each of the items in the dict and stringify them + # This is important because the dictionary may contain other BaseModels + dict_items = [] + for k, v in obj.items(): + python_string = generate_imported_tool_instantiation_call_str(v) + if python_string: + dict_items.append(f"{repr(k)}: {python_string}") + + joined_items = ", ".join(dict_items) + return f"{{{joined_items}}}" + elif isinstance(obj, list): + # Inspect each of the items in the list and stringify them + # This is important because the list may contain other BaseModels + list_items = [generate_imported_tool_instantiation_call_str(v) for v in obj] + filtered_list_items = list(filter(None, list_items)) + list_items = ", ".join(filtered_list_items) + return f"[{list_items}]" + else: + # Otherwise, if it is none of the above, that usually means it is a custom Python class that is NOT a BaseModel + # Thus, we cannot get enough information about it to stringify it + # This may cause issues, but we are making the assumption that any of these custom Python types are handled correctly by the parent library, such as LangChain or CrewAI + # An example would be that WikipediaAPIWrapper has an argument that is a wikipedia (pip install wikipedia) object + # We cannot stringify this easily, but WikipediaAPIWrapper handles the setting of this parameter internally + # This assumption seems fair to me, since usually they are external imports, and LangChain and CrewAI should be bundling those as module-level imports within the tool + # We throw a warning here anyway and provide the class name + print( + f"[WARNING] Skipping parsing unknown class {obj.__class__.__name__} (does not inherit from the Pydantic BaseModel and is not a basic Python type)" + ) + return None + + +def is_base_model(obj: Any): + from crewai_tools.tools.base_tool import BaseModel as CrewAiBaseModel + from langchain_core.pydantic_v1 import BaseModel as LangChainBaseModel + + return isinstance(obj, BaseModel) or isinstance(obj, LangChainBaseModel) or isinstance(obj, CrewAiBaseModel) + + +def generate_import_code(module_attr_map: Optional[dict]): + if not module_attr_map: + return "" + + code_lines = [] + for module, attr in module_attr_map.items(): + module_name = module.split(".")[-1] + code_lines.append(f"# Load the module\n {module_name} = importlib.import_module('{module}')") + code_lines.append(f" # Access the {attr} from the module") + code_lines.append(f" {attr} = getattr({module_name}, '{attr}')") + return "\n".join(code_lines) diff --git a/letta/functions/schema_generator.py b/letta/functions/schema_generator.py index 69fed2cf..5f282f90 100644 --- a/letta/functions/schema_generator.py +++ b/letta/functions/schema_generator.py @@ -1,6 +1,5 @@ import inspect -import typing -from typing import Any, Dict, Optional, Type, get_args, get_origin +from typing import Any, Dict, Optional, Type, Union, get_args, get_origin from docstring_parser import parse from pydantic import BaseModel @@ -8,7 +7,7 @@ from pydantic import BaseModel def is_optional(annotation): # Check if the annotation is a Union - if getattr(annotation, "__origin__", None) is typing.Union: + if getattr(annotation, "__origin__", None) is Union: # Check if None is one of the options in the Union return type(None) in annotation.__args__ return False @@ -164,42 +163,3 @@ def generate_schema_from_args_schema( } return function_call_json - - -def generate_langchain_tool_wrapper(tool_name: str) -> str: - import_statement = f"from langchain_community.tools import {tool_name}" - - # NOTE: this will fail for tools like 'wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())' since it needs to pass an argument to the tool instantiation - # https://python.langchain.com/v0.1/docs/integrations/tools/wikipedia/ - tool_instantiation = f"tool = {tool_name}()" - run_call = f"return tool._run(**kwargs)" - func_name = f"run_{tool_name.lower()}" - - # Combine all parts into the wrapper function - wrapper_function_str = f""" -def {func_name}(**kwargs): - if 'self' in kwargs: - del kwargs['self'] - {import_statement} - {tool_instantiation} - {run_call} -""" - return func_name, wrapper_function_str - - -def generate_crewai_tool_wrapper(tool_name: str) -> str: - import_statement = f"from crewai_tools import {tool_name}" - tool_instantiation = f"tool = {tool_name}()" - run_call = f"return tool._run(**kwargs)" - func_name = f"run_{tool_name.lower()}" - - # Combine all parts into the wrapper function - wrapper_function_str = f""" -def {func_name}(**kwargs): - if 'self' in kwargs: - del kwargs['self'] - {import_statement} - {tool_instantiation} - {run_call} -""" - return func_name, wrapper_function_str diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index 74cd0eb6..067cee1a 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -2,11 +2,11 @@ from typing import Dict, List, Optional from pydantic import Field -from letta.functions.schema_generator import ( +from letta.functions.helpers import ( generate_crewai_tool_wrapper, generate_langchain_tool_wrapper, - generate_schema_from_args_schema, ) +from letta.functions.schema_generator import generate_schema_from_args_schema from letta.schemas.letta_base import LettaBase from letta.schemas.openai.chat_completions import ToolCall @@ -58,12 +58,13 @@ class Tool(BaseTool): ) @classmethod - def from_langchain(cls, langchain_tool) -> "Tool": + def from_langchain(cls, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> "Tool": """ Class method to create an instance of Tool from a Langchain tool (must be from langchain_community.tools). Args: - langchain_tool (LangchainTool): An instance of a crewAI BaseTool (BaseTool from crewai) + langchain_tool (LangChainBaseTool): An instance of a crewAI BaseTool (BaseTool from crewai) + additional_imports_module_attr_map (dict[str, str]): A mapping of module names to attribute name. This is used internally to import all the required classes for the langchain tool. For example, you would pass in `{"langchain_community.utilities": "WikipediaAPIWrapper"}` for `from langchain_community.tools import WikipediaQueryRun`. NOTE: You do NOT need to specify the tool import here, that is done automatically for you. Returns: Tool: A Letta Tool initialized with attributes derived from the provided crewAI BaseTool object. @@ -72,7 +73,7 @@ class Tool(BaseTool): source_type = "python" tags = ["langchain"] # NOTE: langchain tools may come from different packages - wrapper_func_name, wrapper_function_str = generate_langchain_tool_wrapper(langchain_tool.__class__.__name__) + wrapper_func_name, wrapper_function_str = generate_langchain_tool_wrapper(langchain_tool, additional_imports_module_attr_map) json_schema = generate_schema_from_args_schema(langchain_tool.args_schema, name=wrapper_func_name, description=description) # append heartbeat (necessary for triggering another reasoning step after this tool call) @@ -92,7 +93,7 @@ class Tool(BaseTool): ) @classmethod - def from_crewai(cls, crewai_tool) -> "Tool": + def from_crewai(cls, crewai_tool: "CrewAIBaseTool") -> "Tool": """ Class method to create an instance of Tool from a crewAI BaseTool object. @@ -102,11 +103,10 @@ class Tool(BaseTool): Returns: Tool: A Letta Tool initialized with attributes derived from the provided crewAI BaseTool object. """ - crewai_tool.name description = crewai_tool.description source_type = "python" tags = ["crew-ai"] - wrapper_func_name, wrapper_function_str = generate_crewai_tool_wrapper(crewai_tool.__class__.__name__) + wrapper_func_name, wrapper_function_str = generate_crewai_tool_wrapper(crewai_tool) json_schema = generate_schema_from_args_schema(crewai_tool.args_schema, name=wrapper_func_name, description=description) # append heartbeat (necessary for triggering another reasoning step after this tool call) diff --git a/poetry.lock b/poetry.lock index 9bc7609b..169b0be9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -7388,6 +7388,20 @@ MarkupSafe = ">=2.1.1" [package.extras] watchdog = ["watchdog (>=2.3)"] +[[package]] +name = "wikipedia" +version = "1.4.0" +description = "Wikipedia API for Python" +optional = true +python-versions = "*" +files = [ + {file = "wikipedia-1.4.0.tar.gz", hash = "sha256:db0fad1829fdd441b1852306e9856398204dc0786d2996dd2e0c8bb8e26133b2"}, +] + +[package.dependencies] +beautifulsoup4 = "*" +requests = ">=2.0.0,<3.0.0" + [[package]] name = "wrapt" version = "1.16.0" @@ -7815,8 +7829,9 @@ ollama = ["llama-index-embeddings-ollama"] postgres = ["pg8000", "pgvector"] qdrant = ["qdrant-client"] server = ["fastapi", "uvicorn", "websockets"] +tests = ["wikipedia"] [metadata] lock-version = "2.0" python-versions = "<3.13,>=3.10" -content-hash = "e13bb1fe8b39e5cc233b9a5fc1d62e671e9475a4e5c7961548dfc2c958f19f2f" +content-hash = "0c2547e076e664c564e571135b9a18c4a3e1bf4d164c7312fbb291cfa9e7b780" diff --git a/pyproject.toml b/pyproject.toml index 9f0c86ad..35c4e49a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ llama-index = "^0.11.9" llama-index-embeddings-openai = "^0.2.5" llama-index-embeddings-ollama = "^0.3.1" #llama-index-embeddings-huggingface = {version = "^0.2.0", optional = true} +wikipedia = {version = "^1.4.0", optional = true} [tool.poetry.extras] #local = ["llama-index-embeddings-huggingface"] @@ -82,6 +83,7 @@ autogen = ["pyautogen"] qdrant = ["qdrant-client"] ollama = ["llama-index-embeddings-ollama"] crewai-tools = ["crewai", "docker", "crewai-tools"] +tests = ["wikipedia"] [tool.poetry.group.dev.dependencies] black = "^24.4.2" diff --git a/tests/test_new_client.py b/tests/test_new_client.py index d12c8f48..4de8e016 100644 --- a/tests/test_new_client.py +++ b/tests/test_new_client.py @@ -284,8 +284,6 @@ def test_tools_from_crewai(client): retrieved_tool = client.get_tool(tool_id) source_code = retrieved_tool.source_code - print(source_code) - # Parse the function and attempt to use it local_scope = {} exec(source_code, {}, local_scope) @@ -299,6 +297,56 @@ def test_tools_from_crewai(client): assert expected_content in func(website_url=simple_webpage_url) +def test_tools_from_langchain(client): + # create langchain tool + from langchain_community.tools import WikipediaQueryRun + from langchain_community.utilities import WikipediaAPIWrapper + + from letta.schemas.tool import Tool + + api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100) + langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper) + + # Translate to memGPT Tool + tool = Tool.from_langchain(langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"}) + + # Add the tool + client.add_tool(tool) + + # list tools + tools = client.list_tools() + assert tool.name in [t.name for t in tools] + + # get tool + tool_id = client.get_tool_id(name=tool.name) + retrieved_tool = client.get_tool(tool_id) + source_code = retrieved_tool.source_code + + # Parse the function and attempt to use it + local_scope = {} + exec(source_code, {}, local_scope) + func = local_scope[tool.name] + + expected_content = "Albert Einstein ( EYEN-styne; German:" + assert expected_content in func(query="Albert Einstein") + + +def test_tool_creation_langchain_missing_imports(client): + # create langchain tool + from langchain_community.tools import WikipediaQueryRun + from langchain_community.utilities import WikipediaAPIWrapper + + from letta.schemas.tool import Tool + + api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100) + langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper) + + # Translate to memGPT Tool + # Intentionally missing {"langchain_community.utilities": "WikipediaAPIWrapper"} + with pytest.raises(RuntimeError): + Tool.from_langchain(langchain_tool) + + def test_sources(client, agent): # list sources (empty) sources = client.list_sources()