fix: (1) refactor in Agent.step() to fix out-of-order timestamps, (2) bug fixes with usage of preset/human vs filename values (#1145)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from memgpt import create_client, Admin
|
||||
from memgpt.constants import DEFAULT_PRESET, DEFAULT_HUMAN, DEFAULT_PERSONA
|
||||
from memgpt.utils import get_human_text, get_persona_text
|
||||
|
||||
|
||||
"""
|
||||
@@ -23,7 +24,12 @@ def main():
|
||||
client = create_client(base_url="http://localhost:8283", token=token)
|
||||
|
||||
# Create an agent
|
||||
agent_info = client.create_agent(name="my_agent", preset=DEFAULT_PRESET, persona=DEFAULT_PERSONA, human=DEFAULT_HUMAN)
|
||||
agent_info = client.create_agent(
|
||||
name="my_agent",
|
||||
preset=DEFAULT_PRESET,
|
||||
persona_name=get_persona_text(DEFAULT_PERSONA),
|
||||
human_name=get_human_text(DEFAULT_HUMAN),
|
||||
)
|
||||
print(f"Created agent: {agent_info.name} with ID {str(agent_info.id)}")
|
||||
|
||||
# Send a message to the agent
|
||||
|
||||
@@ -199,8 +199,8 @@ class Agent(object):
|
||||
init_agent_state = AgentState(
|
||||
name=name if name else create_random_username(),
|
||||
user_id=created_by,
|
||||
persona=preset.persona_name,
|
||||
human=preset.human_name,
|
||||
persona=preset.persona,
|
||||
human=preset.human,
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
preset=preset.name, # TODO link via preset.id instead of name?
|
||||
@@ -609,47 +609,77 @@ class Agent(object):
|
||||
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
|
||||
skip_verify: bool = False,
|
||||
return_dicts: bool = True, # if True, return dicts, if False, return Message objects
|
||||
recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
|
||||
) -> Tuple[List[Union[dict, Message]], bool, bool, bool]:
|
||||
"""Top-level event message handler for the MemGPT agent"""
|
||||
|
||||
def strip_name_field_from_user_message(user_message_text: str) -> Tuple[str, Optional[str]]:
|
||||
"""If 'name' exists in the JSON string, remove it and return the cleaned text + name value"""
|
||||
try:
|
||||
user_message_json = dict(json.loads(user_message_text, strict=JSON_LOADS_STRICT))
|
||||
# Special handling for AutoGen messages with 'name' field
|
||||
# Treat 'name' as a special field
|
||||
# If it exists in the input message, elevate it to the 'message' level
|
||||
name = user_message_json.pop("name", None)
|
||||
clean_message = json.dumps(user_message_json, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
|
||||
except Exception as e:
|
||||
print(f"{CLI_WARNING_PREFIX}handling of 'name' field failed with: {e}")
|
||||
|
||||
return clean_message, name
|
||||
|
||||
def validate_json(user_message_text: str, raise_on_error: bool) -> str:
|
||||
try:
|
||||
user_message_json = dict(json.loads(user_message_text, strict=JSON_LOADS_STRICT))
|
||||
user_message_json_val = json.dumps(user_message_json, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return user_message_json_val
|
||||
except Exception as e:
|
||||
print(f"{CLI_WARNING_PREFIX}couldn't parse user input message as JSON: {e}")
|
||||
if raise_on_error:
|
||||
raise e
|
||||
|
||||
try:
|
||||
# Step 0: add user message
|
||||
if user_message is not None:
|
||||
if isinstance(user_message, Message):
|
||||
user_message_text = user_message.text
|
||||
# Validate JSON via save/load
|
||||
user_message_text = validate_json(user_message.text, False)
|
||||
cleaned_user_message_text, name = strip_name_field_from_user_message(user_message_text)
|
||||
|
||||
if name is not None:
|
||||
# Update Message object
|
||||
user_message.text = cleaned_user_message_text
|
||||
user_message.name = name
|
||||
|
||||
# Recreate timestamp
|
||||
if recreate_message_timestamp:
|
||||
user_message.created_at = datetime.datetime.now()
|
||||
|
||||
elif isinstance(user_message, str):
|
||||
user_message_text = user_message
|
||||
# Validate JSON via save/load
|
||||
user_message = validate_json(user_message, False)
|
||||
cleaned_user_message_text, name = strip_name_field_from_user_message(user_message)
|
||||
|
||||
# If user_message['name'] is not None, it will be handled properly by dict_to_message
|
||||
# So no need to run strip_name_field_from_user_message
|
||||
|
||||
# Create the associated Message object (in the database)
|
||||
user_message = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
model=self.model,
|
||||
openai_message_dict={"role": "user", "content": cleaned_user_message_text, "name": name},
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Bad type for user_message: {type(user_message)}")
|
||||
|
||||
packed_user_message = {"role": "user", "content": user_message_text}
|
||||
# Special handling for AutoGen messages with 'name' field
|
||||
try:
|
||||
user_message_json = json.loads(user_message_text, strict=JSON_LOADS_STRICT)
|
||||
# Special handling for AutoGen messages with 'name' field
|
||||
# Treat 'name' as a special field
|
||||
# If it exists in the input message, elevate it to the 'message' level
|
||||
if "name" in user_message_json:
|
||||
packed_user_message["name"] = user_message_json["name"]
|
||||
user_message_json.pop("name", None)
|
||||
packed_user_message["content"] = json.dumps(user_message_json, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
except Exception as e:
|
||||
print(f"{CLI_WARNING_PREFIX}handling of 'name' field failed with: {e}")
|
||||
self.interface.user_message(user_message.text, msg_obj=user_message)
|
||||
|
||||
# Create the associated Message object (in the database)
|
||||
packed_user_message_obj = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
model=self.model,
|
||||
openai_message_dict=packed_user_message,
|
||||
)
|
||||
self.interface.user_message(user_message_text, msg_obj=packed_user_message_obj)
|
||||
|
||||
input_message_sequence = self.messages + [packed_user_message]
|
||||
input_message_sequence = self.messages + [user_message.to_openai_dict()]
|
||||
# Alternatively, the requestor can send an empty user message
|
||||
else:
|
||||
input_message_sequence = self.messages
|
||||
packed_user_message = None
|
||||
|
||||
if len(input_message_sequence) > 1 and input_message_sequence[-1]["role"] != "user":
|
||||
printd(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue")
|
||||
@@ -698,14 +728,7 @@ class Agent(object):
|
||||
if isinstance(user_message, Message):
|
||||
all_new_messages = [user_message] + all_response_messages
|
||||
else:
|
||||
all_new_messages = [
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
model=self.model,
|
||||
openai_message_dict=packed_user_message,
|
||||
)
|
||||
] + all_response_messages
|
||||
raise ValueError(type(user_message))
|
||||
else:
|
||||
all_new_messages = all_response_messages
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Annotated
|
||||
|
||||
from memgpt import create_client
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.utils import get_human_text, get_persona_text
|
||||
|
||||
# from memgpt.agent import Agent
|
||||
from memgpt.errors import LLMJSONParsingError
|
||||
@@ -55,7 +56,11 @@ def bench(
|
||||
bench_id = uuid.uuid4()
|
||||
|
||||
for i in range(n_tries):
|
||||
agent = client.create_agent(name=f"benchmark_{bench_id}_agent_{i}", persona=PERSONA, human=HUMAN)
|
||||
agent = client.create_agent(
|
||||
name=f"benchmark_{bench_id}_agent_{i}",
|
||||
persona=get_persona_text(PERSONA),
|
||||
human=get_human_text(HUMAN),
|
||||
)
|
||||
|
||||
agent_id = agent.id
|
||||
result, msg = send_message(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import datetime
|
||||
import requests
|
||||
from requests.exceptions import RequestException
|
||||
import uuid
|
||||
from typing import Dict, List, Union, Optional, Tuple
|
||||
|
||||
|
||||
@@ -476,6 +476,8 @@ class AgentState:
|
||||
self.name = name
|
||||
self.user_id = user_id
|
||||
self.preset = preset
|
||||
# The INITIAL values of the persona and human
|
||||
# The values inside self.state['persona'], self.state['human'] are the CURRENT values
|
||||
self.persona = persona
|
||||
self.human = human
|
||||
|
||||
|
||||
@@ -324,8 +324,10 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
|
||||
id=agent_id,
|
||||
name=agent_config["name"],
|
||||
user_id=user.id,
|
||||
persona=agent_config["persona"], # eg 'sam_pov'
|
||||
human=agent_config["human"], # eg 'basic'
|
||||
# persona_name=agent_config["persona"], # eg 'sam_pov'
|
||||
# human_name=agent_config["human"], # eg 'basic'
|
||||
persona=state_dict["memory"]["persona"], # NOTE: hacky (not init, but latest)
|
||||
human=state_dict["memory"]["human"], # NOTE: hacky (not init, but latest)
|
||||
preset=agent_config["preset"], # eg 'memgpt_chat'
|
||||
state=dict(
|
||||
human=state_dict["memory"]["human"],
|
||||
|
||||
@@ -54,7 +54,7 @@ class ToolModel(BaseModel):
|
||||
class AgentStateModel(BaseModel):
|
||||
id: uuid.UUID = Field(..., description="The unique identifier of the agent.")
|
||||
name: str = Field(..., description="The name of the agent.")
|
||||
description: str = Field(None, description="The description of the agent.")
|
||||
description: Optional[str] = Field(None, description="The description of the agent.")
|
||||
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the agent.")
|
||||
|
||||
# timestamps
|
||||
|
||||
@@ -9,13 +9,15 @@ from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.data_types import AgentState
|
||||
from memgpt.models.pydantic_models import LLMConfigModel, EmbeddingConfigModel, AgentStateModel
|
||||
from memgpt.models.pydantic_models import LLMConfigModel, EmbeddingConfigModel, AgentStateModel, PresetModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ListAgentsResponse(BaseModel):
|
||||
num_agents: int = Field(..., description="The number of agents available to the user.")
|
||||
# TODO make return type List[AgentStateModel]
|
||||
# also return - presets: List[PresetModel]
|
||||
agents: List[dict] = Field(..., description="List of agent configurations.")
|
||||
|
||||
|
||||
@@ -25,6 +27,7 @@ class CreateAgentRequest(BaseModel):
|
||||
|
||||
class CreateAgentResponse(BaseModel):
|
||||
agent_state: AgentStateModel = Field(..., description="The state of the newly created agent.")
|
||||
preset: PresetModel = Field(..., description="The preset that the agent was created from.")
|
||||
|
||||
|
||||
def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
@@ -62,8 +65,8 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
|
||||
# TODO turn into a pydantic model
|
||||
name=request.config["name"],
|
||||
preset=request.config["preset"] if "preset" in request.config else None,
|
||||
# persona_name=request.config["persona_name"] if "persona_text" in request.config else None,
|
||||
# human_name=request.config["human_name"] if "human_text" in request.config else None,
|
||||
# persona_name=request.config["persona_name"] if "persona_name" in request.config else None,
|
||||
# human_name=request.config["human_name"] if "human_name" in request.config else None,
|
||||
persona=request.config["persona"] if "persona" in request.config else None,
|
||||
human=request.config["human"] if "human" in request.config else None,
|
||||
# llm_config=LLMConfigModel(
|
||||
@@ -76,6 +79,10 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
|
||||
raise
|
||||
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
|
||||
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
|
||||
|
||||
# TODO when get_preset returns a PresetModel instead of Preset, we can remove this packing/unpacking line
|
||||
preset = server.ms.get_preset(name=agent_state.preset, user_id=user_id)
|
||||
|
||||
return CreateAgentResponse(
|
||||
agent_state=AgentStateModel(
|
||||
id=agent_state.id,
|
||||
@@ -89,11 +96,19 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
|
||||
state=agent_state.state,
|
||||
created_at=int(agent_state.created_at.timestamp()),
|
||||
functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead
|
||||
)
|
||||
),
|
||||
preset=PresetModel(
|
||||
name=preset.name,
|
||||
id=preset.id,
|
||||
user_id=preset.user_id,
|
||||
description=preset.description,
|
||||
created_at=preset.created_at,
|
||||
system=preset.system,
|
||||
persona=preset.persona,
|
||||
human=preset.human,
|
||||
functions_schema=preset.functions_schema,
|
||||
),
|
||||
)
|
||||
# return CreateAgentResponse(
|
||||
# agent_state=AgentStateModel(
|
||||
# )
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -16,6 +16,7 @@ import memgpt.server.utils as server_utils
|
||||
import memgpt.system as system
|
||||
from memgpt.agent import Agent, save_agent
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.utils import get_human_text, get_persona_text
|
||||
|
||||
# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
|
||||
from memgpt.cli.cli_config import get_model_options
|
||||
@@ -572,8 +573,8 @@ class SyncServer(LockingServer):
|
||||
user_id: uuid.UUID,
|
||||
name: Optional[str] = None,
|
||||
preset: Optional[str] = None,
|
||||
persona: Optional[str] = None,
|
||||
human: Optional[str] = None,
|
||||
persona: Optional[str] = None, # NOTE: this is not the name, it's the memory init value
|
||||
human: Optional[str] = None, # NOTE: this is not the name, it's the memory init value
|
||||
llm_config: Optional[LLMConfig] = None,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
interface: Union[AgentInterface, None] = None,
|
||||
@@ -600,14 +601,34 @@ class SyncServer(LockingServer):
|
||||
# TODO: fix this db dependency and remove
|
||||
# self.ms.create_agent(agent_state)
|
||||
|
||||
# TODO modify to do creation via preset
|
||||
try:
|
||||
preset_obj = self.ms.get_preset(name=preset if preset else self.config.preset, user_id=user_id)
|
||||
assert preset_obj is not None, f"preset {preset if preset else self.config.preset} does not exist"
|
||||
logger.debug(f"Attempting to create agent from preset:\n{preset_obj}")
|
||||
|
||||
# Overwrite fields in the preset if they were specified
|
||||
preset_obj.human = human if human else self.config.human
|
||||
preset_obj.persona = persona if persona else self.config.persona
|
||||
if human is not None:
|
||||
preset_obj.human = human
|
||||
# This is a check for a common bug where users were providing filenames instead of values
|
||||
try:
|
||||
get_human_text(human)
|
||||
raise ValueError(human)
|
||||
raise UserWarning(
|
||||
f"It looks like there is a human file named {human} - did you mean to pass the file contents to the `human` arg?"
|
||||
)
|
||||
except:
|
||||
pass
|
||||
if persona is not None:
|
||||
preset_obj.persona = persona
|
||||
try:
|
||||
get_persona_text(persona)
|
||||
raise ValueError(persona)
|
||||
raise UserWarning(
|
||||
f"It looks like there is a persona file named {persona} - did you mean to pass the file contents to the `persona` arg?"
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
llm_config = llm_config if llm_config else self.server_llm_config
|
||||
embedding_config = embedding_config if embedding_config else self.server_embedding_config
|
||||
@@ -694,6 +715,7 @@ class SyncServer(LockingServer):
|
||||
}
|
||||
return agent_config
|
||||
|
||||
# TODO make return type pydantic
|
||||
def list_agents(self, user_id: uuid.UUID) -> dict:
|
||||
"""List all available agents to a user"""
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
@@ -711,6 +733,12 @@ class SyncServer(LockingServer):
|
||||
# Get the agent object (loaded in memory)
|
||||
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)
|
||||
|
||||
# TODO remove this eventually when return type get pydanticfied
|
||||
# this is to add persona_name and human_name so that the columns in UI can populate
|
||||
preset = self.ms.get_preset(name=agent_state.preset, user_id=user_id)
|
||||
return_dict["persona_name"] = preset.persona_name
|
||||
return_dict["human_name"] = preset.human_name
|
||||
|
||||
# Add information about tools
|
||||
# TODO memgpt_agent should really have a field of List[ToolModel]
|
||||
# then we could just pull that field and return it here
|
||||
|
||||
@@ -43,8 +43,7 @@ def agent():
|
||||
client.server.create_user({"id": user_id})
|
||||
|
||||
agent_state = client.create_agent(
|
||||
persona=constants.DEFAULT_PERSONA,
|
||||
human=constants.DEFAULT_HUMAN,
|
||||
preset=constants.DEFAULT_PRESET,
|
||||
)
|
||||
|
||||
return client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)
|
||||
|
||||
@@ -25,8 +25,7 @@ def create_test_agent():
|
||||
client = create_client()
|
||||
|
||||
agent_state = client.create_agent(
|
||||
persona=constants.DEFAULT_PERSONA,
|
||||
human=constants.DEFAULT_HUMAN,
|
||||
preset=constants.DEFAULT_PRESET,
|
||||
)
|
||||
|
||||
global agent_obj
|
||||
|
||||
@@ -12,6 +12,7 @@ from memgpt.cli.cli_load import load_directory
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.data_types import User, AgentState, EmbeddingConfig
|
||||
from memgpt.utils import get_human_text, get_persona_text
|
||||
from tests import TEST_MEMGPT_CONFIG
|
||||
from .utils import wipe_config
|
||||
|
||||
@@ -113,8 +114,8 @@ def test_load_directory(
|
||||
user_id=user.id,
|
||||
name="test_agent",
|
||||
preset=TEST_MEMGPT_CONFIG.preset,
|
||||
persona=TEST_MEMGPT_CONFIG.persona,
|
||||
human=TEST_MEMGPT_CONFIG.human,
|
||||
persona=get_persona_text(TEST_MEMGPT_CONFIG.persona),
|
||||
human=get_human_text(TEST_MEMGPT_CONFIG.human),
|
||||
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
|
||||
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
|
||||
)
|
||||
|
||||
@@ -49,8 +49,8 @@ def test_storage(storage_connector):
|
||||
user_id=user_1.id,
|
||||
name="agent_1",
|
||||
preset=DEFAULT_PRESET,
|
||||
persona=DEFAULT_PERSONA,
|
||||
human=DEFAULT_HUMAN,
|
||||
persona=get_persona_text(DEFAULT_PERSONA),
|
||||
human=get_human_text(DEFAULT_HUMAN),
|
||||
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
|
||||
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
|
||||
)
|
||||
|
||||
@@ -99,8 +99,6 @@ def agent_id(server, user_id):
|
||||
user_id=user_id,
|
||||
name="test_agent",
|
||||
preset="memgpt_chat",
|
||||
human="cs_phd",
|
||||
persona="sam_pov",
|
||||
)
|
||||
print(f"Created agent\n{agent_state}")
|
||||
yield agent_state.id
|
||||
|
||||
@@ -11,6 +11,7 @@ from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.data_types import User
|
||||
from memgpt.constants import MAX_EMBEDDING_DIM
|
||||
from memgpt.utils import get_human_text, get_persona_text
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
@@ -175,8 +176,10 @@ def test_storage(
|
||||
name="agent_1",
|
||||
id=agent_1_id,
|
||||
preset=TEST_MEMGPT_CONFIG.preset,
|
||||
persona=TEST_MEMGPT_CONFIG.persona,
|
||||
human=TEST_MEMGPT_CONFIG.human,
|
||||
# persona_name=TEST_MEMGPT_CONFIG.persona,
|
||||
# human_name=TEST_MEMGPT_CONFIG.human,
|
||||
persona=get_persona_text(TEST_MEMGPT_CONFIG.persona),
|
||||
human=get_human_text(TEST_MEMGPT_CONFIG.human),
|
||||
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
|
||||
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
|
||||
)
|
||||
|
||||
@@ -24,8 +24,6 @@ def create_test_agent():
|
||||
client = create_client()
|
||||
agent_state = client.create_agent(
|
||||
name=test_agent_name,
|
||||
persona=constants.DEFAULT_PERSONA,
|
||||
human=constants.DEFAULT_HUMAN,
|
||||
)
|
||||
|
||||
global agent_obj
|
||||
|
||||
Reference in New Issue
Block a user