From 0491a8bbe3f3ad700b0d38cf12239d30faaffa10 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Mon, 26 Aug 2024 16:47:41 -0400 Subject: [PATCH] refactor: Remove JSON constant for common method (#1680) --- memgpt/agent.py | 12 +- memgpt/autogen/interface.py | 6 +- memgpt/constants.py | 11 +- memgpt/functions/function_sets/base.py | 12 +- memgpt/functions/function_sets/extras.py | 6 +- memgpt/functions/schema_generator.py | 18 +-- memgpt/interface.py | 12 +- memgpt/llm_api/anthropic.py | 3 +- memgpt/llm_api/cohere.py | 7 +- memgpt/llm_api/google_ai.py | 12 +- memgpt/llm_api/llm_api_tools.py | 4 +- memgpt/local_llm/chat_completion_proxy.py | 11 +- memgpt/local_llm/function_parser.py | 12 +- .../grammars/gbnf_grammar_generator.py | 4 +- memgpt/local_llm/json_parser.py | 24 ++-- .../llm_chat_completion_wrappers/airoboros.py | 19 ++- .../llm_chat_completion_wrappers/chatml.py | 30 ++--- .../configurable_wrapper.py | 21 ++- .../llm_chat_completion_wrappers/dolphin.py | 11 +- .../llm_chat_completion_wrappers/llama3.py | 22 ++-- .../simple_summary_wrapper.py | 11 +- .../llm_chat_completion_wrappers/zephyr.py | 19 +-- memgpt/local_llm/settings/settings.py | 6 +- memgpt/main.py | 11 +- memgpt/openai_backcompat/openai_object.py | 12 +- memgpt/prompts/gpt_functions.py | 83 +++--------- memgpt/schemas/message.py | 8 +- memgpt/server/rest_api/utils.py | 30 ++++- memgpt/server/server.py | 9 +- memgpt/server/ws_api/example_client.py | 9 +- memgpt/server/ws_api/protocol.py | 52 +++----- memgpt/server/ws_api/server.py | 2 +- memgpt/system.py | 17 ++- memgpt/utils.py | 21 ++- paper_experiments/nested_kv_task/nested_kv.py | 3 +- tests/test_agent_function_update.py | 121 ++++++++++++++++++ tests/test_function_parser.py | 4 +- tests/test_json_parsers.py | 6 +- tests/test_websocket_server.py | 7 +- 39 files changed, 368 insertions(+), 320 deletions(-) create mode 100644 tests/test_agent_function_update.py diff --git a/memgpt/agent.py b/memgpt/agent.py index 60184be5..0003591f 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -11,8 +11,6 @@ from memgpt.constants import ( CLI_WARNING_PREFIX, FIRST_MESSAGE_ATTEMPTS, IN_CONTEXT_MEMORY_KEYWORD, - JSON_ENSURE_ASCII, - JSON_LOADS_STRICT, LLM_MAX_TOKENS, MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST, MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC, @@ -44,6 +42,8 @@ from memgpt.utils import ( get_tool_call_id, get_utc_time, is_utc_datetime, + json_dumps, + json_loads, parse_json, printd, united_diff, @@ -654,12 +654,12 @@ class Agent(object): def strip_name_field_from_user_message(user_message_text: str) -> Tuple[str, Optional[str]]: """If 'name' exists in the JSON string, remove it and return the cleaned text + name value""" try: - user_message_json = dict(json.loads(user_message_text, strict=JSON_LOADS_STRICT)) + user_message_json = dict(json_loads(user_message_text)) # Special handling for AutoGen messages with 'name' field # Treat 'name' as a special field # If it exists in the input message, elevate it to the 'message' level name = user_message_json.pop("name", None) - clean_message = json.dumps(user_message_json, ensure_ascii=JSON_ENSURE_ASCII) + clean_message = json_dumps(user_message_json) except Exception as e: print(f"{CLI_WARNING_PREFIX}handling of 'name' field failed with: {e}") @@ -668,8 +668,8 @@ class Agent(object): def validate_json(user_message_text: str, raise_on_error: bool) -> str: try: - user_message_json = dict(json.loads(user_message_text, strict=JSON_LOADS_STRICT)) - user_message_json_val = json.dumps(user_message_json, ensure_ascii=JSON_ENSURE_ASCII) + user_message_json = dict(json_loads(user_message_text)) + user_message_json_val = json_dumps(user_message_json) return user_message_json_val except Exception as e: print(f"{CLI_WARNING_PREFIX}couldn't parse user input message as JSON: {e}") diff --git a/memgpt/autogen/interface.py b/memgpt/autogen/interface.py index 66e083fe..bff04f71 100644 --- a/memgpt/autogen/interface.py +++ b/memgpt/autogen/interface.py @@ -4,7 +4,7 @@ from typing import Optional from colorama import Fore, Style, init -from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT +from memgpt.constants import CLI_WARNING_PREFIX from memgpt.schemas.message import Message init(autoreset=True) @@ -113,7 +113,7 @@ class AutoGenInterface(object): return else: try: - msg_json = json.loads(msg, strict=JSON_LOADS_STRICT) + msg_json = json_loads(msg) except: print(f"{CLI_WARNING_PREFIX}failed to parse user message into json") message = f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}" if self.fancy else f"[user] {msg}" @@ -203,7 +203,7 @@ class AutoGenInterface(object): ) else: try: - msg_dict = json.loads(msg, strict=JSON_LOADS_STRICT) + msg_dict = json_loads(msg) if "status" in msg_dict and msg_dict["status"] == "OK": message = ( f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}" diff --git a/memgpt/constants.py b/memgpt/constants.py index e7b8ceaf..d3f7634f 100644 --- a/memgpt/constants.py +++ b/memgpt/constants.py @@ -114,14 +114,9 @@ REQ_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}Function called using request_hea # FUNC_FAILED_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}Function call failed" FUNC_FAILED_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}Function call failed, returning control" -FUNCTION_PARAM_NAME_REQ_HEARTBEAT = "request_heartbeat" -FUNCTION_PARAM_TYPE_REQ_HEARTBEAT = "boolean" -FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT = "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function." RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE = 5 -# GLOBAL SETTINGS FOR `json.dumps()` -JSON_ENSURE_ASCII = False - -# GLOBAL SETTINGS FOR `json.loads()` -JSON_LOADS_STRICT = False +# TODO Is this config or constant? +CORE_MEMORY_PERSONA_CHAR_LIMIT: int = 2000 +CORE_MEMORY_HUMAN_CHAR_LIMIT: int = 2000 diff --git a/memgpt/functions/function_sets/base.py b/memgpt/functions/function_sets/base.py index 9c751d2c..39c43a92 100644 --- a/memgpt/functions/function_sets/base.py +++ b/memgpt/functions/function_sets/base.py @@ -4,11 +4,7 @@ import math from typing import Optional from memgpt.agent import Agent -from memgpt.constants import ( - JSON_ENSURE_ASCII, - MAX_PAUSE_HEARTBEATS, - RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE, -) +from memgpt.constants import MAX_PAUSE_HEARTBEATS, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE ### Functions / tools the agent can use # All functions should return a response string (or None) @@ -81,7 +77,7 @@ def conversation_search(self: Agent, query: str, page: Optional[int] = 0) -> Opt else: results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):" results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results] - results_str = f"{results_pref} {json.dumps(results_formatted, ensure_ascii=JSON_ENSURE_ASCII)}" + results_str = f"{results_pref} {json_dumps(results_formatted)}" return results_str @@ -111,7 +107,7 @@ def conversation_search_date(self: Agent, start_date: str, end_date: str, page: else: results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):" results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results] - results_str = f"{results_pref} {json.dumps(results_formatted, ensure_ascii=JSON_ENSURE_ASCII)}" + results_str = f"{results_pref} {json_dumps(results_formatted)}" return results_str @@ -154,5 +150,5 @@ def archival_memory_search(self: Agent, query: str, page: Optional[int] = 0) -> else: results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):" results_formatted = [f"timestamp: {d['timestamp']}, memory: {d['content']}" for d in results] - results_str = f"{results_pref} {json.dumps(results_formatted, ensure_ascii=JSON_ENSURE_ASCII)}" + results_str = f"{results_pref} {json_dumps(results_formatted)}" return results_str diff --git a/memgpt/functions/function_sets/extras.py b/memgpt/functions/function_sets/extras.py index 32565895..007f346b 100644 --- a/memgpt/functions/function_sets/extras.py +++ b/memgpt/functions/function_sets/extras.py @@ -6,8 +6,6 @@ from typing import Optional import requests from memgpt.constants import ( - JSON_ENSURE_ASCII, - JSON_LOADS_STRICT, MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, ) @@ -123,10 +121,10 @@ def http_request(self, method: str, url: str, payload_json: Optional[str] = None else: # Validate and convert the payload for other types of requests if payload_json: - payload = json.loads(payload_json, strict=JSON_LOADS_STRICT) + payload = json_loads(payload_json) else: payload = {} - print(f"[HTTP] launching {method} request to {url}, payload=\n{json.dumps(payload, indent=2, ensure_ascii=JSON_ENSURE_ASCII)}") + print(f"[HTTP] launching {method} request to {url}, payload=\n{json_dumps(payload, indent=2)}") response = requests.request(method, url, json=payload, headers=headers) return {"status_code": response.status_code, "headers": dict(response.headers), "body": response.text} diff --git a/memgpt/functions/schema_generator.py b/memgpt/functions/schema_generator.py index 893da002..7a3dc965 100644 --- a/memgpt/functions/schema_generator.py +++ b/memgpt/functions/schema_generator.py @@ -5,14 +5,6 @@ from typing import Any, Dict, Optional, Type, get_args, get_origin from docstring_parser import parse from pydantic import BaseModel -from memgpt.constants import ( - FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, - FUNCTION_PARAM_NAME_REQ_HEARTBEAT, - FUNCTION_PARAM_TYPE_REQ_HEARTBEAT, -) - -NO_HEARTBEAT_FUNCTIONS = ["send_message", "pause_heartbeats"] - def is_optional(annotation): # Check if the annotation is a Union @@ -136,12 +128,12 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[ schema["parameters"]["required"].append(param.name) # append the heartbeat - if function.__name__ not in NO_HEARTBEAT_FUNCTIONS: - schema["parameters"]["properties"][FUNCTION_PARAM_NAME_REQ_HEARTBEAT] = { - "type": FUNCTION_PARAM_TYPE_REQ_HEARTBEAT, - "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, + if function.__name__ not in ["send_message", "pause_heartbeats"]: + schema["parameters"]["properties"]["request_heartbeat"] = { + "type": "boolean", + "description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.", } - schema["parameters"]["required"].append(FUNCTION_PARAM_NAME_REQ_HEARTBEAT) + schema["parameters"]["required"].append("request_heartbeat") return schema diff --git a/memgpt/interface.py b/memgpt/interface.py index 2f322217..7206c53b 100644 --- a/memgpt/interface.py +++ b/memgpt/interface.py @@ -5,9 +5,9 @@ from typing import List, Optional from colorama import Fore, Style, init -from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT +from memgpt.constants import CLI_WARNING_PREFIX from memgpt.schemas.message import Message -from memgpt.utils import printd +from memgpt.utils import json_loads, printd init(autoreset=True) @@ -127,7 +127,7 @@ class CLIInterface(AgentInterface): return else: try: - msg_json = json.loads(msg, strict=JSON_LOADS_STRICT) + msg_json = json_loads(msg) except: printd(f"{CLI_WARNING_PREFIX}failed to parse user message into json") printd_user_message("🧑", msg) @@ -228,7 +228,7 @@ class CLIInterface(AgentInterface): printd_function_message("", msg) else: try: - msg_dict = json.loads(msg, strict=JSON_LOADS_STRICT) + msg_dict = json_loads(msg) if "status" in msg_dict and msg_dict["status"] == "OK": printd_function_message("", str(msg), color=Fore.GREEN) else: @@ -259,7 +259,7 @@ class CLIInterface(AgentInterface): CLIInterface.internal_monologue(content) # I think the next one is not up to date # function_message(msg["function_call"]) - args = json.loads(msg["function_call"].get("arguments"), strict=JSON_LOADS_STRICT) + args = json_loads(msg["function_call"].get("arguments")) CLIInterface.assistant_message(args.get("message")) # assistant_message(content) elif msg.get("tool_calls"): @@ -267,7 +267,7 @@ class CLIInterface(AgentInterface): CLIInterface.internal_monologue(content) function_obj = msg["tool_calls"][0].get("function") if function_obj: - args = json.loads(function_obj.get("arguments"), strict=JSON_LOADS_STRICT) + args = json_loads(function_obj.get("arguments")) CLIInterface.assistant_message(args.get("message")) else: CLIInterface.internal_monologue(content) diff --git a/memgpt/llm_api/anthropic.py b/memgpt/llm_api/anthropic.py index 785105c7..cdaba1af 100644 --- a/memgpt/llm_api/anthropic.py +++ b/memgpt/llm_api/anthropic.py @@ -5,7 +5,6 @@ from typing import List, Optional, Union import requests -from memgpt.constants import JSON_ENSURE_ASCII from memgpt.schemas.message import Message from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool from memgpt.schemas.openai.chat_completion_response import ( @@ -257,7 +256,7 @@ def convert_anthropic_response_to_chatcompletion( type="function", function=FunctionCall( name=response_json["content"][1]["name"], - arguments=json.dumps(response_json["content"][1]["input"], ensure_ascii=JSON_ENSURE_ASCII), + arguments=json_dumps(response_json["content"][1]["input"]), ), ) ] diff --git a/memgpt/llm_api/cohere.py b/memgpt/llm_api/cohere.py index 8d16c326..1623cc32 100644 --- a/memgpt/llm_api/cohere.py +++ b/memgpt/llm_api/cohere.py @@ -4,7 +4,6 @@ from typing import List, Optional, Union import requests -from memgpt.constants import JSON_ENSURE_ASCII from memgpt.local_llm.utils import count_tokens from memgpt.schemas.message import Message from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool @@ -17,7 +16,7 @@ from memgpt.schemas.openai.chat_completion_response import ( Message as ChoiceMessage, # NOTE: avoid conflict with our own MemGPT Message datatype ) from memgpt.schemas.openai.chat_completion_response import ToolCall, UsageStatistics -from memgpt.utils import get_tool_call_id, get_utc_time, smart_urljoin +from memgpt.utils import get_tool_call_id, get_utc_time, json_dumps, smart_urljoin BASE_URL = "https://api.cohere.ai/v1" @@ -155,9 +154,7 @@ def convert_cohere_response_to_chatcompletion( completion_tokens = response_json["meta"]["billed_units"]["output_tokens"] else: # For some reason input_tokens not included in 'meta' 'tokens' dict? - prompt_tokens = count_tokens( - json.dumps(response_json["chat_history"], ensure_ascii=JSON_ENSURE_ASCII) - ) # NOTE: this is a very rough approximation + prompt_tokens = count_tokens(json_dumps(response_json["chat_history"])) # NOTE: this is a very rough approximation completion_tokens = response_json["meta"]["tokens"]["output_tokens"] finish_reason = remap_finish_reason(response_json["finish_reason"]) diff --git a/memgpt/llm_api/google_ai.py b/memgpt/llm_api/google_ai.py index ed2e7cfc..4cf49642 100644 --- a/memgpt/llm_api/google_ai.py +++ b/memgpt/llm_api/google_ai.py @@ -4,7 +4,7 @@ from typing import List, Optional import requests -from memgpt.constants import JSON_ENSURE_ASCII, NON_USER_MSG_PREFIX +from memgpt.constants import NON_USER_MSG_PREFIX from memgpt.local_llm.json_parser import clean_json_string_extra_backslash from memgpt.local_llm.utils import count_tokens from memgpt.schemas.openai.chat_completion_request import Tool @@ -310,7 +310,7 @@ def convert_google_ai_response_to_chatcompletion( type="function", function=FunctionCall( name=function_name, - arguments=clean_json_string_extra_backslash(json.dumps(function_args, ensure_ascii=JSON_ENSURE_ASCII)), + arguments=clean_json_string_extra_backslash(json_dumps(function_args)), ), ) ], @@ -374,12 +374,8 @@ def convert_google_ai_response_to_chatcompletion( else: # Count it ourselves assert input_messages is not None, f"Didn't get UsageMetadata from the API response, so input_messages is required" - prompt_tokens = count_tokens( - json.dumps(input_messages, ensure_ascii=JSON_ENSURE_ASCII) - ) # NOTE: this is a very rough approximation - completion_tokens = count_tokens( - json.dumps(openai_response_message.model_dump(), ensure_ascii=JSON_ENSURE_ASCII) - ) # NOTE: this is also approximate + prompt_tokens = count_tokens(json_dumps(input_messages)) # NOTE: this is a very rough approximation + completion_tokens = count_tokens(json_dumps(openai_response_message.model_dump())) # NOTE: this is also approximate total_tokens = prompt_tokens + completion_tokens usage = UsageStatistics( prompt_tokens=prompt_tokens, diff --git a/memgpt/llm_api/llm_api_tools.py b/memgpt/llm_api/llm_api_tools.py index 3a06ab8e..310a91a8 100644 --- a/memgpt/llm_api/llm_api_tools.py +++ b/memgpt/llm_api/llm_api_tools.py @@ -9,7 +9,7 @@ from typing import List, Optional, Union import requests -from memgpt.constants import CLI_WARNING_PREFIX, JSON_ENSURE_ASCII +from memgpt.constants import CLI_WARNING_PREFIX from memgpt.credentials import MemGPTCredentials from memgpt.llm_api.anthropic import anthropic_chat_completions_request from memgpt.llm_api.azure_openai import ( @@ -109,7 +109,7 @@ def unpack_inner_thoughts_from_kwargs( # replace the kwargs new_choice = choice.model_copy(deep=True) - new_choice.message.tool_calls[0].function.arguments = json.dumps(func_args, ensure_ascii=JSON_ENSURE_ASCII) + new_choice.message.tool_calls[0].function.arguments = json_dumps(func_args) # also replace the message content if new_choice.message.content is not None: warnings.warn(f"Overwriting existing inner monologue ({new_choice.message.content}) with kwarg ({inner_thoughts})") diff --git a/memgpt/local_llm/chat_completion_proxy.py b/memgpt/local_llm/chat_completion_proxy.py index 1fa87092..ccd62ef0 100644 --- a/memgpt/local_llm/chat_completion_proxy.py +++ b/memgpt/local_llm/chat_completion_proxy.py @@ -1,11 +1,10 @@ """Key idea: create drop-in replacement for agent's ChatCompletion call that runs on an OpenLLM backend""" -import json import uuid import requests -from memgpt.constants import CLI_WARNING_PREFIX, JSON_ENSURE_ASCII +from memgpt.constants import CLI_WARNING_PREFIX from memgpt.errors import LocalLLMConnectionError, LocalLLMError from memgpt.local_llm.constants import DEFAULT_WRAPPER from memgpt.local_llm.function_parser import patch_function @@ -33,7 +32,7 @@ from memgpt.schemas.openai.chat_completion_response import ( ToolCall, UsageStatistics, ) -from memgpt.utils import get_tool_call_id, get_utc_time +from memgpt.utils import get_tool_call_id, get_utc_time, json_dumps has_shown_warning = False grammar_supported_backends = ["koboldcpp", "llamacpp", "webui", "webui-legacy"] @@ -189,7 +188,7 @@ def get_chat_completion( chat_completion_result = llm_wrapper.output_to_chat_completion_response(result, first_message=first_message) else: chat_completion_result = llm_wrapper.output_to_chat_completion_response(result) - printd(json.dumps(chat_completion_result, indent=2, ensure_ascii=JSON_ENSURE_ASCII)) + printd(json_dumps(chat_completion_result, indent=2)) except Exception as e: raise LocalLLMError(f"Failed to parse JSON from local LLM response - error: {str(e)}") @@ -206,13 +205,13 @@ def get_chat_completion( usage["prompt_tokens"] = count_tokens(prompt) # NOTE: we should compute on-the-fly anyways since we might have to correct for errors during JSON parsing - usage["completion_tokens"] = count_tokens(json.dumps(chat_completion_result, ensure_ascii=JSON_ENSURE_ASCII)) + usage["completion_tokens"] = count_tokens(json_dumps(chat_completion_result)) """ if usage["completion_tokens"] is None: printd(f"usage dict was missing completion_tokens, computing on-the-fly...") # chat_completion_result is dict with 'role' and 'content' # token counter wants a string - usage["completion_tokens"] = count_tokens(json.dumps(chat_completion_result, ensure_ascii=JSON_ENSURE_ASCII)) + usage["completion_tokens"] = count_tokens(json_dumps(chat_completion_result)) """ # NOTE: this is the token count that matters most diff --git a/memgpt/local_llm/function_parser.py b/memgpt/local_llm/function_parser.py index 57dc078f..5f45dd29 100644 --- a/memgpt/local_llm/function_parser.py +++ b/memgpt/local_llm/function_parser.py @@ -1,7 +1,7 @@ import copy import json -from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT +from memgpt.utils import json_dumps, json_loads NO_HEARTBEAT_FUNCS = ["send_message", "pause_heartbeats"] @@ -13,16 +13,16 @@ def insert_heartbeat(message): if message_copy.get("function_call"): # function_name = message.get("function_call").get("name") params = message_copy.get("function_call").get("arguments") - params = json.loads(params, strict=JSON_LOADS_STRICT) + params = json_loads(params) params["request_heartbeat"] = True - message_copy["function_call"]["arguments"] = json.dumps(params, ensure_ascii=JSON_ENSURE_ASCII) + message_copy["function_call"]["arguments"] = json_dumps(params) elif message_copy.get("tool_call"): # function_name = message.get("tool_calls")[0].get("function").get("name") params = message_copy.get("tool_calls")[0].get("function").get("arguments") - params = json.loads(params, strict=JSON_LOADS_STRICT) + params = json_loads(params) params["request_heartbeat"] = True - message_copy["tools_calls"][0]["function"]["arguments"] = json.dumps(params, ensure_ascii=JSON_ENSURE_ASCII) + message_copy["tools_calls"][0]["function"]["arguments"] = json_dumps(params) return message_copy @@ -40,7 +40,7 @@ def heartbeat_correction(message_history, new_message): last_message_was_user = False if message_history[-1]["role"] == "user": try: - content = json.loads(message_history[-1]["content"], strict=JSON_LOADS_STRICT) + content = json_loads(message_history[-1]["content"]) except json.JSONDecodeError: return None # Check if it's a user message or system message diff --git a/memgpt/local_llm/grammars/gbnf_grammar_generator.py b/memgpt/local_llm/grammars/gbnf_grammar_generator.py index c38a0abe..59da81bc 100644 --- a/memgpt/local_llm/grammars/gbnf_grammar_generator.py +++ b/memgpt/local_llm/grammars/gbnf_grammar_generator.py @@ -21,7 +21,7 @@ from typing import ( from docstring_parser import parse from pydantic import BaseModel, create_model -from memgpt.constants import JSON_ENSURE_ASCII +from memgpt.utils import json_dumps, json_loads class PydanticDataType(Enum): @@ -731,7 +731,7 @@ def generate_markdown_documentation( if hasattr(model, "Config") and hasattr(model.Config, "json_schema_extra") and "example" in model.Config.json_schema_extra: documentation += f" Expected Example Output for {format_model_and_field_name(model.__name__)}:\n" - json_example = json.dumps(model.Config.json_schema_extra["example"], ensure_ascii=JSON_ENSURE_ASCII) + json_example = json_dumps(model.Config.json_schema_extra["example"]) documentation += format_multiline_description(json_example, 2) + "\n" return documentation diff --git a/memgpt/local_llm/json_parser.py b/memgpt/local_llm/json_parser.py index 13849201..fd4bb750 100644 --- a/memgpt/local_llm/json_parser.py +++ b/memgpt/local_llm/json_parser.py @@ -1,8 +1,8 @@ import json import re -from memgpt.constants import JSON_LOADS_STRICT from memgpt.errors import LLMJSONParsingError +from memgpt.utils import json_loads def clean_json_string_extra_backslash(s): @@ -45,7 +45,7 @@ def extract_first_json(string: str): depth -= 1 if depth == 0 and start_index is not None: try: - return json.loads(string[start_index : i + 1], strict=JSON_LOADS_STRICT) + return json_loads(string[start_index : i + 1]) except json.JSONDecodeError as e: raise LLMJSONParsingError(f"Matched closing bracket, but decode failed with error: {str(e)}") printd("No valid JSON object found.") @@ -174,21 +174,21 @@ def clean_json(raw_llm_output, messages=None, functions=None): from memgpt.utils import printd strategies = [ - lambda output: json.loads(output, strict=JSON_LOADS_STRICT), - lambda output: json.loads(output + "}", strict=JSON_LOADS_STRICT), - lambda output: json.loads(output + "}}", strict=JSON_LOADS_STRICT), - lambda output: json.loads(output + '"}}', strict=JSON_LOADS_STRICT), + lambda output: json_loads(output), + lambda output: json_loads(output + "}"), + lambda output: json_loads(output + "}}"), + lambda output: json_loads(output + '"}}'), # with strip and strip comma - lambda output: json.loads(output.strip().rstrip(",") + "}", strict=JSON_LOADS_STRICT), - lambda output: json.loads(output.strip().rstrip(",") + "}}", strict=JSON_LOADS_STRICT), - lambda output: json.loads(output.strip().rstrip(",") + '"}}', strict=JSON_LOADS_STRICT), + lambda output: json_loads(output.strip().rstrip(",") + "}"), + lambda output: json_loads(output.strip().rstrip(",") + "}}"), + lambda output: json_loads(output.strip().rstrip(",") + '"}}'), # more complex patchers - lambda output: json.loads(repair_json_string(output), strict=JSON_LOADS_STRICT), - lambda output: json.loads(repair_even_worse_json(output), strict=JSON_LOADS_STRICT), + lambda output: json_loads(repair_json_string(output)), + lambda output: json_loads(repair_even_worse_json(output)), lambda output: extract_first_json(output + "}}"), lambda output: clean_and_interpret_send_message_json(output), # replace underscores - lambda output: json.loads(replace_escaped_underscores(output), strict=JSON_LOADS_STRICT), + lambda output: json_loads(replace_escaped_underscores(output)), lambda output: extract_first_json(replace_escaped_underscores(output) + "}}"), ] diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py index 2f870c3b..efbf1b31 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py @@ -1,6 +1,5 @@ -import json +from memgpt.utils import json_dumps, json_loads -from ...constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT from ...errors import LLMJSONParsingError from ..json_parser import clean_json from .wrapper_base import LLMChatCompletionWrapper @@ -114,9 +113,9 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper): """ airo_func_call = { "function": function_call["name"], - "params": json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT), + "params": json_loads(function_call["arguments"]), } - return json.dumps(airo_func_call, indent=2, ensure_ascii=JSON_ENSURE_ASCII) + return json_dumps(airo_func_call, indent=2) # Add a sep for the conversation if self.include_section_separators: @@ -129,7 +128,7 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper): if message["role"] == "user": if self.simplify_json_content: try: - content_json = json.loads(message["content"], strict=JSON_LOADS_STRICT) + content_json = json_loads(message["content"]) content_simple = content_json["message"] prompt += f"\nUSER: {content_simple}" except: @@ -206,7 +205,7 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper): "content": None, "function_call": { "name": function_name, - "arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII), + "arguments": json_dumps(function_parameters), }, } return message @@ -325,10 +324,10 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper): "function": function_call["name"], "params": { "inner_thoughts": inner_thoughts, - **json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT), + **json_loads(function_call["arguments"]), }, } - return json.dumps(airo_func_call, indent=2, ensure_ascii=JSON_ENSURE_ASCII) + return json_dumps(airo_func_call, indent=2) # Add a sep for the conversation if self.include_section_separators: @@ -347,7 +346,7 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper): user_prefix = "USER" if self.simplify_json_content: try: - content_json = json.loads(message["content"], strict=JSON_LOADS_STRICT) + content_json = json_loads(message["content"]) content_simple = content_json["message"] prompt += f"\n{user_prefix}: {content_simple}" except: @@ -447,7 +446,7 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper): "content": inner_thoughts, "function_call": { "name": function_name, - "arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII), + "arguments": json_dumps(function_parameters), }, } return message diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/chatml.py b/memgpt/local_llm/llm_chat_completion_wrappers/chatml.py index 787ebc4f..532b05f7 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/chatml.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/chatml.py @@ -1,11 +1,9 @@ -import json - -from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT from memgpt.errors import LLMJSONParsingError from memgpt.local_llm.json_parser import clean_json from memgpt.local_llm.llm_chat_completion_wrappers.wrapper_base import ( LLMChatCompletionWrapper, ) +from memgpt.utils import json_dumps, json_loads PREFIX_HINT = """# Reminders: # Important information about yourself and the user is stored in (limited) core memory @@ -137,10 +135,10 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper): "function": function_call["name"], "params": { "inner_thoughts": inner_thoughts, - **json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT), + **json_loads(function_call["arguments"]), }, } - return json.dumps(airo_func_call, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII) + return json_dumps(airo_func_call, indent=self.json_indent) # NOTE: BOS/EOS chatml tokens are NOT inserted here def _compile_assistant_message(self, message) -> str: @@ -167,15 +165,15 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper): if self.simplify_json_content: # Make user messages not JSON but plaintext instead try: - user_msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT) + user_msg_json = json_loads(message["content"]) user_msg_str = user_msg_json["message"] except: user_msg_str = message["content"] else: # Otherwise just dump the full json try: - user_msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT) - user_msg_str = json.dumps(user_msg_json, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII) + user_msg_json = json_loads(message["content"]) + user_msg_str = json_dumps(user_msg_json, indent=self.json_indent) except: user_msg_str = message["content"] @@ -189,8 +187,8 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper): prompt = "" try: # indent the function replies - function_return_dict = json.loads(message["content"], strict=JSON_LOADS_STRICT) - function_return_str = json.dumps(function_return_dict, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII) + function_return_dict = json_loads(message["content"]) + function_return_str = json_dumps(function_return_dict, indent=self.json_indent) except: function_return_str = message["content"] @@ -219,7 +217,7 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper): if self.use_system_role_in_user: try: - msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT) + msg_json = json_loads(message["content"]) if msg_json["type"] != "user_message": role_str = "system" except: @@ -329,7 +327,7 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper): "content": inner_thoughts, "function_call": { "name": function_name, - "arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII), + "arguments": json_dumps(function_parameters), }, } return message @@ -394,10 +392,10 @@ class ChatMLOuterInnerMonologueWrapper(ChatMLInnerMonologueWrapper): "function": function_call["name"], "params": { # "inner_thoughts": inner_thoughts, - **json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT), + **json_loads(function_call["arguments"]), }, } - return json.dumps(airo_func_call, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII) + return json_dumps(airo_func_call, indent=self.json_indent) def output_to_chat_completion_response(self, raw_llm_output, first_message=False): """NOTE: Modified to expect "inner_thoughts" outside the function @@ -458,7 +456,7 @@ class ChatMLOuterInnerMonologueWrapper(ChatMLInnerMonologueWrapper): "content": inner_thoughts, # "function_call": { # "name": function_name, - # "arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII), + # "arguments": json_dumps(function_parameters), # }, } @@ -466,7 +464,7 @@ class ChatMLOuterInnerMonologueWrapper(ChatMLInnerMonologueWrapper): if function_name is not None: message["function_call"] = { "name": function_name, - "arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII), + "arguments": json_dumps(function_parameters), } return message diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py b/memgpt/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py index 9d272449..bbf4b0f5 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py @@ -1,8 +1,7 @@ -import json - import yaml -from ...constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT +from memgpt.utils import json_dumps, json_loads + from ...errors import LLMJSONParsingError from ..json_parser import clean_json from .wrapper_base import LLMChatCompletionWrapper @@ -131,10 +130,10 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper): "function": function_call["name"], "params": { "inner_thoughts": inner_thoughts, - **json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT), + **json_loads(function_call["arguments"]), }, } - return json.dumps(airo_func_call, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII) + return json_dumps(airo_func_call, indent=self.json_indent) # NOTE: BOS/EOS chatml tokens are NOT inserted here def _compile_assistant_message(self, message) -> str: @@ -161,15 +160,15 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper): if self.simplify_json_content: # Make user messages not JSON but plaintext instead try: - user_msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT) + user_msg_json = json_loads(message["content"]) user_msg_str = user_msg_json["message"] except: user_msg_str = message["content"] else: # Otherwise just dump the full json try: - user_msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT) - user_msg_str = json.dumps(user_msg_json, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII) + user_msg_json = json_loads(message["content"]) + user_msg_str = json_dumps(user_msg_json, indent=self.json_indent) except: user_msg_str = message["content"] @@ -183,8 +182,8 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper): prompt = "" try: # indent the function replies - function_return_dict = json.loads(message["content"], strict=JSON_LOADS_STRICT) - function_return_str = json.dumps(function_return_dict, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII) + function_return_dict = json_loads(message["content"]) + function_return_str = json_dumps(function_return_dict, indent=self.json_indent) except: function_return_str = message["content"] @@ -309,7 +308,7 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper): "content": inner_thoughts, "function_call": { "name": function_name, - "arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII), + "arguments": json_dumps(function_parameters), }, } return message diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/dolphin.py b/memgpt/local_llm/llm_chat_completion_wrappers/dolphin.py index e2056d92..00e06de6 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/dolphin.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/dolphin.py @@ -1,6 +1,7 @@ import json -from ...constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT +from memgpt.utils import json_dumps, json_loads + from ...errors import LLMJSONParsingError from ..json_parser import clean_json from .wrapper_base import LLMChatCompletionWrapper @@ -127,9 +128,9 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper): """ airo_func_call = { "function": function_call["name"], - "params": json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT), + "params": json_loads(function_call["arguments"]), } - return json.dumps(airo_func_call, indent=2, ensure_ascii=JSON_ENSURE_ASCII) + return json_dumps(airo_func_call, indent=2) # option (1): from HF README: # <|im_start|>user @@ -156,7 +157,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper): if message["role"] == "user": if self.simplify_json_content: try: - content_json = (json.loads(message["content"], strict=JSON_LOADS_STRICT),) + content_json = (json_loads(message["content"]),) content_simple = content_json["message"] prompt += f"\n{IM_START_TOKEN}user\n{content_simple}{IM_END_TOKEN}" # prompt += f"\nUSER: {content_simple}" @@ -241,7 +242,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper): "content": None, "function_call": { "name": function_name, - "arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII), + "arguments": json_dumps(function_parameters), }, } return message diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/llama3.py b/memgpt/local_llm/llm_chat_completion_wrappers/llama3.py index 2141bdaa..867cedfd 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/llama3.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/llama3.py @@ -1,11 +1,11 @@ import json -from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT from memgpt.errors import LLMJSONParsingError from memgpt.local_llm.json_parser import clean_json from memgpt.local_llm.llm_chat_completion_wrappers.wrapper_base import ( LLMChatCompletionWrapper, ) +from memgpt.utils import json_dumps, json_loads PREFIX_HINT = """# Reminders: # Important information about yourself and the user is stored in (limited) core memory @@ -137,10 +137,10 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper): "function": function_call["name"], "params": { "inner_thoughts": inner_thoughts, - **json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT), + **json_loads(function_call["arguments"]), }, } - return json.dumps(airo_func_call, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII) + return json_dumps(airo_func_call, indent=self.json_indent) # NOTE: BOS/EOS chatml tokens are NOT inserted here def _compile_assistant_message(self, message) -> str: @@ -167,18 +167,17 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper): if self.simplify_json_content: # Make user messages not JSON but plaintext instead try: - user_msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT) + user_msg_json = json_loads(message["content"]) user_msg_str = user_msg_json["message"] except: user_msg_str = message["content"] else: # Otherwise just dump the full json try: - user_msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT) - user_msg_str = json.dumps( + user_msg_json = json_loads(message["content"]) + user_msg_str = json_dumps( user_msg_json, indent=self.json_indent, - ensure_ascii=JSON_ENSURE_ASCII, ) except: user_msg_str = message["content"] @@ -193,11 +192,10 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper): prompt = "" try: # indent the function replies - function_return_dict = json.loads(message["content"], strict=JSON_LOADS_STRICT) - function_return_str = json.dumps( + function_return_dict = json_loads(message["content"]) + function_return_str = json_dumps( function_return_dict, indent=self.json_indent, - ensure_ascii=JSON_ENSURE_ASCII, ) except: function_return_str = message["content"] @@ -229,7 +227,7 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper): if self.use_system_role_in_user: try: - msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT) + msg_json = json_loads(message["content"]) if msg_json["type"] != "user_message": role_str = "system" except: @@ -343,7 +341,7 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper): "content": inner_thoughts, "function_call": { "name": function_name, - "arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII), + "arguments": json_dumps(function_parameters), }, } return message diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py b/memgpt/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py index e3dee6d2..600320bb 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py @@ -1,6 +1,7 @@ import json -from ...constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT +from memgpt.utils import json_dumps, json_loads + from .wrapper_base import LLMChatCompletionWrapper @@ -85,9 +86,9 @@ class SimpleSummaryWrapper(LLMChatCompletionWrapper): """ airo_func_call = { "function": function_call["name"], - "params": json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT), + "params": json_loads(function_call["arguments"]), } - return json.dumps(airo_func_call, indent=2, ensure_ascii=JSON_ENSURE_ASCII) + return json_dumps(airo_func_call, indent=2) # Add a sep for the conversation if self.include_section_separators: @@ -100,7 +101,7 @@ class SimpleSummaryWrapper(LLMChatCompletionWrapper): if message["role"] == "user": if self.simplify_json_content: try: - content_json = json.loads(message["content"], strict=JSON_LOADS_STRICT) + content_json = json_loads(message["content"]) content_simple = content_json["message"] prompt += f"\nUSER: {content_simple}" except: @@ -151,7 +152,7 @@ class SimpleSummaryWrapper(LLMChatCompletionWrapper): "content": raw_llm_output, # "function_call": { # "name": function_name, - # "arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII), + # "arguments": json_dumps(function_parameters), # }, } return message diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/zephyr.py b/memgpt/local_llm/llm_chat_completion_wrappers/zephyr.py index 4e45b052..924f2288 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/zephyr.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/zephyr.py @@ -1,6 +1,7 @@ import json -from ...constants import JSON_ENSURE_ASCII +from memgpt.utils import json_dumps, json_loads + from ...errors import LLMJSONParsingError from ..json_parser import clean_json from .wrapper_base import LLMChatCompletionWrapper @@ -76,9 +77,9 @@ class ZephyrMistralWrapper(LLMChatCompletionWrapper): def create_function_call(function_call): airo_func_call = { "function": function_call["name"], - "params": json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT), + "params": json_loads(function_call["arguments"]), } - return json.dumps(airo_func_call, indent=2, ensure_ascii=JSON_ENSURE_ASCII) + return json_dumps(airo_func_call, indent=2) for message in messages[1:]: assert message["role"] in ["user", "assistant", "function", "tool"], message @@ -86,7 +87,7 @@ class ZephyrMistralWrapper(LLMChatCompletionWrapper): if message["role"] == "user": if self.simplify_json_content: try: - content_json = json.loads(message["content"], strict=JSON_LOADS_STRICT) + content_json = json_loads(message["content"]) content_simple = content_json["message"] prompt += f"\n<|user|>\n{content_simple}{IM_END_TOKEN}" # prompt += f"\nUSER: {content_simple}" @@ -171,7 +172,7 @@ class ZephyrMistralWrapper(LLMChatCompletionWrapper): "content": None, "function_call": { "name": function_name, - "arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII), + "arguments": json_dumps(function_parameters), }, } return message @@ -239,10 +240,10 @@ class ZephyrMistralInnerMonologueWrapper(ZephyrMistralWrapper): "function": function_call["name"], "params": { "inner_thoughts": inner_thoughts, - **json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT), + **json_loads(function_call["arguments"]), }, } - return json.dumps(airo_func_call, indent=2, ensure_ascii=JSON_ENSURE_ASCII) + return json_dumps(airo_func_call, indent=2) # Add a sep for the conversation if self.include_section_separators: @@ -255,7 +256,7 @@ class ZephyrMistralInnerMonologueWrapper(ZephyrMistralWrapper): if message["role"] == "user": if self.simplify_json_content: try: - content_json = json.loads(message["content"], strict=JSON_LOADS_STRICT) + content_json = json_loads(message["content"]) content_simple = content_json["message"] prompt += f"\n<|user|>\n{content_simple}{IM_END_TOKEN}" except: @@ -340,7 +341,7 @@ class ZephyrMistralInnerMonologueWrapper(ZephyrMistralWrapper): "content": inner_thoughts, "function_call": { "name": function_name, - "arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII), + "arguments": json_dumps(function_parameters), }, } return message diff --git a/memgpt/local_llm/settings/settings.py b/memgpt/local_llm/settings/settings.py index 975a4e74..b51dad6f 100644 --- a/memgpt/local_llm/settings/settings.py +++ b/memgpt/local_llm/settings/settings.py @@ -1,7 +1,7 @@ import json import os -from memgpt.constants import JSON_ENSURE_ASCII, MEMGPT_DIR +from memgpt.constants import MEMGPT_DIR from memgpt.local_llm.settings.deterministic_mirostat import ( settings as det_miro_settings, ) @@ -48,9 +48,7 @@ def get_completions_settings(defaults="simple") -> dict: with open(settings_file, "r", encoding="utf-8") as file: user_settings = json.load(file) if len(user_settings) > 0: - printd( - f"Updating base settings with the following user settings:\n{json.dumps(user_settings,indent=2, ensure_ascii=JSON_ENSURE_ASCII)}" - ) + printd(f"Updating base settings with the following user settings:\n{json_dumps(user_settings,indent=2)}") settings.update(user_settings) else: printd(f"'{settings_file}' was empty, ignoring...") diff --git a/memgpt/main.py b/memgpt/main.py index e2121912..2bbc5411 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -19,12 +19,7 @@ from memgpt.cli.cli import delete_agent, open_folder, quickstart, run, server, v from memgpt.cli.cli_config import add, add_tool, configure, delete, list, list_tools from memgpt.cli.cli_load import app as load_app from memgpt.config import MemGPTConfig -from memgpt.constants import ( - FUNC_FAILED_HEARTBEAT_MESSAGE, - JSON_ENSURE_ASCII, - JSON_LOADS_STRICT, - REQ_HEARTBEAT_MESSAGE, -) +from memgpt.constants import FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE from memgpt.metadata import MetadataStore from memgpt.schemas.enums import OptionState @@ -275,14 +270,14 @@ def run_agent_loop( if args_string is None: print("Assistant missing send_message function arguments") break # cancel op - args_json = json.loads(args_string, strict=JSON_LOADS_STRICT) + args_json = json_loads(args_string) if "message" not in args_json: print("Assistant missing send_message message argument") break # cancel op # Once we found our target, rewrite it args_json["message"] = text - new_args_string = json.dumps(args_json, ensure_ascii=JSON_ENSURE_ASCII) + new_args_string = json_dumps(args_json) message_obj.tool_calls[0].function["arguments"] = new_args_string # To persist to the database, all we need to do is "re-insert" into recall memory diff --git a/memgpt/openai_backcompat/openai_object.py b/memgpt/openai_backcompat/openai_object.py index ec157db5..37ffe02f 100644 --- a/memgpt/openai_backcompat/openai_object.py +++ b/memgpt/openai_backcompat/openai_object.py @@ -1,18 +1,10 @@ # https://github.com/openai/openai-python/blob/v0.27.4/openai/openai_object.py -import json from copy import deepcopy from enum import Enum from typing import Optional, Tuple, Union -from memgpt.constants import JSON_ENSURE_ASCII - -# import openai - - -# from openai import api_requestor, util -# from openai.openai_response import OpenAIResponse -# from openai.util import ApiType +from memgpt.utils import json_dumps api_requestor = None api_resources = None @@ -342,7 +334,7 @@ class OpenAIObject(dict): def __str__(self): obj = self.to_dict_recursive() - return json.dumps(obj, sort_keys=True, indent=2, ensure_ascii=JSON_ENSURE_ASCII) + return json_dumps(obj, sort_keys=True, indent=2) def to_dict(self): return dict(self) diff --git a/memgpt/prompts/gpt_functions.py b/memgpt/prompts/gpt_functions.py index 41af04a0..9c67f6b6 100644 --- a/memgpt/prompts/gpt_functions.py +++ b/memgpt/prompts/gpt_functions.py @@ -1,4 +1,11 @@ -from ..constants import FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, MAX_PAUSE_HEARTBEATS +from ..constants import MAX_PAUSE_HEARTBEATS + +request_heartbeat = { + "request_heartbeat": { + "type": "boolean", + "description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.", + } +} # FUNCTIONS_PROMPT_MULTISTEP_NO_HEARTBEATS = FUNCTIONS_PROMPT_MULTISTEP[:-1] FUNCTIONS_CHAINING = { @@ -42,12 +49,8 @@ FUNCTIONS_CHAINING = { "message": { "type": "string", "description": "Message to send ChatGPT. Phrase your message as a full English sentence.", - }, - "request_heartbeat": { - "type": "boolean", - "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, - }, - }, + } + }.update(request_heartbeat), "required": ["message", "request_heartbeat"], }, }, @@ -65,11 +68,7 @@ FUNCTIONS_CHAINING = { "type": "string", "description": "Content to write to the memory. All unicode (including emojis) are supported.", }, - "request_heartbeat": { - "type": "boolean", - "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, - }, - }, + }.update(request_heartbeat), "required": ["name", "content", "request_heartbeat"], }, }, @@ -91,11 +90,7 @@ FUNCTIONS_CHAINING = { "type": "string", "description": "Content to write to the memory. All unicode (including emojis) are supported.", }, - "request_heartbeat": { - "type": "boolean", - "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, - }, - }, + }.update(request_heartbeat), "required": ["name", "old_content", "new_content", "request_heartbeat"], }, }, @@ -113,11 +108,7 @@ FUNCTIONS_CHAINING = { "type": "integer", "description": "Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).", }, - "request_heartbeat": { - "type": "boolean", - "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, - }, - }, + }.update(request_heartbeat), "required": ["query", "page", "request_heartbeat"], }, }, @@ -135,11 +126,7 @@ FUNCTIONS_CHAINING = { "type": "integer", "description": "Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).", }, - "request_heartbeat": { - "type": "boolean", - "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, - }, - }, + }.update(request_heartbeat), "required": ["query", "request_heartbeat"], }, }, @@ -161,11 +148,7 @@ FUNCTIONS_CHAINING = { "type": "integer", "description": "Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).", }, - "request_heartbeat": { - "type": "boolean", - "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, - }, - }, + }.update(request_heartbeat), "required": ["start_date", "end_date", "page", "request_heartbeat"], }, }, @@ -187,11 +170,7 @@ FUNCTIONS_CHAINING = { "type": "integer", "description": "Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).", }, - "request_heartbeat": { - "type": "boolean", - "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, - }, - }, + }.update(request_heartbeat), "required": ["start_date", "end_date", "request_heartbeat"], }, }, @@ -205,11 +184,7 @@ FUNCTIONS_CHAINING = { "type": "string", "description": "Content to write to the memory. All unicode (including emojis) are supported.", }, - "request_heartbeat": { - "type": "boolean", - "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, - }, - }, + }.update(request_heartbeat), "required": ["content", "request_heartbeat"], }, }, @@ -227,11 +202,7 @@ FUNCTIONS_CHAINING = { "type": "integer", "description": "Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).", }, - "request_heartbeat": { - "type": "boolean", - "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, - }, - }, + }.update(request_heartbeat), "required": ["query", "request_heartbeat"], }, }, @@ -253,11 +224,7 @@ FUNCTIONS_CHAINING = { "type": "integer", "description": "How many lines to read (defaults to 1).", }, - "request_heartbeat": { - "type": "boolean", - "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, - }, - }, + }.update(request_heartbeat), "required": ["filename", "line_start", "request_heartbeat"], }, }, @@ -275,11 +242,7 @@ FUNCTIONS_CHAINING = { "type": "string", "description": "Content to append to the file.", }, - "request_heartbeat": { - "type": "boolean", - "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, - }, - }, + }.update(request_heartbeat), "required": ["filename", "content", "request_heartbeat"], }, }, @@ -301,11 +264,7 @@ FUNCTIONS_CHAINING = { "type": "string", "description": "A JSON string representing the request payload.", }, - "request_heartbeat": { - "type": "boolean", - "description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, - }, - }, + }.update(request_heartbeat), "required": ["method", "url", "request_heartbeat"], }, }, diff --git a/memgpt/schemas/message.py b/memgpt/schemas/message.py index cc42250b..2ba7ceab 100644 --- a/memgpt/schemas/message.py +++ b/memgpt/schemas/message.py @@ -6,13 +6,13 @@ from typing import List, Optional, Union from pydantic import Field, field_validator -from memgpt.constants import JSON_ENSURE_ASCII, TOOL_CALL_ID_MAX_LEN +from memgpt.constants import TOOL_CALL_ID_MAX_LEN from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG from memgpt.schemas.enums import MessageRole from memgpt.schemas.memgpt_base import MemGPTBase from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage from memgpt.schemas.openai.chat_completions import ToolCall -from memgpt.utils import get_utc_time, is_utc_datetime +from memgpt.utils import get_utc_time, is_utc_datetime, json_dumps def add_inner_thoughts_to_tool_call( @@ -29,7 +29,7 @@ def add_inner_thoughts_to_tool_call( func_args[inner_thoughts_key] = inner_thoughts # create the updated tool call (as a string) updated_tool_call = copy.deepcopy(tool_call) - updated_tool_call.function.arguments = json.dumps(func_args, ensure_ascii=JSON_ENSURE_ASCII) + updated_tool_call.function.arguments = json_dumps(func_args) return updated_tool_call except json.JSONDecodeError as e: # TODO: change to logging @@ -517,7 +517,7 @@ class Message(BaseMessage): cohere_message = [] for tc in self.tool_calls: # TODO better way to pack? - function_call_text = json.dumps(tc.to_dict(), ensure_ascii=JSON_ENSURE_ASCII) + function_call_text = json_dumps(tc.to_dict()) cohere_message.append( { "role": function_call_role, diff --git a/memgpt/server/rest_api/utils.py b/memgpt/server/rest_api/utils.py index 4f09f6fb..79bfcb16 100644 --- a/memgpt/server/rest_api/utils.py +++ b/memgpt/server/rest_api/utils.py @@ -5,7 +5,11 @@ from typing import AsyncGenerator, Union from pydantic import BaseModel -from memgpt.constants import JSON_ENSURE_ASCII +from memgpt.orm.user import User +from memgpt.orm.utilities import get_db_session +from memgpt.server.rest_api.interface import StreamingServerInterface +from memgpt.server.server import SyncServer +from memgpt.utils import json_dumps SSE_PREFIX = "data: " SSE_SUFFIX = "\n\n" @@ -16,8 +20,28 @@ SSE_ARTIFICIAL_DELAY = 0.1 def sse_formatter(data: Union[dict, str]) -> str: """Prefix with 'data: ', and always include double newlines""" assert type(data) in [dict, str], f"Expected type dict or str, got type {type(data)}" - data_str = json.dumps(data, ensure_ascii=JSON_ENSURE_ASCII) if isinstance(data, dict) else data - return f"{SSE_PREFIX}{data_str}{SSE_SUFFIX}" + data_str = json_dumps(data) if isinstance(data, dict) else data + return f"data: {data_str}\n\n" + + +async def sse_generator(generator: Generator[dict, None, None]) -> Generator[str, None, None]: + """Generator that returns 'data: dict' formatted items, e.g.: + + data: {"id":"chatcmpl-9E0PdSZ2IBzAGlQ3SEWHJ5YwzucSP","object":"chat.completion.chunk","created":1713125205,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"}"}}]},"logprobs":null,"finish_reason":null}]} + + data: {"id":"chatcmpl-9E0PdSZ2IBzAGlQ3SEWHJ5YwzucSP","object":"chat.completion.chunk","created":1713125205,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]} + + data: [DONE] + + """ + try: + for msg in generator: + yield sse_formatter(msg) + if SSE_ARTIFICIAL_DELAY: + await asyncio.sleep(SSE_ARTIFICIAL_DELAY) # Sleep to prevent a tight loop, adjust time as needed + except Exception as e: + yield sse_formatter({"error": f"{str(e)}"}) + yield sse_formatter(SSE_FINISH_MSG) # Signal that the stream is complete async def sse_async_generator(generator: AsyncGenerator, finish_message=True): diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 3e15d747..c0fac7d1 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -20,7 +20,6 @@ from memgpt.agent import Agent, save_agent from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.cli.cli_config import get_model_options from memgpt.config import MemGPTConfig -from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT from memgpt.credentials import MemGPTCredentials from memgpt.data_sources.connectors import DataConnector, load_data @@ -66,7 +65,7 @@ from memgpt.schemas.source import Source, SourceCreate, SourceUpdate from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate from memgpt.schemas.usage import MemGPTUsageStatistics from memgpt.schemas.user import User, UserCreate -from memgpt.utils import create_random_username +from memgpt.utils import create_random_username, json_dumps, json_loads # from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin @@ -557,11 +556,9 @@ class SyncServer(LockingServer): for x in range(len(memgpt_agent.messages) - 1, 0, -1): if memgpt_agent.messages[x].get("role") == "assistant": text = command[len("rewrite ") :].strip() - args = json.loads(memgpt_agent.messages[x].get("function_call").get("arguments"), strict=JSON_LOADS_STRICT) + args = json_loads(memgpt_agent.messages[x].get("function_call").get("arguments")) args["message"] = text - memgpt_agent.messages[x].get("function_call").update( - {"arguments": json.dumps(args, ensure_ascii=JSON_ENSURE_ASCII)} - ) + memgpt_agent.messages[x].get("function_call").update({"arguments": json_dumps(args)}) break # No skip options diff --git a/memgpt/server/ws_api/example_client.py b/memgpt/server/ws_api/example_client.py index eb38cab8..377dc6ff 100644 --- a/memgpt/server/ws_api/example_client.py +++ b/memgpt/server/ws_api/example_client.py @@ -4,7 +4,6 @@ import json import websockets import memgpt.server.ws_api.protocol as protocol -from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT from memgpt.server.constants import WS_CLIENT_TIMEOUT, WS_DEFAULT_PORT from memgpt.server.utils import condition_to_stop_receiving, print_server_response @@ -27,12 +26,12 @@ async def send_message_and_print_replies(websocket, user_message, agent_id): # Wait for messages in a loop, since the server may send a few while True: response = await asyncio.wait_for(websocket.recv(), WS_CLIENT_TIMEOUT) - response = json.loads(response, strict=JSON_LOADS_STRICT) + response = json_loads(response) if CLEAN_RESPONSES: print_server_response(response) else: - print(f"Server response:\n{json.dumps(response, indent=2, ensure_ascii=JSON_ENSURE_ASCII)}") + print(f"Server response:\n{json_dumps(response, indent=2)}") # Check for a specific condition to break the loop if condition_to_stop_receiving(response): @@ -62,8 +61,8 @@ async def basic_cli_client(): await websocket.send(protocol.client_command_create(example_config)) # Wait for the response response = await websocket.recv() - response = json.loads(response, strict=JSON_LOADS_STRICT) - print(f"Server response:\n{json.dumps(response, indent=2, ensure_ascii=JSON_ENSURE_ASCII)}") + response = json_loads(response) + print(f"Server response:\n{json_dumps(response, indent=2)}") await asyncio.sleep(1) diff --git a/memgpt/server/ws_api/protocol.py b/memgpt/server/ws_api/protocol.py index fdeea3b9..0ee0a52d 100644 --- a/memgpt/server/ws_api/protocol.py +++ b/memgpt/server/ws_api/protocol.py @@ -1,89 +1,81 @@ import json -from memgpt.constants import JSON_ENSURE_ASCII +from memgpt.utils import json_dumps # Server -> client def server_error(msg): """General server error""" - return json.dumps( + return json_dumps( { "type": "server_error", "message": msg, - }, - ensure_ascii=JSON_ENSURE_ASCII, + } ) def server_command_response(status): - return json.dumps( + return json_dumps( { "type": "command_response", "status": status, - }, - ensure_ascii=JSON_ENSURE_ASCII, + } ) def server_agent_response_error(msg): - return json.dumps( + return json_dumps( { "type": "agent_response_error", "message": msg, - }, - ensure_ascii=JSON_ENSURE_ASCII, + } ) def server_agent_response_start(): - return json.dumps( + return json_dumps( { "type": "agent_response_start", - }, - ensure_ascii=JSON_ENSURE_ASCII, + } ) def server_agent_response_end(): - return json.dumps( + return json_dumps( { "type": "agent_response_end", - }, - ensure_ascii=JSON_ENSURE_ASCII, + } ) def server_agent_internal_monologue(msg): - return json.dumps( + return json_dumps( { "type": "agent_response", "message_type": "internal_monologue", "message": msg, - }, - ensure_ascii=JSON_ENSURE_ASCII, + } ) def server_agent_assistant_message(msg): - return json.dumps( + return json_dumps( { "type": "agent_response", "message_type": "assistant_message", "message": msg, - }, - ensure_ascii=JSON_ENSURE_ASCII, + } ) def server_agent_function_message(msg): - return json.dumps( + return json_dumps( { "type": "agent_response", "message_type": "function_message", "message": msg, - }, - ensure_ascii=JSON_ENSURE_ASCII, + } ) @@ -91,22 +83,20 @@ def server_agent_function_message(msg): def client_user_message(msg, agent_id=None): - return json.dumps( + return json_dumps( { "type": "user_message", "message": msg, "agent_id": agent_id, - }, - ensure_ascii=JSON_ENSURE_ASCII, + } ) def client_command_create(config): - return json.dumps( + return json_dumps( { "type": "command", "command": "create_agent", "config": config, - }, - ensure_ascii=JSON_ENSURE_ASCII, + } ) diff --git a/memgpt/server/ws_api/server.py b/memgpt/server/ws_api/server.py index 6ee9fcb8..833d103f 100644 --- a/memgpt/server/ws_api/server.py +++ b/memgpt/server/ws_api/server.py @@ -55,7 +55,7 @@ class WebSocketServer: # Assuming the message is a JSON string try: - data = json.loads(message, strict=JSON_LOADS_STRICT) + data = json_loads(message) except: print(f"[server] bad data from client:\n{data}") await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}")) diff --git a/memgpt/system.py b/memgpt/system.py index dcfefdca..d903bf1f 100644 --- a/memgpt/system.py +++ b/memgpt/system.py @@ -6,10 +6,9 @@ from .constants import ( INITIAL_BOOT_MESSAGE, INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG, INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT, - JSON_ENSURE_ASCII, MESSAGE_SUMMARY_WARNING_STR, ) -from .utils import get_local_time +from .utils import get_local_time, json_dumps def get_initial_boot_messages(version="startup"): @@ -98,7 +97,7 @@ def get_heartbeat(reason="Automated timer", include_location=False, location_nam if include_location: packaged_message["location"] = location_name - return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII) + return json_dumps(packaged_message) def get_login_event(last_login="Never (first login)", include_location=False, location_name="San Francisco, CA, USA"): @@ -113,7 +112,7 @@ def get_login_event(last_login="Never (first login)", include_location=False, lo if include_location: packaged_message["location"] = location_name - return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII) + return json_dumps(packaged_message) def package_user_message( @@ -137,7 +136,7 @@ def package_user_message( if name: packaged_message["name"] = name - return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII) + return json_dumps(packaged_message) def package_function_response(was_success, response_string, timestamp=None): @@ -148,7 +147,7 @@ def package_function_response(was_success, response_string, timestamp=None): "time": formatted_time, } - return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII) + return json_dumps(packaged_message) def package_system_message(system_message, message_type="system_alert", time=None): @@ -175,7 +174,7 @@ def package_summarize_message(summary, summary_length, hidden_message_count, tot "time": formatted_time, } - return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII) + return json_dumps(packaged_message) def package_summarize_message_no_summary(hidden_message_count, timestamp=None, message=None): @@ -194,7 +193,7 @@ def package_summarize_message_no_summary(hidden_message_count, timestamp=None, m "time": formatted_time, } - return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII) + return json_dumps(packaged_message) def get_token_limit_warning(): @@ -205,4 +204,4 @@ def get_token_limit_warning(): "time": formatted_time, } - return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII) + return json_dumps(packaged_message) diff --git a/memgpt/utils.py b/memgpt/utils.py index 433b6b62..b1bb7fbd 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -28,12 +28,9 @@ from memgpt.constants import ( CORE_MEMORY_HUMAN_CHAR_LIMIT, CORE_MEMORY_PERSONA_CHAR_LIMIT, FUNCTION_RETURN_CHAR_LIMIT, - JSON_ENSURE_ASCII, - JSON_LOADS_STRICT, MEMGPT_DIR, TOOL_CALL_ID_MAX_LEN, ) -from memgpt.openai_backcompat.openai_object import OpenAIObject from memgpt.schemas.openai.chat_completion_response import ChatCompletionResponse DEBUG = False @@ -782,6 +779,8 @@ def open_folder_in_explorer(folder_path): class OpenAIBackcompatUnpickler(pickle.Unpickler): def find_class(self, module, name): if module == "openai.openai_object": + from memgpt.openai_backcompat.openai_object import OpenAIObject + return OpenAIObject return super().find_class(module, name) @@ -873,7 +872,7 @@ def parse_json(string) -> dict: """Parse JSON string into JSON with both json and demjson""" result = None try: - result = json.loads(string, strict=JSON_LOADS_STRICT) + result = json_loads(string) return result except Exception as e: print(f"Error parsing json with json package: {e}") @@ -906,7 +905,7 @@ def validate_function_response(function_response_string: any, strict: bool = Fal # Allow dict through since it will be cast to json.dumps() try: # TODO find a better way to do this that won't result in double escapes - function_response_string = json.dumps(function_response_string, ensure_ascii=JSON_ENSURE_ASCII) + function_response_string = json_dumps(function_response_string) except: raise ValueError(function_response_string) @@ -1020,8 +1019,8 @@ def get_human_text(name: str): def get_schema_diff(schema_a, schema_b): # Assuming f_schema and linked_function['json_schema'] are your JSON schemas - f_schema_json = json.dumps(schema_a, indent=2, ensure_ascii=JSON_ENSURE_ASCII) - linked_function_json = json.dumps(schema_b, indent=2, ensure_ascii=JSON_ENSURE_ASCII) + f_schema_json = json_dumps(schema_a) + linked_function_json = json_dumps(schema_b) # Compute the difference using difflib difference = list(difflib.ndiff(f_schema_json.splitlines(keepends=True), linked_function_json.splitlines(keepends=True))) @@ -1056,3 +1055,11 @@ def create_uuid_from_string(val: str): """ hex_string = hashlib.md5(val.encode("UTF-8")).hexdigest() return uuid.UUID(hex=hex_string) + + +def json_dumps(data, indent=2): + return json.dumps(data, indent=indent, ensure_ascii=False) + + +def json_loads(data): + return json.loads(data, strict=False) diff --git a/paper_experiments/nested_kv_task/nested_kv.py b/paper_experiments/nested_kv_task/nested_kv.py index 34d9ef8c..f2cfe6eb 100644 --- a/paper_experiments/nested_kv_task/nested_kv.py +++ b/paper_experiments/nested_kv_task/nested_kv.py @@ -34,7 +34,6 @@ from tqdm import tqdm from memgpt import MemGPT, utils from memgpt.cli.cli_config import delete from memgpt.config import MemGPTConfig -from memgpt.constants import JSON_ENSURE_ASCII # TODO: update personas NESTED_PERSONA = "You are MemGPT DOC-QA bot. Your job is to answer questions about documents that are stored in your archival memory. The answer to the users question will ALWAYS be in your archival memory, so remember to keep searching if you can't find the answer. DO NOT STOP SEARCHING UNTIL YOU VERIFY THAT THE VALUE IS NOT A KEY. Do not stop making nested lookups until this condition is met." # TODO decide on a good persona/human @@ -71,7 +70,7 @@ def archival_memory_text_search(self, query: str, page: Optional[int] = 0) -> Op else: results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):" results_formatted = [f"memory: {d.text}" for d in results] - results_str = f"{results_pref} {json.dumps(results_formatted, ensure_ascii=JSON_ENSURE_ASCII)}" + results_str = f"{results_pref} {utils.json_dumps(results_formatted)}" return results_str diff --git a/tests/test_agent_function_update.py b/tests/test_agent_function_update.py new file mode 100644 index 00000000..524966e9 --- /dev/null +++ b/tests/test_agent_function_update.py @@ -0,0 +1,121 @@ +import inspect +import os +import uuid + +import pytest + +from memgpt import constants, create_client +from memgpt.functions.functions import USER_FUNCTIONS_DIR +from memgpt.schemas.message import Message +from memgpt.settings import settings +from memgpt.utils import assistant_function_to_tool, json_dumps, json_loads +from tests.mock_factory.models import MockUserFactory +from tests.utils import create_config, wipe_config + + +def hello_world(self) -> str: + """Test function for agent to gain access to + + Returns: + str: A message for the world + """ + return "hello, world!" + + +@pytest.fixture(scope="module") +def agent(): + """Create a test agent that we can call functions on""" + wipe_config() + global client + if os.getenv("OPENAI_API_KEY"): + create_config("openai") + else: + create_config("memgpt_hosted") + + # create memgpt client + client = create_client() + + # ensure user exists + user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid) + if not client.server.get_user(user_id=user_id): + client.server.create_user({"id": user_id}) + + agent_state = client.create_agent( + preset=settings.preset, + ) + + return client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id) + + +@pytest.fixture(scope="module") +def hello_world_function(): + with open(os.path.join(USER_FUNCTIONS_DIR, "hello_world.py"), "w", encoding="utf-8") as f: + f.write(inspect.getsource(hello_world)) + + +@pytest.fixture(scope="module") +def ai_function_call(): + return Message( + **assistant_function_to_tool( + { + "role": "assistant", + "content": "I will now call hello world", + "function_call": { + "name": "hello_world", + "arguments": json_dumps({}), + }, + } + ) + ) + + +def test_add_function_happy(agent, hello_world_function, ai_function_call): + agent.add_function("hello_world") + + assert "hello_world" in [f_schema["name"] for f_schema in agent.functions] + assert "hello_world" in agent.functions_python.keys() + + msgs, heartbeat_req, function_failed = agent._handle_ai_response(ai_function_call) + content = json_loads(msgs[-1].to_openai_dict()["content"]) + assert content["message"] == "hello, world!" + assert content["status"] == "OK" + assert not function_failed + + +def test_add_function_already_loaded(agent, hello_world_function): + agent.add_function("hello_world") + # no exception for duplicate loading + agent.add_function("hello_world") + + +def test_add_function_not_exist(agent): + # pytest assert exception + with pytest.raises(ValueError): + agent.add_function("non_existent") + + +def test_remove_function_happy(agent, hello_world_function): + agent.add_function("hello_world") + + # ensure function is loaded + assert "hello_world" in [f_schema["name"] for f_schema in agent.functions] + assert "hello_world" in agent.functions_python.keys() + + agent.remove_function("hello_world") + + assert "hello_world" not in [f_schema["name"] for f_schema in agent.functions] + assert "hello_world" not in agent.functions_python.keys() + + +def test_remove_function_not_exist(agent): + # do not raise error + agent.remove_function("non_existent") + + +def test_remove_base_function_fails(agent): + with pytest.raises(ValueError): + agent.remove_function("send_message") + + +if __name__ == "__main__": + pytest.main(["-vv", os.path.abspath(__file__)]) diff --git a/tests/test_function_parser.py b/tests/test_function_parser.py index 146b29ad..eebc1074 100644 --- a/tests/test_function_parser.py +++ b/tests/test_function_parser.py @@ -1,8 +1,8 @@ import json import memgpt.system as system -from memgpt.constants import JSON_ENSURE_ASCII from memgpt.local_llm.function_parser import patch_function +from memgpt.utils import json_dumps EXAMPLE_FUNCTION_CALL_SEND_MESSAGE = { "message_history": [ @@ -32,7 +32,7 @@ EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING = { "content": "I'll append to memory.", "function_call": { "name": "core_memory_append", - "arguments": json.dumps({"content": "new_stuff"}, ensure_ascii=JSON_ENSURE_ASCII), + "arguments": json_dumps({"content": "new_stuff"}), }, }, } diff --git a/tests/test_json_parsers.py b/tests/test_json_parsers.py index 35ee19aa..d04aa5d6 100644 --- a/tests/test_json_parsers.py +++ b/tests/test_json_parsers.py @@ -1,7 +1,7 @@ -import json +from mempgt.utils import json_loads import memgpt.local_llm.json_parser as json_parser -from memgpt.constants import JSON_LOADS_STRICT +from memgpt.constants import json EXAMPLE_ESCAPED_UNDERSCORES = """{ "function":"send\_message", @@ -90,7 +90,7 @@ def test_json_parsers(): for string in test_strings: try: - json.loads(string, strict=JSON_LOADS_STRICT) + json_loads(string) assert False, f"Test JSON string should have failed basic JSON parsing:\n{string}" except: print("String failed (expectedly)") diff --git a/tests/test_websocket_server.py b/tests/test_websocket_server.py index 4dfff052..ab20c259 100644 --- a/tests/test_websocket_server.py +++ b/tests/test_websocket_server.py @@ -1,12 +1,11 @@ import asyncio -import json import pytest import websockets -from memgpt.constants import JSON_ENSURE_ASCII from memgpt.server.constants import WS_DEFAULT_PORT from memgpt.server.ws_api.server import WebSocketServer +from memgpt.utils import json_dumps @pytest.mark.asyncio @@ -36,7 +35,7 @@ async def test_websocket_server(): async with websockets.connect(uri) as websocket: # Initialize the server with a test config print("Sending config to server...") - await websocket.send(json.dumps({"type": "initialize", "config": test_config}, ensure_ascii=JSON_ENSURE_ASCII)) + await websocket.send(json_dumps({"type": "initialize", "config": test_config})) # Wait for the response response = await websocket.recv() print(f"Response from the agent: {response}") @@ -45,7 +44,7 @@ async def test_websocket_server(): # Send a message to the agent print("Sending message to server...") - await websocket.send(json.dumps({"type": "message", "content": "Hello, Agent!"}, ensure_ascii=JSON_ENSURE_ASCII)) + await websocket.send(json_dumps({"type": "message", "content": "Hello, Agent!"})) # Wait for the response # NOTE: we should be waiting for multiple responses response = await websocket.recv()