fix: (1) refactor in Agent.step() to fix out-of-order timestamps, (2) bug fixes with usage of preset/human vs filename values (#1145)

This commit is contained in:
Charles Packer
2024-03-16 20:11:31 -07:00
committed by GitHub
parent 70a9e9b81c
commit 16a9bddb22
16 changed files with 146 additions and 66 deletions

View File

@@ -1,5 +1,6 @@
from memgpt import create_client, Admin
from memgpt.constants import DEFAULT_PRESET, DEFAULT_HUMAN, DEFAULT_PERSONA
from memgpt.utils import get_human_text, get_persona_text
"""
@@ -23,7 +24,12 @@ def main():
client = create_client(base_url="http://localhost:8283", token=token)
# Create an agent
agent_info = client.create_agent(name="my_agent", preset=DEFAULT_PRESET, persona=DEFAULT_PERSONA, human=DEFAULT_HUMAN)
agent_info = client.create_agent(
name="my_agent",
preset=DEFAULT_PRESET,
persona_name=get_persona_text(DEFAULT_PERSONA),
human_name=get_human_text(DEFAULT_HUMAN),
)
print(f"Created agent: {agent_info.name} with ID {str(agent_info.id)}")
# Send a message to the agent

View File

@@ -199,8 +199,8 @@ class Agent(object):
init_agent_state = AgentState(
name=name if name else create_random_username(),
user_id=created_by,
persona=preset.persona_name,
human=preset.human_name,
persona=preset.persona,
human=preset.human,
llm_config=llm_config,
embedding_config=embedding_config,
preset=preset.name, # TODO link via preset.id instead of name?
@@ -609,47 +609,77 @@ class Agent(object):
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
skip_verify: bool = False,
return_dicts: bool = True, # if True, return dicts, if False, return Message objects
recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
) -> Tuple[List[Union[dict, Message]], bool, bool, bool]:
"""Top-level event message handler for the MemGPT agent"""
def strip_name_field_from_user_message(user_message_text: str) -> Tuple[str, Optional[str]]:
"""If 'name' exists in the JSON string, remove it and return the cleaned text + name value"""
try:
user_message_json = dict(json.loads(user_message_text, strict=JSON_LOADS_STRICT))
# Special handling for AutoGen messages with 'name' field
# Treat 'name' as a special field
# If it exists in the input message, elevate it to the 'message' level
name = user_message_json.pop("name", None)
clean_message = json.dumps(user_message_json, ensure_ascii=JSON_ENSURE_ASCII)
except Exception as e:
print(f"{CLI_WARNING_PREFIX}handling of 'name' field failed with: {e}")
return clean_message, name
def validate_json(user_message_text: str, raise_on_error: bool) -> str:
try:
user_message_json = dict(json.loads(user_message_text, strict=JSON_LOADS_STRICT))
user_message_json_val = json.dumps(user_message_json, ensure_ascii=JSON_ENSURE_ASCII)
return user_message_json_val
except Exception as e:
print(f"{CLI_WARNING_PREFIX}couldn't parse user input message as JSON: {e}")
if raise_on_error:
raise e
try:
# Step 0: add user message
if user_message is not None:
if isinstance(user_message, Message):
user_message_text = user_message.text
# Validate JSON via save/load
user_message_text = validate_json(user_message.text, False)
cleaned_user_message_text, name = strip_name_field_from_user_message(user_message_text)
if name is not None:
# Update Message object
user_message.text = cleaned_user_message_text
user_message.name = name
# Recreate timestamp
if recreate_message_timestamp:
user_message.created_at = datetime.datetime.now()
elif isinstance(user_message, str):
user_message_text = user_message
# Validate JSON via save/load
user_message = validate_json(user_message, False)
cleaned_user_message_text, name = strip_name_field_from_user_message(user_message)
# If user_message['name'] is not None, it will be handled properly by dict_to_message
# So no need to run strip_name_field_from_user_message
# Create the associated Message object (in the database)
user_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict={"role": "user", "content": cleaned_user_message_text, "name": name},
)
else:
raise ValueError(f"Bad type for user_message: {type(user_message)}")
packed_user_message = {"role": "user", "content": user_message_text}
# Special handling for AutoGen messages with 'name' field
try:
user_message_json = json.loads(user_message_text, strict=JSON_LOADS_STRICT)
# Special handling for AutoGen messages with 'name' field
# Treat 'name' as a special field
# If it exists in the input message, elevate it to the 'message' level
if "name" in user_message_json:
packed_user_message["name"] = user_message_json["name"]
user_message_json.pop("name", None)
packed_user_message["content"] = json.dumps(user_message_json, ensure_ascii=JSON_ENSURE_ASCII)
except Exception as e:
print(f"{CLI_WARNING_PREFIX}handling of 'name' field failed with: {e}")
self.interface.user_message(user_message.text, msg_obj=user_message)
# Create the associated Message object (in the database)
packed_user_message_obj = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict=packed_user_message,
)
self.interface.user_message(user_message_text, msg_obj=packed_user_message_obj)
input_message_sequence = self.messages + [packed_user_message]
input_message_sequence = self.messages + [user_message.to_openai_dict()]
# Alternatively, the requestor can send an empty user message
else:
input_message_sequence = self.messages
packed_user_message = None
if len(input_message_sequence) > 1 and input_message_sequence[-1]["role"] != "user":
printd(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue")
@@ -698,14 +728,7 @@ class Agent(object):
if isinstance(user_message, Message):
all_new_messages = [user_message] + all_response_messages
else:
all_new_messages = [
Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict=packed_user_message,
)
] + all_response_messages
raise ValueError(type(user_message))
else:
all_new_messages = all_response_messages

View File

@@ -7,6 +7,7 @@ from typing import Annotated
from memgpt import create_client
from memgpt.config import MemGPTConfig
from memgpt.utils import get_human_text, get_persona_text
# from memgpt.agent import Agent
from memgpt.errors import LLMJSONParsingError
@@ -55,7 +56,11 @@ def bench(
bench_id = uuid.uuid4()
for i in range(n_tries):
agent = client.create_agent(name=f"benchmark_{bench_id}_agent_{i}", persona=PERSONA, human=HUMAN)
agent = client.create_agent(
name=f"benchmark_{bench_id}_agent_{i}",
persona=get_persona_text(PERSONA),
human=get_human_text(HUMAN),
)
agent_id = agent.id
result, msg = send_message(

View File

@@ -1,5 +1,6 @@
import datetime
import requests
from requests.exceptions import RequestException
import uuid
from typing import Dict, List, Union, Optional, Tuple

View File

@@ -476,6 +476,8 @@ class AgentState:
self.name = name
self.user_id = user_id
self.preset = preset
# The INITIAL values of the persona and human
# The values inside self.state['persona'], self.state['human'] are the CURRENT values
self.persona = persona
self.human = human

View File

@@ -324,8 +324,10 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
id=agent_id,
name=agent_config["name"],
user_id=user.id,
persona=agent_config["persona"], # eg 'sam_pov'
human=agent_config["human"], # eg 'basic'
# persona_name=agent_config["persona"], # eg 'sam_pov'
# human_name=agent_config["human"], # eg 'basic'
persona=state_dict["memory"]["persona"], # NOTE: hacky (not init, but latest)
human=state_dict["memory"]["human"], # NOTE: hacky (not init, but latest)
preset=agent_config["preset"], # eg 'memgpt_chat'
state=dict(
human=state_dict["memory"]["human"],

View File

@@ -54,7 +54,7 @@ class ToolModel(BaseModel):
class AgentStateModel(BaseModel):
id: uuid.UUID = Field(..., description="The unique identifier of the agent.")
name: str = Field(..., description="The name of the agent.")
description: str = Field(None, description="The description of the agent.")
description: Optional[str] = Field(None, description="The description of the agent.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the agent.")
# timestamps

View File

@@ -9,13 +9,15 @@ from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.data_types import AgentState
from memgpt.models.pydantic_models import LLMConfigModel, EmbeddingConfigModel, AgentStateModel
from memgpt.models.pydantic_models import LLMConfigModel, EmbeddingConfigModel, AgentStateModel, PresetModel
router = APIRouter()
class ListAgentsResponse(BaseModel):
num_agents: int = Field(..., description="The number of agents available to the user.")
# TODO make return type List[AgentStateModel]
# also return - presets: List[PresetModel]
agents: List[dict] = Field(..., description="List of agent configurations.")
@@ -25,6 +27,7 @@ class CreateAgentRequest(BaseModel):
class CreateAgentResponse(BaseModel):
agent_state: AgentStateModel = Field(..., description="The state of the newly created agent.")
preset: PresetModel = Field(..., description="The preset that the agent was created from.")
def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, password: str):
@@ -62,8 +65,8 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
# TODO turn into a pydantic model
name=request.config["name"],
preset=request.config["preset"] if "preset" in request.config else None,
# persona_name=request.config["persona_name"] if "persona_text" in request.config else None,
# human_name=request.config["human_name"] if "human_text" in request.config else None,
# persona_name=request.config["persona_name"] if "persona_name" in request.config else None,
# human_name=request.config["human_name"] if "human_name" in request.config else None,
persona=request.config["persona"] if "persona" in request.config else None,
human=request.config["human"] if "human" in request.config else None,
# llm_config=LLMConfigModel(
@@ -76,6 +79,10 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
raise
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
# TODO when get_preset returns a PresetModel instead of Preset, we can remove this packing/unpacking line
preset = server.ms.get_preset(name=agent_state.preset, user_id=user_id)
return CreateAgentResponse(
agent_state=AgentStateModel(
id=agent_state.id,
@@ -89,11 +96,19 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
state=agent_state.state,
created_at=int(agent_state.created_at.timestamp()),
functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead
)
),
preset=PresetModel(
name=preset.name,
id=preset.id,
user_id=preset.user_id,
description=preset.description,
created_at=preset.created_at,
system=preset.system,
persona=preset.persona,
human=preset.human,
functions_schema=preset.functions_schema,
),
)
# return CreateAgentResponse(
# agent_state=AgentStateModel(
# )
except Exception as e:
print(str(e))
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -16,6 +16,7 @@ import memgpt.server.utils as server_utils
import memgpt.system as system
from memgpt.agent import Agent, save_agent
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.utils import get_human_text, get_persona_text
# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
from memgpt.cli.cli_config import get_model_options
@@ -572,8 +573,8 @@ class SyncServer(LockingServer):
user_id: uuid.UUID,
name: Optional[str] = None,
preset: Optional[str] = None,
persona: Optional[str] = None,
human: Optional[str] = None,
persona: Optional[str] = None, # NOTE: this is not the name, it's the memory init value
human: Optional[str] = None, # NOTE: this is not the name, it's the memory init value
llm_config: Optional[LLMConfig] = None,
embedding_config: Optional[EmbeddingConfig] = None,
interface: Union[AgentInterface, None] = None,
@@ -600,14 +601,34 @@ class SyncServer(LockingServer):
# TODO: fix this db dependency and remove
# self.ms.create_agent(agent_state)
# TODO modify to do creation via preset
try:
preset_obj = self.ms.get_preset(name=preset if preset else self.config.preset, user_id=user_id)
assert preset_obj is not None, f"preset {preset if preset else self.config.preset} does not exist"
logger.debug(f"Attempting to create agent from preset:\n{preset_obj}")
# Overwrite fields in the preset if they were specified
preset_obj.human = human if human else self.config.human
preset_obj.persona = persona if persona else self.config.persona
if human is not None:
preset_obj.human = human
# This is a check for a common bug where users were providing filenames instead of values
try:
get_human_text(human)
raise ValueError(human)
raise UserWarning(
f"It looks like there is a human file named {human} - did you mean to pass the file contents to the `human` arg?"
)
except:
pass
if persona is not None:
preset_obj.persona = persona
try:
get_persona_text(persona)
raise ValueError(persona)
raise UserWarning(
f"It looks like there is a persona file named {persona} - did you mean to pass the file contents to the `persona` arg?"
)
except:
pass
llm_config = llm_config if llm_config else self.server_llm_config
embedding_config = embedding_config if embedding_config else self.server_embedding_config
@@ -694,6 +715,7 @@ class SyncServer(LockingServer):
}
return agent_config
# TODO make return type pydantic
def list_agents(self, user_id: uuid.UUID) -> dict:
"""List all available agents to a user"""
if self.ms.get_user(user_id=user_id) is None:
@@ -711,6 +733,12 @@ class SyncServer(LockingServer):
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)
# TODO remove this eventually when return type get pydanticfied
# this is to add persona_name and human_name so that the columns in UI can populate
preset = self.ms.get_preset(name=agent_state.preset, user_id=user_id)
return_dict["persona_name"] = preset.persona_name
return_dict["human_name"] = preset.human_name
# Add information about tools
# TODO memgpt_agent should really have a field of List[ToolModel]
# then we could just pull that field and return it here

View File

@@ -43,8 +43,7 @@ def agent():
client.server.create_user({"id": user_id})
agent_state = client.create_agent(
persona=constants.DEFAULT_PERSONA,
human=constants.DEFAULT_HUMAN,
preset=constants.DEFAULT_PRESET,
)
return client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)

View File

@@ -25,8 +25,7 @@ def create_test_agent():
client = create_client()
agent_state = client.create_agent(
persona=constants.DEFAULT_PERSONA,
human=constants.DEFAULT_HUMAN,
preset=constants.DEFAULT_PRESET,
)
global agent_obj

View File

@@ -12,6 +12,7 @@ from memgpt.cli.cli_load import load_directory
from memgpt.credentials import MemGPTCredentials
from memgpt.metadata import MetadataStore
from memgpt.data_types import User, AgentState, EmbeddingConfig
from memgpt.utils import get_human_text, get_persona_text
from tests import TEST_MEMGPT_CONFIG
from .utils import wipe_config
@@ -113,8 +114,8 @@ def test_load_directory(
user_id=user.id,
name="test_agent",
preset=TEST_MEMGPT_CONFIG.preset,
persona=TEST_MEMGPT_CONFIG.persona,
human=TEST_MEMGPT_CONFIG.human,
persona=get_persona_text(TEST_MEMGPT_CONFIG.persona),
human=get_human_text(TEST_MEMGPT_CONFIG.human),
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
)

View File

@@ -49,8 +49,8 @@ def test_storage(storage_connector):
user_id=user_1.id,
name="agent_1",
preset=DEFAULT_PRESET,
persona=DEFAULT_PERSONA,
human=DEFAULT_HUMAN,
persona=get_persona_text(DEFAULT_PERSONA),
human=get_human_text(DEFAULT_HUMAN),
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
)

View File

@@ -99,8 +99,6 @@ def agent_id(server, user_id):
user_id=user_id,
name="test_agent",
preset="memgpt_chat",
human="cs_phd",
persona="sam_pov",
)
print(f"Created agent\n{agent_state}")
yield agent_state.id

View File

@@ -11,6 +11,7 @@ from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.metadata import MetadataStore
from memgpt.data_types import User
from memgpt.constants import MAX_EMBEDDING_DIM
from memgpt.utils import get_human_text, get_persona_text
from datetime import datetime, timedelta
@@ -175,8 +176,10 @@ def test_storage(
name="agent_1",
id=agent_1_id,
preset=TEST_MEMGPT_CONFIG.preset,
persona=TEST_MEMGPT_CONFIG.persona,
human=TEST_MEMGPT_CONFIG.human,
# persona_name=TEST_MEMGPT_CONFIG.persona,
# human_name=TEST_MEMGPT_CONFIG.human,
persona=get_persona_text(TEST_MEMGPT_CONFIG.persona),
human=get_human_text(TEST_MEMGPT_CONFIG.human),
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
)

View File

@@ -24,8 +24,6 @@ def create_test_agent():
client = create_client()
agent_state = client.create_agent(
name=test_agent_name,
persona=constants.DEFAULT_PERSONA,
human=constants.DEFAULT_HUMAN,
)
global agent_obj