feat: return MemGPTMessage for Python Client (#1738)
This commit is contained in:
@@ -130,10 +130,11 @@ class AbstractClient(object):
|
||||
agent_id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
stream: Optional[bool] = False,
|
||||
include_full_message: Optional[bool] = False,
|
||||
) -> MemGPTResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
def user_message(self, agent_id: str, message: str) -> MemGPTResponse:
|
||||
def user_message(self, agent_id: str, message: str, include_full_message: Optional[bool] = False) -> MemGPTResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
def create_human(self, name: str, text: str) -> Human:
|
||||
@@ -586,7 +587,7 @@ class RESTClient(AbstractClient):
|
||||
|
||||
# agent interactions
|
||||
|
||||
def user_message(self, agent_id: str, message: str) -> MemGPTResponse:
|
||||
def user_message(self, agent_id: str, message: str, include_full_message: Optional[bool] = False) -> MemGPTResponse:
|
||||
"""
|
||||
Send a message to an agent as a user
|
||||
|
||||
@@ -597,7 +598,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
response (MemGPTResponse): Response from the agent
|
||||
"""
|
||||
return self.send_message(agent_id, message, role="user")
|
||||
return self.send_message(agent_id, message, role="user", include_full_message=include_full_message)
|
||||
|
||||
def save(self):
|
||||
raise NotImplementedError
|
||||
@@ -690,6 +691,7 @@ class RESTClient(AbstractClient):
|
||||
name: Optional[str] = None,
|
||||
stream_steps: bool = False,
|
||||
stream_tokens: bool = False,
|
||||
include_full_message: Optional[bool] = False,
|
||||
) -> Union[MemGPTResponse, Generator[MemGPTStreamingResponse, None, None]]:
|
||||
"""
|
||||
Send a message to an agent
|
||||
@@ -705,6 +707,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
response (MemGPTResponse): Response from the agent
|
||||
"""
|
||||
# TODO: implement include_full_message
|
||||
messages = [MessageCreate(role=MessageRole(role), text=message, name=name)]
|
||||
# TODO: figure out how to handle stream_steps and stream_tokens
|
||||
|
||||
@@ -721,7 +724,16 @@ class RESTClient(AbstractClient):
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to send message: {response.text}")
|
||||
return MemGPTResponse(**response.json())
|
||||
response = MemGPTResponse(**response.json())
|
||||
|
||||
# simplify messages
|
||||
if not include_full_message:
|
||||
messages = []
|
||||
for message in response.messages:
|
||||
messages += message.to_memgpt_message()
|
||||
response.messages = messages
|
||||
|
||||
return response
|
||||
|
||||
# humans / personas
|
||||
|
||||
@@ -1356,6 +1368,10 @@ class LocalClient(AbstractClient):
|
||||
self.interface = QueuingInterface(debug=debug)
|
||||
self.server = SyncServer(default_interface_factory=lambda: self.interface)
|
||||
|
||||
# set logging levels
|
||||
memgpt.utils.DEBUG = debug
|
||||
logging.getLogger().setLevel(logging.CRITICAL)
|
||||
|
||||
# create user if does not exist
|
||||
existing_user = self.server.get_user(self.user_id)
|
||||
if not existing_user:
|
||||
@@ -1662,6 +1678,7 @@ class LocalClient(AbstractClient):
|
||||
agent_name: Optional[str] = None,
|
||||
stream_steps: bool = False,
|
||||
stream_tokens: bool = False,
|
||||
include_full_message: Optional[bool] = False,
|
||||
) -> MemGPTResponse:
|
||||
"""
|
||||
Send a message to an agent
|
||||
@@ -1704,9 +1721,18 @@ class LocalClient(AbstractClient):
|
||||
messages = self.interface.to_list()
|
||||
for m in messages:
|
||||
assert isinstance(m, Message), f"Expected Message object, got {type(m)}"
|
||||
return MemGPTResponse(messages=messages, usage=usage)
|
||||
|
||||
def user_message(self, agent_id: str, message: str) -> MemGPTResponse:
|
||||
# format messages
|
||||
if include_full_message:
|
||||
memgpt_messages = messages
|
||||
else:
|
||||
memgpt_messages = []
|
||||
for m in messages:
|
||||
memgpt_messages += m.to_memgpt_message()
|
||||
|
||||
return MemGPTResponse(messages=memgpt_messages, usage=usage)
|
||||
|
||||
def user_message(self, agent_id: str, message: str, include_full_message: Optional[bool] = False) -> MemGPTResponse:
|
||||
"""
|
||||
Send a message to an agent as a user
|
||||
|
||||
@@ -1718,7 +1744,7 @@ class LocalClient(AbstractClient):
|
||||
response (MemGPTResponse): Response from the agent
|
||||
"""
|
||||
self.interface.clear()
|
||||
return self.send_message(role="user", agent_id=agent_id, message=message)
|
||||
return self.send_message(role="user", agent_id=agent_id, message=message, include_full_message=include_full_message)
|
||||
|
||||
def run_command(self, agent_id: str, command: str) -> MemGPTResponse:
|
||||
"""
|
||||
@@ -1957,6 +1983,8 @@ class LocalClient(AbstractClient):
|
||||
# parse source code/schema
|
||||
source_code = parse_source_code(func)
|
||||
source_type = "python"
|
||||
if not tags:
|
||||
tags = []
|
||||
|
||||
# call server function
|
||||
return self.server.create_tool(
|
||||
@@ -2263,7 +2291,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
models (List[EmbeddingConfig]): List of embedding models
|
||||
"""
|
||||
return self.server.list_embedding_models()
|
||||
return [self.server.server_embedding_config]
|
||||
|
||||
def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool] = True) -> List[Block]:
|
||||
"""
|
||||
|
||||
@@ -155,11 +155,11 @@ class MemGPTConfig:
|
||||
llm_config_dict = {k: v for k, v in llm_config_dict.items() if v is not None}
|
||||
embedding_config_dict = {k: v for k, v in embedding_config_dict.items() if v is not None}
|
||||
# Correct the types that aren't strings
|
||||
if llm_config_dict["context_window"] is not None:
|
||||
if "context_window" in llm_config_dict and llm_config_dict["context_window"] is not None:
|
||||
llm_config_dict["context_window"] = int(llm_config_dict["context_window"])
|
||||
if embedding_config_dict["embedding_dim"] is not None:
|
||||
if "embedding_dim" in embedding_config_dict and embedding_config_dict["embedding_dim"] is not None:
|
||||
embedding_config_dict["embedding_dim"] = int(embedding_config_dict["embedding_dim"])
|
||||
if embedding_config_dict["embedding_chunk_size"] is not None:
|
||||
if "embedding_chunk_size" in embedding_config_dict and embedding_config_dict["embedding_chunk_size"] is not None:
|
||||
embedding_config_dict["embedding_chunk_size"] = int(embedding_config_dict["embedding_chunk_size"])
|
||||
# Construct the inner properties
|
||||
llm_config = LLMConfig(**llm_config_dict)
|
||||
|
||||
@@ -42,6 +42,7 @@ from memgpt.streaming_interface import (
|
||||
AgentChunkStreamingInterface,
|
||||
AgentRefreshStreamingInterface,
|
||||
)
|
||||
from memgpt.utils import json_dumps
|
||||
|
||||
LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local"]
|
||||
|
||||
|
||||
@@ -889,7 +889,7 @@ class SyncServer(Server):
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
label: Optional[str] = None,
|
||||
template: bool = True,
|
||||
template: Optional[bool] = None,
|
||||
name: Optional[str] = None,
|
||||
id: Optional[str] = None,
|
||||
) -> Optional[List[Block]]:
|
||||
|
||||
7
poetry.lock
generated
7
poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
@@ -7130,6 +7130,11 @@ files = [
|
||||
{file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"},
|
||||
{file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"},
|
||||
{file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"},
|
||||
{file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"},
|
||||
{file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"},
|
||||
{file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"},
|
||||
{file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"},
|
||||
{file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
||||
@@ -127,7 +127,7 @@ def test_agent_interactions(client: Union[LocalClient, RESTClient], agent: Agent
|
||||
|
||||
message = "Hello, agent!"
|
||||
print("Sending message", message)
|
||||
response = client.user_message(agent_id=agent.id, message=message)
|
||||
response = client.user_message(agent_id=agent.id, message=message, include_full_message=True)
|
||||
print("Response", response)
|
||||
assert isinstance(response.usage, MemGPTUsageStatistics)
|
||||
assert response.usage.step_count == 1
|
||||
@@ -401,11 +401,7 @@ def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentStat
|
||||
"""Test that we can update the details of a message"""
|
||||
|
||||
# create a message
|
||||
message_response = client.send_message(
|
||||
agent_id=agent.id,
|
||||
message="Test message",
|
||||
role="user",
|
||||
)
|
||||
message_response = client.send_message(agent_id=agent.id, message="Test message", role="user", include_full_message=True)
|
||||
print("Messages=", message_response)
|
||||
assert isinstance(message_response, MemGPTResponse)
|
||||
assert isinstance(message_response.messages[-1], Message)
|
||||
|
||||
Reference in New Issue
Block a user