refactor: Remove JSON constant for common method (#1680)

This commit is contained in:
Ethan Knox
2024-08-26 16:47:41 -04:00
committed by GitHub
parent 7f589eaf63
commit 0491a8bbe3
39 changed files with 368 additions and 320 deletions

View File

@@ -11,8 +11,6 @@ from memgpt.constants import (
CLI_WARNING_PREFIX,
FIRST_MESSAGE_ATTEMPTS,
IN_CONTEXT_MEMORY_KEYWORD,
JSON_ENSURE_ASCII,
JSON_LOADS_STRICT,
LLM_MAX_TOKENS,
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
@@ -44,6 +42,8 @@ from memgpt.utils import (
get_tool_call_id,
get_utc_time,
is_utc_datetime,
json_dumps,
json_loads,
parse_json,
printd,
united_diff,
@@ -654,12 +654,12 @@ class Agent(object):
def strip_name_field_from_user_message(user_message_text: str) -> Tuple[str, Optional[str]]:
"""If 'name' exists in the JSON string, remove it and return the cleaned text + name value"""
try:
user_message_json = dict(json.loads(user_message_text, strict=JSON_LOADS_STRICT))
user_message_json = dict(json_loads(user_message_text))
# Special handling for AutoGen messages with 'name' field
# Treat 'name' as a special field
# If it exists in the input message, elevate it to the 'message' level
name = user_message_json.pop("name", None)
clean_message = json.dumps(user_message_json, ensure_ascii=JSON_ENSURE_ASCII)
clean_message = json_dumps(user_message_json)
except Exception as e:
print(f"{CLI_WARNING_PREFIX}handling of 'name' field failed with: {e}")
@@ -668,8 +668,8 @@ class Agent(object):
def validate_json(user_message_text: str, raise_on_error: bool) -> str:
try:
user_message_json = dict(json.loads(user_message_text, strict=JSON_LOADS_STRICT))
user_message_json_val = json.dumps(user_message_json, ensure_ascii=JSON_ENSURE_ASCII)
user_message_json = dict(json_loads(user_message_text))
user_message_json_val = json_dumps(user_message_json)
return user_message_json_val
except Exception as e:
print(f"{CLI_WARNING_PREFIX}couldn't parse user input message as JSON: {e}")

View File

@@ -4,7 +4,7 @@ from typing import Optional
from colorama import Fore, Style, init
from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT
from memgpt.constants import CLI_WARNING_PREFIX
from memgpt.schemas.message import Message
init(autoreset=True)
@@ -113,7 +113,7 @@ class AutoGenInterface(object):
return
else:
try:
msg_json = json.loads(msg, strict=JSON_LOADS_STRICT)
msg_json = json_loads(msg)
except:
print(f"{CLI_WARNING_PREFIX}failed to parse user message into json")
message = f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}" if self.fancy else f"[user] {msg}"
@@ -203,7 +203,7 @@ class AutoGenInterface(object):
)
else:
try:
msg_dict = json.loads(msg, strict=JSON_LOADS_STRICT)
msg_dict = json_loads(msg)
if "status" in msg_dict and msg_dict["status"] == "OK":
message = (
f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}"

View File

@@ -114,14 +114,9 @@ REQ_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}Function called using request_hea
# FUNC_FAILED_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}Function call failed"
FUNC_FAILED_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}Function call failed, returning control"
FUNCTION_PARAM_NAME_REQ_HEARTBEAT = "request_heartbeat"
FUNCTION_PARAM_TYPE_REQ_HEARTBEAT = "boolean"
FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT = "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function."
RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE = 5
# GLOBAL SETTINGS FOR `json.dumps()`
JSON_ENSURE_ASCII = False
# GLOBAL SETTINGS FOR `json.loads()`
JSON_LOADS_STRICT = False
# TODO Is this config or constant?
CORE_MEMORY_PERSONA_CHAR_LIMIT: int = 2000
CORE_MEMORY_HUMAN_CHAR_LIMIT: int = 2000

View File

@@ -4,11 +4,7 @@ import math
from typing import Optional
from memgpt.agent import Agent
from memgpt.constants import (
JSON_ENSURE_ASCII,
MAX_PAUSE_HEARTBEATS,
RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE,
)
from memgpt.constants import MAX_PAUSE_HEARTBEATS, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
### Functions / tools the agent can use
# All functions should return a response string (or None)
@@ -81,7 +77,7 @@ def conversation_search(self: Agent, query: str, page: Optional[int] = 0) -> Opt
else:
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results]
results_str = f"{results_pref} {json.dumps(results_formatted, ensure_ascii=JSON_ENSURE_ASCII)}"
results_str = f"{results_pref} {json_dumps(results_formatted)}"
return results_str
@@ -111,7 +107,7 @@ def conversation_search_date(self: Agent, start_date: str, end_date: str, page:
else:
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results]
results_str = f"{results_pref} {json.dumps(results_formatted, ensure_ascii=JSON_ENSURE_ASCII)}"
results_str = f"{results_pref} {json_dumps(results_formatted)}"
return results_str
@@ -154,5 +150,5 @@ def archival_memory_search(self: Agent, query: str, page: Optional[int] = 0) ->
else:
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
results_formatted = [f"timestamp: {d['timestamp']}, memory: {d['content']}" for d in results]
results_str = f"{results_pref} {json.dumps(results_formatted, ensure_ascii=JSON_ENSURE_ASCII)}"
results_str = f"{results_pref} {json_dumps(results_formatted)}"
return results_str

View File

