feat: Support pydantic models as parameters to custom functions (#839)
Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
committed by
GitHub
parent
c07746b097
commit
74c0d9af9d
@@ -461,7 +461,14 @@ class Agent(object):
|
||||
# Failure case 3: function failed during execution
|
||||
self.interface.function_message(f"Running {function_name}({function_args})")
|
||||
try:
|
||||
spec = inspect.getfullargspec(function_to_call).annotations
|
||||
|
||||
for name, arg in function_args.items():
|
||||
if isinstance(function_args[name], dict):
|
||||
function_args[name] = spec[name](**function_args[name])
|
||||
|
||||
function_args["self"] = self # need to attach self to arg since it's dynamically linked
|
||||
|
||||
function_response = function_to_call(**function_args)
|
||||
if function_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]:
|
||||
# with certain functions we rely on the paging mechanism to handle overflow
|
||||
|
||||
@@ -3,6 +3,7 @@ import typing
|
||||
from typing import get_args
|
||||
|
||||
from docstring_parser import parse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from memgpt.constants import FUNCTION_PARAM_NAME_REQ_HEARTBEAT, FUNCTION_PARAM_TYPE_REQ_HEARTBEAT, FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT
|
||||
|
||||
@@ -53,6 +54,30 @@ def type_to_json_schema_type(py_type):
|
||||
return type_map.get(py_type, "string") # Default to "string" if type not in map
|
||||
|
||||
|
||||
def pydantic_model_to_open_ai(model):
|
||||
schema = model.model_json_schema()
|
||||
docstring = parse(model.__doc__ or "")
|
||||
parameters = {k: v for k, v in schema.items() if k not in ("title", "description")}
|
||||
for param in docstring.params:
|
||||
if (name := param.arg_name) in parameters["properties"] and (description := param.description):
|
||||
if "description" not in parameters["properties"][name]:
|
||||
parameters["properties"][name]["description"] = description
|
||||
|
||||
parameters["required"] = sorted(k for k, v in parameters["properties"].items() if "default" not in v)
|
||||
|
||||
if "description" not in schema:
|
||||
if docstring.short_description:
|
||||
schema["description"] = docstring.short_description
|
||||
else:
|
||||
raise
|
||||
|
||||
return {
|
||||
"name": schema["title"],
|
||||
"description": schema["description"],
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
|
||||
def generate_schema(function):
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(function)
|
||||
@@ -83,13 +108,16 @@ def generate_schema(function):
|
||||
if not param_doc or not param_doc.description:
|
||||
raise ValueError(f"Parameter '{param.name}' in function '{function.__name__}' lacks a description in the docstring")
|
||||
|
||||
# Add parameter details to the schema
|
||||
param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
|
||||
schema["parameters"]["properties"][param.name] = {
|
||||
# "type": "string" if param.annotation == str else str(param.annotation),
|
||||
"type": type_to_json_schema_type(param.annotation) if param.annotation != inspect.Parameter.empty else "string",
|
||||
"description": param_doc.description,
|
||||
}
|
||||
if inspect.isclass(param.annotation) and issubclass(param.annotation, BaseModel):
|
||||
schema["parameters"]["properties"][param.name] = pydantic_model_to_open_ai(param.annotation)
|
||||
else:
|
||||
# Add parameter details to the schema
|
||||
param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
|
||||
schema["parameters"]["properties"][param.name] = {
|
||||
# "type": "string" if param.annotation == str else str(param.annotation),
|
||||
"type": type_to_json_schema_type(param.annotation) if param.annotation != inspect.Parameter.empty else "string",
|
||||
"description": param_doc.description,
|
||||
}
|
||||
if param.default == inspect.Parameter.empty:
|
||||
schema["parameters"]["required"].append(param.name)
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ def get_chat_completion(
|
||||
|
||||
# If the wrapper uses grammar, generate the grammar using the grammar generating function
|
||||
# TODO move this to a flag
|
||||
if "grammar" in wrapper:
|
||||
if wrapper is not None and "grammar" in wrapper:
|
||||
# When using grammars, we don't want to do any extras output tricks like appending a response prefix
|
||||
setattr(llm_wrapper, "assistant_prefix_extra_first_message", "")
|
||||
setattr(llm_wrapper, "assistant_prefix_extra", "")
|
||||
@@ -125,9 +125,11 @@ def get_chat_completion(
|
||||
try:
|
||||
# if hasattr(llm_wrapper, "supports_first_message"):
|
||||
if hasattr(llm_wrapper, "supports_first_message") and llm_wrapper.supports_first_message:
|
||||
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions, first_message=first_message)
|
||||
prompt = llm_wrapper.chat_completion_to_prompt(
|
||||
messages, functions, first_message=first_message, function_documentation=documentation
|
||||
)
|
||||
else:
|
||||
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions)
|
||||
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions, function_documentation=documentation)
|
||||
|
||||
printd(prompt)
|
||||
except Exception as e:
|
||||
@@ -248,8 +250,8 @@ def generate_grammar_and_documentation(
|
||||
grammar_function_models,
|
||||
outer_object_name="function",
|
||||
outer_object_content="params",
|
||||
model_prefix="Function",
|
||||
fields_prefix="Parameter",
|
||||
model_prefix="function",
|
||||
fields_prefix="params",
|
||||
add_inner_thoughts=add_inner_thoughts_top_level,
|
||||
allow_only_inner_thoughts=allow_only_inner_thoughts,
|
||||
)
|
||||
|
||||
@@ -465,8 +465,7 @@ def generate_gbnf_grammar(model: Type[BaseModel], processed_models: set, created
|
||||
nested_rules = []
|
||||
has_markdown_code_block = False
|
||||
has_triple_quoted_string = False
|
||||
look_for_markdown_code_block = False
|
||||
look_for_triple_quoted_string = False
|
||||
|
||||
for field_name, field_info in model_fields.items():
|
||||
if not issubclass(model, BaseModel):
|
||||
field_type, default_value = field_info
|
||||
@@ -621,7 +620,7 @@ null ::= "null"
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" ws
|
||||
)* "\""
|
||||
ws ::= ([ \t\n] ws)?
|
||||
float ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
|
||||
@@ -636,13 +635,13 @@ object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
)? "}"
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
)? "]"
|
||||
|
||||
number ::= integer | float"""
|
||||
|
||||
@@ -683,7 +682,7 @@ def generate_markdown_documentation(
|
||||
if add_prefix:
|
||||
documentation += f"{model_prefix}: {model.__name__}\n"
|
||||
else:
|
||||
documentation += f"Model: {model.__name__}\n"
|
||||
documentation += f"class: {model.__name__}\n"
|
||||
|
||||
# Handling multi-line model description with proper indentation
|
||||
|
||||
@@ -691,18 +690,19 @@ def generate_markdown_documentation(
|
||||
base_class_doc = getdoc(BaseModel)
|
||||
class_description = class_doc if class_doc and class_doc != base_class_doc else ""
|
||||
if class_description != "":
|
||||
documentation += " Description: "
|
||||
documentation += format_multiline_description(class_description, 0) + "\n"
|
||||
documentation += format_multiline_description("description: " + class_description, 1) + "\n"
|
||||
|
||||
if add_prefix:
|
||||
# Indenting the fields section
|
||||
documentation += f" {fields_prefix}:\n"
|
||||
else:
|
||||
documentation += f" Fields:\n"
|
||||
documentation += f" attributes:\n"
|
||||
if isclass(model) and issubclass(model, BaseModel):
|
||||
for name, field_type in model.__annotations__.items():
|
||||
# if name == "markdown_code_block":
|
||||
# continue
|
||||
if isclass(field_type) and issubclass(field_type, BaseModel):
|
||||
pyd_models.append((field_type, False))
|
||||
if get_origin(field_type) == list:
|
||||
element_type = get_args(field_type)[0]
|
||||
if isclass(element_type) and issubclass(element_type, BaseModel):
|
||||
@@ -748,25 +748,33 @@ def generate_field_markdown(
|
||||
|
||||
if get_origin(field_type) == list:
|
||||
element_type = get_args(field_type)[0]
|
||||
field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})"
|
||||
field_text = f"{indent}{field_name} ({field_type.__name__} of {element_type.__name__})"
|
||||
if field_description != "":
|
||||
field_text += ":\n"
|
||||
field_text += ": "
|
||||
else:
|
||||
field_text += "\n"
|
||||
elif get_origin(field_type) == Union:
|
||||
element_types = get_args(field_type)
|
||||
types = []
|
||||
for element_type in element_types:
|
||||
types.append(format_model_and_field_name(element_type.__name__))
|
||||
types.append(element_type.__name__)
|
||||
field_text = f"{indent}{field_name} ({' or '.join(types)})"
|
||||
if field_description != "":
|
||||
field_text += ":\n"
|
||||
field_text += ": "
|
||||
else:
|
||||
field_text += "\n"
|
||||
elif issubclass(field_type, Enum):
|
||||
enum_values = [f"'{str(member.value)}'" for member in field_type]
|
||||
|
||||
field_text = f"{indent}{field_name} ({' or '.join(enum_values)})"
|
||||
if field_description != "":
|
||||
field_text += ": "
|
||||
else:
|
||||
field_text += "\n"
|
||||
else:
|
||||
field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)})"
|
||||
field_text = f"{indent}{field_name} ({field_type.__name__})"
|
||||
if field_description != "":
|
||||
field_text += ":\n"
|
||||
field_text += ": "
|
||||
else:
|
||||
field_text += "\n"
|
||||
|
||||
@@ -774,7 +782,7 @@ def generate_field_markdown(
|
||||
return field_text
|
||||
|
||||
if field_description != "":
|
||||
field_text += f" Description: " + field_description + "\n"
|
||||
field_text += field_description + "\n"
|
||||
|
||||
# Check for and include field-specific examples if available
|
||||
if hasattr(model, "Config") and hasattr(model.Config, "json_schema_extra") and "example" in model.Config.json_schema_extra:
|
||||
@@ -784,7 +792,7 @@ def generate_field_markdown(
|
||||
field_text += f"{indent} Example: {example_text}\n"
|
||||
|
||||
if isclass(field_type) and issubclass(field_type, BaseModel):
|
||||
field_text += f"{indent} Details:\n"
|
||||
field_text += f"{indent} details:\n"
|
||||
for name, type_ in field_type.__annotations__.items():
|
||||
field_text += generate_field_markdown(name, type_, field_type, depth + 2)
|
||||
|
||||
@@ -952,7 +960,7 @@ def format_multiline_description(description: str, indent_level: int) -> str:
|
||||
Returns:
|
||||
str: Formatted multiline description.
|
||||
"""
|
||||
indent = " " * indent_level
|
||||
indent = " " * indent_level
|
||||
return indent + description.replace("\n", "\n" + indent)
|
||||
|
||||
|
||||
@@ -1154,6 +1162,7 @@ def create_dynamic_model_from_function(func: Callable, add_inner_thoughts: bool
|
||||
default_value = ...
|
||||
else:
|
||||
default_value = param.default
|
||||
|
||||
dynamic_fields[param.name] = (param.annotation if param.annotation != inspect.Parameter.empty else str, default_value)
|
||||
# Creating the dynamic model
|
||||
dynamic_model = create_model(f"{func.__name__}", **dynamic_fields)
|
||||
|
||||
@@ -26,7 +26,7 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
|
||||
self.include_opening_brance_in_prefix = include_opening_brace_in_prefix
|
||||
self.include_section_separators = include_section_separators
|
||||
|
||||
def chat_completion_to_prompt(self, messages, functions):
|
||||
def chat_completion_to_prompt(self, messages, functions, function_documentation=None):
|
||||
"""Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format
|
||||
|
||||
A chat.
|
||||
@@ -87,8 +87,11 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
|
||||
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
|
||||
prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
|
||||
prompt += f"\nAvailable functions:"
|
||||
for function_dict in functions:
|
||||
prompt += f"\n{create_function_description(function_dict)}"
|
||||
if function_documentation is not None:
|
||||
prompt += f"\n{function_documentation}"
|
||||
else:
|
||||
for function_dict in functions:
|
||||
prompt += f"\n{create_function_description(function_dict)}"
|
||||
|
||||
def create_function_call(function_call):
|
||||
"""Go from ChatCompletion to Airoboros style function trace (in prompt)
|
||||
@@ -230,7 +233,7 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
|
||||
self.assistant_prefix_extra = assistant_prefix_extra
|
||||
self.include_section_separators = include_section_separators
|
||||
|
||||
def chat_completion_to_prompt(self, messages, functions):
|
||||
def chat_completion_to_prompt(self, messages, functions, function_documentation=None):
|
||||
"""Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format
|
||||
|
||||
A chat.
|
||||
@@ -293,8 +296,11 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
|
||||
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
|
||||
prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
|
||||
prompt += f"\nAvailable functions:"
|
||||
for function_dict in functions:
|
||||
prompt += f"\n{create_function_description(function_dict)}"
|
||||
if function_documentation is not None:
|
||||
prompt += f"\n{function_documentation}"
|
||||
else:
|
||||
for function_dict in functions:
|
||||
prompt += f"\n{create_function_description(function_dict)}"
|
||||
|
||||
def create_function_call(function_call, inner_thoughts=None):
|
||||
"""Go from ChatCompletion to Airoboros style function trace (in prompt)
|
||||
|
||||
@@ -95,12 +95,15 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
return prompt
|
||||
|
||||
# NOTE: BOS/EOS chatml tokens are NOT inserted here
|
||||
def _compile_system_message(self, system_message, functions) -> str:
|
||||
def _compile_system_message(self, system_message, functions, function_documentation=None) -> str:
|
||||
"""system prompt + memory + functions -> string"""
|
||||
prompt = ""
|
||||
prompt += system_message
|
||||
prompt += "\n"
|
||||
prompt += self._compile_function_block(functions)
|
||||
if function_documentation is not None:
|
||||
prompt += function_documentation
|
||||
else:
|
||||
prompt += self._compile_function_block(functions)
|
||||
return prompt
|
||||
|
||||
def _compile_function_call(self, function_call, inner_thoughts=None):
|
||||
@@ -186,13 +189,15 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
prompt += function_return_str
|
||||
return prompt
|
||||
|
||||
def chat_completion_to_prompt(self, messages, functions, first_message=False):
|
||||
def chat_completion_to_prompt(self, messages, functions, first_message=False, function_documentation=None):
|
||||
"""chatml-style prompt formatting, with implied support for multi-role"""
|
||||
prompt = ""
|
||||
|
||||
# System insturctions go first
|
||||
assert messages[0]["role"] == "system"
|
||||
system_block = self._compile_system_message(system_message=messages[0]["content"], functions=functions)
|
||||
system_block = self._compile_system_message(
|
||||
system_message=messages[0]["content"], functions=functions, function_documentation=function_documentation
|
||||
)
|
||||
prompt += f"<|im_start|>system\n{system_block.strip()}<|im_end|>"
|
||||
|
||||
# Last are the user/assistant messages
|
||||
|
||||
@@ -26,7 +26,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
|
||||
self.include_opening_brance_in_prefix = include_opening_brace_in_prefix
|
||||
self.include_section_separators = include_section_separators
|
||||
|
||||
def chat_completion_to_prompt(self, messages, functions):
|
||||
def chat_completion_to_prompt(self, messages, functions, function_documentation=None):
|
||||
"""Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format
|
||||
|
||||
<|im_start|>system
|
||||
@@ -97,8 +97,11 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
|
||||
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
|
||||
prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
|
||||
prompt += f"\nAvailable functions:"
|
||||
for function_dict in functions:
|
||||
prompt += f"\n{create_function_description(function_dict)}"
|
||||
if function_documentation is not None:
|
||||
prompt += f"\n{function_documentation}"
|
||||
else:
|
||||
for function_dict in functions:
|
||||
prompt += f"\n{create_function_description(function_dict)}"
|
||||
|
||||
# Put functions INSIDE system message (TODO experiment with this)
|
||||
prompt += IM_END_TOKEN
|
||||
|
||||
@@ -17,7 +17,7 @@ class SimpleSummaryWrapper(LLMChatCompletionWrapper):
|
||||
self.include_assistant_prefix = include_assistant_prefix
|
||||
self.include_section_separators = include_section_separators
|
||||
|
||||
def chat_completion_to_prompt(self, messages, functions):
|
||||
def chat_completion_to_prompt(self, messages, functions, function_documentation=None):
|
||||
"""Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format
|
||||
|
||||
Instructions on how to summarize
|
||||
|
||||
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
|
||||
|
||||
class LLMChatCompletionWrapper(ABC):
|
||||
@abstractmethod
|
||||
def chat_completion_to_prompt(self, messages, functions):
|
||||
def chat_completion_to_prompt(self, messages, functions, function_documentation=None):
|
||||
"""Go from ChatCompletion to a single prompt string"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ class ZephyrMistralWrapper(LLMChatCompletionWrapper):
|
||||
self.include_opening_brance_in_prefix = include_opening_brace_in_prefix
|
||||
self.include_section_separators = include_section_separators
|
||||
|
||||
def chat_completion_to_prompt(self, messages, functions):
|
||||
def chat_completion_to_prompt(self, messages, functions, function_documentation=None):
|
||||
"""
|
||||
Zephyr prompt format:
|
||||
<|system|>
|
||||
@@ -65,8 +65,11 @@ class ZephyrMistralWrapper(LLMChatCompletionWrapper):
|
||||
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
|
||||
prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
|
||||
prompt += f"\nAvailable functions:"
|
||||
for function_dict in functions:
|
||||
prompt += f"\n{create_function_description(function_dict)}"
|
||||
if function_documentation is not None:
|
||||
prompt += f"\n{function_documentation}"
|
||||
else:
|
||||
for function_dict in functions:
|
||||
prompt += f"\n{create_function_description(function_dict)}"
|
||||
|
||||
# Put functions INSIDE system message (TODO experiment with this)
|
||||
prompt += IM_END_TOKEN
|
||||
@@ -199,7 +202,7 @@ class ZephyrMistralInnerMonologueWrapper(ZephyrMistralWrapper):
|
||||
self.include_opening_brance_in_prefix = include_opening_brace_in_prefix
|
||||
self.include_section_separators = include_section_separators
|
||||
|
||||
def chat_completion_to_prompt(self, messages, functions):
|
||||
def chat_completion_to_prompt(self, messages, functions, function_documentation=None):
|
||||
prompt = ""
|
||||
|
||||
IM_START_TOKEN = "<s>"
|
||||
@@ -227,8 +230,11 @@ class ZephyrMistralInnerMonologueWrapper(ZephyrMistralWrapper):
|
||||
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
|
||||
prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
|
||||
prompt += f"\nAvailable functions:"
|
||||
for function_dict in functions:
|
||||
prompt += f"\n{create_function_description(function_dict)}"
|
||||
if function_documentation is not None:
|
||||
prompt += f"\n{function_documentation}"
|
||||
else:
|
||||
for function_dict in functions:
|
||||
prompt += f"\n{create_function_description(function_dict)}"
|
||||
|
||||
def create_function_call(function_call, inner_thoughts=None):
|
||||
airo_func_call = {
|
||||
|
||||
@@ -13,23 +13,42 @@ from .utils import wipe_config, wipe_memgpt_home
|
||||
def test_server():
|
||||
wipe_memgpt_home()
|
||||
|
||||
config = MemGPTConfig(
|
||||
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
||||
recall_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
||||
metadata_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
||||
archival_storage_type="postgres",
|
||||
recall_storage_type="postgres",
|
||||
metadata_storage_type="postgres",
|
||||
# embeddings
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=1536,
|
||||
openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
# llms
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
model="gpt-4",
|
||||
)
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
config = MemGPTConfig(
|
||||
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
||||
recall_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
||||
metadata_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
||||
archival_storage_type="postgres",
|
||||
recall_storage_type="postgres",
|
||||
metadata_storage_type="postgres",
|
||||
# embeddings
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=1536,
|
||||
openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
# llms
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
model="gpt-4",
|
||||
)
|
||||
else: # hosted
|
||||
config = MemGPTConfig(
|
||||
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
||||
recall_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
||||
metadata_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
||||
archival_storage_type="postgres",
|
||||
recall_storage_type="postgres",
|
||||
metadata_storage_type="postgres",
|
||||
# embeddings
|
||||
embedding_endpoint_type="hugging-face",
|
||||
embedding_endpoint="https://embeddings.memgpt.ai",
|
||||
embedding_model="BAAI/bge-large-en-v1.5",
|
||||
embedding_dim=1024,
|
||||
# llms
|
||||
model_endpoint_type="vllm",
|
||||
model_endpoint="https://api.memgpt.ai",
|
||||
model="ehartford/dolphin-2.5-mixtral-8x7b",
|
||||
)
|
||||
config.save()
|
||||
|
||||
server = SyncServer()
|
||||
|
||||
Reference in New Issue
Block a user