From 35ebaef72137e15bf89c4c53d8085b3a6217d865 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Tue, 3 Sep 2024 16:57:57 -0700 Subject: [PATCH] feat: add streaming support (both steps and tokens) to the Python REST client + pytest (#1701) --- memgpt/client/client.py | 41 +++++++++----- memgpt/client/streaming.py | 90 +++++++++++++++++++++++++++++++ memgpt/schemas/memgpt_message.py | 17 +++++- memgpt/schemas/memgpt_response.py | 5 ++ tests/test_client.py | 80 +++++++++++++++++++++++---- 5 files changed, 208 insertions(+), 25 deletions(-) create mode 100644 memgpt/client/streaming.py diff --git a/memgpt/client/client.py b/memgpt/client/client.py index f788374f..e0e55bb7 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -1,5 +1,5 @@ import time -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Generator, List, Optional, Tuple, Union import requests @@ -23,11 +23,11 @@ from memgpt.schemas.block import ( from memgpt.schemas.embedding_config import EmbeddingConfig # new schemas -from memgpt.schemas.enums import JobStatus +from memgpt.schemas.enums import JobStatus, MessageRole from memgpt.schemas.job import Job from memgpt.schemas.llm_config import LLMConfig from memgpt.schemas.memgpt_request import MemGPTRequest -from memgpt.schemas.memgpt_response import MemGPTResponse +from memgpt.schemas.memgpt_response import MemGPTResponse, MemGPTStreamingResponse from memgpt.schemas.memory import ( ArchivalMemorySummary, ChatMemory, @@ -419,15 +419,29 @@ class RESTClient(AbstractClient): return [Message(**message) for message in response.json()] def send_message( - self, agent_id: str, message: str, role: str, name: Optional[str] = None, stream: Optional[bool] = False - ) -> MemGPTResponse: - messages = [MessageCreate(role=role, text=message, name=name)] + self, + agent_id: str, + message: str, + role: str, + name: Optional[str] = None, + stream_steps: bool = False, + stream_tokens: bool = False, + ) -> Union[MemGPTResponse, Generator[MemGPTStreamingResponse, None, None]]: + messages = [MessageCreate(role=MessageRole(role), text=message, name=name)] # TODO: figure out how to handle stream_steps and stream_tokens - request = MemGPTRequest(messages=messages, stream_steps=stream, return_message_object=True) - response = requests.post(f"{self.base_url}/api/agents/{agent_id}/messages", json=request.model_dump(), headers=self.headers) - if response.status_code != 200: - raise ValueError(f"Failed to send message: {response.text}") - return MemGPTResponse(**response.json()) + + # When streaming steps is True, stream_tokens must be False + request = MemGPTRequest(messages=messages, stream_steps=stream_steps, stream_tokens=stream_tokens, return_message_object=True) + if stream_tokens or stream_steps: + from memgpt.client.streaming import _sse_post + + request.return_message_object = False + return _sse_post(f"{self.base_url}/api/agents/{agent_id}/messages", request.model_dump(), self.headers) + else: + response = requests.post(f"{self.base_url}/api/agents/{agent_id}/messages", json=request.model_dump(), headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to send message: {response.text}") + return MemGPTResponse(**response.json()) # humans / personas @@ -957,7 +971,8 @@ class LocalClient(AbstractClient): role: str, agent_id: Optional[str] = None, agent_name: Optional[str] = None, - stream: Optional[bool] = False, + stream_steps: bool = False, + stream_tokens: bool = False, ) -> MemGPTResponse: if not agent_id: assert agent_name, f"Either agent_id or agent_name must be provided" @@ -966,7 +981,7 @@ class LocalClient(AbstractClient): # agent_id = agent_state.id agent_state = self.get_agent(agent_id=agent_id) - if stream: + if stream_steps or stream_tokens: # TODO: implement streaming with stream=True/False raise NotImplementedError self.interface.clear() diff --git a/memgpt/client/streaming.py b/memgpt/client/streaming.py new file mode 100644 index 00000000..0b5ee81e --- /dev/null +++ b/memgpt/client/streaming.py @@ -0,0 +1,90 @@ +import json +from typing import Generator + +import httpx +from httpx_sse import SSEError, connect_sse + +from memgpt.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING +from memgpt.errors import LLMError +from memgpt.schemas.enums import MessageStreamStatus +from memgpt.schemas.memgpt_message import ( + FunctionCallMessage, + FunctionReturn, + InternalMonologue, +) +from memgpt.schemas.memgpt_response import MemGPTStreamingResponse + + +def _sse_post(url: str, data: dict, headers: dict) -> Generator[MemGPTStreamingResponse, None, None]: + + with httpx.Client() as client: + with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source: + + # Inspect for errors before iterating (see https://github.com/florimondmanca/httpx-sse/pull/12) + if not event_source.response.is_success: + # handle errors + from memgpt.utils import printd + + printd("Caught error before iterating SSE request:", vars(event_source.response)) + printd(event_source.response.read()) + + try: + response_bytes = event_source.response.read() + response_dict = json.loads(response_bytes.decode("utf-8")) + error_message = response_dict["error"]["message"] + # e.g.: This model's maximum context length is 8192 tokens. However, your messages resulted in 8198 tokens (7450 in the messages, 748 in the functions). Please reduce the length of the messages or functions. + if OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in error_message: + raise LLMError(error_message) + except LLMError: + raise + except: + print(f"Failed to parse SSE message, throwing SSE HTTP error up the stack") + event_source.response.raise_for_status() + + try: + for sse in event_source.iter_sse(): + # if sse.data == OPENAI_SSE_DONE: + # print("finished") + # break + if sse.data in [status.value for status in MessageStreamStatus]: + # break + # print("sse.data::", sse.data) + yield MessageStreamStatus(sse.data) + else: + chunk_data = json.loads(sse.data) + if "internal_monologue" in chunk_data: + yield InternalMonologue(**chunk_data) + elif "function_call" in chunk_data: + yield FunctionCallMessage(**chunk_data) + elif "function_return" in chunk_data: + yield FunctionReturn(**chunk_data) + else: + raise ValueError(f"Unknown message type in chunk_data: {chunk_data}") + + except SSEError as e: + print("Caught an error while iterating the SSE stream:", str(e)) + if "application/json" in str(e): # Check if the error is because of JSON response + # TODO figure out a better way to catch the error other than re-trying with a POST + response = client.post(url=url, json=data, headers=headers) # Make the request again to get the JSON response + if response.headers["Content-Type"].startswith("application/json"): + error_details = response.json() # Parse the JSON to get the error message + print("Request:", vars(response.request)) + print("POST Error:", error_details) + print("Original SSE Error:", str(e)) + else: + print("Failed to retrieve JSON error message via retry.") + else: + print("SSEError not related to 'application/json' content type.") + + # Optionally re-raise the exception if you need to propagate it + raise e + + except Exception as e: + if event_source.response.request is not None: + print("HTTP Request:", vars(event_source.response.request)) + if event_source.response is not None: + print("HTTP Status:", event_source.response.status_code) + print("HTTP Headers:", event_source.response.headers) + # print("HTTP Body:", event_source.response.text) + print("Exception message:", str(e)) + raise e diff --git a/memgpt/schemas/memgpt_message.py b/memgpt/schemas/memgpt_message.py index 3e7d2fc5..4f9e9b72 100644 --- a/memgpt/schemas/memgpt_message.py +++ b/memgpt/schemas/memgpt_message.py @@ -2,7 +2,7 @@ import json from datetime import datetime, timezone from typing import Literal, Optional, Union -from pydantic import BaseModel, field_serializer +from pydantic import BaseModel, field_serializer, field_validator # MemGPT API style responses (intended to be easier to use vs getting true Message types) @@ -79,6 +79,21 @@ class FunctionCallMessage(BaseMemGPTMessage): FunctionCall: lambda v: v.model_dump(exclude_none=True), } + # NOTE: this is required to cast dicts into FunctionCallMessage objects + # Without this extra validator, Pydantic will throw an error if 'name' or 'arguments' are None + # (instead of properly casting to FunctionCallDelta instead of FunctionCall) + @field_validator("function_call", mode="before") + @classmethod + def validate_function_call(cls, v): + if isinstance(v, dict): + if "name" in v and "arguments" in v: + return FunctionCall(name=v["name"], arguments=v["arguments"]) + elif "name" in v or "arguments" in v: + return FunctionCallDelta(name=v.get("name"), arguments=v.get("arguments")) + else: + raise ValueError("function_call must contain either 'name' or 'arguments'") + return v + class FunctionReturn(BaseMemGPTMessage): """ diff --git a/memgpt/schemas/memgpt_response.py b/memgpt/schemas/memgpt_response.py index 00132e42..114195d5 100644 --- a/memgpt/schemas/memgpt_response.py +++ b/memgpt/schemas/memgpt_response.py @@ -2,6 +2,7 @@ from typing import List, Union from pydantic import BaseModel, Field +from memgpt.schemas.enums import MessageStreamStatus from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage from memgpt.schemas.message import Message from memgpt.schemas.usage import MemGPTUsageStatistics @@ -15,3 +16,7 @@ class MemGPTResponse(BaseModel): ..., description="The messages returned by the agent." ) usage: MemGPTUsageStatistics = Field(..., description="The usage statistics of the agent.") + + +# The streaming response is either [DONE], [DONE_STEP], [DONE], an error, or a MemGPTMessage +MemGPTStreamingResponse = Union[MemGPTMessage, MessageStreamStatus] diff --git a/tests/test_client.py b/tests/test_client.py index c3f5ee7d..b496f7ba 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2,13 +2,18 @@ import os import threading import time import uuid +from typing import Union import pytest from dotenv import load_dotenv from memgpt import Admin, create_client +from memgpt.client.client import LocalClient, RESTClient from memgpt.constants import DEFAULT_PRESET -from memgpt.schemas.enums import JobStatus +from memgpt.schemas.agent import AgentState +from memgpt.schemas.enums import JobStatus, MessageStreamStatus +from memgpt.schemas.memgpt_message import FunctionCallMessage, InternalMonologue +from memgpt.schemas.memgpt_response import MemGPTStreamingResponse from memgpt.schemas.message import Message from memgpt.schemas.usage import MemGPTUsageStatistics @@ -77,7 +82,7 @@ def client(request): # Fixture for test agent @pytest.fixture(scope="module") -def agent(client): +def agent(client: Union[LocalClient, RESTClient]): agent_state = client.create_agent(name=test_agent_name) print("AGENT ID", agent_state.id) yield agent_state @@ -86,7 +91,7 @@ def agent(client): client.delete_agent(agent_state.id) -def test_agent(client, agent): +def test_agent(client: Union[LocalClient, RESTClient], agent: AgentState): # test client.rename_agent new_name = "RenamedTestAgent" @@ -101,7 +106,7 @@ def test_agent(client, agent): assert client.agent_exists(agent_id=delete_agent.id) == False, "Agent deletion failed" -def test_memory(client, agent): +def test_memory(client: Union[LocalClient, RESTClient], agent: AgentState): # _reset_config() memory_response = client.get_in_context_memory(agent_id=agent.id) @@ -117,7 +122,7 @@ def test_memory(client, agent): ), "Memory update failed" -def test_agent_interactions(client, agent): +def test_agent_interactions(client: Union[LocalClient, RESTClient], agent: AgentState): # _reset_config() message = "Hello, agent!" @@ -134,7 +139,7 @@ def test_agent_interactions(client, agent): # TODO: add streaming tests -def test_archival_memory(client, agent): +def test_archival_memory(client: Union[LocalClient, RESTClient], agent: AgentState): # _reset_config() memory_content = "Archival memory content" @@ -168,7 +173,7 @@ def test_archival_memory(client, agent): client.get_archival_memory(agent.id) -def test_core_memory(client, agent): +def test_core_memory(client: Union[LocalClient, RESTClient], agent: AgentState): response = client.send_message(agent_id=agent.id, message="Update your core memory to remember that my name is Timber!", role="user") print("Response", response) @@ -176,7 +181,7 @@ def test_core_memory(client, agent): assert "Timber" in memory.get_block("human").value, f"Updating core memory failed: {memory.get_block('human').value}" -def test_messages(client, agent): +def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState): # _reset_config() send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user") @@ -186,7 +191,60 @@ def test_messages(client, agent): assert len(messages_response) > 0, "Retrieving messages failed" -def test_humans_personas(client, agent): +def test_streaming_send_message(client: Union[LocalClient, RESTClient], agent: AgentState): + if isinstance(client, LocalClient): + pytest.skip("Skipping test_streaming_send_message because LocalClient does not support streaming") + assert isinstance(client, RESTClient), client + + # First, try streaming just steps + + # Next, try streaming both steps and tokens + response = client.send_message( + agent_id=agent.id, + message="This is a test. Repeat after me: 'banana'", + role="user", + stream_steps=True, + stream_tokens=True, + ) + + # Some manual checks to run + # 1. Check that there were inner thoughts + inner_thoughts_exist = False + # 2. Check that the agent runs `send_message` + send_message_ran = False + # 3. Check that we get all the start/stop/end tokens we want + # This includes all of the MessageStreamStatus enums + done_gen = False + done_step = False + done = False + + # print(response) + assert response, "Sending message failed" + for chunk in response: + assert isinstance(chunk, MemGPTStreamingResponse) + if isinstance(chunk, InternalMonologue) and chunk.internal_monologue and chunk.internal_monologue != "": + inner_thoughts_exist = True + if isinstance(chunk, FunctionCallMessage) and chunk.function_call and chunk.function_call.name == "send_message": + send_message_ran = True + if isinstance(chunk, MessageStreamStatus): + if chunk == MessageStreamStatus.done: + assert not done, "Message stream already done" + done = True + elif chunk == MessageStreamStatus.done_step: + assert not done_step, "Message stream already done step" + done_step = True + elif chunk == MessageStreamStatus.done_generation: + assert not done_gen, "Message stream already done generation" + done_gen = True + + assert inner_thoughts_exist, "No inner thoughts found" + assert send_message_ran, "send_message function call not found" + assert done, "Message stream not done" + assert done_step, "Message stream not done step" + assert done_gen, "Message stream not done generation" + + +def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentState): # _reset_config() humans_response = client.list_humans() @@ -221,7 +279,7 @@ def test_humans_personas(client, agent): # assert tool_response, "Creating tool failed" -def test_config(client, agent): +def test_config(client: Union[LocalClient, RESTClient], agent: AgentState): # _reset_config() models_response = client.list_models() @@ -236,7 +294,7 @@ def test_config(client, agent): # print("CONFIG", config_response) -def test_sources(client, agent): +def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): # _reset_config() # clear sources