@@ -6,8 +6,6 @@ from typing import Optional
import requests
from memgpt.constants import (
JSON_ENSURE_ASCII,
JSON_LOADS_STRICT,
MESSAGE_CHATGPT_FUNCTION_MODEL,
MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE,
)
@@ -123,10 +121,10 @@ def http_request(self, method: str, url: str, payload_json: Optional[str] = None
else:
# Validate and convert the payload for other types of requests
if payload_json:
payload = json.loads(payload_json, strict=JSON_LOADS_STRICT)
payload = json_loads(payload_json)
else:
payload = {}
print(f"[HTTP] launching {method} request to {url}, payload=\n{json.dumps(payload, indent=2, ensure_ascii=JSON_ENSURE_ASCII)}")
print(f"[HTTP] launching {method} request to {url}, payload=\n{json_dumps(payload, indent=2)}")
response = requests.request(method, url, json=payload, headers=headers)
return {"status_code": response.status_code, "headers": dict(response.headers), "body": response.text}

View File

@@ -5,14 +5,6 @@ from typing import Any, Dict, Optional, Type, get_args, get_origin
from docstring_parser import parse
from pydantic import BaseModel
from memgpt.constants import (
FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
FUNCTION_PARAM_NAME_REQ_HEARTBEAT,
FUNCTION_PARAM_TYPE_REQ_HEARTBEAT,
)
NO_HEARTBEAT_FUNCTIONS = ["send_message", "pause_heartbeats"]
def is_optional(annotation):
# Check if the annotation is a Union
@@ -136,12 +128,12 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
schema["parameters"]["required"].append(param.name)
# append the heartbeat
if function.__name__ not in NO_HEARTBEAT_FUNCTIONS:
schema["parameters"]["properties"][FUNCTION_PARAM_NAME_REQ_HEARTBEAT] = {
"type": FUNCTION_PARAM_TYPE_REQ_HEARTBEAT,
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
if function.__name__ not in ["send_message", "pause_heartbeats"]:
schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.",
}
schema["parameters"]["required"].append(FUNCTION_PARAM_NAME_REQ_HEARTBEAT)
schema["parameters"]["required"].append("request_heartbeat")
return schema

View File

@@ -5,9 +5,9 @@ from typing import List, Optional
from colorama import Fore, Style, init
from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT
from memgpt.constants import CLI_WARNING_PREFIX
from memgpt.schemas.message import Message
from memgpt.utils import printd
from memgpt.utils import json_loads, printd
init(autoreset=True)
@@ -127,7 +127,7 @@ class CLIInterface(AgentInterface):
return
else:
try:
msg_json = json.loads(msg, strict=JSON_LOADS_STRICT)
msg_json = json_loads(msg)
except:
printd(f"{CLI_WARNING_PREFIX}failed to parse user message into json")
printd_user_message("🧑", msg)
@@ -228,7 +228,7 @@ class CLIInterface(AgentInterface):
printd_function_message("", msg)
else:
try:
msg_dict = json.loads(msg, strict=JSON_LOADS_STRICT)
msg_dict = json_loads(msg)
if "status" in msg_dict and msg_dict["status"] == "OK":
printd_function_message("", str(msg), color=Fore.GREEN)
else:
@@ -259,7 +259,7 @@ class CLIInterface(AgentInterface):
CLIInterface.internal_monologue(content)
# I think the next one is not up to date
# function_message(msg["function_call"])
args = json.loads(msg["function_call"].get("arguments"), strict=JSON_LOADS_STRICT)
args = json_loads(msg["function_call"].get("arguments"))
CLIInterface.assistant_message(args.get("message"))
# assistant_message(content)
elif msg.get("tool_calls"):
@@ -267,7 +267,7 @@ class CLIInterface(AgentInterface):
CLIInterface.internal_monologue(content)
function_obj = msg["tool_calls"][0].get("function")
if function_obj:
args = json.loads(function_obj.get("arguments"), strict=JSON_LOADS_STRICT)
args = json_loads(function_obj.get("arguments"))
CLIInterface.assistant_message(args.get("message"))
else:
CLIInterface.internal_monologue(content)

View File

@@ -5,7 +5,6 @@ from typing import List, Optional, Union
import requests
from memgpt.constants import JSON_ENSURE_ASCII
from memgpt.schemas.message import Message
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
from memgpt.schemas.openai.chat_completion_response import (
@@ -257,7 +256,7 @@ def convert_anthropic_response_to_chatcompletion(
type="function",
function=FunctionCall(
name=response_json["content"][1]["name"],
arguments=json.dumps(response_json["content"][1]["input"], ensure_ascii=JSON_ENSURE_ASCII),
arguments=json_dumps(response_json["content"][1]["input"]),
),
)
]

View File

@@ -4,7 +4,6 @@ from typing import List, Optional, Union
import requests
from memgpt.constants import JSON_ENSURE_ASCII
from memgpt.local_llm.utils import count_tokens
from memgpt.schemas.message import Message
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
@@ -17,7 +16,7 @@ from memgpt.schemas.openai.chat_completion_response import (
Message as ChoiceMessage, # NOTE: avoid conflict with our own MemGPT Message datatype
)
from memgpt.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
from memgpt.utils import get_tool_call_id, get_utc_time, smart_urljoin
from memgpt.utils import get_tool_call_id, get_utc_time, json_dumps, smart_urljoin
BASE_URL = "https://api.cohere.ai/v1"
@@ -155,9 +154,7 @@ def convert_cohere_response_to_chatcompletion(
completion_tokens = response_json["meta"]["billed_units"]["output_tokens"]
else:
# For some reason input_tokens not included in 'meta' 'tokens' dict?
prompt_tokens = count_tokens(
json.dumps(response_json["chat_history"], ensure_ascii=JSON_ENSURE_ASCII)
) # NOTE: this is a very rough approximation
prompt_tokens = count_tokens(json_dumps(response_json["chat_history"])) # NOTE: this is a very rough approximation
completion_tokens = response_json["meta"]["tokens"]["output_tokens"]
finish_reason = remap_finish_reason(response_json["finish_reason"])

View File

@@ -4,7 +4,7 @@ from typing import List, Optional
import requests
from memgpt.constants import JSON_ENSURE_ASCII, NON_USER_MSG_PREFIX
from memgpt.constants import NON_USER_MSG_PREFIX
from memgpt.local_llm.json_parser import clean_json_string_extra_backslash
from memgpt.local_llm.utils import count_tokens
from memgpt.schemas.openai.chat_completion_request import Tool
@@ -310,7 +310,7 @@ def convert_google_ai_response_to_chatcompletion(
type="function",
function=FunctionCall(
name=function_name,
arguments=clean_json_string_extra_backslash(json.dumps(function_args, ensure_ascii=JSON_ENSURE_ASCII)),
arguments=clean_json_string_extra_backslash(json_dumps(function_args)),
),
)
],
@@ -374,12 +374,8 @@ def convert_google_ai_response_to_chatcompletion(
else:
# Count it ourselves
assert input_messages is not None, f"Didn't get UsageMetadata from the API response, so input_messages is required"
prompt_tokens = count_tokens(
json.dumps(input_messages, ensure_ascii=JSON_ENSURE_ASCII)
) # NOTE: this is a very rough approximation
completion_tokens = count_tokens(
json.dumps(openai_response_message.model_dump(), ensure_ascii=JSON_ENSURE_ASCII)
) # NOTE: this is also approximate
prompt_tokens = count_tokens(json_dumps(input_messages)) # NOTE: this is a very rough approximation
completion_tokens = count_tokens(json_dumps(openai_response_message.model_dump())) # NOTE: this is also approximate
total_tokens = prompt_tokens + completion_tokens
usage = UsageStatistics(
prompt_tokens=prompt_tokens,

View File

@@ -9,7 +9,7 @@ from typing import List, Optional, Union
import requests
from memgpt.constants import CLI_WARNING_PREFIX, JSON_ENSURE_ASCII
from memgpt.constants import CLI_WARNING_PREFIX
from memgpt.credentials import MemGPTCredentials
from memgpt.llm_api.anthropic import anthropic_chat_completions_request
from memgpt.llm_api.azure_openai import (
@@ -109,7 +109,7 @@ def unpack_inner_thoughts_from_kwargs(
# replace the kwargs
new_choice = choice.model_copy(deep=True)
new_choice.message.tool_calls[0].function.arguments = json.dumps(func_args, ensure_ascii=JSON_ENSURE_ASCII)
new_choice.message.tool_calls[0].function.arguments = json_dumps(func_args)
# also replace the message content
if new_choice.message.content is not None:
warnings.warn(f"Overwriting existing inner monologue ({new_choice.message.content}) with kwarg ({inner_thoughts})")

View File

@@ -1,11 +1,10 @@
"""Key idea: create drop-in replacement for agent's ChatCompletion call that runs on an OpenLLM backend"""
import json
import uuid
import requests
from memgpt.constants import CLI_WARNING_PREFIX, JSON_ENSURE_ASCII
from memgpt.constants import CLI_WARNING_PREFIX
from memgpt.errors import LocalLLMConnectionError, LocalLLMError
from memgpt.local_llm.constants import DEFAULT_WRAPPER
from memgpt.local_llm.function_parser import patch_function
@@ -33,7 +32,7 @@ from memgpt.schemas.openai.chat_completion_response import (
ToolCall,
UsageStatistics,
)
from memgpt.utils import get_tool_call_id, get_utc_time
from memgpt.utils import get_tool_call_id, get_utc_time, json_dumps
has_shown_warning = False
grammar_supported_backends = ["koboldcpp", "llamacpp", "webui", "webui-legacy"]
@@ -189,7 +188,7 @@ def get_chat_completion(
chat_completion_result = llm_wrapper.output_to_chat_completion_response(result, first_message=first_message)
else:
chat_completion_result = llm_wrapper.output_to_chat_completion_response(result)
printd(json.dumps(chat_completion_result, indent=2, ensure_ascii=JSON_ENSURE_ASCII))
printd(json_dumps(chat_completion_result, indent=2))
except Exception as e:
raise LocalLLMError(f"Failed to parse JSON from local LLM response - error: {str(e)}")
@@ -206,13 +205,13 @@ def get_chat_completion(
usage["prompt_tokens"] = count_tokens(prompt)
# NOTE: we should compute on-the-fly anyways since we might have to correct for errors during JSON parsing
usage["completion_tokens"] = count_tokens(json.dumps(chat_completion_result, ensure_ascii=JSON_ENSURE_ASCII))
usage["completion_tokens"] = count_tokens(json_dumps(chat_completion_result))
"""
if usage["completion_tokens"] is None:
printd(f"usage dict was missing completion_tokens, computing on-the-fly...")
# chat_completion_result is dict with 'role' and 'content'
# token counter wants a string
usage["completion_tokens"] = count_tokens(json.dumps(chat_completion_result, ensure_ascii=JSON_ENSURE_ASCII))
usage["completion_tokens"] = count_tokens(json_dumps(chat_completion_result))
"""
# NOTE: this is the token count that matters most

View File

@@ -1,7 +1,7 @@
import copy
import json
from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
from memgpt.utils import json_dumps, json_loads
NO_HEARTBEAT_FUNCS = ["send_message", "pause_heartbeats"]
@@ -13,16 +13,16 @@ def insert_heartbeat(message):
if message_copy.get("function_call"):
# function_name = message.get("function_call").get("name")
params = message_copy.get("function_call").get("arguments")
params = json.loads(params, strict=JSON_LOADS_STRICT)
params = json_loads(params)
params["request_heartbeat"] = True
message_copy["function_call"]["arguments"] = json.dumps(params, ensure_ascii=JSON_ENSURE_ASCII)
message_copy["function_call"]["arguments"] = json_dumps(params)
elif message_copy.get("tool_call"):
# function_name = message.get("tool_calls")[0].get("function").get("name")
params = message_copy.get("tool_calls")[0].get("function").get("arguments")
params = json.loads(params, strict=JSON_LOADS_STRICT)
params = json_loads(params)
params["request_heartbeat"] = True
message_copy["tools_calls"][0]["function"]["arguments"] = json.dumps(params, ensure_ascii=JSON_ENSURE_ASCII)
message_copy["tools_calls"][0]["function"]["arguments"] = json_dumps(params)
return message_copy
@@ -40,7 +40,7 @@ def heartbeat_correction(message_history, new_message):
last_message_was_user = False
if message_history[-1]["role"] == "user":
try:
content = json.loads(message_history[-1]["content"], strict=JSON_LOADS_STRICT)
content = json_loads(message_history[-1]["content"])
except json.JSONDecodeError:
return None
# Check if it's a user message or system message

View File

@@ -21,7 +21,7 @@ from typing import (
from docstring_parser import parse
from pydantic import BaseModel, create_model
from memgpt.constants import JSON_ENSURE_ASCII
from memgpt.utils import json_dumps, json_loads
class PydanticDataType(Enum):
@@ -731,7 +731,7 @@ def generate_markdown_documentation(
if hasattr(model, "Config") and hasattr(model.Config, "json_schema_extra") and "example" in model.Config.json_schema_extra:
documentation += f" Expected Example Output for {format_model_and_field_name(model.__name__)}:\n"
json_example = json.dumps(model.Config.json_schema_extra["example"], ensure_ascii=JSON_ENSURE_ASCII)
json_example = json_dumps(model.Config.json_schema_extra["example"])
documentation += format_multiline_description(json_example, 2) + "\n"
return documentation

View File

@@ -1,8 +1,8 @@
import json
import re
from memgpt.constants import JSON_LOADS_STRICT
from memgpt.errors import LLMJSONParsingError
from memgpt.utils import json_loads
def clean_json_string_extra_backslash(s):
@@ -45,7 +45,7 @@ def extract_first_json(string: str):
depth -= 1
if depth == 0 and start_index is not None:
try:
return json.loads(string[start_index : i + 1], strict=JSON_LOADS_STRICT)
return json_loads(string[start_index : i + 1])
except json.JSONDecodeError as e:
raise LLMJSONParsingError(f"Matched closing bracket, but decode failed with error: {str(e)}")
printd("No valid JSON object found.")
@@ -174,21 +174,21 @@ def clean_json(raw_llm_output, messages=None, functions=None):
from memgpt.utils import printd
strategies = [
lambda output: json.loads(output, strict=JSON_LOADS_STRICT),
lambda output: json.loads(output + "}", strict=JSON_LOADS_STRICT),
lambda output: json.loads(output + "}}", strict=JSON_LOADS_STRICT),
lambda output: json.loads(output + '"}}', strict=JSON_LOADS_STRICT),
lambda output: json_loads(output),
lambda output: json_loads(output + "}"),
lambda output: json_loads(output + "}}"),
lambda output: json_loads(output + '"}}'),
# with strip and strip comma
lambda output: json.loads(output.strip().rstrip(",") + "}", strict=JSON_LOADS_STRICT),
lambda output: json.loads(output.strip().rstrip(",") + "}}", strict=JSON_LOADS_STRICT),
lambda output: json.loads(output.strip().rstrip(",") + '"}}', strict=JSON_LOADS_STRICT),
lambda output: json_loads(output.strip().rstrip(",") + "}"),
lambda output: json_loads(output.strip().rstrip(",") + "}}"),
lambda output: json_loads(output.strip().rstrip(",") + '"}}'),
# more complex patchers
lambda output: json.loads(repair_json_string(output), strict=JSON_LOADS_STRICT),
lambda output: json.loads(repair_even_worse_json(output), strict=JSON_LOADS_STRICT),
lambda output: json_loads(repair_json_string(output)),
lambda output: json_loads(repair_even_worse_json(output)),
lambda output: extract_first_json(output + "}}"),
lambda output: clean_and_interpret_send_message_json(output),
# replace underscores
lambda output: json.loads(replace_escaped_underscores(output), strict=JSON_LOADS_STRICT),
lambda output: json_loads(replace_escaped_underscores(output)),
lambda output: extract_first_json(replace_escaped_underscores(output) + "}}"),
]

View File

@@ -1,6 +1,5 @@
import json
from memgpt.utils import json_dumps, json_loads
from ...constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
from ...errors import LLMJSONParsingError
from ..json_parser import clean_json
from .wrapper_base import LLMChatCompletionWrapper
@@ -114,9 +113,9 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
"""
airo_func_call = {
"function": function_call["name"],
"params": json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
"params": json_loads(function_call["arguments"]),
}
return json.dumps(airo_func_call, indent=2, ensure_ascii=JSON_ENSURE_ASCII)
return json_dumps(airo_func_call, indent=2)
# Add a sep for the conversation
if self.include_section_separators:
@@ -129,7 +128,7 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
if message["role"] == "user":
if self.simplify_json_content:
try:
content_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
content_json = json_loads(message["content"])
content_simple = content_json["message"]
prompt += f"\nUSER: {content_simple}"
except:
@@ -206,7 +205,7 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
"content": None,
"function_call": {
"name": function_name,
"arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
"arguments": json_dumps(function_parameters),
},
}
return message
@@ -325,10 +324,10 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
"function": function_call["name"],
"params": {
"inner_thoughts": inner_thoughts,
**json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
**json_loads(function_call["arguments"]),
},
}
return json.dumps(airo_func_call, indent=2, ensure_ascii=JSON_ENSURE_ASCII)
return json_dumps(airo_func_call, indent=2)
# Add a sep for the conversation
if self.include_section_separators:
@@ -347,7 +346,7 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
user_prefix = "USER"
if self.simplify_json_content:
try:
content_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
content_json = json_loads(message["content"])
content_simple = content_json["message"]
prompt += f"\n{user_prefix}: {content_simple}"
except:
@@ -447,7 +446,7 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
"content": inner_thoughts,
"function_call": {
"name": function_name,
"arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
"arguments": json_dumps(function_parameters),
},
}
return message

View File

@@ -1,11 +1,9 @@
import json
from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
from memgpt.errors import LLMJSONParsingError
from memgpt.local_llm.json_parser import clean_json
from memgpt.local_llm.llm_chat_completion_wrappers.wrapper_base import (
LLMChatCompletionWrapper,
)
from memgpt.utils import json_dumps, json_loads
PREFIX_HINT = """# Reminders:
# Important information about yourself and the user is stored in (limited) core memory
@@ -137,10 +135,10 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
"function": function_call["name"],
"params": {
"inner_thoughts": inner_thoughts,
**json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
**json_loads(function_call["arguments"]),
},
}
return json.dumps(airo_func_call, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII)
return json_dumps(airo_func_call, indent=self.json_indent)
# NOTE: BOS/EOS chatml tokens are NOT inserted here
def _compile_assistant_message(self, message) -> str:
@@ -167,15 +165,15 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
if self.simplify_json_content:
# Make user messages not JSON but plaintext instead
try:
user_msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
user_msg_json = json_loads(message["content"])
user_msg_str = user_msg_json["message"]
except:
user_msg_str = message["content"]
else:
# Otherwise just dump the full json
try:
user_msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
user_msg_str = json.dumps(user_msg_json, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII)
user_msg_json = json_loads(message["content"])
user_msg_str = json_dumps(user_msg_json, indent=self.json_indent)
except:
user_msg_str = message["content"]
@@ -189,8 +187,8 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
prompt = ""
try:
# indent the function replies
function_return_dict = json.loads(message["content"], strict=JSON_LOADS_STRICT)
function_return_str = json.dumps(function_return_dict, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII)
function_return_dict = json_loads(message["content"])
function_return_str = json_dumps(function_return_dict, indent=self.json_indent)
except:
function_return_str = message["content"]
@@ -219,7 +217,7 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
if self.use_system_role_in_user:
try:
msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
msg_json = json_loads(message["content"])
if msg_json["type"] != "user_message":
role_str = "system"
except:
@@ -329,7 +327,7 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
"content": inner_thoughts,
"function_call": {
"name": function_name,
"arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
"arguments": json_dumps(function_parameters),
},
}
return message
@@ -394,10 +392,10 @@ class ChatMLOuterInnerMonologueWrapper(ChatMLInnerMonologueWrapper):
"function": function_call["name"],
"params": {
# "inner_thoughts": inner_thoughts,
**json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
**json_loads(function_call["arguments"]),
},
}
return json.dumps(airo_func_call, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII)
return json_dumps(airo_func_call, indent=self.json_indent)
def output_to_chat_completion_response(self, raw_llm_output, first_message=False):
"""NOTE: Modified to expect "inner_thoughts" outside the function
@@ -458,7 +456,7 @@ class ChatMLOuterInnerMonologueWrapper(ChatMLInnerMonologueWrapper):
"content": inner_thoughts,
# "function_call": {
# "name": function_name,
# "arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
# "arguments": json_dumps(function_parameters),
# },
}
@@ -466,7 +464,7 @@ class ChatMLOuterInnerMonologueWrapper(ChatMLInnerMonologueWrapper):
if function_name is not None:
message["function_call"] = {
"name": function_name,
"arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
"arguments": json_dumps(function_parameters),
}
return message

View File

@@ -1,8 +1,7 @@
import json
import yaml
from ...constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
from memgpt.utils import json_dumps, json_loads
from ...errors import LLMJSONParsingError
from ..json_parser import clean_json
from .wrapper_base import LLMChatCompletionWrapper
@@ -131,10 +130,10 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper):
"function": function_call["name"],
"params": {
"inner_thoughts": inner_thoughts,
**json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
**json_loads(function_call["arguments"]),
},
}
return json.dumps(airo_func_call, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII)
return json_dumps(airo_func_call, indent=self.json_indent)
# NOTE: BOS/EOS chatml tokens are NOT inserted here
def _compile_assistant_message(self, message) -> str:
@@ -161,15 +160,15 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper):
if self.simplify_json_content:
# Make user messages not JSON but plaintext instead
try:
user_msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
user_msg_json = json_loads(message["content"])
user_msg_str = user_msg_json["message"]
except:
user_msg_str = message["content"]
else:
# Otherwise just dump the full json
try:
user_msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
user_msg_str = json.dumps(user_msg_json, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII)
user_msg_json = json_loads(message["content"])
user_msg_str = json_dumps(user_msg_json, indent=self.json_indent)
except:
user_msg_str = message["content"]
@@ -183,8 +182,8 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper):
prompt = ""
try:
# indent the function replies
function_return_dict = json.loads(message["content"], strict=JSON_LOADS_STRICT)
function_return_str = json.dumps(function_return_dict, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII)
function_return_dict = json_loads(message["content"])
function_return_str = json_dumps(function_return_dict, indent=self.json_indent)
except:
function_return_str = message["content"]
@@ -309,7 +308,7 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper):
"content": inner_thoughts,
"function_call": {
"name": function_name,
"arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
"arguments": json_dumps(function_parameters),
},
}
return message

View File

@@ -1,6 +1,7 @@
import json
from ...constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
from memgpt.utils import json_dumps, json_loads
from ...errors import LLMJSONParsingError
from ..json_parser import clean_json
from .wrapper_base import LLMChatCompletionWrapper
@@ -127,9 +128,9 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
"""
airo_func_call = {
"function": function_call["name"],
"params": json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
"params": json_loads(function_call["arguments"]),
}
return json.dumps(airo_func_call, indent=2, ensure_ascii=JSON_ENSURE_ASCII)
return json_dumps(airo_func_call, indent=2)
# option (1): from HF README:
# <|im_start|>user
@@ -156,7 +157,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
if message["role"] == "user":
if self.simplify_json_content:
try:
content_json = (json.loads(message["content"], strict=JSON_LOADS_STRICT),)
content_json = (json_loads(message["content"]),)
content_simple = content_json["message"]
prompt += f"\n{IM_START_TOKEN}user\n{content_simple}{IM_END_TOKEN}"
# prompt += f"\nUSER: {content_simple}"
@@ -241,7 +242,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
"content": None,
"function_call": {
"name": function_name,
"arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
"arguments": json_dumps(function_parameters),
},
}
return message

View File

@@ -1,11 +1,11 @@
import json
from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
from memgpt.errors import LLMJSONParsingError
from memgpt.local_llm.json_parser import clean_json
from memgpt.local_llm.llm_chat_completion_wrappers.wrapper_base import (
LLMChatCompletionWrapper,
)
from memgpt.utils import json_dumps, json_loads
PREFIX_HINT = """# Reminders:
# Important information about yourself and the user is stored in (limited) core memory
@@ -137,10 +137,10 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper):
"function": function_call["name"],
"params": {
"inner_thoughts": inner_thoughts,
**json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
**json_loads(function_call["arguments"]),
},
}
return json.dumps(airo_func_call, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII)
return json_dumps(airo_func_call, indent=self.json_indent)
# NOTE: BOS/EOS chatml tokens are NOT inserted here
def _compile_assistant_message(self, message) -> str:
@@ -167,18 +167,17 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper):
if self.simplify_json_content:
# Make user messages not JSON but plaintext instead
try:
user_msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
user_msg_json = json_loads(message["content"])
user_msg_str = user_msg_json["message"]
except:
user_msg_str = message["content"]
else:
# Otherwise just dump the full json
try:
user_msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
user_msg_str = json.dumps(
user_msg_json = json_loads(message["content"])
user_msg_str = json_dumps(
user_msg_json,
indent=self.json_indent,
ensure_ascii=JSON_ENSURE_ASCII,
)
except:
user_msg_str = message["content"]
@@ -193,11 +192,10 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper):
prompt = ""
try:
# indent the function replies
function_return_dict = json.loads(message["content"], strict=JSON_LOADS_STRICT)
function_return_str = json.dumps(
function_return_dict = json_loads(message["content"])
function_return_str = json_dumps(
function_return_dict,
indent=self.json_indent,
ensure_ascii=JSON_ENSURE_ASCII,
)
except:
function_return_str = message["content"]
@@ -229,7 +227,7 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper):
if self.use_system_role_in_user:
try:
msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
msg_json = json_loads(message["content"])
if msg_json["type"] != "user_message":
role_str = "system"
except:
@@ -343,7 +341,7 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper):
"content": inner_thoughts,
"function_call": {
"name": function_name,
"arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
"arguments": json_dumps(function_parameters),
},
}
return message

View File

@@ -1,6 +1,7 @@
import json
from ...constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
from memgpt.utils import json_dumps, json_loads
from .wrapper_base import LLMChatCompletionWrapper
@@ -85,9 +86,9 @@ class SimpleSummaryWrapper(LLMChatCompletionWrapper):
"""
airo_func_call = {
"function": function_call["name"],
"params": json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
"params": json_loads(function_call["arguments"]),
}
return json.dumps(airo_func_call, indent=2, ensure_ascii=JSON_ENSURE_ASCII)
return json_dumps(airo_func_call, indent=2)
# Add a sep for the conversation
if self.include_section_separators:
@@ -100,7 +101,7 @@ class SimpleSummaryWrapper(LLMChatCompletionWrapper):
if message["role"] == "user":
if self.simplify_json_content:
try:
content_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
content_json = json_loads(message["content"])
content_simple = content_json["message"]
prompt += f"\nUSER: {content_simple}"
except:
@@ -151,7 +152,7 @@ class SimpleSummaryWrapper(LLMChatCompletionWrapper):
"content": raw_llm_output,
# "function_call": {
# "name": function_name,
# "arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
# "arguments": json_dumps(function_parameters),
# },
}
return message

View File

@@ -1,6 +1,7 @@
import json
from ...constants import JSON_ENSURE_ASCII
from memgpt.utils import json_dumps, json_loads
from ...errors import LLMJSONParsingError
from ..json_parser import clean_json
from .wrapper_base import LLMChatCompletionWrapper
@@ -76,9 +77,9 @@ class ZephyrMistralWrapper(LLMChatCompletionWrapper):
def create_function_call(function_call):
airo_func_call = {
"function": function_call["name"],
"params": json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
"params": json_loads(function_call["arguments"]),
}
return json.dumps(airo_func_call, indent=2, ensure_ascii=JSON_ENSURE_ASCII)
return json_dumps(airo_func_call, indent=2)
for message in messages[1:]:
assert message["role"] in ["user", "assistant", "function", "tool"], message
@@ -86,7 +87,7 @@ class ZephyrMistralWrapper(LLMChatCompletionWrapper):
if message["role"] == "user":
if self.simplify_json_content:
try:
content_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
content_json = json_loads(message["content"])
content_simple = content_json["message"]
prompt += f"\n<|user|>\n{content_simple}{IM_END_TOKEN}"
# prompt += f"\nUSER: {content_simple}"
@@ -171,7 +172,7 @@ class ZephyrMistralWrapper(LLMChatCompletionWrapper):
"content": None,
"function_call": {
"name": function_name,
"arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
"arguments": json_dumps(function_parameters),
},
}
return message
@@ -239,10 +240,10 @@ class ZephyrMistralInnerMonologueWrapper(ZephyrMistralWrapper):
"function": function_call["name"],
"params": {
"inner_thoughts": inner_thoughts,
**json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
**json_loads(function_call["arguments"]),
},
}
return json.dumps(airo_func_call, indent=2, ensure_ascii=JSON_ENSURE_ASCII)
return json_dumps(airo_func_call, indent=2)
# Add a sep for the conversation
if self.include_section_separators:
@@ -255,7 +256,7 @@ class ZephyrMistralInnerMonologueWrapper(ZephyrMistralWrapper):
if message["role"] == "user":
if self.simplify_json_content:
try:
content_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
content_json = json_loads(message["content"])
content_simple = content_json["message"]
prompt += f"\n<|user|>\n{content_simple}{IM_END_TOKEN}"
except:
@@ -340,7 +341,7 @@ class ZephyrMistralInnerMonologueWrapper(ZephyrMistralWrapper):
"content": inner_thoughts,
"function_call": {
"name": function_name,
"arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
"arguments": json_dumps(function_parameters),
},
}
return message

View File

@@ -1,7 +1,7 @@
import json
import os
from memgpt.constants import JSON_ENSURE_ASCII, MEMGPT_DIR
from memgpt.constants import MEMGPT_DIR
from memgpt.local_llm.settings.deterministic_mirostat import (
settings as det_miro_settings,
)
@@ -48,9 +48,7 @@ def get_completions_settings(defaults="simple") -> dict:
with open(settings_file, "r", encoding="utf-8") as file:
user_settings = json.load(file)
if len(user_settings) > 0:
printd(
f"Updating base settings with the following user settings:\n{json.dumps(user_settings,indent=2, ensure_ascii=JSON_ENSURE_ASCII)}"
)
printd(f"Updating base settings with the following user settings:\n{json_dumps(user_settings,indent=2)}")
settings.update(user_settings)
else:
printd(f"'{settings_file}' was empty, ignoring...")

View File

@@ -19,12 +19,7 @@ from memgpt.cli.cli import delete_agent, open_folder, quickstart, run, server, v
from memgpt.cli.cli_config import add, add_tool, configure, delete, list, list_tools
from memgpt.cli.cli_load import app as load_app
from memgpt.config import MemGPTConfig
from memgpt.constants import (
FUNC_FAILED_HEARTBEAT_MESSAGE,
JSON_ENSURE_ASCII,
JSON_LOADS_STRICT,
REQ_HEARTBEAT_MESSAGE,
)
from memgpt.constants import FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE
from memgpt.metadata import MetadataStore
from memgpt.schemas.enums import OptionState
@@ -275,14 +270,14 @@ def run_agent_loop(
if args_string is None:
print("Assistant missing send_message function arguments")
break # cancel op
args_json = json.loads(args_string, strict=JSON_LOADS_STRICT)
args_json = json_loads(args_string)
if "message" not in args_json:
print("Assistant missing send_message message argument")
break # cancel op
# Once we found our target, rewrite it
args_json["message"] = text
new_args_string = json.dumps(args_json, ensure_ascii=JSON_ENSURE_ASCII)
new_args_string = json_dumps(args_json)
message_obj.tool_calls[0].function["arguments"] = new_args_string
# To persist to the database, all we need to do is "re-insert" into recall memory

View File

@@ -1,18 +1,10 @@
# https://github.com/openai/openai-python/blob/v0.27.4/openai/openai_object.py
import json
from copy import deepcopy
from enum import Enum
from typing import Optional, Tuple, Union
from memgpt.constants import JSON_ENSURE_ASCII
# import openai
# from openai import api_requestor, util
# from openai.openai_response import OpenAIResponse
# from openai.util import ApiType
from memgpt.utils import json_dumps
api_requestor = None
api_resources = None
@@ -342,7 +334,7 @@ class OpenAIObject(dict):
def __str__(self):
obj = self.to_dict_recursive()
return json.dumps(obj, sort_keys=True, indent=2, ensure_ascii=JSON_ENSURE_ASCII)
return json_dumps(obj, sort_keys=True, indent=2)
def to_dict(self):
return dict(self)

View File

@@ -1,4 +1,11 @@
from ..constants import FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, MAX_PAUSE_HEARTBEATS
from ..constants import MAX_PAUSE_HEARTBEATS
request_heartbeat = {
"request_heartbeat": {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.",
}
}
# FUNCTIONS_PROMPT_MULTISTEP_NO_HEARTBEATS = FUNCTIONS_PROMPT_MULTISTEP[:-1]
FUNCTIONS_CHAINING = {
@@ -42,12 +49,8 @@ FUNCTIONS_CHAINING = {
"message": {
"type": "string",
"description": "Message to send ChatGPT. Phrase your message as a full English sentence.",
},
"request_heartbeat": {
"type": "boolean",
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
},
},
}
}.update(request_heartbeat),
"required": ["message", "request_heartbeat"],
},
},
@@ -65,11 +68,7 @@ FUNCTIONS_CHAINING = {
"type": "string",
"description": "Content to write to the memory. All unicode (including emojis) are supported.",
},
"request_heartbeat": {
"type": "boolean",
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
},
},
}.update(request_heartbeat),
"required": ["name", "content", "request_heartbeat"],
},
},
@@ -91,11 +90,7 @@ FUNCTIONS_CHAINING = {
"type": "string",
"description": "Content to write to the memory. All unicode (including emojis) are supported.",
},
"request_heartbeat": {
"type": "boolean",
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
},
},
}.update(request_heartbeat),
"required": ["name", "old_content", "new_content", "request_heartbeat"],
},
},
@@ -113,11 +108,7 @@ FUNCTIONS_CHAINING = {
"type": "integer",
"description": "Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).",
},
"request_heartbeat": {
"type": "boolean",
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
},
},
}.update(request_heartbeat),
"required": ["query", "page", "request_heartbeat"],
},
},
@@ -135,11 +126,7 @@ FUNCTIONS_CHAINING = {
"type": "integer",
"description": "Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).",
},
"request_heartbeat": {
"type": "boolean",
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
},
},
}.update(request_heartbeat),
"required": ["query", "request_heartbeat"],
},
},
@@ -161,11 +148,7 @@ FUNCTIONS_CHAINING = {
"type": "integer",
"description": "Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).",
},
"request_heartbeat": {
"type": "boolean",
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
},
},
}.update(request_heartbeat),
"required": ["start_date", "end_date", "page", "request_heartbeat"],
},
},
@@ -187,11 +170,7 @@ FUNCTIONS_CHAINING = {
"type": "integer",
"description": "Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).",
},
"request_heartbeat": {
"type": "boolean",
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
},
},
}.update(request_heartbeat),
"required": ["start_date", "end_date", "request_heartbeat"],
},
},
@@ -205,11 +184,7 @@ FUNCTIONS_CHAINING = {
"type": "string",
"description": "Content to write to the memory. All unicode (including emojis) are supported.",
},
"request_heartbeat": {
"type": "boolean",
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
},
},
}.update(request_heartbeat),
"required": ["content", "request_heartbeat"],
},
},
@@ -227,11 +202,7 @@ FUNCTIONS_CHAINING = {
"type": "integer",
"description": "Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).",
},
"request_heartbeat": {
"type": "boolean",
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
},
},
}.update(request_heartbeat),
"required": ["query", "request_heartbeat"],
},
},
@@ -253,11 +224,7 @@ FUNCTIONS_CHAINING = {
"type": "integer",
"description": "How many lines to read (defaults to 1).",
},
"request_heartbeat": {
"type": "boolean",
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
},
},
}.update(request_heartbeat),
"required": ["filename", "line_start", "request_heartbeat"],
},
},
@@ -275,11 +242,7 @@ FUNCTIONS_CHAINING = {
"type": "string",
"description": "Content to append to the file.",
},
"request_heartbeat": {
"type": "boolean",
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
},
},
}.update(request_heartbeat),
"required": ["filename", "content", "request_heartbeat"],
},
},
@@ -301,11 +264,7 @@ FUNCTIONS_CHAINING = {
"type": "string",
"description": "A JSON string representing the request payload.",
},
"request_heartbeat": {
"type": "boolean",
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
},
},
}.update(request_heartbeat),
"required": ["method", "url", "request_heartbeat"],
},
},

