feat: Support pydantic models as parameters to custom functions (#839)

Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
Maximilian Winter
2024-01-18 02:21:00 +01:00
committed by GitHub
parent c07746b097
commit 74c0d9af9d
11 changed files with 153 additions and 68 deletions

View File

@@ -461,7 +461,14 @@ class Agent(object):
# Failure case 3: function failed during execution
self.interface.function_message(f"Running {function_name}({function_args})")
try:
spec = inspect.getfullargspec(function_to_call).annotations
for name, arg in function_args.items():
if isinstance(function_args[name], dict):
function_args[name] = spec[name](**function_args[name])
function_args["self"] = self # need to attach self to arg since it's dynamically linked
function_response = function_to_call(**function_args)
if function_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]:
# with certain functions we rely on the paging mechanism to handle overflow

View File

@@ -3,6 +3,7 @@ import typing
from typing import get_args
from docstring_parser import parse
from pydantic import BaseModel
from memgpt.constants import FUNCTION_PARAM_NAME_REQ_HEARTBEAT, FUNCTION_PARAM_TYPE_REQ_HEARTBEAT, FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT
@@ -53,6 +54,30 @@ def type_to_json_schema_type(py_type):
return type_map.get(py_type, "string") # Default to "string" if type not in map
def pydantic_model_to_open_ai(model):
schema = model.model_json_schema()
docstring = parse(model.__doc__ or "")
parameters = {k: v for k, v in schema.items() if k not in ("title", "description")}
for param in docstring.params:
if (name := param.arg_name) in parameters["properties"] and (description := param.description):
if "description" not in parameters["properties"][name]:
parameters["properties"][name]["description"] = description
parameters["required"] = sorted(k for k, v in parameters["properties"].items() if "default" not in v)
if "description" not in schema:
if docstring.short_description:
schema["description"] = docstring.short_description
else:
raise
return {
"name": schema["title"],
"description": schema["description"],
"parameters": parameters,
}
def generate_schema(function):
# Get the signature of the function
sig = inspect.signature(function)
@@ -83,13 +108,16 @@ def generate_schema(function):
if not param_doc or not param_doc.description:
raise ValueError(f"Parameter '{param.name}' in function '{function.__name__}' lacks a description in the docstring")
# Add parameter details to the schema
param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
schema["parameters"]["properties"][param.name] = {
# "type": "string" if param.annotation == str else str(param.annotation),
"type": type_to_json_schema_type(param.annotation) if param.annotation != inspect.Parameter.empty else "string",
"description": param_doc.description,
}
if inspect.isclass(param.annotation) and issubclass(param.annotation, BaseModel):
schema["parameters"]["properties"][param.name] = pydantic_model_to_open_ai(param.annotation)
else:
# Add parameter details to the schema
param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
schema["parameters"]["properties"][param.name] = {
# "type": "string" if param.annotation == str else str(param.annotation),
"type": type_to_json_schema_type(param.annotation) if param.annotation != inspect.Parameter.empty else "string",
"description": param_doc.description,
}
if param.default == inspect.Parameter.empty:
schema["parameters"]["required"].append(param.name)

View File

@@ -88,7 +88,7 @@ def get_chat_completion(
# If the wrapper uses grammar, generate the grammar using the grammar generating function
# TODO move this to a flag
if "grammar" in wrapper:
if wrapper is not None and "grammar" in wrapper:
# When using grammars, we don't want to do any extras output tricks like appending a response prefix
setattr(llm_wrapper, "assistant_prefix_extra_first_message", "")
setattr(llm_wrapper, "assistant_prefix_extra", "")
@@ -125,9 +125,11 @@ def get_chat_completion(
try:
# if hasattr(llm_wrapper, "supports_first_message"):
if hasattr(llm_wrapper, "supports_first_message") and llm_wrapper.supports_first_message:
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions, first_message=first_message)
prompt = llm_wrapper.chat_completion_to_prompt(
messages, functions, first_message=first_message, function_documentation=documentation
)
else:
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions)
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions, function_documentation=documentation)
printd(prompt)
except Exception as e:
@@ -248,8 +250,8 @@ def generate_grammar_and_documentation(
grammar_function_models,
outer_object_name="function",
outer_object_content="params",
model_prefix="Function",
fields_prefix="Parameter",
model_prefix="function",
fields_prefix="params",
add_inner_thoughts=add_inner_thoughts_top_level,
allow_only_inner_thoughts=allow_only_inner_thoughts,
)

