feat: add streaming support (both steps and tokens) to the Python REST client + pytest (#1701)

This commit is contained in:
Charles Packer
2024-09-03 16:57:57 -07:00
committed by GitHub
parent 282e7b5289
commit 35ebaef721
5 changed files with 208 additions and 25 deletions

View File

@@ -1,5 +1,5 @@
import time
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, Generator, List, Optional, Tuple, Union
import requests
@@ -23,11 +23,11 @@ from memgpt.schemas.block import (
from memgpt.schemas.embedding_config import EmbeddingConfig
# new schemas
from memgpt.schemas.enums import JobStatus
from memgpt.schemas.enums import JobStatus, MessageRole
from memgpt.schemas.job import Job
from memgpt.schemas.llm_config import LLMConfig
from memgpt.schemas.memgpt_request import MemGPTRequest
from memgpt.schemas.memgpt_response import MemGPTResponse
from memgpt.schemas.memgpt_response import MemGPTResponse, MemGPTStreamingResponse
from memgpt.schemas.memory import (
ArchivalMemorySummary,
ChatMemory,
@@ -419,15 +419,29 @@ class RESTClient(AbstractClient):
return [Message(**message) for message in response.json()]
def send_message(
self, agent_id: str, message: str, role: str, name: Optional[str] = None, stream: Optional[bool] = False
) -> MemGPTResponse:
messages = [MessageCreate(role=role, text=message, name=name)]
self,
agent_id: str,
message: str,
role: str,
name: Optional[str] = None,
stream_steps: bool = False,
stream_tokens: bool = False,
) -> Union[MemGPTResponse, Generator[MemGPTStreamingResponse, None, None]]:
messages = [MessageCreate(role=MessageRole(role), text=message, name=name)]
# TODO: figure out how to handle stream_steps and stream_tokens
request = MemGPTRequest(messages=messages, stream_steps=stream, return_message_object=True)
response = requests.post(f"{self.base_url}/api/agents/{agent_id}/messages", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to send message: {response.text}")
return MemGPTResponse(**response.json())
# When streaming steps is True, stream_tokens must be False
request = MemGPTRequest(messages=messages, stream_steps=stream_steps, stream_tokens=stream_tokens, return_message_object=True)
if stream_tokens or stream_steps:
from memgpt.client.streaming import _sse_post
request.return_message_object = False
return _sse_post(f"{self.base_url}/api/agents/{agent_id}/messages", request.model_dump(), self.headers)
else:
response = requests.post(f"{self.base_url}/api/agents/{agent_id}/messages", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to send message: {response.text}")
return MemGPTResponse(**response.json())
# humans / personas
@@ -957,7 +971,8 @@ class LocalClient(AbstractClient):
role: str,
agent_id: Optional[str] = None,
agent_name: Optional[str] = None,
stream: Optional[bool] = False,
stream_steps: bool = False,
stream_tokens: bool = False,
) -> MemGPTResponse:
if not agent_id:
assert agent_name, f"Either agent_id or agent_name must be provided"
@@ -966,7 +981,7 @@ class LocalClient(AbstractClient):
# agent_id = agent_state.id
agent_state = self.get_agent(agent_id=agent_id)
if stream:
if stream_steps or stream_tokens:
# TODO: implement streaming with stream=True/False
raise NotImplementedError
self.interface.clear()

View File

@@ -0,0 +1,90 @@
import json
from typing import Generator
import httpx
from httpx_sse import SSEError, connect_sse
from memgpt.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from memgpt.errors import LLMError
from memgpt.schemas.enums import MessageStreamStatus
from memgpt.schemas.memgpt_message import (
FunctionCallMessage,
FunctionReturn,
InternalMonologue,
)
from memgpt.schemas.memgpt_response import MemGPTStreamingResponse
def _sse_post(url: str, data: dict, headers: dict) -> Generator[MemGPTStreamingResponse, None, None]:
with httpx.Client() as client:
with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source:
# Inspect for errors before iterating (see https://github.com/florimondmanca/httpx-sse/pull/12)
if not event_source.response.is_success:
# handle errors
from memgpt.utils import printd
printd("Caught error before iterating SSE request:", vars(event_source.response))
printd(event_source.response.read())
try:
response_bytes = event_source.response.read()
response_dict = json.loads(response_bytes.decode("utf-8"))
error_message = response_dict["error"]["message"]
# e.g.: This model's maximum context length is 8192 tokens. However, your messages resulted in 8198 tokens (7450 in the messages, 748 in the functions). Please reduce the length of the messages or functions.
if OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in error_message:
raise LLMError(error_message)
except LLMError:
raise
except:
print(f"Failed to parse SSE message, throwing SSE HTTP error up the stack")
event_source.response.raise_for_status()
try:
for sse in event_source.iter_sse():
# if sse.data == OPENAI_SSE_DONE:
# print("finished")
# break
if sse.data in [status.value for status in MessageStreamStatus]:
# break
# print("sse.data::", sse.data)
yield MessageStreamStatus(sse.data)
else:
chunk_data = json.loads(sse.data)
if "internal_monologue" in chunk_data:
yield InternalMonologue(**chunk_data)
elif "function_call" in chunk_data:
yield FunctionCallMessage(**chunk_data)
elif "function_return" in chunk_data:
yield FunctionReturn(**chunk_data)
else:
raise ValueError(f"Unknown message type in chunk_data: {chunk_data}")
except SSEError as e:
print("Caught an error while iterating the SSE stream:", str(e))
if "application/json" in str(e): # Check if the error is because of JSON response
# TODO figure out a better way to catch the error other than re-trying with a POST
response = client.post(url=url, json=data, headers=headers) # Make the request again to get the JSON response
if response.headers["Content-Type"].startswith("application/json"):
error_details = response.json() # Parse the JSON to get the error message
print("Request:", vars(response.request))
print("POST Error:", error_details)
print("Original SSE Error:", str(e))
else:
print("Failed to retrieve JSON error message via retry.")
else:
print("SSEError not related to 'application/json' content type.")
# Optionally re-raise the exception if you need to propagate it
raise e
except Exception as e:
if event_source.response.request is not None:
print("HTTP Request:", vars(event_source.response.request))
if event_source.response is not None:
print("HTTP Status:", event_source.response.status_code)
print("HTTP Headers:", event_source.response.headers)
# print("HTTP Body:", event_source.response.text)
print("Exception message:", str(e))
raise e

View File

@@ -2,7 +2,7 @@ import json
from datetime import datetime, timezone
from typing import Literal, Optional, Union
from pydantic import BaseModel, field_serializer
from pydantic import BaseModel, field_serializer, field_validator
# MemGPT API style responses (intended to be easier to use vs getting true Message types)
@@ -79,6 +79,21 @@ class FunctionCallMessage(BaseMemGPTMessage):
FunctionCall: lambda v: v.model_dump(exclude_none=True),
}
# NOTE: this is required to cast dicts into FunctionCallMessage objects
# Without this extra validator, Pydantic will throw an error if 'name' or 'arguments' are None
# (instead of properly casting to FunctionCallDelta instead of FunctionCall)
@field_validator("function_call", mode="before")
@classmethod
def validate_function_call(cls, v):
if isinstance(v, dict):
if "name" in v and "arguments" in v:
return FunctionCall(name=v["name"], arguments=v["arguments"])
elif "name" in v or "arguments" in v:
return FunctionCallDelta(name=v.get("name"), arguments=v.get("arguments"))
else:
raise ValueError("function_call must contain either 'name' or 'arguments'")
return v
class FunctionReturn(BaseMemGPTMessage):
"""

View File

@@ -2,6 +2,7 @@ from typing import List, Union
from pydantic import BaseModel, Field
from memgpt.schemas.enums import MessageStreamStatus
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
from memgpt.schemas.message import Message
from memgpt.schemas.usage import MemGPTUsageStatistics
@@ -15,3 +16,7 @@ class MemGPTResponse(BaseModel):
..., description="The messages returned by the agent."
)
usage: MemGPTUsageStatistics = Field(..., description="The usage statistics of the agent.")
# The streaming response is either [DONE], [DONE_STEP], [DONE], an error, or a MemGPTMessage
MemGPTStreamingResponse = Union[MemGPTMessage, MessageStreamStatus]

View File

@@ -2,13 +2,18 @@ import os
import threading
import time
import uuid
from typing import Union
import pytest
from dotenv import load_dotenv
from memgpt import Admin, create_client
from memgpt.client.client import LocalClient, RESTClient
from memgpt.constants import DEFAULT_PRESET
from memgpt.schemas.enums import JobStatus
from memgpt.schemas.agent import AgentState
from memgpt.schemas.enums import JobStatus, MessageStreamStatus
from memgpt.schemas.memgpt_message import FunctionCallMessage, InternalMonologue
from memgpt.schemas.memgpt_response import MemGPTStreamingResponse
from memgpt.schemas.message import Message
from memgpt.schemas.usage import MemGPTUsageStatistics
@@ -77,7 +82,7 @@ def client(request):
# Fixture for test agent
@pytest.fixture(scope="module")
def agent(client):
def agent(client: Union[LocalClient, RESTClient]):
agent_state = client.create_agent(name=test_agent_name)
print("AGENT ID", agent_state.id)
yield agent_state
@@ -86,7 +91,7 @@ def agent(client):
client.delete_agent(agent_state.id)
def test_agent(client, agent):
def test_agent(client: Union[LocalClient, RESTClient], agent: AgentState):
# test client.rename_agent
new_name = "RenamedTestAgent"
@@ -101,7 +106,7 @@ def test_agent(client, agent):
assert client.agent_exists(agent_id=delete_agent.id) == False, "Agent deletion failed"
def test_memory(client, agent):
def test_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
memory_response = client.get_in_context_memory(agent_id=agent.id)
@@ -117,7 +122,7 @@ def test_memory(client, agent):
), "Memory update failed"
def test_agent_interactions(client, agent):
def test_agent_interactions(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
message = "Hello, agent!"
@@ -134,7 +139,7 @@ def test_agent_interactions(client, agent):
# TODO: add streaming tests
def test_archival_memory(client, agent):
def test_archival_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
memory_content = "Archival memory content"
@@ -168,7 +173,7 @@ def test_archival_memory(client, agent):
client.get_archival_memory(agent.id)
def test_core_memory(client, agent):
def test_core_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
response = client.send_message(agent_id=agent.id, message="Update your core memory to remember that my name is Timber!", role="user")
print("Response", response)
@@ -176,7 +181,7 @@ def test_core_memory(client, agent):
assert "Timber" in memory.get_block("human").value, f"Updating core memory failed: {memory.get_block('human').value}"
def test_messages(client, agent):
def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
@@ -186,7 +191,60 @@ def test_messages(client, agent):
assert len(messages_response) > 0, "Retrieving messages failed"
def test_humans_personas(client, agent):
def test_streaming_send_message(client: Union[LocalClient, RESTClient], agent: AgentState):
if isinstance(client, LocalClient):
pytest.skip("Skipping test_streaming_send_message because LocalClient does not support streaming")
assert isinstance(client, RESTClient), client
# First, try streaming just steps
# Next, try streaming both steps and tokens
response = client.send_message(
agent_id=agent.id,
message="This is a test. Repeat after me: 'banana'",
role="user",
stream_steps=True,
stream_tokens=True,
)
# Some manual checks to run
# 1. Check that there were inner thoughts
inner_thoughts_exist = False
# 2. Check that the agent runs `send_message`
send_message_ran = False
# 3. Check that we get all the start/stop/end tokens we want
# This includes all of the MessageStreamStatus enums
done_gen = False
done_step = False
done = False
# print(response)
assert response, "Sending message failed"
for chunk in response:
assert isinstance(chunk, MemGPTStreamingResponse)
if isinstance(chunk, InternalMonologue) and chunk.internal_monologue and chunk.internal_monologue != "":
inner_thoughts_exist = True
if isinstance(chunk, FunctionCallMessage) and chunk.function_call and chunk.function_call.name == "send_message":
send_message_ran = True
if isinstance(chunk, MessageStreamStatus):
if chunk == MessageStreamStatus.done:
assert not done, "Message stream already done"
done = True
elif chunk == MessageStreamStatus.done_step:
assert not done_step, "Message stream already done step"
done_step = True
elif chunk == MessageStreamStatus.done_generation:
assert not done_gen, "Message stream already done generation"
done_gen = True
assert inner_thoughts_exist, "No inner thoughts found"
assert send_message_ran, "send_message function call not found"
assert done, "Message stream not done"
assert done_step, "Message stream not done step"
assert done_gen, "Message stream not done generation"
def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
humans_response = client.list_humans()
@@ -221,7 +279,7 @@ def test_humans_personas(client, agent):
# assert tool_response, "Creating tool failed"
def test_config(client, agent):
def test_config(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
models_response = client.list_models()
@@ -236,7 +294,7 @@ def test_config(client, agent):
# print("CONFIG", config_response)
def test_sources(client, agent):
def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
# clear sources