View File

@@ -6,13 +6,13 @@ from typing import List, Optional, Union
from pydantic import Field, field_validator
from memgpt.constants import JSON_ENSURE_ASCII, TOOL_CALL_ID_MAX_LEN
from memgpt.constants import TOOL_CALL_ID_MAX_LEN
from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG
from memgpt.schemas.enums import MessageRole
from memgpt.schemas.memgpt_base import MemGPTBase
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
from memgpt.schemas.openai.chat_completions import ToolCall
from memgpt.utils import get_utc_time, is_utc_datetime
from memgpt.utils import get_utc_time, is_utc_datetime, json_dumps
def add_inner_thoughts_to_tool_call(
@@ -29,7 +29,7 @@ def add_inner_thoughts_to_tool_call(
func_args[inner_thoughts_key] = inner_thoughts
# create the updated tool call (as a string)
updated_tool_call = copy.deepcopy(tool_call)
updated_tool_call.function.arguments = json.dumps(func_args, ensure_ascii=JSON_ENSURE_ASCII)
updated_tool_call.function.arguments = json_dumps(func_args)
return updated_tool_call
except json.JSONDecodeError as e:
# TODO: change to logging
@@ -517,7 +517,7 @@ class Message(BaseMessage):
cohere_message = []
for tc in self.tool_calls:
# TODO better way to pack?
function_call_text = json.dumps(tc.to_dict(), ensure_ascii=JSON_ENSURE_ASCII)
function_call_text = json_dumps(tc.to_dict())
cohere_message.append(
{
"role": function_call_role,

View File

@@ -5,7 +5,11 @@ from typing import AsyncGenerator, Union
from pydantic import BaseModel
from memgpt.constants import JSON_ENSURE_ASCII
from memgpt.orm.user import User
from memgpt.orm.utilities import get_db_session
from memgpt.server.rest_api.interface import StreamingServerInterface
from memgpt.server.server import SyncServer
from memgpt.utils import json_dumps
SSE_PREFIX = "data: "
SSE_SUFFIX = "\n\n"
@@ -16,8 +20,28 @@ SSE_ARTIFICIAL_DELAY = 0.1
def sse_formatter(data: Union[dict, str]) -> str:
"""Prefix with 'data: ', and always include double newlines"""
assert type(data) in [dict, str], f"Expected type dict or str, got type {type(data)}"
data_str = json.dumps(data, ensure_ascii=JSON_ENSURE_ASCII) if isinstance(data, dict) else data
return f"{SSE_PREFIX}{data_str}{SSE_SUFFIX}"
data_str = json_dumps(data) if isinstance(data, dict) else data
return f"data: {data_str}\n\n"
async def sse_generator(generator: Generator[dict, None, None]) -> Generator[str, None, None]:
"""Generator that returns 'data: dict' formatted items, e.g.:
data: {"id":"chatcmpl-9E0PdSZ2IBzAGlQ3SEWHJ5YwzucSP","object":"chat.completion.chunk","created":1713125205,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"}"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-9E0PdSZ2IBzAGlQ3SEWHJ5YwzucSP","object":"chat.completion.chunk","created":1713125205,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}
data: [DONE]
"""
try:
for msg in generator:
yield sse_formatter(msg)
if SSE_ARTIFICIAL_DELAY:
await asyncio.sleep(SSE_ARTIFICIAL_DELAY) # Sleep to prevent a tight loop, adjust time as needed
except Exception as e:
yield sse_formatter({"error": f"{str(e)}"})
yield sse_formatter(SSE_FINISH_MSG) # Signal that the stream is complete
async def sse_async_generator(generator: AsyncGenerator, finish_message=True):

View File

@@ -20,7 +20,6 @@ from memgpt.agent import Agent, save_agent
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.cli.cli_config import get_model_options
from memgpt.config import MemGPTConfig
from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
from memgpt.credentials import MemGPTCredentials
from memgpt.data_sources.connectors import DataConnector, load_data
@@ -66,7 +65,7 @@ from memgpt.schemas.source import Source, SourceCreate, SourceUpdate
from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate
from memgpt.schemas.usage import MemGPTUsageStatistics
from memgpt.schemas.user import User, UserCreate
from memgpt.utils import create_random_username
from memgpt.utils import create_random_username, json_dumps, json_loads
# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
@@ -557,11 +556,9 @@ class SyncServer(LockingServer):
for x in range(len(memgpt_agent.messages) - 1, 0, -1):
if memgpt_agent.messages[x].get("role") == "assistant":
text = command[len("rewrite ") :].strip()
args = json.loads(memgpt_agent.messages[x].get("function_call").get("arguments"), strict=JSON_LOADS_STRICT)
args = json_loads(memgpt_agent.messages[x].get("function_call").get("arguments"))
args["message"] = text
memgpt_agent.messages[x].get("function_call").update(
{"arguments": json.dumps(args, ensure_ascii=JSON_ENSURE_ASCII)}
)
memgpt_agent.messages[x].get("function_call").update({"arguments": json_dumps(args)})
break
# No skip options

View File

@@ -4,7 +4,6 @@ import json
import websockets
import memgpt.server.ws_api.protocol as protocol
from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
from memgpt.server.constants import WS_CLIENT_TIMEOUT, WS_DEFAULT_PORT
from memgpt.server.utils import condition_to_stop_receiving, print_server_response
@@ -27,12 +26,12 @@ async def send_message_and_print_replies(websocket, user_message, agent_id):
# Wait for messages in a loop, since the server may send a few
while True:
response = await asyncio.wait_for(websocket.recv(), WS_CLIENT_TIMEOUT)
response = json.loads(response, strict=JSON_LOADS_STRICT)
response = json_loads(response)
if CLEAN_RESPONSES:
print_server_response(response)
else:
print(f"Server response:\n{json.dumps(response, indent=2, ensure_ascii=JSON_ENSURE_ASCII)}")
print(f"Server response:\n{json_dumps(response, indent=2)}")
# Check for a specific condition to break the loop
if condition_to_stop_receiving(response):
@@ -62,8 +61,8 @@ async def basic_cli_client():
await websocket.send(protocol.client_command_create(example_config))
# Wait for the response
response = await websocket.recv()
response = json.loads(response, strict=JSON_LOADS_STRICT)
print(f"Server response:\n{json.dumps(response, indent=2, ensure_ascii=JSON_ENSURE_ASCII)}")
response = json_loads(response)
print(f"Server response:\n{json_dumps(response, indent=2)}")
await asyncio.sleep(1)

View File

@@ -1,89 +1,81 @@
import json
from memgpt.constants import JSON_ENSURE_ASCII
from memgpt.utils import json_dumps
# Server -> client
def server_error(msg):
"""General server error"""
return json.dumps(
return json_dumps(
{
"type": "server_error",
"message": msg,
},
ensure_ascii=JSON_ENSURE_ASCII,
}
)
def server_command_response(status):
return json.dumps(
return json_dumps(
{
"type": "command_response",
"status": status,
},
ensure_ascii=JSON_ENSURE_ASCII,
}
)
def server_agent_response_error(msg):
return json.dumps(
return json_dumps(
{
"type": "agent_response_error",
"message": msg,
},
ensure_ascii=JSON_ENSURE_ASCII,
}
)
def server_agent_response_start():
return json.dumps(
return json_dumps(
{
"type": "agent_response_start",
},
ensure_ascii=JSON_ENSURE_ASCII,
}
)
def server_agent_response_end():
return json.dumps(
return json_dumps(
{
"type": "agent_response_end",
},
ensure_ascii=JSON_ENSURE_ASCII,
}
)
def server_agent_internal_monologue(msg):
return json.dumps(
return json_dumps(
{
"type": "agent_response",
"message_type": "internal_monologue",
"message": msg,
},
ensure_ascii=JSON_ENSURE_ASCII,
}
)
def server_agent_assistant_message(msg):
return json.dumps(
return json_dumps(
{
"type": "agent_response",
"message_type": "assistant_message",
"message": msg,
},
ensure_ascii=JSON_ENSURE_ASCII,
}
)
def server_agent_function_message(msg):
return json.dumps(
return json_dumps(
{
"type": "agent_response",
"message_type": "function_message",
"message": msg,
},
ensure_ascii=JSON_ENSURE_ASCII,
}
)
@@ -91,22 +83,20 @@ def server_agent_function_message(msg):
def client_user_message(msg, agent_id=None):
return json.dumps(
return json_dumps(
{
"type": "user_message",
"message": msg,
"agent_id": agent_id,
},
ensure_ascii=JSON_ENSURE_ASCII,
}
)
def client_command_create(config):
return json.dumps(
return json_dumps(
{
"type": "command",
"command": "create_agent",
"config": config,
},
ensure_ascii=JSON_ENSURE_ASCII,
}
)

View File

@@ -55,7 +55,7 @@ class WebSocketServer:
# Assuming the message is a JSON string
try:
data = json.loads(message, strict=JSON_LOADS_STRICT)
data = json_loads(message)
except:
print(f"[server] bad data from client:\n{data}")
await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}"))

View File

@@ -6,10 +6,9 @@ from .constants import (
INITIAL_BOOT_MESSAGE,
INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG,
INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT,
JSON_ENSURE_ASCII,
MESSAGE_SUMMARY_WARNING_STR,
)
from .utils import get_local_time
from .utils import get_local_time, json_dumps
def get_initial_boot_messages(version="startup"):
@@ -98,7 +97,7 @@ def get_heartbeat(reason="Automated timer", include_location=False, location_nam
if include_location:
packaged_message["location"] = location_name
return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII)
return json_dumps(packaged_message)
def get_login_event(last_login="Never (first login)", include_location=False, location_name="San Francisco, CA, USA"):
@@ -113,7 +112,7 @@ def get_login_event(last_login="Never (first login)", include_location=False, lo
if include_location:
packaged_message["location"] = location_name
return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII)
return json_dumps(packaged_message)
def package_user_message(
@@ -137,7 +136,7 @@ def package_user_message(
if name:
packaged_message["name"] = name
return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII)
return json_dumps(packaged_message)
def package_function_response(was_success, response_string, timestamp=None):
@@ -148,7 +147,7 @@ def package_function_response(was_success, response_string, timestamp=None):
"time": formatted_time,
}
return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII)
return json_dumps(packaged_message)
def package_system_message(system_message, message_type="system_alert", time=None):
@@ -175,7 +174,7 @@ def package_summarize_message(summary, summary_length, hidden_message_count, tot
"time": formatted_time,
}
return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII)
return json_dumps(packaged_message)
def package_summarize_message_no_summary(hidden_message_count, timestamp=None, message=None):
@@ -194,7 +193,7 @@ def package_summarize_message_no_summary(hidden_message_count, timestamp=None, m
"time": formatted_time,
}
return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII)
return json_dumps(packaged_message)
def get_token_limit_warning():
@@ -205,4 +204,4 @@ def get_token_limit_warning():
"time": formatted_time,
}
return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII)
return json_dumps(packaged_message)

