refactor: Remove JSON constant for common method (#1680)
This commit is contained in:
@@ -11,8 +11,6 @@ from memgpt.constants import (
|
||||
CLI_WARNING_PREFIX,
|
||||
FIRST_MESSAGE_ATTEMPTS,
|
||||
IN_CONTEXT_MEMORY_KEYWORD,
|
||||
JSON_ENSURE_ASCII,
|
||||
JSON_LOADS_STRICT,
|
||||
LLM_MAX_TOKENS,
|
||||
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
|
||||
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
|
||||
@@ -44,6 +42,8 @@ from memgpt.utils import (
|
||||
get_tool_call_id,
|
||||
get_utc_time,
|
||||
is_utc_datetime,
|
||||
json_dumps,
|
||||
json_loads,
|
||||
parse_json,
|
||||
printd,
|
||||
united_diff,
|
||||
@@ -654,12 +654,12 @@ class Agent(object):
|
||||
def strip_name_field_from_user_message(user_message_text: str) -> Tuple[str, Optional[str]]:
|
||||
"""If 'name' exists in the JSON string, remove it and return the cleaned text + name value"""
|
||||
try:
|
||||
user_message_json = dict(json.loads(user_message_text, strict=JSON_LOADS_STRICT))
|
||||
user_message_json = dict(json_loads(user_message_text))
|
||||
# Special handling for AutoGen messages with 'name' field
|
||||
# Treat 'name' as a special field
|
||||
# If it exists in the input message, elevate it to the 'message' level
|
||||
name = user_message_json.pop("name", None)
|
||||
clean_message = json.dumps(user_message_json, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
clean_message = json_dumps(user_message_json)
|
||||
|
||||
except Exception as e:
|
||||
print(f"{CLI_WARNING_PREFIX}handling of 'name' field failed with: {e}")
|
||||
@@ -668,8 +668,8 @@ class Agent(object):
|
||||
|
||||
def validate_json(user_message_text: str, raise_on_error: bool) -> str:
|
||||
try:
|
||||
user_message_json = dict(json.loads(user_message_text, strict=JSON_LOADS_STRICT))
|
||||
user_message_json_val = json.dumps(user_message_json, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
user_message_json = dict(json_loads(user_message_text))
|
||||
user_message_json_val = json_dumps(user_message_json)
|
||||
return user_message_json_val
|
||||
except Exception as e:
|
||||
print(f"{CLI_WARNING_PREFIX}couldn't parse user input message as JSON: {e}")
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Optional
|
||||
|
||||
from colorama import Fore, Style, init
|
||||
|
||||
from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT
|
||||
from memgpt.constants import CLI_WARNING_PREFIX
|
||||
from memgpt.schemas.message import Message
|
||||
|
||||
init(autoreset=True)
|
||||
@@ -113,7 +113,7 @@ class AutoGenInterface(object):
|
||||
return
|
||||
else:
|
||||
try:
|
||||
msg_json = json.loads(msg, strict=JSON_LOADS_STRICT)
|
||||
msg_json = json_loads(msg)
|
||||
except:
|
||||
print(f"{CLI_WARNING_PREFIX}failed to parse user message into json")
|
||||
message = f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}" if self.fancy else f"[user] {msg}"
|
||||
@@ -203,7 +203,7 @@ class AutoGenInterface(object):
|
||||
)
|
||||
else:
|
||||
try:
|
||||
msg_dict = json.loads(msg, strict=JSON_LOADS_STRICT)
|
||||
msg_dict = json_loads(msg)
|
||||
if "status" in msg_dict and msg_dict["status"] == "OK":
|
||||
message = (
|
||||
f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}" if self.fancy else f"[function] {msg}"
|
||||
|
||||
@@ -114,14 +114,9 @@ REQ_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}Function called using request_hea
|
||||
# FUNC_FAILED_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}Function call failed"
|
||||
FUNC_FAILED_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}Function call failed, returning control"
|
||||
|
||||
FUNCTION_PARAM_NAME_REQ_HEARTBEAT = "request_heartbeat"
|
||||
FUNCTION_PARAM_TYPE_REQ_HEARTBEAT = "boolean"
|
||||
FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT = "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function."
|
||||
|
||||
RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE = 5
|
||||
|
||||
# GLOBAL SETTINGS FOR `json.dumps()`
|
||||
JSON_ENSURE_ASCII = False
|
||||
|
||||
# GLOBAL SETTINGS FOR `json.loads()`
|
||||
JSON_LOADS_STRICT = False
|
||||
# TODO Is this config or constant?
|
||||
CORE_MEMORY_PERSONA_CHAR_LIMIT: int = 2000
|
||||
CORE_MEMORY_HUMAN_CHAR_LIMIT: int = 2000
|
||||
|
||||
@@ -4,11 +4,7 @@ import math
|
||||
from typing import Optional
|
||||
|
||||
from memgpt.agent import Agent
|
||||
from memgpt.constants import (
|
||||
JSON_ENSURE_ASCII,
|
||||
MAX_PAUSE_HEARTBEATS,
|
||||
RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE,
|
||||
)
|
||||
from memgpt.constants import MAX_PAUSE_HEARTBEATS, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
|
||||
### Functions / tools the agent can use
|
||||
# All functions should return a response string (or None)
|
||||
@@ -81,7 +77,7 @@ def conversation_search(self: Agent, query: str, page: Optional[int] = 0) -> Opt
|
||||
else:
|
||||
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
|
||||
results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results]
|
||||
results_str = f"{results_pref} {json.dumps(results_formatted, ensure_ascii=JSON_ENSURE_ASCII)}"
|
||||
results_str = f"{results_pref} {json_dumps(results_formatted)}"
|
||||
return results_str
|
||||
|
||||
|
||||
@@ -111,7 +107,7 @@ def conversation_search_date(self: Agent, start_date: str, end_date: str, page:
|
||||
else:
|
||||
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
|
||||
results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results]
|
||||
results_str = f"{results_pref} {json.dumps(results_formatted, ensure_ascii=JSON_ENSURE_ASCII)}"
|
||||
results_str = f"{results_pref} {json_dumps(results_formatted)}"
|
||||
return results_str
|
||||
|
||||
|
||||
@@ -154,5 +150,5 @@ def archival_memory_search(self: Agent, query: str, page: Optional[int] = 0) ->
|
||||
else:
|
||||
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
|
||||
results_formatted = [f"timestamp: {d['timestamp']}, memory: {d['content']}" for d in results]
|
||||
results_str = f"{results_pref} {json.dumps(results_formatted, ensure_ascii=JSON_ENSURE_ASCII)}"
|
||||
results_str = f"{results_pref} {json_dumps(results_formatted)}"
|
||||
return results_str
|
||||
|
||||
@@ -6,8 +6,6 @@ from typing import Optional
|
||||
import requests
|
||||
|
||||
from memgpt.constants import (
|
||||
JSON_ENSURE_ASCII,
|
||||
JSON_LOADS_STRICT,
|
||||
MESSAGE_CHATGPT_FUNCTION_MODEL,
|
||||
MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE,
|
||||
)
|
||||
@@ -123,10 +121,10 @@ def http_request(self, method: str, url: str, payload_json: Optional[str] = None
|
||||
else:
|
||||
# Validate and convert the payload for other types of requests
|
||||
if payload_json:
|
||||
payload = json.loads(payload_json, strict=JSON_LOADS_STRICT)
|
||||
payload = json_loads(payload_json)
|
||||
else:
|
||||
payload = {}
|
||||
print(f"[HTTP] launching {method} request to {url}, payload=\n{json.dumps(payload, indent=2, ensure_ascii=JSON_ENSURE_ASCII)}")
|
||||
print(f"[HTTP] launching {method} request to {url}, payload=\n{json_dumps(payload, indent=2)}")
|
||||
response = requests.request(method, url, json=payload, headers=headers)
|
||||
|
||||
return {"status_code": response.status_code, "headers": dict(response.headers), "body": response.text}
|
||||
|
||||
@@ -5,14 +5,6 @@ from typing import Any, Dict, Optional, Type, get_args, get_origin
|
||||
from docstring_parser import parse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from memgpt.constants import (
|
||||
FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
FUNCTION_PARAM_NAME_REQ_HEARTBEAT,
|
||||
FUNCTION_PARAM_TYPE_REQ_HEARTBEAT,
|
||||
)
|
||||
|
||||
NO_HEARTBEAT_FUNCTIONS = ["send_message", "pause_heartbeats"]
|
||||
|
||||
|
||||
def is_optional(annotation):
|
||||
# Check if the annotation is a Union
|
||||
@@ -136,12 +128,12 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
|
||||
schema["parameters"]["required"].append(param.name)
|
||||
|
||||
# append the heartbeat
|
||||
if function.__name__ not in NO_HEARTBEAT_FUNCTIONS:
|
||||
schema["parameters"]["properties"][FUNCTION_PARAM_NAME_REQ_HEARTBEAT] = {
|
||||
"type": FUNCTION_PARAM_TYPE_REQ_HEARTBEAT,
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
if function.__name__ not in ["send_message", "pause_heartbeats"]:
|
||||
schema["parameters"]["properties"]["request_heartbeat"] = {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.",
|
||||
}
|
||||
schema["parameters"]["required"].append(FUNCTION_PARAM_NAME_REQ_HEARTBEAT)
|
||||
schema["parameters"]["required"].append("request_heartbeat")
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ from typing import List, Optional
|
||||
|
||||
from colorama import Fore, Style, init
|
||||
|
||||
from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT
|
||||
from memgpt.constants import CLI_WARNING_PREFIX
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.utils import printd
|
||||
from memgpt.utils import json_loads, printd
|
||||
|
||||
init(autoreset=True)
|
||||
|
||||
@@ -127,7 +127,7 @@ class CLIInterface(AgentInterface):
|
||||
return
|
||||
else:
|
||||
try:
|
||||
msg_json = json.loads(msg, strict=JSON_LOADS_STRICT)
|
||||
msg_json = json_loads(msg)
|
||||
except:
|
||||
printd(f"{CLI_WARNING_PREFIX}failed to parse user message into json")
|
||||
printd_user_message("🧑", msg)
|
||||
@@ -228,7 +228,7 @@ class CLIInterface(AgentInterface):
|
||||
printd_function_message("", msg)
|
||||
else:
|
||||
try:
|
||||
msg_dict = json.loads(msg, strict=JSON_LOADS_STRICT)
|
||||
msg_dict = json_loads(msg)
|
||||
if "status" in msg_dict and msg_dict["status"] == "OK":
|
||||
printd_function_message("", str(msg), color=Fore.GREEN)
|
||||
else:
|
||||
@@ -259,7 +259,7 @@ class CLIInterface(AgentInterface):
|
||||
CLIInterface.internal_monologue(content)
|
||||
# I think the next one is not up to date
|
||||
# function_message(msg["function_call"])
|
||||
args = json.loads(msg["function_call"].get("arguments"), strict=JSON_LOADS_STRICT)
|
||||
args = json_loads(msg["function_call"].get("arguments"))
|
||||
CLIInterface.assistant_message(args.get("message"))
|
||||
# assistant_message(content)
|
||||
elif msg.get("tool_calls"):
|
||||
@@ -267,7 +267,7 @@ class CLIInterface(AgentInterface):
|
||||
CLIInterface.internal_monologue(content)
|
||||
function_obj = msg["tool_calls"][0].get("function")
|
||||
if function_obj:
|
||||
args = json.loads(function_obj.get("arguments"), strict=JSON_LOADS_STRICT)
|
||||
args = json_loads(function_obj.get("arguments"))
|
||||
CLIInterface.assistant_message(args.get("message"))
|
||||
else:
|
||||
CLIInterface.internal_monologue(content)
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from memgpt.constants import JSON_ENSURE_ASCII
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
|
||||
from memgpt.schemas.openai.chat_completion_response import (
|
||||
@@ -257,7 +256,7 @@ def convert_anthropic_response_to_chatcompletion(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=response_json["content"][1]["name"],
|
||||
arguments=json.dumps(response_json["content"][1]["input"], ensure_ascii=JSON_ENSURE_ASCII),
|
||||
arguments=json_dumps(response_json["content"][1]["input"]),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from memgpt.constants import JSON_ENSURE_ASCII
|
||||
from memgpt.local_llm.utils import count_tokens
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
|
||||
@@ -17,7 +16,7 @@ from memgpt.schemas.openai.chat_completion_response import (
|
||||
Message as ChoiceMessage, # NOTE: avoid conflict with our own MemGPT Message datatype
|
||||
)
|
||||
from memgpt.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
||||
from memgpt.utils import get_tool_call_id, get_utc_time, smart_urljoin
|
||||
from memgpt.utils import get_tool_call_id, get_utc_time, json_dumps, smart_urljoin
|
||||
|
||||
BASE_URL = "https://api.cohere.ai/v1"
|
||||
|
||||
@@ -155,9 +154,7 @@ def convert_cohere_response_to_chatcompletion(
|
||||
completion_tokens = response_json["meta"]["billed_units"]["output_tokens"]
|
||||
else:
|
||||
# For some reason input_tokens not included in 'meta' 'tokens' dict?
|
||||
prompt_tokens = count_tokens(
|
||||
json.dumps(response_json["chat_history"], ensure_ascii=JSON_ENSURE_ASCII)
|
||||
) # NOTE: this is a very rough approximation
|
||||
prompt_tokens = count_tokens(json_dumps(response_json["chat_history"])) # NOTE: this is a very rough approximation
|
||||
completion_tokens = response_json["meta"]["tokens"]["output_tokens"]
|
||||
|
||||
finish_reason = remap_finish_reason(response_json["finish_reason"])
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from memgpt.constants import JSON_ENSURE_ASCII, NON_USER_MSG_PREFIX
|
||||
from memgpt.constants import NON_USER_MSG_PREFIX
|
||||
from memgpt.local_llm.json_parser import clean_json_string_extra_backslash
|
||||
from memgpt.local_llm.utils import count_tokens
|
||||
from memgpt.schemas.openai.chat_completion_request import Tool
|
||||
@@ -310,7 +310,7 @@ def convert_google_ai_response_to_chatcompletion(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name,
|
||||
arguments=clean_json_string_extra_backslash(json.dumps(function_args, ensure_ascii=JSON_ENSURE_ASCII)),
|
||||
arguments=clean_json_string_extra_backslash(json_dumps(function_args)),
|
||||
),
|
||||
)
|
||||
],
|
||||
@@ -374,12 +374,8 @@ def convert_google_ai_response_to_chatcompletion(
|
||||
else:
|
||||
# Count it ourselves
|
||||
assert input_messages is not None, f"Didn't get UsageMetadata from the API response, so input_messages is required"
|
||||
prompt_tokens = count_tokens(
|
||||
json.dumps(input_messages, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
) # NOTE: this is a very rough approximation
|
||||
completion_tokens = count_tokens(
|
||||
json.dumps(openai_response_message.model_dump(), ensure_ascii=JSON_ENSURE_ASCII)
|
||||
) # NOTE: this is also approximate
|
||||
prompt_tokens = count_tokens(json_dumps(input_messages)) # NOTE: this is a very rough approximation
|
||||
completion_tokens = count_tokens(json_dumps(openai_response_message.model_dump())) # NOTE: this is also approximate
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
usage = UsageStatistics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from memgpt.constants import CLI_WARNING_PREFIX, JSON_ENSURE_ASCII
|
||||
from memgpt.constants import CLI_WARNING_PREFIX
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.llm_api.anthropic import anthropic_chat_completions_request
|
||||
from memgpt.llm_api.azure_openai import (
|
||||
@@ -109,7 +109,7 @@ def unpack_inner_thoughts_from_kwargs(
|
||||
|
||||
# replace the kwargs
|
||||
new_choice = choice.model_copy(deep=True)
|
||||
new_choice.message.tool_calls[0].function.arguments = json.dumps(func_args, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
new_choice.message.tool_calls[0].function.arguments = json_dumps(func_args)
|
||||
# also replace the message content
|
||||
if new_choice.message.content is not None:
|
||||
warnings.warn(f"Overwriting existing inner monologue ({new_choice.message.content}) with kwarg ({inner_thoughts})")
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
"""Key idea: create drop-in replacement for agent's ChatCompletion call that runs on an OpenLLM backend"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
|
||||
from memgpt.constants import CLI_WARNING_PREFIX, JSON_ENSURE_ASCII
|
||||
from memgpt.constants import CLI_WARNING_PREFIX
|
||||
from memgpt.errors import LocalLLMConnectionError, LocalLLMError
|
||||
from memgpt.local_llm.constants import DEFAULT_WRAPPER
|
||||
from memgpt.local_llm.function_parser import patch_function
|
||||
@@ -33,7 +32,7 @@ from memgpt.schemas.openai.chat_completion_response import (
|
||||
ToolCall,
|
||||
UsageStatistics,
|
||||
)
|
||||
from memgpt.utils import get_tool_call_id, get_utc_time
|
||||
from memgpt.utils import get_tool_call_id, get_utc_time, json_dumps
|
||||
|
||||
has_shown_warning = False
|
||||
grammar_supported_backends = ["koboldcpp", "llamacpp", "webui", "webui-legacy"]
|
||||
@@ -189,7 +188,7 @@ def get_chat_completion(
|
||||
chat_completion_result = llm_wrapper.output_to_chat_completion_response(result, first_message=first_message)
|
||||
else:
|
||||
chat_completion_result = llm_wrapper.output_to_chat_completion_response(result)
|
||||
printd(json.dumps(chat_completion_result, indent=2, ensure_ascii=JSON_ENSURE_ASCII))
|
||||
printd(json_dumps(chat_completion_result, indent=2))
|
||||
except Exception as e:
|
||||
raise LocalLLMError(f"Failed to parse JSON from local LLM response - error: {str(e)}")
|
||||
|
||||
@@ -206,13 +205,13 @@ def get_chat_completion(
|
||||
usage["prompt_tokens"] = count_tokens(prompt)
|
||||
|
||||
# NOTE: we should compute on-the-fly anyways since we might have to correct for errors during JSON parsing
|
||||
usage["completion_tokens"] = count_tokens(json.dumps(chat_completion_result, ensure_ascii=JSON_ENSURE_ASCII))
|
||||
usage["completion_tokens"] = count_tokens(json_dumps(chat_completion_result))
|
||||
"""
|
||||
if usage["completion_tokens"] is None:
|
||||
printd(f"usage dict was missing completion_tokens, computing on-the-fly...")
|
||||
# chat_completion_result is dict with 'role' and 'content'
|
||||
# token counter wants a string
|
||||
usage["completion_tokens"] = count_tokens(json.dumps(chat_completion_result, ensure_ascii=JSON_ENSURE_ASCII))
|
||||
usage["completion_tokens"] = count_tokens(json_dumps(chat_completion_result))
|
||||
"""
|
||||
|
||||
# NOTE: this is the token count that matters most
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import copy
|
||||
import json
|
||||
|
||||
from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
|
||||
from memgpt.utils import json_dumps, json_loads
|
||||
|
||||
NO_HEARTBEAT_FUNCS = ["send_message", "pause_heartbeats"]
|
||||
|
||||
@@ -13,16 +13,16 @@ def insert_heartbeat(message):
|
||||
if message_copy.get("function_call"):
|
||||
# function_name = message.get("function_call").get("name")
|
||||
params = message_copy.get("function_call").get("arguments")
|
||||
params = json.loads(params, strict=JSON_LOADS_STRICT)
|
||||
params = json_loads(params)
|
||||
params["request_heartbeat"] = True
|
||||
message_copy["function_call"]["arguments"] = json.dumps(params, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
message_copy["function_call"]["arguments"] = json_dumps(params)
|
||||
|
||||
elif message_copy.get("tool_call"):
|
||||
# function_name = message.get("tool_calls")[0].get("function").get("name")
|
||||
params = message_copy.get("tool_calls")[0].get("function").get("arguments")
|
||||
params = json.loads(params, strict=JSON_LOADS_STRICT)
|
||||
params = json_loads(params)
|
||||
params["request_heartbeat"] = True
|
||||
message_copy["tools_calls"][0]["function"]["arguments"] = json.dumps(params, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
message_copy["tools_calls"][0]["function"]["arguments"] = json_dumps(params)
|
||||
|
||||
return message_copy
|
||||
|
||||
@@ -40,7 +40,7 @@ def heartbeat_correction(message_history, new_message):
|
||||
last_message_was_user = False
|
||||
if message_history[-1]["role"] == "user":
|
||||
try:
|
||||
content = json.loads(message_history[-1]["content"], strict=JSON_LOADS_STRICT)
|
||||
content = json_loads(message_history[-1]["content"])
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
# Check if it's a user message or system message
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing import (
|
||||
from docstring_parser import parse
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
from memgpt.constants import JSON_ENSURE_ASCII
|
||||
from memgpt.utils import json_dumps, json_loads
|
||||
|
||||
|
||||
class PydanticDataType(Enum):
|
||||
@@ -731,7 +731,7 @@ def generate_markdown_documentation(
|
||||
|
||||
if hasattr(model, "Config") and hasattr(model.Config, "json_schema_extra") and "example" in model.Config.json_schema_extra:
|
||||
documentation += f" Expected Example Output for {format_model_and_field_name(model.__name__)}:\n"
|
||||
json_example = json.dumps(model.Config.json_schema_extra["example"], ensure_ascii=JSON_ENSURE_ASCII)
|
||||
json_example = json_dumps(model.Config.json_schema_extra["example"])
|
||||
documentation += format_multiline_description(json_example, 2) + "\n"
|
||||
|
||||
return documentation
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import json
|
||||
import re
|
||||
|
||||
from memgpt.constants import JSON_LOADS_STRICT
|
||||
from memgpt.errors import LLMJSONParsingError
|
||||
from memgpt.utils import json_loads
|
||||
|
||||
|
||||
def clean_json_string_extra_backslash(s):
|
||||
@@ -45,7 +45,7 @@ def extract_first_json(string: str):
|
||||
depth -= 1
|
||||
if depth == 0 and start_index is not None:
|
||||
try:
|
||||
return json.loads(string[start_index : i + 1], strict=JSON_LOADS_STRICT)
|
||||
return json_loads(string[start_index : i + 1])
|
||||
except json.JSONDecodeError as e:
|
||||
raise LLMJSONParsingError(f"Matched closing bracket, but decode failed with error: {str(e)}")
|
||||
printd("No valid JSON object found.")
|
||||
@@ -174,21 +174,21 @@ def clean_json(raw_llm_output, messages=None, functions=None):
|
||||
from memgpt.utils import printd
|
||||
|
||||
strategies = [
|
||||
lambda output: json.loads(output, strict=JSON_LOADS_STRICT),
|
||||
lambda output: json.loads(output + "}", strict=JSON_LOADS_STRICT),
|
||||
lambda output: json.loads(output + "}}", strict=JSON_LOADS_STRICT),
|
||||
lambda output: json.loads(output + '"}}', strict=JSON_LOADS_STRICT),
|
||||
lambda output: json_loads(output),
|
||||
lambda output: json_loads(output + "}"),
|
||||
lambda output: json_loads(output + "}}"),
|
||||
lambda output: json_loads(output + '"}}'),
|
||||
# with strip and strip comma
|
||||
lambda output: json.loads(output.strip().rstrip(",") + "}", strict=JSON_LOADS_STRICT),
|
||||
lambda output: json.loads(output.strip().rstrip(",") + "}}", strict=JSON_LOADS_STRICT),
|
||||
lambda output: json.loads(output.strip().rstrip(",") + '"}}', strict=JSON_LOADS_STRICT),
|
||||
lambda output: json_loads(output.strip().rstrip(",") + "}"),
|
||||
lambda output: json_loads(output.strip().rstrip(",") + "}}"),
|
||||
lambda output: json_loads(output.strip().rstrip(",") + '"}}'),
|
||||
# more complex patchers
|
||||
lambda output: json.loads(repair_json_string(output), strict=JSON_LOADS_STRICT),
|
||||
lambda output: json.loads(repair_even_worse_json(output), strict=JSON_LOADS_STRICT),
|
||||
lambda output: json_loads(repair_json_string(output)),
|
||||
lambda output: json_loads(repair_even_worse_json(output)),
|
||||
lambda output: extract_first_json(output + "}}"),
|
||||
lambda output: clean_and_interpret_send_message_json(output),
|
||||
# replace underscores
|
||||
lambda output: json.loads(replace_escaped_underscores(output), strict=JSON_LOADS_STRICT),
|
||||
lambda output: json_loads(replace_escaped_underscores(output)),
|
||||
lambda output: extract_first_json(replace_escaped_underscores(output) + "}}"),
|
||||
]
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
from memgpt.utils import json_dumps, json_loads
|
||||
|
||||
from ...constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
|
||||
from ...errors import LLMJSONParsingError
|
||||
from ..json_parser import clean_json
|
||||
from .wrapper_base import LLMChatCompletionWrapper
|
||||
@@ -114,9 +113,9 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
|
||||
"""
|
||||
airo_func_call = {
|
||||
"function": function_call["name"],
|
||||
"params": json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
|
||||
"params": json_loads(function_call["arguments"]),
|
||||
}
|
||||
return json.dumps(airo_func_call, indent=2, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return json_dumps(airo_func_call, indent=2)
|
||||
|
||||
# Add a sep for the conversation
|
||||
if self.include_section_separators:
|
||||
@@ -129,7 +128,7 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
|
||||
if message["role"] == "user":
|
||||
if self.simplify_json_content:
|
||||
try:
|
||||
content_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
|
||||
content_json = json_loads(message["content"])
|
||||
content_simple = content_json["message"]
|
||||
prompt += f"\nUSER: {content_simple}"
|
||||
except:
|
||||
@@ -206,7 +205,7 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
|
||||
"content": None,
|
||||
"function_call": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
|
||||
"arguments": json_dumps(function_parameters),
|
||||
},
|
||||
}
|
||||
return message
|
||||
@@ -325,10 +324,10 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
|
||||
"function": function_call["name"],
|
||||
"params": {
|
||||
"inner_thoughts": inner_thoughts,
|
||||
**json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
|
||||
**json_loads(function_call["arguments"]),
|
||||
},
|
||||
}
|
||||
return json.dumps(airo_func_call, indent=2, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return json_dumps(airo_func_call, indent=2)
|
||||
|
||||
# Add a sep for the conversation
|
||||
if self.include_section_separators:
|
||||
@@ -347,7 +346,7 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
|
||||
user_prefix = "USER"
|
||||
if self.simplify_json_content:
|
||||
try:
|
||||
content_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
|
||||
content_json = json_loads(message["content"])
|
||||
content_simple = content_json["message"]
|
||||
prompt += f"\n{user_prefix}: {content_simple}"
|
||||
except:
|
||||
@@ -447,7 +446,7 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
|
||||
"content": inner_thoughts,
|
||||
"function_call": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
|
||||
"arguments": json_dumps(function_parameters),
|
||||
},
|
||||
}
|
||||
return message
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import json
|
||||
|
||||
from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
|
||||
from memgpt.errors import LLMJSONParsingError
|
||||
from memgpt.local_llm.json_parser import clean_json
|
||||
from memgpt.local_llm.llm_chat_completion_wrappers.wrapper_base import (
|
||||
LLMChatCompletionWrapper,
|
||||
)
|
||||
from memgpt.utils import json_dumps, json_loads
|
||||
|
||||
PREFIX_HINT = """# Reminders:
|
||||
# Important information about yourself and the user is stored in (limited) core memory
|
||||
@@ -137,10 +135,10 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
"function": function_call["name"],
|
||||
"params": {
|
||||
"inner_thoughts": inner_thoughts,
|
||||
**json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
|
||||
**json_loads(function_call["arguments"]),
|
||||
},
|
||||
}
|
||||
return json.dumps(airo_func_call, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return json_dumps(airo_func_call, indent=self.json_indent)
|
||||
|
||||
# NOTE: BOS/EOS chatml tokens are NOT inserted here
|
||||
def _compile_assistant_message(self, message) -> str:
|
||||
@@ -167,15 +165,15 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
if self.simplify_json_content:
|
||||
# Make user messages not JSON but plaintext instead
|
||||
try:
|
||||
user_msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
|
||||
user_msg_json = json_loads(message["content"])
|
||||
user_msg_str = user_msg_json["message"]
|
||||
except:
|
||||
user_msg_str = message["content"]
|
||||
else:
|
||||
# Otherwise just dump the full json
|
||||
try:
|
||||
user_msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
|
||||
user_msg_str = json.dumps(user_msg_json, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
user_msg_json = json_loads(message["content"])
|
||||
user_msg_str = json_dumps(user_msg_json, indent=self.json_indent)
|
||||
except:
|
||||
user_msg_str = message["content"]
|
||||
|
||||
@@ -189,8 +187,8 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
prompt = ""
|
||||
try:
|
||||
# indent the function replies
|
||||
function_return_dict = json.loads(message["content"], strict=JSON_LOADS_STRICT)
|
||||
function_return_str = json.dumps(function_return_dict, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
function_return_dict = json_loads(message["content"])
|
||||
function_return_str = json_dumps(function_return_dict, indent=self.json_indent)
|
||||
except:
|
||||
function_return_str = message["content"]
|
||||
|
||||
@@ -219,7 +217,7 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
|
||||
if self.use_system_role_in_user:
|
||||
try:
|
||||
msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
|
||||
msg_json = json_loads(message["content"])
|
||||
if msg_json["type"] != "user_message":
|
||||
role_str = "system"
|
||||
except:
|
||||
@@ -329,7 +327,7 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
"content": inner_thoughts,
|
||||
"function_call": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
|
||||
"arguments": json_dumps(function_parameters),
|
||||
},
|
||||
}
|
||||
return message
|
||||
@@ -394,10 +392,10 @@ class ChatMLOuterInnerMonologueWrapper(ChatMLInnerMonologueWrapper):
|
||||
"function": function_call["name"],
|
||||
"params": {
|
||||
# "inner_thoughts": inner_thoughts,
|
||||
**json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
|
||||
**json_loads(function_call["arguments"]),
|
||||
},
|
||||
}
|
||||
return json.dumps(airo_func_call, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return json_dumps(airo_func_call, indent=self.json_indent)
|
||||
|
||||
def output_to_chat_completion_response(self, raw_llm_output, first_message=False):
|
||||
"""NOTE: Modified to expect "inner_thoughts" outside the function
|
||||
@@ -458,7 +456,7 @@ class ChatMLOuterInnerMonologueWrapper(ChatMLInnerMonologueWrapper):
|
||||
"content": inner_thoughts,
|
||||
# "function_call": {
|
||||
# "name": function_name,
|
||||
# "arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
|
||||
# "arguments": json_dumps(function_parameters),
|
||||
# },
|
||||
}
|
||||
|
||||
@@ -466,7 +464,7 @@ class ChatMLOuterInnerMonologueWrapper(ChatMLInnerMonologueWrapper):
|
||||
if function_name is not None:
|
||||
message["function_call"] = {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
|
||||
"arguments": json_dumps(function_parameters),
|
||||
}
|
||||
|
||||
return message
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import json
|
||||
|
||||
import yaml
|
||||
|
||||
from ...constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
|
||||
from memgpt.utils import json_dumps, json_loads
|
||||
|
||||
from ...errors import LLMJSONParsingError
|
||||
from ..json_parser import clean_json
|
||||
from .wrapper_base import LLMChatCompletionWrapper
|
||||
@@ -131,10 +130,10 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper):
|
||||
"function": function_call["name"],
|
||||
"params": {
|
||||
"inner_thoughts": inner_thoughts,
|
||||
**json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
|
||||
**json_loads(function_call["arguments"]),
|
||||
},
|
||||
}
|
||||
return json.dumps(airo_func_call, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return json_dumps(airo_func_call, indent=self.json_indent)
|
||||
|
||||
# NOTE: BOS/EOS chatml tokens are NOT inserted here
|
||||
def _compile_assistant_message(self, message) -> str:
|
||||
@@ -161,15 +160,15 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper):
|
||||
if self.simplify_json_content:
|
||||
# Make user messages not JSON but plaintext instead
|
||||
try:
|
||||
user_msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
|
||||
user_msg_json = json_loads(message["content"])
|
||||
user_msg_str = user_msg_json["message"]
|
||||
except:
|
||||
user_msg_str = message["content"]
|
||||
else:
|
||||
# Otherwise just dump the full json
|
||||
try:
|
||||
user_msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
|
||||
user_msg_str = json.dumps(user_msg_json, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
user_msg_json = json_loads(message["content"])
|
||||
user_msg_str = json_dumps(user_msg_json, indent=self.json_indent)
|
||||
except:
|
||||
user_msg_str = message["content"]
|
||||
|
||||
@@ -183,8 +182,8 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper):
|
||||
prompt = ""
|
||||
try:
|
||||
# indent the function replies
|
||||
function_return_dict = json.loads(message["content"], strict=JSON_LOADS_STRICT)
|
||||
function_return_str = json.dumps(function_return_dict, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
function_return_dict = json_loads(message["content"])
|
||||
function_return_str = json_dumps(function_return_dict, indent=self.json_indent)
|
||||
except:
|
||||
function_return_str = message["content"]
|
||||
|
||||
@@ -309,7 +308,7 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper):
|
||||
"content": inner_thoughts,
|
||||
"function_call": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
|
||||
"arguments": json_dumps(function_parameters),
|
||||
},
|
||||
}
|
||||
return message
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
|
||||
from ...constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
|
||||
from memgpt.utils import json_dumps, json_loads
|
||||
|
||||
from ...errors import LLMJSONParsingError
|
||||
from ..json_parser import clean_json
|
||||
from .wrapper_base import LLMChatCompletionWrapper
|
||||
@@ -127,9 +128,9 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
|
||||
"""
|
||||
airo_func_call = {
|
||||
"function": function_call["name"],
|
||||
"params": json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
|
||||
"params": json_loads(function_call["arguments"]),
|
||||
}
|
||||
return json.dumps(airo_func_call, indent=2, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return json_dumps(airo_func_call, indent=2)
|
||||
|
||||
# option (1): from HF README:
|
||||
# <|im_start|>user
|
||||
@@ -156,7 +157,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
|
||||
if message["role"] == "user":
|
||||
if self.simplify_json_content:
|
||||
try:
|
||||
content_json = (json.loads(message["content"], strict=JSON_LOADS_STRICT),)
|
||||
content_json = (json_loads(message["content"]),)
|
||||
content_simple = content_json["message"]
|
||||
prompt += f"\n{IM_START_TOKEN}user\n{content_simple}{IM_END_TOKEN}"
|
||||
# prompt += f"\nUSER: {content_simple}"
|
||||
@@ -241,7 +242,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
|
||||
"content": None,
|
||||
"function_call": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
|
||||
"arguments": json_dumps(function_parameters),
|
||||
},
|
||||
}
|
||||
return message
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import json
|
||||
|
||||
from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
|
||||
from memgpt.errors import LLMJSONParsingError
|
||||
from memgpt.local_llm.json_parser import clean_json
|
||||
from memgpt.local_llm.llm_chat_completion_wrappers.wrapper_base import (
|
||||
LLMChatCompletionWrapper,
|
||||
)
|
||||
from memgpt.utils import json_dumps, json_loads
|
||||
|
||||
PREFIX_HINT = """# Reminders:
|
||||
# Important information about yourself and the user is stored in (limited) core memory
|
||||
@@ -137,10 +137,10 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
"function": function_call["name"],
|
||||
"params": {
|
||||
"inner_thoughts": inner_thoughts,
|
||||
**json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
|
||||
**json_loads(function_call["arguments"]),
|
||||
},
|
||||
}
|
||||
return json.dumps(airo_func_call, indent=self.json_indent, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return json_dumps(airo_func_call, indent=self.json_indent)
|
||||
|
||||
# NOTE: BOS/EOS chatml tokens are NOT inserted here
|
||||
def _compile_assistant_message(self, message) -> str:
|
||||
@@ -167,18 +167,17 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
if self.simplify_json_content:
|
||||
# Make user messages not JSON but plaintext instead
|
||||
try:
|
||||
user_msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
|
||||
user_msg_json = json_loads(message["content"])
|
||||
user_msg_str = user_msg_json["message"]
|
||||
except:
|
||||
user_msg_str = message["content"]
|
||||
else:
|
||||
# Otherwise just dump the full json
|
||||
try:
|
||||
user_msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
|
||||
user_msg_str = json.dumps(
|
||||
user_msg_json = json_loads(message["content"])
|
||||
user_msg_str = json_dumps(
|
||||
user_msg_json,
|
||||
indent=self.json_indent,
|
||||
ensure_ascii=JSON_ENSURE_ASCII,
|
||||
)
|
||||
except:
|
||||
user_msg_str = message["content"]
|
||||
@@ -193,11 +192,10 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
prompt = ""
|
||||
try:
|
||||
# indent the function replies
|
||||
function_return_dict = json.loads(message["content"], strict=JSON_LOADS_STRICT)
|
||||
function_return_str = json.dumps(
|
||||
function_return_dict = json_loads(message["content"])
|
||||
function_return_str = json_dumps(
|
||||
function_return_dict,
|
||||
indent=self.json_indent,
|
||||
ensure_ascii=JSON_ENSURE_ASCII,
|
||||
)
|
||||
except:
|
||||
function_return_str = message["content"]
|
||||
@@ -229,7 +227,7 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
|
||||
if self.use_system_role_in_user:
|
||||
try:
|
||||
msg_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
|
||||
msg_json = json_loads(message["content"])
|
||||
if msg_json["type"] != "user_message":
|
||||
role_str = "system"
|
||||
except:
|
||||
@@ -343,7 +341,7 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
"content": inner_thoughts,
|
||||
"function_call": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
|
||||
"arguments": json_dumps(function_parameters),
|
||||
},
|
||||
}
|
||||
return message
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
|
||||
from ...constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
|
||||
from memgpt.utils import json_dumps, json_loads
|
||||
|
||||
from .wrapper_base import LLMChatCompletionWrapper
|
||||
|
||||
|
||||
@@ -85,9 +86,9 @@ class SimpleSummaryWrapper(LLMChatCompletionWrapper):
|
||||
"""
|
||||
airo_func_call = {
|
||||
"function": function_call["name"],
|
||||
"params": json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
|
||||
"params": json_loads(function_call["arguments"]),
|
||||
}
|
||||
return json.dumps(airo_func_call, indent=2, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return json_dumps(airo_func_call, indent=2)
|
||||
|
||||
# Add a sep for the conversation
|
||||
if self.include_section_separators:
|
||||
@@ -100,7 +101,7 @@ class SimpleSummaryWrapper(LLMChatCompletionWrapper):
|
||||
if message["role"] == "user":
|
||||
if self.simplify_json_content:
|
||||
try:
|
||||
content_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
|
||||
content_json = json_loads(message["content"])
|
||||
content_simple = content_json["message"]
|
||||
prompt += f"\nUSER: {content_simple}"
|
||||
except:
|
||||
@@ -151,7 +152,7 @@ class SimpleSummaryWrapper(LLMChatCompletionWrapper):
|
||||
"content": raw_llm_output,
|
||||
# "function_call": {
|
||||
# "name": function_name,
|
||||
# "arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
|
||||
# "arguments": json_dumps(function_parameters),
|
||||
# },
|
||||
}
|
||||
return message
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
|
||||
from ...constants import JSON_ENSURE_ASCII
|
||||
from memgpt.utils import json_dumps, json_loads
|
||||
|
||||
from ...errors import LLMJSONParsingError
|
||||
from ..json_parser import clean_json
|
||||
from .wrapper_base import LLMChatCompletionWrapper
|
||||
@@ -76,9 +77,9 @@ class ZephyrMistralWrapper(LLMChatCompletionWrapper):
|
||||
def create_function_call(function_call):
|
||||
airo_func_call = {
|
||||
"function": function_call["name"],
|
||||
"params": json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
|
||||
"params": json_loads(function_call["arguments"]),
|
||||
}
|
||||
return json.dumps(airo_func_call, indent=2, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return json_dumps(airo_func_call, indent=2)
|
||||
|
||||
for message in messages[1:]:
|
||||
assert message["role"] in ["user", "assistant", "function", "tool"], message
|
||||
@@ -86,7 +87,7 @@ class ZephyrMistralWrapper(LLMChatCompletionWrapper):
|
||||
if message["role"] == "user":
|
||||
if self.simplify_json_content:
|
||||
try:
|
||||
content_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
|
||||
content_json = json_loads(message["content"])
|
||||
content_simple = content_json["message"]
|
||||
prompt += f"\n<|user|>\n{content_simple}{IM_END_TOKEN}"
|
||||
# prompt += f"\nUSER: {content_simple}"
|
||||
@@ -171,7 +172,7 @@ class ZephyrMistralWrapper(LLMChatCompletionWrapper):
|
||||
"content": None,
|
||||
"function_call": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
|
||||
"arguments": json_dumps(function_parameters),
|
||||
},
|
||||
}
|
||||
return message
|
||||
@@ -239,10 +240,10 @@ class ZephyrMistralInnerMonologueWrapper(ZephyrMistralWrapper):
|
||||
"function": function_call["name"],
|
||||
"params": {
|
||||
"inner_thoughts": inner_thoughts,
|
||||
**json.loads(function_call["arguments"], strict=JSON_LOADS_STRICT),
|
||||
**json_loads(function_call["arguments"]),
|
||||
},
|
||||
}
|
||||
return json.dumps(airo_func_call, indent=2, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return json_dumps(airo_func_call, indent=2)
|
||||
|
||||
# Add a sep for the conversation
|
||||
if self.include_section_separators:
|
||||
@@ -255,7 +256,7 @@ class ZephyrMistralInnerMonologueWrapper(ZephyrMistralWrapper):
|
||||
if message["role"] == "user":
|
||||
if self.simplify_json_content:
|
||||
try:
|
||||
content_json = json.loads(message["content"], strict=JSON_LOADS_STRICT)
|
||||
content_json = json_loads(message["content"])
|
||||
content_simple = content_json["message"]
|
||||
prompt += f"\n<|user|>\n{content_simple}{IM_END_TOKEN}"
|
||||
except:
|
||||
@@ -340,7 +341,7 @@ class ZephyrMistralInnerMonologueWrapper(ZephyrMistralWrapper):
|
||||
"content": inner_thoughts,
|
||||
"function_call": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_parameters, ensure_ascii=JSON_ENSURE_ASCII),
|
||||
"arguments": json_dumps(function_parameters),
|
||||
},
|
||||
}
|
||||
return message
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from memgpt.constants import JSON_ENSURE_ASCII, MEMGPT_DIR
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
from memgpt.local_llm.settings.deterministic_mirostat import (
|
||||
settings as det_miro_settings,
|
||||
)
|
||||
@@ -48,9 +48,7 @@ def get_completions_settings(defaults="simple") -> dict:
|
||||
with open(settings_file, "r", encoding="utf-8") as file:
|
||||
user_settings = json.load(file)
|
||||
if len(user_settings) > 0:
|
||||
printd(
|
||||
f"Updating base settings with the following user settings:\n{json.dumps(user_settings,indent=2, ensure_ascii=JSON_ENSURE_ASCII)}"
|
||||
)
|
||||
printd(f"Updating base settings with the following user settings:\n{json_dumps(user_settings,indent=2)}")
|
||||
settings.update(user_settings)
|
||||
else:
|
||||
printd(f"'{settings_file}' was empty, ignoring...")
|
||||
|
||||
@@ -19,12 +19,7 @@ from memgpt.cli.cli import delete_agent, open_folder, quickstart, run, server, v
|
||||
from memgpt.cli.cli_config import add, add_tool, configure, delete, list, list_tools
|
||||
from memgpt.cli.cli_load import app as load_app
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import (
|
||||
FUNC_FAILED_HEARTBEAT_MESSAGE,
|
||||
JSON_ENSURE_ASCII,
|
||||
JSON_LOADS_STRICT,
|
||||
REQ_HEARTBEAT_MESSAGE,
|
||||
)
|
||||
from memgpt.constants import FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.schemas.enums import OptionState
|
||||
|
||||
@@ -275,14 +270,14 @@ def run_agent_loop(
|
||||
if args_string is None:
|
||||
print("Assistant missing send_message function arguments")
|
||||
break # cancel op
|
||||
args_json = json.loads(args_string, strict=JSON_LOADS_STRICT)
|
||||
args_json = json_loads(args_string)
|
||||
if "message" not in args_json:
|
||||
print("Assistant missing send_message message argument")
|
||||
break # cancel op
|
||||
|
||||
# Once we found our target, rewrite it
|
||||
args_json["message"] = text
|
||||
new_args_string = json.dumps(args_json, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
new_args_string = json_dumps(args_json)
|
||||
message_obj.tool_calls[0].function["arguments"] = new_args_string
|
||||
|
||||
# To persist to the database, all we need to do is "re-insert" into recall memory
|
||||
|
||||
@@ -1,18 +1,10 @@
|
||||
# https://github.com/openai/openai-python/blob/v0.27.4/openai/openai_object.py
|
||||
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from memgpt.constants import JSON_ENSURE_ASCII
|
||||
|
||||
# import openai
|
||||
|
||||
|
||||
# from openai import api_requestor, util
|
||||
# from openai.openai_response import OpenAIResponse
|
||||
# from openai.util import ApiType
|
||||
from memgpt.utils import json_dumps
|
||||
|
||||
api_requestor = None
|
||||
api_resources = None
|
||||
@@ -342,7 +334,7 @@ class OpenAIObject(dict):
|
||||
|
||||
def __str__(self):
|
||||
obj = self.to_dict_recursive()
|
||||
return json.dumps(obj, sort_keys=True, indent=2, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return json_dumps(obj, sort_keys=True, indent=2)
|
||||
|
||||
def to_dict(self):
|
||||
return dict(self)
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
from ..constants import FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT, MAX_PAUSE_HEARTBEATS
|
||||
from ..constants import MAX_PAUSE_HEARTBEATS
|
||||
|
||||
request_heartbeat = {
|
||||
"request_heartbeat": {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.",
|
||||
}
|
||||
}
|
||||
|
||||
# FUNCTIONS_PROMPT_MULTISTEP_NO_HEARTBEATS = FUNCTIONS_PROMPT_MULTISTEP[:-1]
|
||||
FUNCTIONS_CHAINING = {
|
||||
@@ -42,12 +49,8 @@ FUNCTIONS_CHAINING = {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Message to send ChatGPT. Phrase your message as a full English sentence.",
|
||||
},
|
||||
"request_heartbeat": {
|
||||
"type": "boolean",
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
},
|
||||
},
|
||||
}
|
||||
}.update(request_heartbeat),
|
||||
"required": ["message", "request_heartbeat"],
|
||||
},
|
||||
},
|
||||
@@ -65,11 +68,7 @@ FUNCTIONS_CHAINING = {
|
||||
"type": "string",
|
||||
"description": "Content to write to the memory. All unicode (including emojis) are supported.",
|
||||
},
|
||||
"request_heartbeat": {
|
||||
"type": "boolean",
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
},
|
||||
},
|
||||
}.update(request_heartbeat),
|
||||
"required": ["name", "content", "request_heartbeat"],
|
||||
},
|
||||
},
|
||||
@@ -91,11 +90,7 @@ FUNCTIONS_CHAINING = {
|
||||
"type": "string",
|
||||
"description": "Content to write to the memory. All unicode (including emojis) are supported.",
|
||||
},
|
||||
"request_heartbeat": {
|
||||
"type": "boolean",
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
},
|
||||
},
|
||||
}.update(request_heartbeat),
|
||||
"required": ["name", "old_content", "new_content", "request_heartbeat"],
|
||||
},
|
||||
},
|
||||
@@ -113,11 +108,7 @@ FUNCTIONS_CHAINING = {
|
||||
"type": "integer",
|
||||
"description": "Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).",
|
||||
},
|
||||
"request_heartbeat": {
|
||||
"type": "boolean",
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
},
|
||||
},
|
||||
}.update(request_heartbeat),
|
||||
"required": ["query", "page", "request_heartbeat"],
|
||||
},
|
||||
},
|
||||
@@ -135,11 +126,7 @@ FUNCTIONS_CHAINING = {
|
||||
"type": "integer",
|
||||
"description": "Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).",
|
||||
},
|
||||
"request_heartbeat": {
|
||||
"type": "boolean",
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
},
|
||||
},
|
||||
}.update(request_heartbeat),
|
||||
"required": ["query", "request_heartbeat"],
|
||||
},
|
||||
},
|
||||
@@ -161,11 +148,7 @@ FUNCTIONS_CHAINING = {
|
||||
"type": "integer",
|
||||
"description": "Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).",
|
||||
},
|
||||
"request_heartbeat": {
|
||||
"type": "boolean",
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
},
|
||||
},
|
||||
}.update(request_heartbeat),
|
||||
"required": ["start_date", "end_date", "page", "request_heartbeat"],
|
||||
},
|
||||
},
|
||||
@@ -187,11 +170,7 @@ FUNCTIONS_CHAINING = {
|
||||
"type": "integer",
|
||||
"description": "Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).",
|
||||
},
|
||||
"request_heartbeat": {
|
||||
"type": "boolean",
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
},
|
||||
},
|
||||
}.update(request_heartbeat),
|
||||
"required": ["start_date", "end_date", "request_heartbeat"],
|
||||
},
|
||||
},
|
||||
@@ -205,11 +184,7 @@ FUNCTIONS_CHAINING = {
|
||||
"type": "string",
|
||||
"description": "Content to write to the memory. All unicode (including emojis) are supported.",
|
||||
},
|
||||
"request_heartbeat": {
|
||||
"type": "boolean",
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
},
|
||||
},
|
||||
}.update(request_heartbeat),
|
||||
"required": ["content", "request_heartbeat"],
|
||||
},
|
||||
},
|
||||
@@ -227,11 +202,7 @@ FUNCTIONS_CHAINING = {
|
||||
"type": "integer",
|
||||
"description": "Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).",
|
||||
},
|
||||
"request_heartbeat": {
|
||||
"type": "boolean",
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
},
|
||||
},
|
||||
}.update(request_heartbeat),
|
||||
"required": ["query", "request_heartbeat"],
|
||||
},
|
||||
},
|
||||
@@ -253,11 +224,7 @@ FUNCTIONS_CHAINING = {
|
||||
"type": "integer",
|
||||
"description": "How many lines to read (defaults to 1).",
|
||||
},
|
||||
"request_heartbeat": {
|
||||
"type": "boolean",
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
},
|
||||
},
|
||||
}.update(request_heartbeat),
|
||||
"required": ["filename", "line_start", "request_heartbeat"],
|
||||
},
|
||||
},
|
||||
@@ -275,11 +242,7 @@ FUNCTIONS_CHAINING = {
|
||||
"type": "string",
|
||||
"description": "Content to append to the file.",
|
||||
},
|
||||
"request_heartbeat": {
|
||||
"type": "boolean",
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
},
|
||||
},
|
||||
}.update(request_heartbeat),
|
||||
"required": ["filename", "content", "request_heartbeat"],
|
||||
},
|
||||
},
|
||||
@@ -301,11 +264,7 @@ FUNCTIONS_CHAINING = {
|
||||
"type": "string",
|
||||
"description": "A JSON string representing the request payload.",
|
||||
},
|
||||
"request_heartbeat": {
|
||||
"type": "boolean",
|
||||
"description": FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT,
|
||||
},
|
||||
},
|
||||
}.update(request_heartbeat),
|
||||
"required": ["method", "url", "request_heartbeat"],
|
||||
},
|
||||
},
|
||||
|
||||
@@ -6,13 +6,13 @@ from typing import List, Optional, Union
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from memgpt.constants import JSON_ENSURE_ASCII, TOOL_CALL_ID_MAX_LEN
|
||||
from memgpt.constants import TOOL_CALL_ID_MAX_LEN
|
||||
from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from memgpt.schemas.enums import MessageRole
|
||||
from memgpt.schemas.memgpt_base import MemGPTBase
|
||||
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
|
||||
from memgpt.schemas.openai.chat_completions import ToolCall
|
||||
from memgpt.utils import get_utc_time, is_utc_datetime
|
||||
from memgpt.utils import get_utc_time, is_utc_datetime, json_dumps
|
||||
|
||||
|
||||
def add_inner_thoughts_to_tool_call(
|
||||
@@ -29,7 +29,7 @@ def add_inner_thoughts_to_tool_call(
|
||||
func_args[inner_thoughts_key] = inner_thoughts
|
||||
# create the updated tool call (as a string)
|
||||
updated_tool_call = copy.deepcopy(tool_call)
|
||||
updated_tool_call.function.arguments = json.dumps(func_args, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
updated_tool_call.function.arguments = json_dumps(func_args)
|
||||
return updated_tool_call
|
||||
except json.JSONDecodeError as e:
|
||||
# TODO: change to logging
|
||||
@@ -517,7 +517,7 @@ class Message(BaseMessage):
|
||||
cohere_message = []
|
||||
for tc in self.tool_calls:
|
||||
# TODO better way to pack?
|
||||
function_call_text = json.dumps(tc.to_dict(), ensure_ascii=JSON_ENSURE_ASCII)
|
||||
function_call_text = json_dumps(tc.to_dict())
|
||||
cohere_message.append(
|
||||
{
|
||||
"role": function_call_role,
|
||||
|
||||
@@ -5,7 +5,11 @@ from typing import AsyncGenerator, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from memgpt.constants import JSON_ENSURE_ASCII
|
||||
from memgpt.orm.user import User
|
||||
from memgpt.orm.utilities import get_db_session
|
||||
from memgpt.server.rest_api.interface import StreamingServerInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.utils import json_dumps
|
||||
|
||||
SSE_PREFIX = "data: "
|
||||
SSE_SUFFIX = "\n\n"
|
||||
@@ -16,8 +20,28 @@ SSE_ARTIFICIAL_DELAY = 0.1
|
||||
def sse_formatter(data: Union[dict, str]) -> str:
|
||||
"""Prefix with 'data: ', and always include double newlines"""
|
||||
assert type(data) in [dict, str], f"Expected type dict or str, got type {type(data)}"
|
||||
data_str = json.dumps(data, ensure_ascii=JSON_ENSURE_ASCII) if isinstance(data, dict) else data
|
||||
return f"{SSE_PREFIX}{data_str}{SSE_SUFFIX}"
|
||||
data_str = json_dumps(data) if isinstance(data, dict) else data
|
||||
return f"data: {data_str}\n\n"
|
||||
|
||||
|
||||
async def sse_generator(generator: Generator[dict, None, None]) -> Generator[str, None, None]:
|
||||
"""Generator that returns 'data: dict' formatted items, e.g.:
|
||||
|
||||
data: {"id":"chatcmpl-9E0PdSZ2IBzAGlQ3SEWHJ5YwzucSP","object":"chat.completion.chunk","created":1713125205,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"}"}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-9E0PdSZ2IBzAGlQ3SEWHJ5YwzucSP","object":"chat.completion.chunk","created":1713125205,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}
|
||||
|
||||
data: [DONE]
|
||||
|
||||
"""
|
||||
try:
|
||||
for msg in generator:
|
||||
yield sse_formatter(msg)
|
||||
if SSE_ARTIFICIAL_DELAY:
|
||||
await asyncio.sleep(SSE_ARTIFICIAL_DELAY) # Sleep to prevent a tight loop, adjust time as needed
|
||||
except Exception as e:
|
||||
yield sse_formatter({"error": f"{str(e)}"})
|
||||
yield sse_formatter(SSE_FINISH_MSG) # Signal that the stream is complete
|
||||
|
||||
|
||||
async def sse_async_generator(generator: AsyncGenerator, finish_message=True):
|
||||
|
||||
@@ -20,7 +20,6 @@ from memgpt.agent import Agent, save_agent
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.cli.cli_config import get_model_options
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.data_sources.connectors import DataConnector, load_data
|
||||
|
||||
@@ -66,7 +65,7 @@ from memgpt.schemas.source import Source, SourceCreate, SourceUpdate
|
||||
from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate
|
||||
from memgpt.schemas.usage import MemGPTUsageStatistics
|
||||
from memgpt.schemas.user import User, UserCreate
|
||||
from memgpt.utils import create_random_username
|
||||
from memgpt.utils import create_random_username, json_dumps, json_loads
|
||||
|
||||
# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
|
||||
|
||||
@@ -557,11 +556,9 @@ class SyncServer(LockingServer):
|
||||
for x in range(len(memgpt_agent.messages) - 1, 0, -1):
|
||||
if memgpt_agent.messages[x].get("role") == "assistant":
|
||||
text = command[len("rewrite ") :].strip()
|
||||
args = json.loads(memgpt_agent.messages[x].get("function_call").get("arguments"), strict=JSON_LOADS_STRICT)
|
||||
args = json_loads(memgpt_agent.messages[x].get("function_call").get("arguments"))
|
||||
args["message"] = text
|
||||
memgpt_agent.messages[x].get("function_call").update(
|
||||
{"arguments": json.dumps(args, ensure_ascii=JSON_ENSURE_ASCII)}
|
||||
)
|
||||
memgpt_agent.messages[x].get("function_call").update({"arguments": json_dumps(args)})
|
||||
break
|
||||
|
||||
# No skip options
|
||||
|
||||
@@ -4,7 +4,6 @@ import json
|
||||
import websockets
|
||||
|
||||
import memgpt.server.ws_api.protocol as protocol
|
||||
from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
|
||||
from memgpt.server.constants import WS_CLIENT_TIMEOUT, WS_DEFAULT_PORT
|
||||
from memgpt.server.utils import condition_to_stop_receiving, print_server_response
|
||||
|
||||
@@ -27,12 +26,12 @@ async def send_message_and_print_replies(websocket, user_message, agent_id):
|
||||
# Wait for messages in a loop, since the server may send a few
|
||||
while True:
|
||||
response = await asyncio.wait_for(websocket.recv(), WS_CLIENT_TIMEOUT)
|
||||
response = json.loads(response, strict=JSON_LOADS_STRICT)
|
||||
response = json_loads(response)
|
||||
|
||||
if CLEAN_RESPONSES:
|
||||
print_server_response(response)
|
||||
else:
|
||||
print(f"Server response:\n{json.dumps(response, indent=2, ensure_ascii=JSON_ENSURE_ASCII)}")
|
||||
print(f"Server response:\n{json_dumps(response, indent=2)}")
|
||||
|
||||
# Check for a specific condition to break the loop
|
||||
if condition_to_stop_receiving(response):
|
||||
@@ -62,8 +61,8 @@ async def basic_cli_client():
|
||||
await websocket.send(protocol.client_command_create(example_config))
|
||||
# Wait for the response
|
||||
response = await websocket.recv()
|
||||
response = json.loads(response, strict=JSON_LOADS_STRICT)
|
||||
print(f"Server response:\n{json.dumps(response, indent=2, ensure_ascii=JSON_ENSURE_ASCII)}")
|
||||
response = json_loads(response)
|
||||
print(f"Server response:\n{json_dumps(response, indent=2)}")
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
@@ -1,89 +1,81 @@
|
||||
import json
|
||||
|
||||
from memgpt.constants import JSON_ENSURE_ASCII
|
||||
from memgpt.utils import json_dumps
|
||||
|
||||
# Server -> client
|
||||
|
||||
|
||||
def server_error(msg):
|
||||
"""General server error"""
|
||||
return json.dumps(
|
||||
return json_dumps(
|
||||
{
|
||||
"type": "server_error",
|
||||
"message": msg,
|
||||
},
|
||||
ensure_ascii=JSON_ENSURE_ASCII,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def server_command_response(status):
|
||||
return json.dumps(
|
||||
return json_dumps(
|
||||
{
|
||||
"type": "command_response",
|
||||
"status": status,
|
||||
},
|
||||
ensure_ascii=JSON_ENSURE_ASCII,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def server_agent_response_error(msg):
|
||||
return json.dumps(
|
||||
return json_dumps(
|
||||
{
|
||||
"type": "agent_response_error",
|
||||
"message": msg,
|
||||
},
|
||||
ensure_ascii=JSON_ENSURE_ASCII,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def server_agent_response_start():
|
||||
return json.dumps(
|
||||
return json_dumps(
|
||||
{
|
||||
"type": "agent_response_start",
|
||||
},
|
||||
ensure_ascii=JSON_ENSURE_ASCII,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def server_agent_response_end():
|
||||
return json.dumps(
|
||||
return json_dumps(
|
||||
{
|
||||
"type": "agent_response_end",
|
||||
},
|
||||
ensure_ascii=JSON_ENSURE_ASCII,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def server_agent_internal_monologue(msg):
|
||||
return json.dumps(
|
||||
return json_dumps(
|
||||
{
|
||||
"type": "agent_response",
|
||||
"message_type": "internal_monologue",
|
||||
"message": msg,
|
||||
},
|
||||
ensure_ascii=JSON_ENSURE_ASCII,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def server_agent_assistant_message(msg):
|
||||
return json.dumps(
|
||||
return json_dumps(
|
||||
{
|
||||
"type": "agent_response",
|
||||
"message_type": "assistant_message",
|
||||
"message": msg,
|
||||
},
|
||||
ensure_ascii=JSON_ENSURE_ASCII,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def server_agent_function_message(msg):
|
||||
return json.dumps(
|
||||
return json_dumps(
|
||||
{
|
||||
"type": "agent_response",
|
||||
"message_type": "function_message",
|
||||
"message": msg,
|
||||
},
|
||||
ensure_ascii=JSON_ENSURE_ASCII,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -91,22 +83,20 @@ def server_agent_function_message(msg):
|
||||
|
||||
|
||||
def client_user_message(msg, agent_id=None):
|
||||
return json.dumps(
|
||||
return json_dumps(
|
||||
{
|
||||
"type": "user_message",
|
||||
"message": msg,
|
||||
"agent_id": agent_id,
|
||||
},
|
||||
ensure_ascii=JSON_ENSURE_ASCII,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def client_command_create(config):
|
||||
return json.dumps(
|
||||
return json_dumps(
|
||||
{
|
||||
"type": "command",
|
||||
"command": "create_agent",
|
||||
"config": config,
|
||||
},
|
||||
ensure_ascii=JSON_ENSURE_ASCII,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -55,7 +55,7 @@ class WebSocketServer:
|
||||
|
||||
# Assuming the message is a JSON string
|
||||
try:
|
||||
data = json.loads(message, strict=JSON_LOADS_STRICT)
|
||||
data = json_loads(message)
|
||||
except:
|
||||
print(f"[server] bad data from client:\n{data}")
|
||||
await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}"))
|
||||
|
||||
@@ -6,10 +6,9 @@ from .constants import (
|
||||
INITIAL_BOOT_MESSAGE,
|
||||
INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG,
|
||||
INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT,
|
||||
JSON_ENSURE_ASCII,
|
||||
MESSAGE_SUMMARY_WARNING_STR,
|
||||
)
|
||||
from .utils import get_local_time
|
||||
from .utils import get_local_time, json_dumps
|
||||
|
||||
|
||||
def get_initial_boot_messages(version="startup"):
|
||||
@@ -98,7 +97,7 @@ def get_heartbeat(reason="Automated timer", include_location=False, location_nam
|
||||
if include_location:
|
||||
packaged_message["location"] = location_name
|
||||
|
||||
return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return json_dumps(packaged_message)
|
||||
|
||||
|
||||
def get_login_event(last_login="Never (first login)", include_location=False, location_name="San Francisco, CA, USA"):
|
||||
@@ -113,7 +112,7 @@ def get_login_event(last_login="Never (first login)", include_location=False, lo
|
||||
if include_location:
|
||||
packaged_message["location"] = location_name
|
||||
|
||||
return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return json_dumps(packaged_message)
|
||||
|
||||
|
||||
def package_user_message(
|
||||
@@ -137,7 +136,7 @@ def package_user_message(
|
||||
if name:
|
||||
packaged_message["name"] = name
|
||||
|
||||
return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return json_dumps(packaged_message)
|
||||
|
||||
|
||||
def package_function_response(was_success, response_string, timestamp=None):
|
||||
@@ -148,7 +147,7 @@ def package_function_response(was_success, response_string, timestamp=None):
|
||||
"time": formatted_time,
|
||||
}
|
||||
|
||||
return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return json_dumps(packaged_message)
|
||||
|
||||
|
||||
def package_system_message(system_message, message_type="system_alert", time=None):
|
||||
@@ -175,7 +174,7 @@ def package_summarize_message(summary, summary_length, hidden_message_count, tot
|
||||
"time": formatted_time,
|
||||
}
|
||||
|
||||
return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return json_dumps(packaged_message)
|
||||
|
||||
|
||||
def package_summarize_message_no_summary(hidden_message_count, timestamp=None, message=None):
|
||||
@@ -194,7 +193,7 @@ def package_summarize_message_no_summary(hidden_message_count, timestamp=None, m
|
||||
"time": formatted_time,
|
||||
}
|
||||
|
||||
return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return json_dumps(packaged_message)
|
||||
|
||||
|
||||
def get_token_limit_warning():
|
||||
@@ -205,4 +204,4 @@ def get_token_limit_warning():
|
||||
"time": formatted_time,
|
||||
}
|
||||
|
||||
return json.dumps(packaged_message, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return json_dumps(packaged_message)
|
||||
|
||||
@@ -28,12 +28,9 @@ from memgpt.constants import (
|
||||
CORE_MEMORY_HUMAN_CHAR_LIMIT,
|
||||
CORE_MEMORY_PERSONA_CHAR_LIMIT,
|
||||
FUNCTION_RETURN_CHAR_LIMIT,
|
||||
JSON_ENSURE_ASCII,
|
||||
JSON_LOADS_STRICT,
|
||||
MEMGPT_DIR,
|
||||
TOOL_CALL_ID_MAX_LEN,
|
||||
)
|
||||
from memgpt.openai_backcompat.openai_object import OpenAIObject
|
||||
from memgpt.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
|
||||
DEBUG = False
|
||||
@@ -782,6 +779,8 @@ def open_folder_in_explorer(folder_path):
|
||||
class OpenAIBackcompatUnpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
if module == "openai.openai_object":
|
||||
from memgpt.openai_backcompat.openai_object import OpenAIObject
|
||||
|
||||
return OpenAIObject
|
||||
return super().find_class(module, name)
|
||||
|
||||
@@ -873,7 +872,7 @@ def parse_json(string) -> dict:
|
||||
"""Parse JSON string into JSON with both json and demjson"""
|
||||
result = None
|
||||
try:
|
||||
result = json.loads(string, strict=JSON_LOADS_STRICT)
|
||||
result = json_loads(string)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"Error parsing json with json package: {e}")
|
||||
@@ -906,7 +905,7 @@ def validate_function_response(function_response_string: any, strict: bool = Fal
|
||||
# Allow dict through since it will be cast to json.dumps()
|
||||
try:
|
||||
# TODO find a better way to do this that won't result in double escapes
|
||||
function_response_string = json.dumps(function_response_string, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
function_response_string = json_dumps(function_response_string)
|
||||
except:
|
||||
raise ValueError(function_response_string)
|
||||
|
||||
@@ -1020,8 +1019,8 @@ def get_human_text(name: str):
|
||||
|
||||
def get_schema_diff(schema_a, schema_b):
|
||||
# Assuming f_schema and linked_function['json_schema'] are your JSON schemas
|
||||
f_schema_json = json.dumps(schema_a, indent=2, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
linked_function_json = json.dumps(schema_b, indent=2, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
f_schema_json = json_dumps(schema_a)
|
||||
linked_function_json = json_dumps(schema_b)
|
||||
|
||||
# Compute the difference using difflib
|
||||
difference = list(difflib.ndiff(f_schema_json.splitlines(keepends=True), linked_function_json.splitlines(keepends=True)))
|
||||
@@ -1056,3 +1055,11 @@ def create_uuid_from_string(val: str):
|
||||
"""
|
||||
hex_string = hashlib.md5(val.encode("UTF-8")).hexdigest()
|
||||
return uuid.UUID(hex=hex_string)
|
||||
|
||||
|
||||
def json_dumps(data, indent=2):
|
||||
return json.dumps(data, indent=indent, ensure_ascii=False)
|
||||
|
||||
|
||||
def json_loads(data):
|
||||
return json.loads(data, strict=False)
|
||||
|
||||
@@ -34,7 +34,6 @@ from tqdm import tqdm
|
||||
from memgpt import MemGPT, utils
|
||||
from memgpt.cli.cli_config import delete
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import JSON_ENSURE_ASCII
|
||||
|
||||
# TODO: update personas
|
||||
NESTED_PERSONA = "You are MemGPT DOC-QA bot. Your job is to answer questions about documents that are stored in your archival memory. The answer to the users question will ALWAYS be in your archival memory, so remember to keep searching if you can't find the answer. DO NOT STOP SEARCHING UNTIL YOU VERIFY THAT THE VALUE IS NOT A KEY. Do not stop making nested lookups until this condition is met." # TODO decide on a good persona/human
|
||||
@@ -71,7 +70,7 @@ def archival_memory_text_search(self, query: str, page: Optional[int] = 0) -> Op
|
||||
else:
|
||||
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
|
||||
results_formatted = [f"memory: {d.text}" for d in results]
|
||||
results_str = f"{results_pref} {json.dumps(results_formatted, ensure_ascii=JSON_ENSURE_ASCII)}"
|
||||
results_str = f"{results_pref} {utils.json_dumps(results_formatted)}"
|
||||
return results_str
|
||||
|
||||
|
||||
|
||||
121
tests/test_agent_function_update.py
Normal file
121
tests/test_agent_function_update.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import inspect
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from memgpt import constants, create_client
|
||||
from memgpt.functions.functions import USER_FUNCTIONS_DIR
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.settings import settings
|
||||
from memgpt.utils import assistant_function_to_tool, json_dumps, json_loads
|
||||
from tests.mock_factory.models import MockUserFactory
|
||||
from tests.utils import create_config, wipe_config
|
||||
|
||||
|
||||
def hello_world(self) -> str:
|
||||
"""Test function for agent to gain access to
|
||||
|
||||
Returns:
|
||||
str: A message for the world
|
||||
"""
|
||||
return "hello, world!"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def agent():
|
||||
"""Create a test agent that we can call functions on"""
|
||||
wipe_config()
|
||||
global client
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
create_config("openai")
|
||||
else:
|
||||
create_config("memgpt_hosted")
|
||||
|
||||
# create memgpt client
|
||||
client = create_client()
|
||||
|
||||
# ensure user exists
|
||||
user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid)
|
||||
if not client.server.get_user(user_id=user_id):
|
||||
client.server.create_user({"id": user_id})
|
||||
|
||||
agent_state = client.create_agent(
|
||||
preset=settings.preset,
|
||||
)
|
||||
|
||||
return client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def hello_world_function():
|
||||
with open(os.path.join(USER_FUNCTIONS_DIR, "hello_world.py"), "w", encoding="utf-8") as f:
|
||||
f.write(inspect.getsource(hello_world))
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ai_function_call():
|
||||
return Message(
|
||||
**assistant_function_to_tool(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I will now call hello world",
|
||||
"function_call": {
|
||||
"name": "hello_world",
|
||||
"arguments": json_dumps({}),
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_add_function_happy(agent, hello_world_function, ai_function_call):
|
||||
agent.add_function("hello_world")
|
||||
|
||||
assert "hello_world" in [f_schema["name"] for f_schema in agent.functions]
|
||||
assert "hello_world" in agent.functions_python.keys()
|
||||
|
||||
msgs, heartbeat_req, function_failed = agent._handle_ai_response(ai_function_call)
|
||||
content = json_loads(msgs[-1].to_openai_dict()["content"])
|
||||
assert content["message"] == "hello, world!"
|
||||
assert content["status"] == "OK"
|
||||
assert not function_failed
|
||||
|
||||
|
||||
def test_add_function_already_loaded(agent, hello_world_function):
|
||||
agent.add_function("hello_world")
|
||||
# no exception for duplicate loading
|
||||
agent.add_function("hello_world")
|
||||
|
||||
|
||||
def test_add_function_not_exist(agent):
|
||||
# pytest assert exception
|
||||
with pytest.raises(ValueError):
|
||||
agent.add_function("non_existent")
|
||||
|
||||
|
||||
def test_remove_function_happy(agent, hello_world_function):
|
||||
agent.add_function("hello_world")
|
||||
|
||||
# ensure function is loaded
|
||||
assert "hello_world" in [f_schema["name"] for f_schema in agent.functions]
|
||||
assert "hello_world" in agent.functions_python.keys()
|
||||
|
||||
agent.remove_function("hello_world")
|
||||
|
||||
assert "hello_world" not in [f_schema["name"] for f_schema in agent.functions]
|
||||
assert "hello_world" not in agent.functions_python.keys()
|
||||
|
||||
|
||||
def test_remove_function_not_exist(agent):
|
||||
# do not raise error
|
||||
agent.remove_function("non_existent")
|
||||
|
||||
|
||||
def test_remove_base_function_fails(agent):
|
||||
with pytest.raises(ValueError):
|
||||
agent.remove_function("send_message")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-vv", os.path.abspath(__file__)])
|
||||
@@ -1,8 +1,8 @@
|
||||
import json
|
||||
|
||||
import memgpt.system as system
|
||||
from memgpt.constants import JSON_ENSURE_ASCII
|
||||
from memgpt.local_llm.function_parser import patch_function
|
||||
from memgpt.utils import json_dumps
|
||||
|
||||
EXAMPLE_FUNCTION_CALL_SEND_MESSAGE = {
|
||||
"message_history": [
|
||||
@@ -32,7 +32,7 @@ EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING = {
|
||||
"content": "I'll append to memory.",
|
||||
"function_call": {
|
||||
"name": "core_memory_append",
|
||||
"arguments": json.dumps({"content": "new_stuff"}, ensure_ascii=JSON_ENSURE_ASCII),
|
||||
"arguments": json_dumps({"content": "new_stuff"}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from mempgt.utils import json_loads
|
||||
|
||||
import memgpt.local_llm.json_parser as json_parser
|
||||
from memgpt.constants import JSON_LOADS_STRICT
|
||||
from memgpt.constants import json
|
||||
|
||||
EXAMPLE_ESCAPED_UNDERSCORES = """{
|
||||
"function":"send\_message",
|
||||
@@ -90,7 +90,7 @@ def test_json_parsers():
|
||||
|
||||
for string in test_strings:
|
||||
try:
|
||||
json.loads(string, strict=JSON_LOADS_STRICT)
|
||||
json_loads(string)
|
||||
assert False, f"Test JSON string should have failed basic JSON parsing:\n{string}"
|
||||
except:
|
||||
print("String failed (expectedly)")
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import websockets
|
||||
|
||||
from memgpt.constants import JSON_ENSURE_ASCII
|
||||
from memgpt.server.constants import WS_DEFAULT_PORT
|
||||
from memgpt.server.ws_api.server import WebSocketServer
|
||||
from memgpt.utils import json_dumps
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -36,7 +35,7 @@ async def test_websocket_server():
|
||||
async with websockets.connect(uri) as websocket:
|
||||
# Initialize the server with a test config
|
||||
print("Sending config to server...")
|
||||
await websocket.send(json.dumps({"type": "initialize", "config": test_config}, ensure_ascii=JSON_ENSURE_ASCII))
|
||||
await websocket.send(json_dumps({"type": "initialize", "config": test_config}))
|
||||
# Wait for the response
|
||||
response = await websocket.recv()
|
||||
print(f"Response from the agent: {response}")
|
||||
@@ -45,7 +44,7 @@ async def test_websocket_server():
|
||||
|
||||
# Send a message to the agent
|
||||
print("Sending message to server...")
|
||||
await websocket.send(json.dumps({"type": "message", "content": "Hello, Agent!"}, ensure_ascii=JSON_ENSURE_ASCII))
|
||||
await websocket.send(json_dumps({"type": "message", "content": "Hello, Agent!"}))
|
||||
# Wait for the response
|
||||
# NOTE: we should be waiting for multiple responses
|
||||
response = await websocket.recv()
|
||||
|
||||
Reference in New Issue
Block a user