From 74c0d9af9d0730d1c846808e47653f76c1b79d60 Mon Sep 17 00:00:00 2001 From: Maximilian Winter Date: Thu, 18 Jan 2024 02:21:00 +0100 Subject: [PATCH] feat: Support pydantic models as parameters to custom functions (#839) Co-authored-by: cpacker --- memgpt/agent.py | 7 +++ memgpt/functions/schema_generator.py | 42 ++++++++++++--- memgpt/local_llm/chat_completion_proxy.py | 12 +++-- .../grammars/gbnf_grammar_generator.py | 45 +++++++++------- .../llm_chat_completion_wrappers/airoboros.py | 18 ++++--- .../llm_chat_completion_wrappers/chatml.py | 13 +++-- .../llm_chat_completion_wrappers/dolphin.py | 9 ++-- .../simple_summary_wrapper.py | 2 +- .../wrapper_base.py | 2 +- .../llm_chat_completion_wrappers/zephyr.py | 18 ++++--- tests/test_server.py | 53 +++++++++++++------ 11 files changed, 153 insertions(+), 68 deletions(-) diff --git a/memgpt/agent.py b/memgpt/agent.py index 2c85ab9d..c5ae5380 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -461,7 +461,14 @@ class Agent(object): # Failure case 3: function failed during execution self.interface.function_message(f"Running {function_name}({function_args})") try: + spec = inspect.getfullargspec(function_to_call).annotations + + for name, arg in function_args.items(): + if isinstance(function_args[name], dict): + function_args[name] = spec[name](**function_args[name]) + function_args["self"] = self # need to attach self to arg since it's dynamically linked + function_response = function_to_call(**function_args) if function_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]: # with certain functions we rely on the paging mechanism to handle overflow diff --git a/memgpt/functions/schema_generator.py b/memgpt/functions/schema_generator.py index 1a111da2..aa1a6ac2 100644 --- a/memgpt/functions/schema_generator.py +++ b/memgpt/functions/schema_generator.py @@ -3,6 +3,7 @@ import typing from typing import get_args from docstring_parser import parse +from pydantic import BaseModel from memgpt.constants import FUNCTION_PARAM_NAME_REQ_HEARTBEAT, FUNCTION_PARAM_TYPE_REQ_HEARTBEAT, FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT @@ -53,6 +54,30 @@ def type_to_json_schema_type(py_type): return type_map.get(py_type, "string") # Default to "string" if type not in map +def pydantic_model_to_open_ai(model): + schema = model.model_json_schema() + docstring = parse(model.__doc__ or "") + parameters = {k: v for k, v in schema.items() if k not in ("title", "description")} + for param in docstring.params: + if (name := param.arg_name) in parameters["properties"] and (description := param.description): + if "description" not in parameters["properties"][name]: + parameters["properties"][name]["description"] = description + + parameters["required"] = sorted(k for k, v in parameters["properties"].items() if "default" not in v) + + if "description" not in schema: + if docstring.short_description: + schema["description"] = docstring.short_description + else: + raise + + return { + "name": schema["title"], + "description": schema["description"], + "parameters": parameters, + } + + def generate_schema(function): # Get the signature of the function sig = inspect.signature(function) @@ -83,13 +108,16 @@ def generate_schema(function): if not param_doc or not param_doc.description: raise ValueError(f"Parameter '{param.name}' in function '{function.__name__}' lacks a description in the docstring") - # Add parameter details to the schema - param_doc = next((d for d in docstring.params if d.arg_name == param.name), None) - schema["parameters"]["properties"][param.name] = { - # "type": "string" if param.annotation == str else str(param.annotation), - "type": type_to_json_schema_type(param.annotation) if param.annotation != inspect.Parameter.empty else "string", - "description": param_doc.description, - } + if inspect.isclass(param.annotation) and issubclass(param.annotation, BaseModel): + schema["parameters"]["properties"][param.name] = pydantic_model_to_open_ai(param.annotation) + else: + # Add parameter details to the schema + param_doc = next((d for d in docstring.params if d.arg_name == param.name), None) + schema["parameters"]["properties"][param.name] = { + # "type": "string" if param.annotation == str else str(param.annotation), + "type": type_to_json_schema_type(param.annotation) if param.annotation != inspect.Parameter.empty else "string", + "description": param_doc.description, + } if param.default == inspect.Parameter.empty: schema["parameters"]["required"].append(param.name) diff --git a/memgpt/local_llm/chat_completion_proxy.py b/memgpt/local_llm/chat_completion_proxy.py index 2986ab47..b4e74e56 100644 --- a/memgpt/local_llm/chat_completion_proxy.py +++ b/memgpt/local_llm/chat_completion_proxy.py @@ -88,7 +88,7 @@ def get_chat_completion( # If the wrapper uses grammar, generate the grammar using the grammar generating function # TODO move this to a flag - if "grammar" in wrapper: + if wrapper is not None and "grammar" in wrapper: # When using grammars, we don't want to do any extras output tricks like appending a response prefix setattr(llm_wrapper, "assistant_prefix_extra_first_message", "") setattr(llm_wrapper, "assistant_prefix_extra", "") @@ -125,9 +125,11 @@ def get_chat_completion( try: # if hasattr(llm_wrapper, "supports_first_message"): if hasattr(llm_wrapper, "supports_first_message") and llm_wrapper.supports_first_message: - prompt = llm_wrapper.chat_completion_to_prompt(messages, functions, first_message=first_message) + prompt = llm_wrapper.chat_completion_to_prompt( + messages, functions, first_message=first_message, function_documentation=documentation + ) else: - prompt = llm_wrapper.chat_completion_to_prompt(messages, functions) + prompt = llm_wrapper.chat_completion_to_prompt(messages, functions, function_documentation=documentation) printd(prompt) except Exception as e: @@ -248,8 +250,8 @@ def generate_grammar_and_documentation( grammar_function_models, outer_object_name="function", outer_object_content="params", - model_prefix="Function", - fields_prefix="Parameter", + model_prefix="function", + fields_prefix="params", add_inner_thoughts=add_inner_thoughts_top_level, allow_only_inner_thoughts=allow_only_inner_thoughts, ) diff --git a/memgpt/local_llm/grammars/gbnf_grammar_generator.py b/memgpt/local_llm/grammars/gbnf_grammar_generator.py index 3612cfa7..863a9ae6 100644 --- a/memgpt/local_llm/grammars/gbnf_grammar_generator.py +++ b/memgpt/local_llm/grammars/gbnf_grammar_generator.py @@ -465,8 +465,7 @@ def generate_gbnf_grammar(model: Type[BaseModel], processed_models: set, created nested_rules = [] has_markdown_code_block = False has_triple_quoted_string = False - look_for_markdown_code_block = False - look_for_triple_quoted_string = False + for field_name, field_info in model_fields.items(): if not issubclass(model, BaseModel): field_type, default_value = field_info @@ -621,7 +620,7 @@ null ::= "null" string ::= "\"" ( [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) - )* "\"" ws + )* "\"" ws ::= ([ \t\n] ws)? float ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws @@ -636,13 +635,13 @@ object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* - )? "}" ws + )? "}" array ::= "[" ws ( value ("," ws value)* - )? "]" ws + )? "]" number ::= integer | float""" @@ -683,7 +682,7 @@ def generate_markdown_documentation( if add_prefix: documentation += f"{model_prefix}: {model.__name__}\n" else: - documentation += f"Model: {model.__name__}\n" + documentation += f"class: {model.__name__}\n" # Handling multi-line model description with proper indentation @@ -691,18 +690,19 @@ def generate_markdown_documentation( base_class_doc = getdoc(BaseModel) class_description = class_doc if class_doc and class_doc != base_class_doc else "" if class_description != "": - documentation += " Description: " - documentation += format_multiline_description(class_description, 0) + "\n" + documentation += format_multiline_description("description: " + class_description, 1) + "\n" if add_prefix: # Indenting the fields section documentation += f" {fields_prefix}:\n" else: - documentation += f" Fields:\n" + documentation += f" attributes:\n" if isclass(model) and issubclass(model, BaseModel): for name, field_type in model.__annotations__.items(): # if name == "markdown_code_block": # continue + if isclass(field_type) and issubclass(field_type, BaseModel): + pyd_models.append((field_type, False)) if get_origin(field_type) == list: element_type = get_args(field_type)[0] if isclass(element_type) and issubclass(element_type, BaseModel): @@ -748,25 +748,33 @@ def generate_field_markdown( if get_origin(field_type) == list: element_type = get_args(field_type)[0] - field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})" + field_text = f"{indent}{field_name} ({field_type.__name__} of {element_type.__name__})" if field_description != "": - field_text += ":\n" + field_text += ": " else: field_text += "\n" elif get_origin(field_type) == Union: element_types = get_args(field_type) types = [] for element_type in element_types: - types.append(format_model_and_field_name(element_type.__name__)) + types.append(element_type.__name__) field_text = f"{indent}{field_name} ({' or '.join(types)})" if field_description != "": - field_text += ":\n" + field_text += ": " + else: + field_text += "\n" + elif issubclass(field_type, Enum): + enum_values = [f"'{str(member.value)}'" for member in field_type] + + field_text = f"{indent}{field_name} ({' or '.join(enum_values)})" + if field_description != "": + field_text += ": " else: field_text += "\n" else: - field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)})" + field_text = f"{indent}{field_name} ({field_type.__name__})" if field_description != "": - field_text += ":\n" + field_text += ": " else: field_text += "\n" @@ -774,7 +782,7 @@ def generate_field_markdown( return field_text if field_description != "": - field_text += f" Description: " + field_description + "\n" + field_text += field_description + "\n" # Check for and include field-specific examples if available if hasattr(model, "Config") and hasattr(model.Config, "json_schema_extra") and "example" in model.Config.json_schema_extra: @@ -784,7 +792,7 @@ def generate_field_markdown( field_text += f"{indent} Example: {example_text}\n" if isclass(field_type) and issubclass(field_type, BaseModel): - field_text += f"{indent} Details:\n" + field_text += f"{indent} details:\n" for name, type_ in field_type.__annotations__.items(): field_text += generate_field_markdown(name, type_, field_type, depth + 2) @@ -952,7 +960,7 @@ def format_multiline_description(description: str, indent_level: int) -> str: Returns: str: Formatted multiline description. """ - indent = " " * indent_level + indent = " " * indent_level return indent + description.replace("\n", "\n" + indent) @@ -1154,6 +1162,7 @@ def create_dynamic_model_from_function(func: Callable, add_inner_thoughts: bool default_value = ... else: default_value = param.default + dynamic_fields[param.name] = (param.annotation if param.annotation != inspect.Parameter.empty else str, default_value) # Creating the dynamic model dynamic_model = create_model(f"{func.__name__}", **dynamic_fields) diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py index 5fa9f63a..1ae9693a 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py @@ -26,7 +26,7 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper): self.include_opening_brance_in_prefix = include_opening_brace_in_prefix self.include_section_separators = include_section_separators - def chat_completion_to_prompt(self, messages, functions): + def chat_completion_to_prompt(self, messages, functions, function_documentation=None): """Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format A chat. @@ -87,8 +87,11 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper): # prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format." prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." prompt += f"\nAvailable functions:" - for function_dict in functions: - prompt += f"\n{create_function_description(function_dict)}" + if function_documentation is not None: + prompt += f"\n{function_documentation}" + else: + for function_dict in functions: + prompt += f"\n{create_function_description(function_dict)}" def create_function_call(function_call): """Go from ChatCompletion to Airoboros style function trace (in prompt) @@ -230,7 +233,7 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper): self.assistant_prefix_extra = assistant_prefix_extra self.include_section_separators = include_section_separators - def chat_completion_to_prompt(self, messages, functions): + def chat_completion_to_prompt(self, messages, functions, function_documentation=None): """Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format A chat. @@ -293,8 +296,11 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper): # prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format." prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." prompt += f"\nAvailable functions:" - for function_dict in functions: - prompt += f"\n{create_function_description(function_dict)}" + if function_documentation is not None: + prompt += f"\n{function_documentation}" + else: + for function_dict in functions: + prompt += f"\n{create_function_description(function_dict)}" def create_function_call(function_call, inner_thoughts=None): """Go from ChatCompletion to Airoboros style function trace (in prompt) diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/chatml.py b/memgpt/local_llm/llm_chat_completion_wrappers/chatml.py index 9915a5df..7dd9d583 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/chatml.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/chatml.py @@ -95,12 +95,15 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper): return prompt # NOTE: BOS/EOS chatml tokens are NOT inserted here - def _compile_system_message(self, system_message, functions) -> str: + def _compile_system_message(self, system_message, functions, function_documentation=None) -> str: """system prompt + memory + functions -> string""" prompt = "" prompt += system_message prompt += "\n" - prompt += self._compile_function_block(functions) + if function_documentation is not None: + prompt += function_documentation + else: + prompt += self._compile_function_block(functions) return prompt def _compile_function_call(self, function_call, inner_thoughts=None): @@ -186,13 +189,15 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper): prompt += function_return_str return prompt - def chat_completion_to_prompt(self, messages, functions, first_message=False): + def chat_completion_to_prompt(self, messages, functions, first_message=False, function_documentation=None): """chatml-style prompt formatting, with implied support for multi-role""" prompt = "" # System insturctions go first assert messages[0]["role"] == "system" - system_block = self._compile_system_message(system_message=messages[0]["content"], functions=functions) + system_block = self._compile_system_message( + system_message=messages[0]["content"], functions=functions, function_documentation=function_documentation + ) prompt += f"<|im_start|>system\n{system_block.strip()}<|im_end|>" # Last are the user/assistant messages diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/dolphin.py b/memgpt/local_llm/llm_chat_completion_wrappers/dolphin.py index aa4e8f0b..454e1ae9 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/dolphin.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/dolphin.py @@ -26,7 +26,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper): self.include_opening_brance_in_prefix = include_opening_brace_in_prefix self.include_section_separators = include_section_separators - def chat_completion_to_prompt(self, messages, functions): + def chat_completion_to_prompt(self, messages, functions, function_documentation=None): """Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format <|im_start|>system @@ -97,8 +97,11 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper): # prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format." prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." prompt += f"\nAvailable functions:" - for function_dict in functions: - prompt += f"\n{create_function_description(function_dict)}" + if function_documentation is not None: + prompt += f"\n{function_documentation}" + else: + for function_dict in functions: + prompt += f"\n{create_function_description(function_dict)}" # Put functions INSIDE system message (TODO experiment with this) prompt += IM_END_TOKEN diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py b/memgpt/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py index b7de8714..8e3625fa 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py @@ -17,7 +17,7 @@ class SimpleSummaryWrapper(LLMChatCompletionWrapper): self.include_assistant_prefix = include_assistant_prefix self.include_section_separators = include_section_separators - def chat_completion_to_prompt(self, messages, functions): + def chat_completion_to_prompt(self, messages, functions, function_documentation=None): """Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format Instructions on how to summarize diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/wrapper_base.py b/memgpt/local_llm/llm_chat_completion_wrappers/wrapper_base.py index b1186c46..40b6ab70 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/wrapper_base.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/wrapper_base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod class LLMChatCompletionWrapper(ABC): @abstractmethod - def chat_completion_to_prompt(self, messages, functions): + def chat_completion_to_prompt(self, messages, functions, function_documentation=None): """Go from ChatCompletion to a single prompt string""" pass diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/zephyr.py b/memgpt/local_llm/llm_chat_completion_wrappers/zephyr.py index e8aaf64f..c928da71 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/zephyr.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/zephyr.py @@ -28,7 +28,7 @@ class ZephyrMistralWrapper(LLMChatCompletionWrapper): self.include_opening_brance_in_prefix = include_opening_brace_in_prefix self.include_section_separators = include_section_separators - def chat_completion_to_prompt(self, messages, functions): + def chat_completion_to_prompt(self, messages, functions, function_documentation=None): """ Zephyr prompt format: <|system|> @@ -65,8 +65,11 @@ class ZephyrMistralWrapper(LLMChatCompletionWrapper): # prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format." prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." prompt += f"\nAvailable functions:" - for function_dict in functions: - prompt += f"\n{create_function_description(function_dict)}" + if function_documentation is not None: + prompt += f"\n{function_documentation}" + else: + for function_dict in functions: + prompt += f"\n{create_function_description(function_dict)}" # Put functions INSIDE system message (TODO experiment with this) prompt += IM_END_TOKEN @@ -199,7 +202,7 @@ class ZephyrMistralInnerMonologueWrapper(ZephyrMistralWrapper): self.include_opening_brance_in_prefix = include_opening_brace_in_prefix self.include_section_separators = include_section_separators - def chat_completion_to_prompt(self, messages, functions): + def chat_completion_to_prompt(self, messages, functions, function_documentation=None): prompt = "" IM_START_TOKEN = "" @@ -227,8 +230,11 @@ class ZephyrMistralInnerMonologueWrapper(ZephyrMistralWrapper): # prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format." prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." prompt += f"\nAvailable functions:" - for function_dict in functions: - prompt += f"\n{create_function_description(function_dict)}" + if function_documentation is not None: + prompt += f"\n{function_documentation}" + else: + for function_dict in functions: + prompt += f"\n{create_function_description(function_dict)}" def create_function_call(function_call, inner_thoughts=None): airo_func_call = { diff --git a/tests/test_server.py b/tests/test_server.py index b6bf88d0..f57470cf 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -13,23 +13,42 @@ from .utils import wipe_config, wipe_memgpt_home def test_server(): wipe_memgpt_home() - config = MemGPTConfig( - archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"), - recall_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"), - metadata_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"), - archival_storage_type="postgres", - recall_storage_type="postgres", - metadata_storage_type="postgres", - # embeddings - embedding_endpoint_type="openai", - embedding_endpoint="https://api.openai.com/v1", - embedding_dim=1536, - openai_key=os.getenv("OPENAI_API_KEY"), - # llms - model_endpoint_type="openai", - model_endpoint="https://api.openai.com/v1", - model="gpt-4", - ) + if os.getenv("OPENAI_API_KEY"): + config = MemGPTConfig( + archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"), + recall_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"), + metadata_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"), + archival_storage_type="postgres", + recall_storage_type="postgres", + metadata_storage_type="postgres", + # embeddings + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_dim=1536, + openai_key=os.getenv("OPENAI_API_KEY"), + # llms + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + model="gpt-4", + ) + else: # hosted + config = MemGPTConfig( + archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"), + recall_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"), + metadata_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"), + archival_storage_type="postgres", + recall_storage_type="postgres", + metadata_storage_type="postgres", + # embeddings + embedding_endpoint_type="hugging-face", + embedding_endpoint="https://embeddings.memgpt.ai", + embedding_model="BAAI/bge-large-en-v1.5", + embedding_dim=1024, + # llms + model_endpoint_type="vllm", + model_endpoint="https://api.memgpt.ai", + model="ehartford/dolphin-2.5-mixtral-8x7b", + ) config.save() server = SyncServer()