View File

@@ -28,12 +28,9 @@ from memgpt.constants import (
CORE_MEMORY_HUMAN_CHAR_LIMIT,
CORE_MEMORY_PERSONA_CHAR_LIMIT,
FUNCTION_RETURN_CHAR_LIMIT,
JSON_ENSURE_ASCII,
JSON_LOADS_STRICT,
MEMGPT_DIR,
TOOL_CALL_ID_MAX_LEN,
)
from memgpt.openai_backcompat.openai_object import OpenAIObject
from memgpt.schemas.openai.chat_completion_response import ChatCompletionResponse
DEBUG = False
@@ -782,6 +779,8 @@ def open_folder_in_explorer(folder_path):
class OpenAIBackcompatUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == "openai.openai_object":
from memgpt.openai_backcompat.openai_object import OpenAIObject
return OpenAIObject
return super().find_class(module, name)
@@ -873,7 +872,7 @@ def parse_json(string) -> dict:
"""Parse JSON string into JSON with both json and demjson"""
result = None
try:
result = json.loads(string, strict=JSON_LOADS_STRICT)
result = json_loads(string)
return result
except Exception as e:
print(f"Error parsing json with json package: {e}")
@@ -906,7 +905,7 @@ def validate_function_response(function_response_string: any, strict: bool = Fal
# Allow dict through since it will be cast to json.dumps()
try:
# TODO find a better way to do this that won't result in double escapes
function_response_string = json.dumps(function_response_string, ensure_ascii=JSON_ENSURE_ASCII)
function_response_string = json_dumps(function_response_string)
except:
raise ValueError(function_response_string)
@@ -1020,8 +1019,8 @@ def get_human_text(name: str):
def get_schema_diff(schema_a, schema_b):
# Assuming f_schema and linked_function['json_schema'] are your JSON schemas
f_schema_json = json.dumps(schema_a, indent=2, ensure_ascii=JSON_ENSURE_ASCII)
linked_function_json = json.dumps(schema_b, indent=2, ensure_ascii=JSON_ENSURE_ASCII)
f_schema_json = json_dumps(schema_a)
linked_function_json = json_dumps(schema_b)
# Compute the difference using difflib
difference = list(difflib.ndiff(f_schema_json.splitlines(keepends=True), linked_function_json.splitlines(keepends=True)))
@@ -1056,3 +1055,11 @@ def create_uuid_from_string(val: str):
"""
hex_string = hashlib.md5(val.encode("UTF-8")).hexdigest()
return uuid.UUID(hex=hex_string)
def json_dumps(data, indent=2):
return json.dumps(data, indent=indent, ensure_ascii=False)
def json_loads(data):
return json.loads(data, strict=False)

View File

@@ -34,7 +34,6 @@ from tqdm import tqdm
from memgpt import MemGPT, utils
from memgpt.cli.cli_config import delete
from memgpt.config import MemGPTConfig
from memgpt.constants import JSON_ENSURE_ASCII
# TODO: update personas
NESTED_PERSONA = "You are MemGPT DOC-QA bot. Your job is to answer questions about documents that are stored in your archival memory. The answer to the users question will ALWAYS be in your archival memory, so remember to keep searching if you can't find the answer. DO NOT STOP SEARCHING UNTIL YOU VERIFY THAT THE VALUE IS NOT A KEY. Do not stop making nested lookups until this condition is met." # TODO decide on a good persona/human
@@ -71,7 +70,7 @@ def archival_memory_text_search(self, query: str, page: Optional[int] = 0) -> Op
else:
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
results_formatted = [f"memory: {d.text}" for d in results]
results_str = f"{results_pref} {json.dumps(results_formatted, ensure_ascii=JSON_ENSURE_ASCII)}"
results_str = f"{results_pref} {utils.json_dumps(results_formatted)}"
return results_str

View File

@@ -0,0 +1,121 @@
import inspect
import os
import uuid
import pytest
from memgpt import constants, create_client
from memgpt.functions.functions import USER_FUNCTIONS_DIR
from memgpt.schemas.message import Message
from memgpt.settings import settings
from memgpt.utils import assistant_function_to_tool, json_dumps, json_loads
from tests.mock_factory.models import MockUserFactory
from tests.utils import create_config, wipe_config
def hello_world(self) -> str:
"""Test function for agent to gain access to
Returns:
str: A message for the world
"""
return "hello, world!"
@pytest.fixture(scope="module")
def agent():
"""Create a test agent that we can call functions on"""
wipe_config()
global client
if os.getenv("OPENAI_API_KEY"):
create_config("openai")
else:
create_config("memgpt_hosted")
# create memgpt client
client = create_client()
# ensure user exists
user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid)
if not client.server.get_user(user_id=user_id):
client.server.create_user({"id": user_id})
agent_state = client.create_agent(
preset=settings.preset,
)
return client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)
@pytest.fixture(scope="module")
def hello_world_function():
with open(os.path.join(USER_FUNCTIONS_DIR, "hello_world.py"), "w", encoding="utf-8") as f:
f.write(inspect.getsource(hello_world))
@pytest.fixture(scope="module")
def ai_function_call():
return Message(
**assistant_function_to_tool(
{
"role": "assistant",
"content": "I will now call hello world",
"function_call": {
"name": "hello_world",
"arguments": json_dumps({}),
},
}
)
)
def test_add_function_happy(agent, hello_world_function, ai_function_call):
agent.add_function("hello_world")
assert "hello_world" in [f_schema["name"] for f_schema in agent.functions]
assert "hello_world" in agent.functions_python.keys()
msgs, heartbeat_req, function_failed = agent._handle_ai_response(ai_function_call)
content = json_loads(msgs[-1].to_openai_dict()["content"])
assert content["message"] == "hello, world!"
assert content["status"] == "OK"
assert not function_failed
def test_add_function_already_loaded(agent, hello_world_function):
agent.add_function("hello_world")
# no exception for duplicate loading
agent.add_function("hello_world")
def test_add_function_not_exist(agent):
# pytest assert exception
with pytest.raises(ValueError):
agent.add_function("non_existent")
def test_remove_function_happy(agent, hello_world_function):
agent.add_function("hello_world")
# ensure function is loaded
assert "hello_world" in [f_schema["name"] for f_schema in agent.functions]
assert "hello_world" in agent.functions_python.keys()
agent.remove_function("hello_world")
assert "hello_world" not in [f_schema["name"] for f_schema in agent.functions]
assert "hello_world" not in agent.functions_python.keys()
def test_remove_function_not_exist(agent):
# do not raise error
agent.remove_function("non_existent")
def test_remove_base_function_fails(agent):
with pytest.raises(ValueError):
agent.remove_function("send_message")
if __name__ == "__main__":
pytest.main(["-vv", os.path.abspath(__file__)])

View File

@@ -1,8 +1,8 @@
import json
import memgpt.system as system
from memgpt.constants import JSON_ENSURE_ASCII
from memgpt.local_llm.function_parser import patch_function
from memgpt.utils import json_dumps
EXAMPLE_FUNCTION_CALL_SEND_MESSAGE = {
"message_history": [
@@ -32,7 +32,7 @@ EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING = {
"content": "I'll append to memory.",
"function_call": {
"name": "core_memory_append",
"arguments": json.dumps({"content": "new_stuff"}, ensure_ascii=JSON_ENSURE_ASCII),
"arguments": json_dumps({"content": "new_stuff"}),
},
},
}

View File

@@ -1,7 +1,7 @@
import json
from mempgt.utils import json_loads
import memgpt.local_llm.json_parser as json_parser
from memgpt.constants import JSON_LOADS_STRICT
from memgpt.constants import json
EXAMPLE_ESCAPED_UNDERSCORES = """{
"function":"send\_message",
@@ -90,7 +90,7 @@ def test_json_parsers():
for string in test_strings:
try:
json.loads(string, strict=JSON_LOADS_STRICT)
json_loads(string)
assert False, f"Test JSON string should have failed basic JSON parsing:\n{string}"
except:
print("String failed (expectedly)")

View File

@@ -1,12 +1,11 @@
import asyncio
import json
import pytest
import websockets
from memgpt.constants import JSON_ENSURE_ASCII
from memgpt.server.constants import WS_DEFAULT_PORT
from memgpt.server.ws_api.server import WebSocketServer
from memgpt.utils import json_dumps
@pytest.mark.asyncio
@@ -36,7 +35,7 @@ async def test_websocket_server():
async with websockets.connect(uri) as websocket:
# Initialize the server with a test config
print("Sending config to server...")
await websocket.send(json.dumps({"type": "initialize", "config": test_config}, ensure_ascii=JSON_ENSURE_ASCII))
await websocket.send(json_dumps({"type": "initialize", "config": test_config}))
# Wait for the response
response = await websocket.recv()
print(f"Response from the agent: {response}")
@@ -45,7 +44,7 @@ async def test_websocket_server():
# Send a message to the agent
print("Sending message to server...")
await websocket.send(json.dumps({"type": "message", "content": "Hello, Agent!"}, ensure_ascii=JSON_ENSURE_ASCII))
await websocket.send(json_dumps({"type": "message", "content": "Hello, Agent!"}))
# Wait for the response
# NOTE: we should be waiting for multiple responses
response = await websocket.recv()