View File

@@ -465,8 +465,7 @@ def generate_gbnf_grammar(model: Type[BaseModel], processed_models: set, created
nested_rules = []
has_markdown_code_block = False
has_triple_quoted_string = False
look_for_markdown_code_block = False
look_for_triple_quoted_string = False
for field_name, field_info in model_fields.items():
if not issubclass(model, BaseModel):
field_type, default_value = field_info
@@ -621,7 +620,7 @@ null ::= "null"
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" ws
)* "\""
ws ::= ([ \t\n] ws)?
float ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
@@ -636,13 +635,13 @@ object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws
)? "}"
array ::=
"[" ws (
value
("," ws value)*
)? "]" ws
)? "]"
number ::= integer | float"""
@@ -683,7 +682,7 @@ def generate_markdown_documentation(
if add_prefix:
documentation += f"{model_prefix}: {model.__name__}\n"
else:
documentation += f"Model: {model.__name__}\n"
documentation += f"class: {model.__name__}\n"
# Handling multi-line model description with proper indentation
@@ -691,18 +690,19 @@ def generate_markdown_documentation(
base_class_doc = getdoc(BaseModel)
class_description = class_doc if class_doc and class_doc != base_class_doc else ""
if class_description != "":
documentation += " Description: "
documentation += format_multiline_description(class_description, 0) + "\n"
documentation += format_multiline_description("description: " + class_description, 1) + "\n"
if add_prefix:
# Indenting the fields section
documentation += f" {fields_prefix}:\n"
else:
documentation += f" Fields:\n"
documentation += f" attributes:\n"
if isclass(model) and issubclass(model, BaseModel):
for name, field_type in model.__annotations__.items():
# if name == "markdown_code_block":
# continue
if isclass(field_type) and issubclass(field_type, BaseModel):
pyd_models.append((field_type, False))
if get_origin(field_type) == list:
element_type = get_args(field_type)[0]
if isclass(element_type) and issubclass(element_type, BaseModel):
@@ -748,25 +748,33 @@ def generate_field_markdown(
if get_origin(field_type) == list:
element_type = get_args(field_type)[0]
field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})"
field_text = f"{indent}{field_name} ({field_type.__name__} of {element_type.__name__})"
if field_description != "":
field_text += ":\n"
field_text += ": "
else:
field_text += "\n"
elif get_origin(field_type) == Union:
element_types = get_args(field_type)
types = []
for element_type in element_types:
types.append(format_model_and_field_name(element_type.__name__))
types.append(element_type.__name__)
field_text = f"{indent}{field_name} ({' or '.join(types)})"
if field_description != "":
field_text += ":\n"
field_text += ": "
else:
field_text += "\n"
elif issubclass(field_type, Enum):
enum_values = [f"'{str(member.value)}'" for member in field_type]
field_text = f"{indent}{field_name} ({' or '.join(enum_values)})"
if field_description != "":
field_text += ": "
else:
field_text += "\n"
else:
field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)})"
field_text = f"{indent}{field_name} ({field_type.__name__})"
if field_description != "":
field_text += ":\n"
field_text += ": "
else:
field_text += "\n"
@@ -774,7 +782,7 @@ def generate_field_markdown(
return field_text
if field_description != "":
field_text += f" Description: " + field_description + "\n"
field_text += field_description + "\n"
# Check for and include field-specific examples if available
if hasattr(model, "Config") and hasattr(model.Config, "json_schema_extra") and "example" in model.Config.json_schema_extra:
@@ -784,7 +792,7 @@ def generate_field_markdown(
field_text += f"{indent} Example: {example_text}\n"
if isclass(field_type) and issubclass(field_type, BaseModel):
field_text += f"{indent} Details:\n"
field_text += f"{indent} details:\n"
for name, type_ in field_type.__annotations__.items():
field_text += generate_field_markdown(name, type_, field_type, depth + 2)
@@ -952,7 +960,7 @@ def format_multiline_description(description: str, indent_level: int) -> str:
Returns:
str: Formatted multiline description.
"""
indent = " " * indent_level
indent = " " * indent_level
return indent + description.replace("\n", "\n" + indent)
@@ -1154,6 +1162,7 @@ def create_dynamic_model_from_function(func: Callable, add_inner_thoughts: bool
default_value = ...
else:
default_value = param.default
dynamic_fields[param.name] = (param.annotation if param.annotation != inspect.Parameter.empty else str, default_value)
# Creating the dynamic model
dynamic_model = create_model(f"{func.__name__}", **dynamic_fields)

View File

@@ -26,7 +26,7 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
self.include_opening_brance_in_prefix = include_opening_brace_in_prefix
self.include_section_separators = include_section_separators
def chat_completion_to_prompt(self, messages, functions):
def chat_completion_to_prompt(self, messages, functions, function_documentation=None):
"""Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format
A chat.
@@ -87,8 +87,11 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
prompt += f"\nAvailable functions:"
for function_dict in functions:
prompt += f"\n{create_function_description(function_dict)}"
if function_documentation is not None:
prompt += f"\n{function_documentation}"
else:
for function_dict in functions:
prompt += f"\n{create_function_description(function_dict)}"
def create_function_call(function_call):
"""Go from ChatCompletion to Airoboros style function trace (in prompt)
@@ -230,7 +233,7 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
self.assistant_prefix_extra = assistant_prefix_extra
self.include_section_separators = include_section_separators
def chat_completion_to_prompt(self, messages, functions):
def chat_completion_to_prompt(self, messages, functions, function_documentation=None):
"""Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format
A chat.
@@ -293,8 +296,11 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
prompt += f"\nAvailable functions:"
for function_dict in functions:
prompt += f"\n{create_function_description(function_dict)}"
if function_documentation is not None:
prompt += f"\n{function_documentation}"
else:
for function_dict in functions:
prompt += f"\n{create_function_description(function_dict)}"
def create_function_call(function_call, inner_thoughts=None):
"""Go from ChatCompletion to Airoboros style function trace (in prompt)

View File

@@ -95,12 +95,15 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
return prompt
# NOTE: BOS/EOS chatml tokens are NOT inserted here
def _compile_system_message(self, system_message, functions) -> str:
def _compile_system_message(self, system_message, functions, function_documentation=None) -> str:
"""system prompt + memory + functions -> string"""
prompt = ""
prompt += system_message
prompt += "\n"
prompt += self._compile_function_block(functions)
if function_documentation is not None:
prompt += function_documentation
else:
prompt += self._compile_function_block(functions)
return prompt
def _compile_function_call(self, function_call, inner_thoughts=None):
@@ -186,13 +189,15 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
prompt += function_return_str
return prompt
def chat_completion_to_prompt(self, messages, functions, first_message=False):
def chat_completion_to_prompt(self, messages, functions, first_message=False, function_documentation=None):
"""chatml-style prompt formatting, with implied support for multi-role"""
prompt = ""
# System insturctions go first
assert messages[0]["role"] == "system"
system_block = self._compile_system_message(system_message=messages[0]["content"], functions=functions)
system_block = self._compile_system_message(
system_message=messages[0]["content"], functions=functions, function_documentation=function_documentation
)
prompt += f"<|im_start|>system\n{system_block.strip()}<|im_end|>"
# Last are the user/assistant messages

View File

@@ -26,7 +26,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
self.include_opening_brance_in_prefix = include_opening_brace_in_prefix
self.include_section_separators = include_section_separators
def chat_completion_to_prompt(self, messages, functions):
def chat_completion_to_prompt(self, messages, functions, function_documentation=None):
"""Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format
<|im_start|>system
@@ -97,8 +97,11 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
prompt += f"\nAvailable functions:"
for function_dict in functions:
prompt += f"\n{create_function_description(function_dict)}"
if function_documentation is not None:
prompt += f"\n{function_documentation}"
else:
for function_dict in functions:
prompt += f"\n{create_function_description(function_dict)}"
# Put functions INSIDE system message (TODO experiment with this)
prompt += IM_END_TOKEN

View File

@@ -17,7 +17,7 @@ class SimpleSummaryWrapper(LLMChatCompletionWrapper):
self.include_assistant_prefix = include_assistant_prefix
self.include_section_separators = include_section_separators
def chat_completion_to_prompt(self, messages, functions):
def chat_completion_to_prompt(self, messages, functions, function_documentation=None):
"""Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format
Instructions on how to summarize

View File

@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
class LLMChatCompletionWrapper(ABC):
@abstractmethod
def chat_completion_to_prompt(self, messages, functions):
def chat_completion_to_prompt(self, messages, functions, function_documentation=None):
"""Go from ChatCompletion to a single prompt string"""
pass

View File

@@ -28,7 +28,7 @@ class ZephyrMistralWrapper(LLMChatCompletionWrapper):
self.include_opening_brance_in_prefix = include_opening_brace_in_prefix
self.include_section_separators = include_section_separators
def chat_completion_to_prompt(self, messages, functions):
def chat_completion_to_prompt(self, messages, functions, function_documentation=None):
"""
Zephyr prompt format:
<|system|>
@@ -65,8 +65,11 @@ class ZephyrMistralWrapper(LLMChatCompletionWrapper):
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
prompt += f"\nAvailable functions:"
for function_dict in functions:
prompt += f"\n{create_function_description(function_dict)}"
if function_documentation is not None:
prompt += f"\n{function_documentation}"
else:
for function_dict in functions:
prompt += f"\n{create_function_description(function_dict)}"
# Put functions INSIDE system message (TODO experiment with this)
prompt += IM_END_TOKEN
@@ -199,7 +202,7 @@ class ZephyrMistralInnerMonologueWrapper(ZephyrMistralWrapper):
self.include_opening_brance_in_prefix = include_opening_brace_in_prefix
self.include_section_separators = include_section_separators
def chat_completion_to_prompt(self, messages, functions):
def chat_completion_to_prompt(self, messages, functions, function_documentation=None):
prompt = ""
IM_START_TOKEN = "<s>"
@@ -227,8 +230,11 @@ class ZephyrMistralInnerMonologueWrapper(ZephyrMistralWrapper):
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
prompt += f"\nAvailable functions:"
for function_dict in functions:
prompt += f"\n{create_function_description(function_dict)}"
if function_documentation is not None:
prompt += f"\n{function_documentation}"
else:
for function_dict in functions:
prompt += f"\n{create_function_description(function_dict)}"
def create_function_call(function_call, inner_thoughts=None):
airo_func_call = {

View File

@@ -13,23 +13,42 @@ from .utils import wipe_config, wipe_memgpt_home
def test_server():
wipe_memgpt_home()
config = MemGPTConfig(
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
recall_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
metadata_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
archival_storage_type="postgres",
recall_storage_type="postgres",
metadata_storage_type="postgres",
# embeddings
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
openai_key=os.getenv("OPENAI_API_KEY"),
# llms
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
model="gpt-4",
)
if os.getenv("OPENAI_API_KEY"):
config = MemGPTConfig(
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
recall_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
metadata_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
archival_storage_type="postgres",
recall_storage_type="postgres",
metadata_storage_type="postgres",
# embeddings
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
openai_key=os.getenv("OPENAI_API_KEY"),
# llms
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
model="gpt-4",
)
else: # hosted
config = MemGPTConfig(
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
recall_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
metadata_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
archival_storage_type="postgres",
recall_storage_type="postgres",
metadata_storage_type="postgres",
# embeddings
embedding_endpoint_type="hugging-face",
embedding_endpoint="https://embeddings.memgpt.ai",
embedding_model="BAAI/bge-large-en-v1.5",
embedding_dim=1024,
# llms
model_endpoint_type="vllm",
model_endpoint="https://api.memgpt.ai",
model="ehartford/dolphin-2.5-mixtral-8x7b",
)
config.save()
server = SyncServer()