feat: return MemGPTMessage for Python Client (#1738)

This commit is contained in:
Sarah Wooders
2024-09-10 21:28:26 -07:00
committed by GitHub
parent 9fadd22fb8
commit c8fcbff39e
6 changed files with 49 additions and 19 deletions

View File

@@ -130,10 +130,11 @@ class AbstractClient(object):
agent_id: Optional[str] = None,
name: Optional[str] = None,
stream: Optional[bool] = False,
include_full_message: Optional[bool] = False,
) -> MemGPTResponse:
raise NotImplementedError
def user_message(self, agent_id: str, message: str) -> MemGPTResponse:
def user_message(self, agent_id: str, message: str, include_full_message: Optional[bool] = False) -> MemGPTResponse:
raise NotImplementedError
def create_human(self, name: str, text: str) -> Human:
@@ -586,7 +587,7 @@ class RESTClient(AbstractClient):
# agent interactions
def user_message(self, agent_id: str, message: str) -> MemGPTResponse:
def user_message(self, agent_id: str, message: str, include_full_message: Optional[bool] = False) -> MemGPTResponse:
"""
Send a message to an agent as a user
@@ -597,7 +598,7 @@ class RESTClient(AbstractClient):
Returns:
response (MemGPTResponse): Response from the agent
"""
return self.send_message(agent_id, message, role="user")
return self.send_message(agent_id, message, role="user", include_full_message=include_full_message)
def save(self):
raise NotImplementedError
@@ -690,6 +691,7 @@ class RESTClient(AbstractClient):
name: Optional[str] = None,
stream_steps: bool = False,
stream_tokens: bool = False,
include_full_message: Optional[bool] = False,
) -> Union[MemGPTResponse, Generator[MemGPTStreamingResponse, None, None]]:
"""
Send a message to an agent
@@ -705,6 +707,7 @@ class RESTClient(AbstractClient):
Returns:
response (MemGPTResponse): Response from the agent
"""
# TODO: implement include_full_message
messages = [MessageCreate(role=MessageRole(role), text=message, name=name)]
# TODO: figure out how to handle stream_steps and stream_tokens
@@ -721,7 +724,16 @@ class RESTClient(AbstractClient):
)
if response.status_code != 200:
raise ValueError(f"Failed to send message: {response.text}")
return MemGPTResponse(**response.json())
response = MemGPTResponse(**response.json())
# simplify messages
if not include_full_message:
messages = []
for message in response.messages:
messages += message.to_memgpt_message()
response.messages = messages
return response
# humans / personas
@@ -1356,6 +1368,10 @@ class LocalClient(AbstractClient):
self.interface = QueuingInterface(debug=debug)
self.server = SyncServer(default_interface_factory=lambda: self.interface)
# set logging levels
memgpt.utils.DEBUG = debug
logging.getLogger().setLevel(logging.CRITICAL)
# create user if does not exist
existing_user = self.server.get_user(self.user_id)
if not existing_user:
@@ -1662,6 +1678,7 @@ class LocalClient(AbstractClient):
agent_name: Optional[str] = None,
stream_steps: bool = False,
stream_tokens: bool = False,
include_full_message: Optional[bool] = False,
) -> MemGPTResponse:
"""
Send a message to an agent
@@ -1704,9 +1721,18 @@ class LocalClient(AbstractClient):
messages = self.interface.to_list()
for m in messages:
assert isinstance(m, Message), f"Expected Message object, got {type(m)}"
return MemGPTResponse(messages=messages, usage=usage)
def user_message(self, agent_id: str, message: str) -> MemGPTResponse:
# format messages
if include_full_message:
memgpt_messages = messages
else:
memgpt_messages = []
for m in messages:
memgpt_messages += m.to_memgpt_message()
return MemGPTResponse(messages=memgpt_messages, usage=usage)
def user_message(self, agent_id: str, message: str, include_full_message: Optional[bool] = False) -> MemGPTResponse:
"""
Send a message to an agent as a user
@@ -1718,7 +1744,7 @@ class LocalClient(AbstractClient):
response (MemGPTResponse): Response from the agent
"""
self.interface.clear()
return self.send_message(role="user", agent_id=agent_id, message=message)
return self.send_message(role="user", agent_id=agent_id, message=message, include_full_message=include_full_message)
def run_command(self, agent_id: str, command: str) -> MemGPTResponse:
"""
@@ -1957,6 +1983,8 @@ class LocalClient(AbstractClient):
# parse source code/schema
source_code = parse_source_code(func)
source_type = "python"
if not tags:
tags = []
# call server function
return self.server.create_tool(
@@ -2263,7 +2291,7 @@ class LocalClient(AbstractClient):
Returns:
models (List[EmbeddingConfig]): List of embedding models
"""
return self.server.list_embedding_models()
return [self.server.server_embedding_config]
def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool] = True) -> List[Block]:
"""

View File

@@ -155,11 +155,11 @@ class MemGPTConfig:
llm_config_dict = {k: v for k, v in llm_config_dict.items() if v is not None}
embedding_config_dict = {k: v for k, v in embedding_config_dict.items() if v is not None}
# Correct the types that aren't strings
if llm_config_dict["context_window"] is not None:
if "context_window" in llm_config_dict and llm_config_dict["context_window"] is not None:
llm_config_dict["context_window"] = int(llm_config_dict["context_window"])
if embedding_config_dict["embedding_dim"] is not None:
if "embedding_dim" in embedding_config_dict and embedding_config_dict["embedding_dim"] is not None:
embedding_config_dict["embedding_dim"] = int(embedding_config_dict["embedding_dim"])
if embedding_config_dict["embedding_chunk_size"] is not None:
if "embedding_chunk_size" in embedding_config_dict and embedding_config_dict["embedding_chunk_size"] is not None:
embedding_config_dict["embedding_chunk_size"] = int(embedding_config_dict["embedding_chunk_size"])
# Construct the inner properties
llm_config = LLMConfig(**llm_config_dict)

View File

@@ -42,6 +42,7 @@ from memgpt.streaming_interface import (
AgentChunkStreamingInterface,
AgentRefreshStreamingInterface,
)
from memgpt.utils import json_dumps
LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local"]

View File

@@ -889,7 +889,7 @@ class SyncServer(Server):
self,
user_id: Optional[str] = None,
label: Optional[str] = None,
template: bool = True,
template: Optional[bool] = None,
name: Optional[str] = None,
id: Optional[str] = None,
) -> Optional[List[Block]]:

7
poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]]
name = "aiohappyeyeballs"
@@ -7130,6 +7130,11 @@ files = [
{file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"},
{file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"},
{file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"},
{file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"},
{file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"},
{file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"},
{file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"},
{file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"},
]
[package.dependencies]

View File

@@ -127,7 +127,7 @@ def test_agent_interactions(client: Union[LocalClient, RESTClient], agent: Agent
message = "Hello, agent!"
print("Sending message", message)
response = client.user_message(agent_id=agent.id, message=message)
response = client.user_message(agent_id=agent.id, message=message, include_full_message=True)
print("Response", response)
assert isinstance(response.usage, MemGPTUsageStatistics)
assert response.usage.step_count == 1
@@ -401,11 +401,7 @@ def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentStat
"""Test that we can update the details of a message"""
# create a message
message_response = client.send_message(
agent_id=agent.id,
message="Test message",
role="user",
)
message_response = client.send_message(agent_id=agent.id, message="Test message", role="user", include_full_message=True)
print("Messages=", message_response)
assert isinstance(message_response, MemGPTResponse)
assert isinstance(message_response.messages[-1], Message)