diff --git a/memgpt/agent.py b/memgpt/agent.py index 26a903bf..d24181a7 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -403,6 +403,7 @@ class Agent(object): message_sequence: List[Message], function_call: str = "auto", first_message: bool = False, # hint + stream: bool = False, # TODO move to config? ) -> chat_completion_response.ChatCompletionResponse: """Get response from LLM API""" try: @@ -414,6 +415,9 @@ class Agent(object): function_call=function_call, # hint first_message=first_message, + # streaming + stream=stream, + stream_inferface=self.interface, ) # special case for 'length' if response.choices[0].finish_reason == "length": @@ -628,6 +632,7 @@ class Agent(object): skip_verify: bool = False, return_dicts: bool = True, # if True, return dicts, if False, return Message objects recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field + stream: bool = False, # TODO move to config? ) -> Tuple[List[Union[dict, Message]], bool, bool, bool]: """Top-level event message handler for the MemGPT agent""" @@ -710,6 +715,7 @@ class Agent(object): response = self._get_ai_reply( message_sequence=input_message_sequence, first_message=True, # passed through to the prompt formatter + stream=stream, ) if verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono): break @@ -721,6 +727,7 @@ class Agent(object): else: response = self._get_ai_reply( message_sequence=input_message_sequence, + stream=stream, ) # Step 2: check if LLM wanted to call a function diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index fe1361c5..02260050 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -13,7 +13,9 @@ import typer import questionary from memgpt.log import logger -from memgpt.interface import CLIInterface as interface # for printing to terminal + +# from memgpt.interface import CLIInterface as interface # for printing to terminal +from memgpt.streaming_interface import StreamingRefreshCLIInterface as interface # for printing to terminal from memgpt.cli.cli_config import configure import memgpt.presets.presets as presets import memgpt.utils as utils @@ -445,6 +447,8 @@ def run( debug: Annotated[bool, typer.Option(help="Use --debug to enable debugging output")] = False, no_verify: Annotated[bool, typer.Option(help="Bypass message verification")] = False, yes: Annotated[bool, typer.Option("-y", help="Skip confirmation prompt and use defaults")] = False, + # streaming + stream: Annotated[bool, typer.Option(help="Enables message streaming in the CLI (if the backend supports it)")] = False, ): """Start chatting with an MemGPT agent @@ -710,7 +714,9 @@ def run( from memgpt.main import run_agent_loop print() # extra space - run_agent_loop(memgpt_agent, config, first, ms, no_verify) # TODO: add back no_verify + run_agent_loop( + memgpt_agent=memgpt_agent, config=config, first=first, ms=ms, no_verify=no_verify, stream=stream + ) # TODO: add back no_verify def delete_agent( diff --git a/memgpt/config.py b/memgpt/config.py index 1aff760d..1df21fff 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -90,6 +90,7 @@ class MemGPTConfig: def load(cls) -> "MemGPTConfig": # avoid circular import from memgpt.migrate import config_is_compatible, VERSION_CUTOFF + from memgpt.utils import printd if not config_is_compatible(allow_empty=True): error_message = " ".join( @@ -110,7 +111,7 @@ class MemGPTConfig: # insure all configuration directories exist cls.create_config_dir() - print(f"Loading config from {config_path}") + printd(f"Loading config from {config_path}") if os.path.exists(config_path): # read existing config config.read(config_path) diff --git a/memgpt/llm_api/llm_api_tools.py b/memgpt/llm_api/llm_api_tools.py index 12444a07..fbbaa2f6 100644 --- a/memgpt/llm_api/llm_api_tools.py +++ b/memgpt/llm_api/llm_api_tools.py @@ -3,17 +3,18 @@ import time import requests import os import time -from typing import List +from typing import List, Optional, Union from memgpt.credentials import MemGPTCredentials from memgpt.local_llm.chat_completion_proxy import get_chat_completion from memgpt.constants import CLI_WARNING_PREFIX from memgpt.models.chat_completion_response import ChatCompletionResponse from memgpt.models.chat_completion_request import ChatCompletionRequest, Tool, cast_message_to_subtype +from memgpt.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface from memgpt.data_types import AgentState, Message -from memgpt.llm_api.openai import openai_chat_completions_request +from memgpt.llm_api.openai import openai_chat_completions_request, openai_chat_completions_process_stream from memgpt.llm_api.azure_openai import azure_openai_chat_completions_request, MODEL_TO_AZURE_ENGINE from memgpt.llm_api.google_ai import ( google_ai_chat_completions_request, @@ -126,14 +127,17 @@ def retry_with_exponential_backoff( def create( agent_state: AgentState, messages: List[Message], - functions=None, - functions_python=None, - function_call="auto", + functions: list = None, + functions_python: list = None, + function_call: str = "auto", # hint - first_message=False, + first_message: bool = False, # use tool naming? # if false, will use deprecated 'functions' style - use_tool_naming=True, + use_tool_naming: bool = True, + # streaming? + stream: bool = False, + stream_inferface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None, ) -> ChatCompletionResponse: """Return response to chat completion with backoff""" from memgpt.utils import printd @@ -169,11 +173,25 @@ def create( function_call=function_call, user=str(agent_state.user_id), ) - return openai_chat_completions_request( - url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions - api_key=credentials.openai_key, - data=data, - ) + + if stream: + data.stream = True + assert isinstance(stream_inferface, AgentChunkStreamingInterface) or isinstance( + stream_inferface, AgentRefreshStreamingInterface + ), type(stream_inferface) + return openai_chat_completions_process_stream( + url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions + api_key=credentials.openai_key, + chat_completion_request=data, + stream_inferface=stream_inferface, + ) + else: + data.stream = False + return openai_chat_completions_request( + url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions + api_key=credentials.openai_key, + chat_completion_request=data, + ) # azure elif agent_state.llm_config.model_endpoint_type == "azure": diff --git a/memgpt/llm_api/openai.py b/memgpt/llm_api/openai.py index f24b61b8..8b4d3474 100644 --- a/memgpt/llm_api/openai.py +++ b/memgpt/llm_api/openai.py @@ -1,11 +1,28 @@ import requests -import time -from typing import Union, Optional +import json +import httpx +from httpx_sse import connect_sse +from httpx_sse._exceptions import SSEError +from typing import Union, Optional, Generator -from memgpt.models.chat_completion_response import ChatCompletionResponse +from memgpt.models.chat_completion_response import ( + ChatCompletionResponse, + Choice, + Message, + ToolCall, + FunctionCall, + UsageStatistics, + ChatCompletionChunkResponse, +) from memgpt.models.chat_completion_request import ChatCompletionRequest from memgpt.models.embedding_response import EmbeddingResponse -from memgpt.utils import smart_urljoin +from memgpt.utils import smart_urljoin, get_utc_time +from memgpt.local_llm.utils import num_tokens_from_messages, num_tokens_from_functions +from memgpt.interface import AgentInterface +from memgpt.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface + + +OPENAI_SSE_DONE = "[DONE]" def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional[bool] = False) -> dict: @@ -58,13 +75,233 @@ def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional raise e -def openai_chat_completions_request(url: str, api_key: str, data: ChatCompletionRequest) -> ChatCompletionResponse: - """https://platform.openai.com/docs/guides/text-generation?lang=curl""" +def openai_chat_completions_process_stream( + url: str, + api_key: str, + chat_completion_request: ChatCompletionRequest, + stream_inferface: Optional[Union[AgentChunkStreamingInterface, AgentRefreshStreamingInterface]] = None, +) -> ChatCompletionResponse: + """Process a streaming completion response, and return a ChatCompletionRequest at the end. + + To "stream" the response in MemGPT, we want to call a streaming-compatible interface function + on the chunks received from the OpenAI-compatible server POST SSE response. + """ + assert chat_completion_request.stream == True + + # Count the prompt tokens + # TODO move to post-request? + chat_history = [m.model_dump(exclude_none=True) for m in chat_completion_request.messages] + # print(chat_history) + + prompt_tokens = num_tokens_from_messages( + messages=chat_history, + model=chat_completion_request.model, + ) + # We also need to add the cost of including the functions list to the input prompt + if chat_completion_request.tools is not None: + assert chat_completion_request.functions is None + prompt_tokens += num_tokens_from_functions( + functions=[t.function.model_dump() for t in chat_completion_request.tools], + model=chat_completion_request.model, + ) + elif chat_completion_request.functions is not None: + assert chat_completion_request.tools is None + prompt_tokens += num_tokens_from_functions( + functions=[f.model_dump() for f in chat_completion_request.functions], + model=chat_completion_request.model, + ) + + TEMP_STREAM_RESPONSE_ID = "temp_id" + TEMP_STREAM_FINISH_REASON = "temp_null" + TEMP_STREAM_TOOL_CALL_ID = "temp_id" + chat_completion_response = ChatCompletionResponse( + id=TEMP_STREAM_RESPONSE_ID, + choices=[], + created=get_utc_time(), + model=chat_completion_request.model, + usage=UsageStatistics( + completion_tokens=0, + prompt_tokens=prompt_tokens, + total_tokens=prompt_tokens, + ), + ) + + if stream_inferface: + stream_inferface.stream_start() + + n_chunks = 0 # approx == n_tokens + try: + for chunk_idx, chat_completion_chunk in enumerate( + openai_chat_completions_request_stream(url=url, api_key=api_key, chat_completion_request=chat_completion_request) + ): + assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk) + # print(chat_completion_chunk) + + if stream_inferface: + if isinstance(stream_inferface, AgentChunkStreamingInterface): + stream_inferface.process_chunk(chat_completion_chunk) + elif isinstance(stream_inferface, AgentRefreshStreamingInterface): + stream_inferface.process_refresh(chat_completion_response) + else: + raise TypeError(stream_inferface) + + if chunk_idx == 0: + # initialize the choice objects which we will increment with the deltas + num_choices = len(chat_completion_chunk.choices) + assert num_choices > 0 + chat_completion_response.choices = [ + Choice( + finish_reason=TEMP_STREAM_FINISH_REASON, # NOTE: needs to be ovrerwritten + index=i, + message=Message( + role="assistant", + ), + ) + for i in range(len(chat_completion_chunk.choices)) + ] + + # add the choice delta + assert len(chat_completion_chunk.choices) == len(chat_completion_response.choices), chat_completion_chunk + for chunk_choice in chat_completion_chunk.choices: + if chunk_choice.finish_reason is not None: + chat_completion_response.choices[chunk_choice.index].finish_reason = chunk_choice.finish_reason + + if chunk_choice.logprobs is not None: + chat_completion_response.choices[chunk_choice.index].logprobs = chunk_choice.logprobs + + accum_message = chat_completion_response.choices[chunk_choice.index].message + message_delta = chunk_choice.delta + + if message_delta.content is not None: + content_delta = message_delta.content + if accum_message.content is None: + accum_message.content = content_delta + else: + accum_message.content += content_delta + + if message_delta.tool_calls is not None: + tool_calls_delta = message_delta.tool_calls + + # If this is the first tool call showing up in a chunk, initialize the list with it + if accum_message.tool_calls is None: + accum_message.tool_calls = [ + ToolCall(id=TEMP_STREAM_TOOL_CALL_ID, function=FunctionCall(name="", arguments="")) + for _ in range(len(tool_calls_delta)) + ] + + for tool_call_delta in tool_calls_delta: + if tool_call_delta.id is not None: + # TODO assert that we're not overwriting? + # TODO += instead of =? + accum_message.tool_calls[tool_call_delta.index].id = tool_call_delta.id + if tool_call_delta.function is not None: + if tool_call_delta.function.name is not None: + # TODO assert that we're not overwriting? + # TODO += instead of =? + accum_message.tool_calls[tool_call_delta.index].function.name = tool_call_delta.function.name + if tool_call_delta.function.arguments is not None: + accum_message.tool_calls[tool_call_delta.index].function.arguments += tool_call_delta.function.arguments + + if message_delta.function_call is not None: + raise NotImplementedError(f"Old function_call style not support with stream=True") + + # overwrite response fields based on latest chunk + chat_completion_response.id = chat_completion_chunk.id + chat_completion_response.system_fingerprint = chat_completion_chunk.system_fingerprint + chat_completion_response.created = chat_completion_chunk.created + chat_completion_response.model = chat_completion_chunk.model + + # increment chunk counter + n_chunks += 1 + + except Exception as e: + if stream_inferface: + stream_inferface.stream_end() + print(f"Parsing ChatCompletion stream failed with error:\n{str(e)}") + raise e + finally: + if stream_inferface: + stream_inferface.stream_end() + + # make sure we didn't leave temp stuff in + assert all([c.finish_reason != TEMP_STREAM_FINISH_REASON for c in chat_completion_response.choices]) + assert all( + [ + all([tc != TEMP_STREAM_TOOL_CALL_ID for tc in c.message.tool_calls]) if c.message.tool_calls else True + for c in chat_completion_response.choices + ] + ) + assert chat_completion_response.id != TEMP_STREAM_RESPONSE_ID + + # compute token usage before returning + # TODO try actually computing the #tokens instead of assuming the chunks is the same + chat_completion_response.usage.completion_tokens = n_chunks + chat_completion_response.usage.total_tokens = prompt_tokens + n_chunks + + # printd(chat_completion_response) + return chat_completion_response + + +def _sse_post(url: str, data: dict, headers: dict) -> Generator[ChatCompletionChunkResponse, None, None]: + + with httpx.Client() as client: + with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source: + try: + for sse in event_source.iter_sse(): + # printd(sse.event, sse.data, sse.id, sse.retry) + if sse.data == OPENAI_SSE_DONE: + # print("finished") + break + else: + chunk_data = json.loads(sse.data) + # print("chunk_data::", chunk_data) + chunk_object = ChatCompletionChunkResponse(**chunk_data) + # print("chunk_object::", chunk_object) + # id=chunk_data["id"], + # choices=[ChunkChoice], + # model=chunk_data["model"], + # system_fingerprint=chunk_data["system_fingerprint"] + # ) + yield chunk_object + + except SSEError as e: + if "application/json" in str(e): # Check if the error is because of JSON response + response = client.post(url=url, json=data, headers=headers) # Make the request again to get the JSON response + if response.headers["Content-Type"].startswith("application/json"): + error_details = response.json() # Parse the JSON to get the error message + print("Error:", error_details) + print("Reqeust:", vars(response.request)) + else: + print("Failed to retrieve JSON error message.") + else: + print("SSEError not related to 'application/json' content type.") + + # Optionally re-raise the exception if you need to propagate it + raise e + + except Exception as e: + if event_source.response.request is not None: + print("HTTP Request:", vars(event_source.response.request)) + if event_source.response is not None: + print("HTTP Status:", event_source.response.status_code) + print("HTTP Headers:", event_source.response.headers) + # print("HTTP Body:", event_source.response.text) + print("Exception message:", str(e)) + raise e + + +def openai_chat_completions_request_stream( + url: str, + api_key: str, + chat_completion_request: ChatCompletionRequest, +) -> Generator[ChatCompletionChunkResponse, None, None]: from memgpt.utils import printd url = smart_urljoin(url, "chat/completions") headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} - data = data.model_dump(exclude_none=True) + data = chat_completion_request.model_dump(exclude_none=True) + + printd("Request:\n", json.dumps(data, indent=2)) # If functions == None, strip from the payload if "functions" in data and data["functions"] is None: @@ -77,21 +314,59 @@ def openai_chat_completions_request(url: str, api_key: str, data: ChatCompletion printd(f"Sending request to {url}") try: - # Example code to trigger a rate limit response: - # mock_response = requests.Response() - # mock_response.status_code = 429 - # http_error = requests.exceptions.HTTPError("429 Client Error: Too Many Requests") - # http_error.response = mock_response - # raise http_error + return _sse_post(url=url, data=data, headers=headers) + except requests.exceptions.HTTPError as http_err: + # Handle HTTP errors (e.g., response 4XX, 5XX) + printd(f"Got HTTPError, exception={http_err}, payload={data}") + raise http_err + except requests.exceptions.RequestException as req_err: + # Handle other requests-related errors (e.g., connection error) + printd(f"Got RequestException, exception={req_err}") + raise req_err + except Exception as e: + # Handle other potential errors + printd(f"Got unknown Exception, exception={e}") + raise e - # Example code to trigger a context overflow response (for an 8k model) - # data["messages"][-1]["content"] = " ".join(["repeat after me this is not a fluke"] * 1000) +def openai_chat_completions_request( + url: str, + api_key: str, + chat_completion_request: ChatCompletionRequest, +) -> ChatCompletionResponse: + """Send a ChatCompletion request to an OpenAI-compatible server + + If request.stream == True, will yield ChatCompletionChunkResponses + If request.stream == False, will return a ChatCompletionResponse + + https://platform.openai.com/docs/guides/text-generation?lang=curl + """ + from memgpt.utils import printd + + url = smart_urljoin(url, "chat/completions") + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + data = chat_completion_request.model_dump(exclude_none=True) + + printd("Request:\n", json.dumps(data, indent=2)) + + # If functions == None, strip from the payload + if "functions" in data and data["functions"] is None: + data.pop("functions") + data.pop("function_call", None) # extra safe, should exist always (default="auto") + + if "tools" in data and data["tools"] is None: + data.pop("tools") + data.pop("tool_choice", None) # extra safe, should exist always (default="auto") + + printd(f"Sending request to {url}") + try: response = requests.post(url, headers=headers, json=data) printd(f"response = {response}") response.raise_for_status() # Raises HTTPError for 4XX/5XX status + response = response.json() # convert to dict from string printd(f"response.json = {response}") + response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default return response except requests.exceptions.HTTPError as http_err: diff --git a/memgpt/local_llm/utils.py b/memgpt/local_llm/utils.py index 8306e91b..15c85c28 100644 --- a/memgpt/local_llm/utils.py +++ b/memgpt/local_llm/utils.py @@ -1,6 +1,7 @@ import os import requests import tiktoken +from typing import List import memgpt.local_llm.llm_chat_completion_wrappers.airoboros as airoboros import memgpt.local_llm.llm_chat_completion_wrappers.dolphin as dolphin @@ -74,6 +75,148 @@ def count_tokens(s: str, model: str = "gpt-4") -> int: return len(encoding.encode(s)) +def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"): + """Return the number of tokens used by a list of functions. + + Copied from https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/11 + """ + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + print("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + + num_tokens = 0 + for function in functions: + function_tokens = len(encoding.encode(function["name"])) + function_tokens += len(encoding.encode(function["description"])) + + if "parameters" in function: + parameters = function["parameters"] + if "properties" in parameters: + for propertiesKey in parameters["properties"]: + function_tokens += len(encoding.encode(propertiesKey)) + v = parameters["properties"][propertiesKey] + for field in v: + if field == "type": + function_tokens += 2 + function_tokens += len(encoding.encode(v["type"])) + elif field == "description": + function_tokens += 2 + function_tokens += len(encoding.encode(v["description"])) + elif field == "enum": + function_tokens -= 3 + for o in v["enum"]: + function_tokens += 3 + function_tokens += len(encoding.encode(o)) + else: + print(f"Warning: not supported field {field}") + function_tokens += 11 + + num_tokens += function_tokens + + num_tokens += 12 + return num_tokens + + +def num_tokens_from_tool_calls(tool_calls: List[dict], model: str = "gpt-4"): + """Based on above code (num_tokens_from_functions). + + Example to encode: + [{ + 'id': '8b6707cf-2352-4804-93db-0423f', + 'type': 'function', + 'function': { + 'name': 'send_message', + 'arguments': '{\n "message": "More human than human is our motto."\n}' + } + }] + """ + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + # print("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + + num_tokens = 0 + for tool_call in tool_calls: + function_tokens = len(encoding.encode(tool_call["id"])) + function_tokens += 2 + len(encoding.encode(tool_call["type"])) + function_tokens += 2 + len(encoding.encode(tool_call["function"]["name"])) + function_tokens += 2 + len(encoding.encode(tool_call["function"]["arguments"])) + + num_tokens += function_tokens + + # TODO adjust? + num_tokens += 12 + return num_tokens + + +def num_tokens_from_messages(messages: List[dict], model: str = "gpt-4") -> int: + """Return the number of tokens used by a list of messages. + + From: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + + For counting tokens in function calling RESPONSES, see: + https://hmarr.com/blog/counting-openai-tokens/, https://github.com/hmarr/openai-chat-tokens + + For counting tokens in function calling REQUESTS, see: + https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/11 + """ + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + # print("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + if model in { + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-4-0613", + "gpt-4-32k-0613", + }: + tokens_per_message = 3 + tokens_per_name = 1 + elif model == "gpt-3.5-turbo-0301": + tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + tokens_per_name = -1 # if there's a name, the role is omitted + elif "gpt-3.5-turbo" in model: + # print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") + return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613") + elif "gpt-4" in model: + # print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") + return num_tokens_from_messages(messages, model="gpt-4-0613") + else: + raise NotImplementedError( + f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" + ) + num_tokens = 0 + for message in messages: + num_tokens += tokens_per_message + for key, value in message.items(): + try: + + if isinstance(value, list) and key == "tool_calls": + num_tokens += num_tokens_from_tool_calls(tool_calls=value, model=model) + # special case for tool calling (list) + # num_tokens += len(encoding.encode(value["name"])) + # num_tokens += len(encoding.encode(value["arguments"])) + + else: + num_tokens += len(encoding.encode(value)) + + if key == "name": + num_tokens += tokens_per_name + + except TypeError as e: + print(f"tiktoken encoding failed on: {value}") + raise e + + num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> + return num_tokens + + def get_available_wrappers() -> dict: return { "experimental-wrapper-neural-chat-grammar-noforce": configurable_wrapper.ConfigurableJSONWrapper( diff --git a/memgpt/main.py b/memgpt/main.py index feb83b6e..c6329aa5 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -10,11 +10,10 @@ import typer from rich.console import Console from memgpt.constants import FUNC_FAILED_HEARTBEAT_MESSAGE, JSON_ENSURE_ASCII, JSON_LOADS_STRICT, REQ_HEARTBEAT_MESSAGE -console = Console() - -from memgpt.agent import save_agent from memgpt.agent_store.storage import StorageConnector, TableType -from memgpt.interface import CLIInterface as interface # for printing to terminal + +# from memgpt.interface import CLIInterface as interface # for printing to terminal +from memgpt.streaming_interface import AgentRefreshStreamingInterface from memgpt.config import MemGPTConfig import memgpt.agent as agent import memgpt.system as system @@ -27,6 +26,8 @@ from memgpt.metadata import MetadataStore # import benchmark from memgpt.benchmark.benchmark import bench +# interface = interface() + app = typer.Typer(pretty_exceptions_enable=False) app.command(name="run")(run) app.command(name="version")(version) @@ -47,7 +48,7 @@ app.command(name="benchmark")(bench) app.command(name="delete-agent")(delete_agent) -def clear_line(strip_ui=False): +def clear_line(console, strip_ui=False): if strip_ui: return if os.name == "nt": # for windows @@ -57,7 +58,19 @@ def clear_line(strip_ui=False): sys.stdout.flush() -def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, no_verify=False, cfg=None, strip_ui=False): +def run_agent_loop( + memgpt_agent: agent.Agent, config: MemGPTConfig, first, ms: MetadataStore, no_verify=False, cfg=None, strip_ui=False, stream=False +): + if isinstance(memgpt_agent.interface, AgentRefreshStreamingInterface): + # memgpt_agent.interface.toggle_streaming(on=stream) + if not stream: + memgpt_agent.interface = memgpt_agent.interface.nonstreaming_interface + + if hasattr(memgpt_agent.interface, "console"): + console = memgpt_agent.interface.console + else: + console = Console() + counter = 0 user_input = None skip_next_user_input = False @@ -65,8 +78,8 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, USER_GOES_FIRST = first if not USER_GOES_FIRST: - console.input("[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]") - clear_line(strip_ui) + console.input("[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]\n") + clear_line(console, strip_ui=strip_ui) print() multiline_input = False @@ -74,12 +87,16 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, while True: if not skip_next_user_input and (counter > 0 or USER_GOES_FIRST): # Ask for user input + if not stream: + print() user_input = questionary.text( "Enter your message:", multiline=multiline_input, qmark=">", ).ask() - clear_line(strip_ui) + clear_line(console, strip_ui=strip_ui) + if not stream: + print() # Gracefully exit on Ctrl-C/D if user_input is None: @@ -157,13 +174,13 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, command = user_input.strip().split() amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0 if amount == 0: - interface.print_messages(memgpt_agent._messages, dump=True) + memgpt_agent.interface.print_messages(memgpt_agent._messages, dump=True) else: - interface.print_messages(memgpt_agent._messages[-min(amount, len(memgpt_agent.messages)) :], dump=True) + memgpt_agent.interface.print_messages(memgpt_agent._messages[-min(amount, len(memgpt_agent.messages)) :], dump=True) continue elif user_input.lower() == "/dumpraw": - interface.print_messages_raw(memgpt_agent._messages) + memgpt_agent.interface.print_messages_raw(memgpt_agent._messages) continue elif user_input.lower() == "/memory": @@ -194,9 +211,7 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, else: print(f"Popping last {pop_amount} messages from stack") for _ in range(min(pop_amount, len(memgpt_agent.messages))): - memgpt_agent._messages.pop() - # Persist the state - save_agent(agent=memgpt_agent, ms=ms) + memgpt_agent.messages.pop() continue elif user_input.lower() == "/retry": @@ -218,13 +233,7 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, for x in range(len(memgpt_agent.messages) - 1, 0, -1): if memgpt_agent.messages[x].get("role") == "assistant": text = user_input[len("/rethink ") :].strip() - - # Do the /rethink-ing - message_obj = memgpt_agent._messages[x] - message_obj.text = text - - # To persist to the database, all we need to do is "re-insert" into recall memory - memgpt_agent.persistence_manager.recall_memory.storage.update(record=message_obj) + memgpt_agent.messages[x].update({"content": text}) break continue @@ -321,7 +330,7 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, # No skip options elif user_input.lower() == "/wipe": - memgpt_agent = agent.Agent(interface) + memgpt_agent = agent.Agent(memgpt_agent.interface) user_message = None elif user_input.lower() == "/heartbeat": @@ -354,7 +363,10 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, def process_agent_step(user_message, no_verify): new_messages, heartbeat_request, function_failed, token_warning, tokens_accumulated = memgpt_agent.step( - user_message, first_message=False, skip_verify=no_verify + user_message, + first_message=False, + skip_verify=no_verify, + stream=stream, ) skip_next_user_input = False @@ -376,9 +388,13 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) break else: - with console.status("[bold cyan]Thinking...") as status: + if stream: + # Don't display the "Thinking..." if streaming new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) - break + else: + with console.status("[bold cyan]Thinking...") as status: + new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) + break except KeyboardInterrupt: print("User interrupt occurred.") retry = questionary.confirm("Retry agent.step()?").ask() diff --git a/memgpt/models/chat_completion_response.py b/memgpt/models/chat_completion_response.py index 5369aeb3..983fd629 100644 --- a/memgpt/models/chat_completion_response.py +++ b/memgpt/models/chat_completion_response.py @@ -55,6 +55,8 @@ class UsageStatistics(BaseModel): class ChatCompletionResponse(BaseModel): + """https://platform.openai.com/docs/api-reference/chat/object""" + id: str choices: List[Choice] created: datetime.datetime @@ -64,3 +66,64 @@ class ChatCompletionResponse(BaseModel): # object: str = Field(default="chat.completion") object: Literal["chat.completion"] = "chat.completion" usage: UsageStatistics + + +class FunctionCallDelta(BaseModel): + # arguments: Optional[str] = None + name: Optional[str] = None + arguments: str + # name: str + + +class ToolCallDelta(BaseModel): + index: int + id: Optional[str] = None + # "Currently, only function is supported" + type: Literal["function"] = "function" + # function: ToolCallFunction + function: Optional[FunctionCallDelta] = None + + +class MessageDelta(BaseModel): + """Partial delta stream of a Message + + Example ChunkResponse: + { + 'id': 'chatcmpl-9EOCkKdicNo1tiL1956kPvCnL2lLS', + 'object': 'chat.completion.chunk', + 'created': 1713216662, + 'model': 'gpt-4-0613', + 'system_fingerprint': None, + 'choices': [{ + 'index': 0, + 'delta': {'content': 'User'}, + 'logprobs': None, + 'finish_reason': None + }] + } + """ + + content: Optional[str] = None + tool_calls: Optional[List[ToolCallDelta]] = None + # role: Optional[str] = None + function_call: Optional[FunctionCallDelta] = None # Deprecated + + +class ChunkChoice(BaseModel): + finish_reason: Optional[str] = None # NOTE: when streaming will be null + index: int + delta: MessageDelta + logprobs: Optional[Dict[str, Union[List[MessageContentLogProb], None]]] = None + + +class ChatCompletionChunkResponse(BaseModel): + """https://platform.openai.com/docs/api-reference/chat/streaming""" + + id: str + choices: List[ChunkChoice] + created: datetime.datetime + model: str + # system_fingerprint: str # docs say this is mandatory, but in reality API returns None + system_fingerprint: Optional[str] = None + # object: str = Field(default="chat.completion") + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" diff --git a/memgpt/streaming_interface.py b/memgpt/streaming_interface.py new file mode 100644 index 00000000..928810da --- /dev/null +++ b/memgpt/streaming_interface.py @@ -0,0 +1,398 @@ +from abc import ABC, abstractmethod +import json +import re +import sys +from typing import List, Optional + +# from colorama import Fore, Style, init +from rich.console import Console +from rich.live import Live +from rich.markup import escape +from rich.style import Style +from rich.text import Text + +from memgpt.utils import printd +from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT +from memgpt.data_types import Message +from memgpt.models.chat_completion_response import ChatCompletionChunkResponse, ChatCompletionResponse +from memgpt.interface import AgentInterface, CLIInterface + +# init(autoreset=True) + +# DEBUG = True # puts full message outputs in the terminal +DEBUG = False # only dumps important messages in the terminal + +STRIP_UI = False + + +class AgentChunkStreamingInterface(ABC): + """Interfaces handle MemGPT-related events (observer pattern) + + The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata. + """ + + @abstractmethod + def user_message(self, msg: str, msg_obj: Optional[Message] = None): + """MemGPT receives a user message""" + raise NotImplementedError + + @abstractmethod + def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): + """MemGPT generates some internal monologue""" + raise NotImplementedError + + @abstractmethod + def assistant_message(self, msg: str, msg_obj: Optional[Message] = None): + """MemGPT uses send_message""" + raise NotImplementedError + + @abstractmethod + def function_message(self, msg: str, msg_obj: Optional[Message] = None): + """MemGPT calls a function""" + raise NotImplementedError + + @abstractmethod + def process_chunk(self, chunk: ChatCompletionChunkResponse): + """Process a streaming chunk from an OpenAI-compatible server""" + raise NotImplementedError + + @abstractmethod + def stream_start(self): + """Any setup required before streaming begins""" + raise NotImplementedError + + @abstractmethod + def stream_end(self): + """Any cleanup required after streaming ends""" + raise NotImplementedError + + +class StreamingCLIInterface(AgentChunkStreamingInterface): + """Version of the CLI interface that attaches to a stream generator and prints along the way. + + When a chunk is received, we write the delta to the buffer. If the buffer type has changed, + we write out a newline + set the formatting for the new line. + + The two buffer types are: + (1) content (inner thoughts) + (2) tool_calls (function calling) + + NOTE: this assumes that the deltas received in the chunks are in-order, e.g. + that once 'content' deltas stop streaming, they won't be received again. See notes + on alternative version of the StreamingCLIInterface that does not have this same problem below: + + An alternative implementation could instead maintain the partial message state, and on each + process chunk (1) update the partial message state, (2) refresh/rewrite the state to the screen. + """ + + # CLIInterface is static/stateless + nonstreaming_interface = CLIInterface() + + def __init__(self): + """The streaming CLI interface state for determining which buffer is currently being written to""" + + self.streaming_buffer_type = None + + def _flush(self): + pass + + def process_chunk(self, chunk: ChatCompletionChunkResponse): + assert len(chunk.choices) == 1, chunk + + message_delta = chunk.choices[0].delta + + # Starting a new buffer line + if not self.streaming_buffer_type: + assert not ( + message_delta.content is not None and message_delta.tool_calls is not None and len(message_delta.tool_calls) + ), f"Error: got both content and tool_calls in message stream\n{message_delta}" + + if message_delta.content is not None: + # Write out the prefix for inner thoughts + print("Inner thoughts: ", end="", flush=True) + elif message_delta.tool_calls is not None: + assert len(message_delta.tool_calls) == 1, f"Error: got more than one tool call in response\n{message_delta}" + # Write out the prefix for function calling + print("Calling function: ", end="", flush=True) + + # Potentially switch/flush a buffer line + else: + pass + + # Write out the delta + if message_delta.content is not None: + if self.streaming_buffer_type and self.streaming_buffer_type != "content": + print() + self.streaming_buffer_type = "content" + + # Simple, just write out to the buffer + print(message_delta.content, end="", flush=True) + + elif message_delta.tool_calls is not None: + if self.streaming_buffer_type and self.streaming_buffer_type != "tool_calls": + print() + self.streaming_buffer_type = "tool_calls" + + assert len(message_delta.tool_calls) == 1, f"Error: got more than one tool call in response\n{message_delta}" + function_call = message_delta.tool_calls[0].function + + # Slightly more complex - want to write parameters in a certain way (paren-style) + # function_name(function_args) + if function_call.name: + # NOTE: need to account for closing the brace later + print(f"{function_call.name}(", end="", flush=True) + if function_call.arguments: + print(function_call.arguments, end="", flush=True) + + def stream_start(self): + # should be handled by stream_end(), but just in case + self.streaming_buffer_type = None + + def stream_end(self): + if self.streaming_buffer_type is not None: + # TODO: should have a separate self.tool_call_open_paren flag + if self.streaming_buffer_type == "tool_calls": + print(")", end="", flush=True) + + print() # newline to move the cursor + self.streaming_buffer_type = None # reset buffer tracker + + @staticmethod + def important_message(msg: str): + StreamingCLIInterface.nonstreaming_interface(msg) + + @staticmethod + def warning_message(msg: str): + StreamingCLIInterface.nonstreaming_interface(msg) + + @staticmethod + def internal_monologue(msg: str, msg_obj: Optional[Message] = None): + StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) + + @staticmethod + def assistant_message(msg: str, msg_obj: Optional[Message] = None): + StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) + + @staticmethod + def memory_message(msg: str, msg_obj: Optional[Message] = None): + StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) + + @staticmethod + def system_message(msg: str, msg_obj: Optional[Message] = None): + StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) + + @staticmethod + def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG): + StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) + + @staticmethod + def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG): + StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) + + @staticmethod + def print_messages(message_sequence: List[Message], dump=False): + StreamingCLIInterface.nonstreaming_interface(message_sequence, dump) + + @staticmethod + def print_messages_simple(message_sequence: List[Message]): + StreamingCLIInterface.nonstreaming_interface.print_messages_simple(message_sequence) + + @staticmethod + def print_messages_raw(message_sequence: List[Message]): + StreamingCLIInterface.nonstreaming_interface.print_messages_raw(message_sequence) + + @staticmethod + def step_yield(): + pass + + +class AgentRefreshStreamingInterface(ABC): + """Same as the ChunkStreamingInterface, but + + The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata. + """ + + @abstractmethod + def user_message(self, msg: str, msg_obj: Optional[Message] = None): + """MemGPT receives a user message""" + raise NotImplementedError + + @abstractmethod + def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): + """MemGPT generates some internal monologue""" + raise NotImplementedError + + @abstractmethod + def assistant_message(self, msg: str, msg_obj: Optional[Message] = None): + """MemGPT uses send_message""" + raise NotImplementedError + + @abstractmethod + def function_message(self, msg: str, msg_obj: Optional[Message] = None): + """MemGPT calls a function""" + raise NotImplementedError + + @abstractmethod + def process_refresh(self, response: ChatCompletionResponse): + """Process a streaming chunk from an OpenAI-compatible server""" + raise NotImplementedError + + @abstractmethod + def stream_start(self): + """Any setup required before streaming begins""" + raise NotImplementedError + + @abstractmethod + def stream_end(self): + """Any cleanup required after streaming ends""" + raise NotImplementedError + + @abstractmethod + def toggle_streaming(self, on: bool): + """Toggle streaming on/off (off = regular CLI interface)""" + raise NotImplementedError + + +class StreamingRefreshCLIInterface(AgentRefreshStreamingInterface): + """Version of the CLI interface that attaches to a stream generator and refreshes a render of the message at every step. + + We maintain the partial message state in the interface state, and on each + process chunk we: + (1) update the partial message state, + (2) refresh/rewrite the state to the screen. + """ + + nonstreaming_interface = CLIInterface + + def __init__(self, fancy: bool = True, separate_send_message: bool = True, disable_inner_mono_call: bool = True): + """Initialize the streaming CLI interface state.""" + self.console = Console() + + # Using `Live` with `refresh_per_second` parameter to limit the refresh rate, avoiding excessive updates + self.live = Live("", console=self.console, refresh_per_second=10) + # self.live.start() # Start the Live display context and keep it running + + # Use italics / emoji? + self.fancy = fancy + + self.streaming = True + self.separate_send_message = separate_send_message + self.disable_inner_mono_call = disable_inner_mono_call + + def toggle_streaming(self, on: bool): + self.streaming = on + if on: + self.separate_send_message = True + self.disable_inner_mono_call = True + else: + self.separate_send_message = False + self.disable_inner_mono_call = False + + def update_output(self, content: str): + """Update the displayed output with new content.""" + # We use the `Live` object's update mechanism to refresh content without clearing the console + if not self.fancy: + content = escape(content) + self.live.update(self.console.render_str(content), refresh=True) + + def process_refresh(self, response: ChatCompletionResponse): + """Process the response to rewrite the current output buffer.""" + if not response.choices: + self.update_output("💭 [italic]...[/italic]") + return # Early exit if there are no choices + + choice = response.choices[0] + inner_thoughts = choice.message.content if choice.message.content else "" + tool_calls = choice.message.tool_calls if choice.message.tool_calls else [] + + if self.fancy: + message_string = f"💭 [italic]{inner_thoughts}[/italic]" if inner_thoughts else "" + else: + message_string = "[inner thoughts] " + inner_thoughts if inner_thoughts else "" + + if tool_calls: + function_call = tool_calls[0].function + function_name = function_call.name # Function name, can be an empty string + function_args = function_call.arguments # Function arguments, can be an empty string + if message_string: + message_string += "\n" + # special case here for send_message + if self.separate_send_message and function_name == "send_message": + try: + message = json.loads(function_args)["message"] + except: + prefix = '{\n "message": "' + if len(function_args) < len(prefix): + message = "..." + elif function_args.startswith(prefix): + message = function_args[len(prefix) :] + else: + message = function_args + message_string += f"🤖 [bold yellow]{message}[/bold yellow]" + else: + message_string += f"{function_name}({function_args})" + + self.update_output(message_string) + + def stream_start(self): + if self.streaming: + print() + self.live.start() # Start the Live display context and keep it running + self.update_output("💭 [italic]...[/italic]") + + def stream_end(self): + if self.streaming: + if self.live.is_started: + self.live.stop() + print() + self.live = Live("", console=self.console, refresh_per_second=10) + + @staticmethod + def important_message(msg: str): + StreamingCLIInterface.nonstreaming_interface.important_message(msg) + + @staticmethod + def warning_message(msg: str): + StreamingCLIInterface.nonstreaming_interface.warning_message(msg) + + def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): + if self.disable_inner_mono_call: + return + StreamingCLIInterface.nonstreaming_interface.internal_monologue(msg, msg_obj) + + def assistant_message(self, msg: str, msg_obj: Optional[Message] = None): + if self.separate_send_message: + return + StreamingCLIInterface.nonstreaming_interface.assistant_message(msg, msg_obj) + + @staticmethod + def memory_message(msg: str, msg_obj: Optional[Message] = None): + StreamingCLIInterface.nonstreaming_interface.memory_message(msg, msg_obj) + + @staticmethod + def system_message(msg: str, msg_obj: Optional[Message] = None): + StreamingCLIInterface.nonstreaming_interface.system_message(msg, msg_obj) + + @staticmethod + def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG): + StreamingCLIInterface.nonstreaming_interface.user_message(msg, msg_obj) + + @staticmethod + def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG): + StreamingCLIInterface.nonstreaming_interface.function_message(msg, msg_obj) + + @staticmethod + def print_messages(message_sequence: List[Message], dump=False): + StreamingCLIInterface.nonstreaming_interface.print_messages(message_sequence, dump) + + @staticmethod + def print_messages_simple(message_sequence: List[Message]): + StreamingCLIInterface.nonstreaming_interface.print_messages_simple(message_sequence) + + @staticmethod + def print_messages_raw(message_sequence: List[Message]): + StreamingCLIInterface.nonstreaming_interface.print_messages_raw(message_sequence) + + @staticmethod + def step_yield(): + pass diff --git a/poetry.lock b/poetry.lock index dad0f731..2edac8eb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1524,6 +1524,17 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] +[[package]] +name = "httpx-sse" +version = "0.4.0" +description = "Consume Server-Sent Event (SSE) messages with HTTPX." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"}, + {file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"}, +] + [[package]] name = "huggingface-hub" version = "0.22.2" @@ -6094,4 +6105,4 @@ server = ["fastapi", "uvicorn", "websockets"] [metadata] lock-version = "2.0" python-versions = "<3.13,>=3.10" -content-hash = "5c36931d717323eab3eea32bf383b27578ea8f3467fd230ce543af364caffa92" +content-hash = "a9635dccf8bd7d826f776e36a9d6fbc845a1b7de0586d06c6a9ce7230a5a14bc" diff --git a/pyproject.toml b/pyproject.toml index 9458624e..f6ee2064 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ llama-index-embeddings-openai = "^0.1.1" llama-index-embeddings-huggingface = {version = "^0.2.0", optional = true} llama-index-embeddings-azure-openai = "^0.1.6" python-multipart = "^0.0.9" +httpx-sse = "^0.4.0" [tool.poetry.extras] local = ["llama-index-embeddings-huggingface"]