diff --git a/memgpt/agent.py b/memgpt/agent.py index 57aa22f2..752d3ba1 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -21,14 +21,7 @@ from memgpt.constants import ( MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC, MESSAGE_SUMMARY_WARNING_FRAC, ) -from memgpt.data_types import ( - AgentState, - EmbeddingConfig, - LLMConfig, - Message, - Passage, - Preset, -) +from memgpt.data_types import AgentState, EmbeddingConfig, Message, Passage from memgpt.interface import AgentInterface from memgpt.llm_api.llm_api_tools import create, is_context_overflow_error from memgpt.memory import ArchivalMemory @@ -36,6 +29,7 @@ from memgpt.memory import CoreMemory as InContextMemory from memgpt.memory import RecallMemory, summarize_messages from memgpt.metadata import MetadataStore from memgpt.models import chat_completion_response +from memgpt.models.pydantic_models import ToolModel from memgpt.persistence_manager import LocalStateManager from memgpt.system import ( get_initial_boot_messages, @@ -45,7 +39,6 @@ from memgpt.system import ( ) from memgpt.utils import ( count_tokens, - create_random_username, create_uuid_from_string, get_local_time, get_schema_diff, @@ -200,76 +193,44 @@ class Agent(object): self, interface: AgentInterface, # agents can be created from providing agent_state - agent_state: Optional[AgentState] = None, - # or from providing a preset (requires preset + extra fields) - preset: Optional[Preset] = None, - created_by: Optional[uuid.UUID] = None, - name: Optional[str] = None, - llm_config: Optional[LLMConfig] = None, - embedding_config: Optional[EmbeddingConfig] = None, + agent_state: AgentState, + tools: List[ToolModel], # extras messages_total: Optional[int] = None, # TODO remove? first_message_verify_mono: bool = True, # TODO move to config? ): - # An agent can be created from a Preset object - if preset is not None: - assert agent_state is None, "Can create an agent from a Preset or AgentState (but both were provided)" - assert created_by is not None, "Must provide created_by field when creating an Agent from a Preset" - assert llm_config is not None, "Must provide llm_config field when creating an Agent from a Preset" - assert embedding_config is not None, "Must provide embedding_config field when creating an Agent from a Preset" - # if agent_state is also provided, override any preset values - init_agent_state = AgentState( - name=name if name else create_random_username(), - user_id=created_by, - persona=preset.persona, - human=preset.human, - llm_config=llm_config, - embedding_config=embedding_config, - preset=preset.name, # TODO link via preset.id instead of name? - state={ - "persona": preset.persona, - "human": preset.human, - "system": preset.system, - "functions": preset.functions_schema, - "messages": None, - }, - ) - - # An agent can also be created directly from AgentState - elif agent_state is not None: - assert preset is None, "Can create an agent from a Preset or AgentState (but both were provided)" - assert agent_state.state is not None and agent_state.state != {}, "AgentState.state cannot be empty" - - # Assume the agent_state passed in is formatted correctly - init_agent_state = agent_state - - else: - raise ValueError("Both Preset and AgentState were null (must provide one or the other)") + # tools + for tool in tools: + assert tool.name in agent_state.tools, f"Tool {tool} not found in agent_state.tools" + for tool_name in agent_state.tools: + assert tool_name in [tool.name for tool in tools], f"Tool name {tool_name} not included in agent tool list" + # Store the functions schemas (this is passed as an argument to ChatCompletion) + self.functions = [] + self.functions_python = {} + env = {} + env.update(globals()) + for tool in tools: + # WARNING: name may not be consistent? + exec(tool.module, env) + self.functions_python[tool.name] = env[tool.name] + self.functions.append(tool.json_schema) + assert all([callable(f) for k, f in self.functions_python.items()]), self.functions_python # Hold a copy of the state that was used to init the agent - self.agent_state = init_agent_state + self.agent_state = agent_state # gpt-4, gpt-3.5-turbo, ... self.model = self.agent_state.llm_config.model # Store the system instructions (used to rebuild memory) - if "system" not in self.agent_state.state: - raise ValueError("'system' not found in provided AgentState") - self.system = self.agent_state.state["system"] - - if "functions" not in self.agent_state.state: - raise ValueError(f"'functions' not found in provided AgentState") - # Store the functions schemas (this is passed as an argument to ChatCompletion) - self.functions = self.agent_state.state["functions"] # these are the schema - # Link the actual python functions corresponding to the schemas - self.functions_python = {k: v["python_function"] for k, v in link_functions(function_schemas=self.functions).items()} - assert all([callable(f) for k, f in self.functions_python.items()]), self.functions_python + self.system = self.agent_state.system # Initialize the memory object - if "persona" not in self.agent_state.state: + # TODO: support more general memory types + if "persona" not in self.agent_state.state: # TODO: remove raise ValueError(f"'persona' not found in provided AgentState") - if "human" not in self.agent_state.state: + if "human" not in self.agent_state.state: # TODO: remove raise ValueError(f"'human' not found in provided AgentState") self.memory = initialize_memory(ai_notes=self.agent_state.state["persona"], human_notes=self.agent_state.state["human"]) @@ -283,7 +244,6 @@ class Agent(object): self.interface = interface # Create the persistence manager object based on the AgentState info - # TODO self.persistence_manager = LocalStateManager(agent_state=self.agent_state) # State needed for heartbeat pausing @@ -1053,25 +1013,41 @@ class Agent(object): # self.ms.update_agent(agent=new_agent_state) def update_state(self) -> AgentState: - updated_state = { + # updated_state = { + # "persona": self.memory.persona, + # "human": self.memory.human, + # "system": self.system, + # "functions": self.functions, + # "messages": [str(msg.id) for msg in self._messages], + # } + memory = { + "system": self.system, "persona": self.memory.persona, "human": self.memory.human, - "system": self.system, - "functions": self.functions, - "messages": [str(msg.id) for msg in self._messages], + "messages": [str(msg.id) for msg in self._messages], # TODO: move out into AgentState.message_ids + } + + # TODO: add this field + metadata = { # TODO + "human_name": self.agent_state.persona, + "persona_name": self.agent_state.human, } self.agent_state = AgentState( name=self.agent_state.name, user_id=self.agent_state.user_id, - persona=self.agent_state.persona, - human=self.agent_state.human, + tools=self.agent_state.tools, + system=self.system, + persona=self.agent_state.persona, # TODO: remove + human=self.agent_state.human, # TODO: remove + ## "model_state" llm_config=self.agent_state.llm_config, embedding_config=self.agent_state.embedding_config, preset=self.agent_state.preset, id=self.agent_state.id, created_at=self.agent_state.created_at, - state=updated_state, + ## "agent_state" + state=memory, ) return self.agent_state diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index 6191049b..58ccab98 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -18,7 +18,7 @@ from memgpt.cli.cli_config import configure from memgpt.config import MemGPTConfig from memgpt.constants import CLI_WARNING_PREFIX, MEMGPT_DIR from memgpt.credentials import MemGPTCredentials -from memgpt.data_types import EmbeddingConfig, LLMConfig, User +from memgpt.data_types import AgentState, EmbeddingConfig, LLMConfig, User from memgpt.log import get_logger from memgpt.metadata import MetadataStore from memgpt.migrate import migrate_all_agents, migrate_all_sources @@ -582,9 +582,10 @@ def run( # Update the agent with any overrides ms.update_agent(agent_state) + tools = [ms.get_tool(tool_name) for tool_name in agent_state.tools] # create agent - memgpt_agent = Agent(agent_state=agent_state, interface=interface()) + memgpt_agent = Agent(agent_state=agent_state, interface=interface(), tools=tools) else: # create new agent # create new agent config: override defaults with args if provided @@ -650,13 +651,25 @@ def run( typer.secho(f"-> 🤖 Using persona profile: '{preset_obj.persona_name}'", fg=typer.colors.WHITE) typer.secho(f"-> 🧑 Using human profile: '{preset_obj.human_name}'", fg=typer.colors.WHITE) - memgpt_agent = Agent( - interface=interface(), + agent_state = AgentState( name=agent_name, - created_by=user.id, - preset=preset_obj, + user_id=user.id, + tools=list([schema["name"] for schema in preset_obj.functions_schema]), + system=preset_obj.system, llm_config=llm_config, embedding_config=embedding_config, + human=preset_obj.human, + persona=preset_obj.persona, + preset=preset_obj.name, + state={"messages": None, "persona": preset_obj.persona, "human": preset_obj.human}, + ) + print("tools", agent_state.tools) + tools = [ms.get_tool(tool_name) for tool_name in agent_state.tools] + + memgpt_agent = Agent( + interface=interface(), + agent_state=agent_state, + tools=tools, # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False, ) diff --git a/memgpt/client/admin.py b/memgpt/client/admin.py index a078de6d..d6c46948 100644 --- a/memgpt/client/admin.py +++ b/memgpt/client/admin.py @@ -13,7 +13,7 @@ from memgpt.server.rest_api.admin.users import ( GetAllUsersResponse, GetAPIKeysResponse, ) -from memgpt.server.rest_api.tools.index import ListToolsResponse +from memgpt.server.rest_api.tools.index import CreateToolRequest, ListToolsResponse class Admin: @@ -84,11 +84,40 @@ class Admin: self.delete_key(key) self.delete_user(user["user_id"]) - # tools (currently only available for admin) - def create_tool(self, name: str, file_path: str, source_type: Optional[str] = "python", tags: Optional[List[str]] = None) -> ToolModel: - """Add a tool implemented in a file path""" - source_code = open(file_path, "r", encoding="utf-8").read() - data = {"name": name, "source_code": source_code, "source_type": source_type, "tags": tags} + def create_tool( + self, + func, + name: Optional[str] = None, + update: Optional[bool] = True, # TODO: actually use this + tags: Optional[List[str]] = None, + ): + """Create a tool + + Args: + func (callable): The function to create a tool for. + tags (Optional[List[str]], optional): Tags for the tool. Defaults to None. + update (bool, optional): Update the tool if it already exists. Defaults to True. + + Returns: + Tool object + """ + import inspect + + from memgpt.functions.schema_generator import generate_schema + + # TODO: check if tool already exists + # TODO: how to load modules? + # parse source code/schema + source_code = inspect.getsource(func) + json_schema = generate_schema(func, name) + source_type = "python" + tool_name = json_schema["name"] + + # create data + data = {"name": tool_name, "source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema} + CreateToolRequest(**data) # validate data:w + + # make REST request response = requests.post(f"{self.base_url}/admin/tools", json=data, headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to create tool: {response.text}") @@ -96,7 +125,7 @@ class Admin: def list_tools(self) -> ListToolsResponse: response = requests.get(f"{self.base_url}/admin/tools", headers=self.headers) - return ListToolsResponse(**response.json()) + return ListToolsResponse(**response.json()).tools def delete_tool(self, name: str): response = requests.delete(f"{self.base_url}/admin/tools/{name}", headers=self.headers) diff --git a/memgpt/client/client.py b/memgpt/client/client.py index d6e90303..cf2fa99c 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Tuple, Union import requests from memgpt.config import MemGPTConfig -from memgpt.constants import DEFAULT_PRESET +from memgpt.constants import BASE_TOOLS, DEFAULT_PRESET from memgpt.data_sources.connectors import DataConnector from memgpt.data_types import ( AgentState, @@ -26,6 +26,7 @@ from memgpt.models.pydantic_models import ( SourceModel, ToolModel, ) +from memgpt.presets.presets import load_module_tools # import pydantic response objects from memgpt.server.rest_api from memgpt.server.rest_api.agents.command import CommandResponse @@ -52,7 +53,6 @@ from memgpt.server.rest_api.presets.index import ( ListPresetsResponse, ) from memgpt.server.rest_api.sources.index import ListSourcesResponse -from memgpt.server.rest_api.tools.index import CreateToolResponse from memgpt.server.server import SyncServer @@ -186,12 +186,6 @@ class AbstractClient(object): """List all tools.""" raise NotImplementedError - def create_tool( - self, name: str, file_path: str, source_type: Optional[str] = "python", tags: Optional[List[str]] = None - ) -> CreateToolResponse: - """Create a tool.""" - raise NotImplementedError - # data sources def list_sources(self): @@ -266,9 +260,32 @@ class RESTClient(AbstractClient): human: Optional[str] = None, embedding_config: Optional[EmbeddingConfig] = None, llm_config: Optional[LLMConfig] = None, + # tools + tools: Optional[List[str]] = None, + include_base_tools: Optional[bool] = True, ) -> AgentState: + """ + Create an agent + + Args: + name (str): Name of the agent + tools (List[str]): List of tools (by name) to attach to the agent + include_base_tools (bool): Whether to include base tools (default: `True`) + + Returns: + agent_state (AgentState): State of the the created agent. + + """ if embedding_config or llm_config: raise ValueError("Cannot override embedding_config or llm_config when creating agent via REST API") + + # construct list of tools + tool_names = [] + if tools: + tool_names += tools + if include_base_tools: + tool_names += BASE_TOOLS + # TODO: distinguish between name and objects payload = { "config": { @@ -276,6 +293,7 @@ class RESTClient(AbstractClient): "preset": preset, "persona": persona, "human": human, + "function_names": tool_names, } } response = requests.post(f"{self.base_url}/api/agents", json=payload, headers=self.headers) @@ -310,6 +328,8 @@ class RESTClient(AbstractClient): llm_config=llm_config, embedding_config=embedding_config, state=response.agent_state.state, + system=response.agent_state.system, + tools=response.agent_state.tools, # load datetime from timestampe created_at=datetime.datetime.fromtimestamp(response.agent_state.created_at, tz=datetime.timezone.utc), ) @@ -657,6 +677,8 @@ class LocalClient(AbstractClient): if name and self.agent_exists(agent_name=name): raise ValueError(f"Agent with name {name} already exists (user_id={self.user_id})") + # TODO: implement tools support + self.interface.clear() agent_state = self.server.create_agent( user_id=self.user_id, @@ -664,6 +686,7 @@ class LocalClient(AbstractClient): preset=preset, persona=persona, human=human, + tools=[tool.name for tool in load_module_tools()], ) return agent_state diff --git a/memgpt/config.py b/memgpt/config.py index a6fc4f7e..7e64a823 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -181,6 +181,7 @@ class MemGPTConfig: # create new config anon_clientid = MemGPTConfig.generate_uuid() config = cls(anon_clientid=anon_clientid, config_path=config_path) + config.create_config_dir() # create dirs return config diff --git a/memgpt/constants.py b/memgpt/constants.py index e6ece4ac..86f8f423 100644 --- a/memgpt/constants.py +++ b/memgpt/constants.py @@ -21,6 +21,18 @@ DEFAULT_PERSONA = "sam_pov" DEFAULT_HUMAN = "basic" DEFAULT_PRESET = "memgpt_chat" +# Tools +BASE_TOOLS = [ + "send_message", + "core_memory_replace", + "core_memory_append", + "pause_heartbeats", + "conversation_search", + "conversation_search_date", + "archival_memory_insert", + "archival_memory_search", +] + # LOGGER_LOG_LEVEL is use to convert Text to Logging level value for logging mostly for Cli input to setting level LOGGER_LOG_LEVELS = {"CRITICAL": CRITICAL, "ERROR": ERROR, "WARN": WARN, "WARNING": WARNING, "INFO": INFO, "DEBUG": DEBUG, "NOTSET": NOTSET} diff --git a/memgpt/data_types.py b/memgpt/data_types.py index 4d60e3ba..e2984cae 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -754,11 +754,16 @@ class AgentState: self, name: str, user_id: uuid.UUID, - persona: str, # the filename where the persona was originally sourced from - human: str, # the filename where the human was originally sourced from + # tools + tools: List[str], # list of tools by name + # system prompt + system: str, + # config + persona: str, # the filename where the persona was originally sourced from # TODO: remove + human: str, # the filename where the human was originally sourced from # TODO: remove llm_config: LLMConfig, embedding_config: EmbeddingConfig, - preset: str, + preset: str, # TODO: remove # (in-context) state contains: # persona: str # the current persona text # human: str # the current human text @@ -768,6 +773,8 @@ class AgentState: id: Optional[uuid.UUID] = None, state: Optional[dict] = None, created_at: Optional[datetime] = None, + # messages (TODO: implement this) + # _metadata: Optional[dict] = None, ): if id is None: self.id = uuid.uuid4() @@ -779,6 +786,7 @@ class AgentState: # TODO(swooders) we need to handle the case where name is None here # in AgentConfig we autogenerate a name, not sure what the correct thing w/ DBs is, what about NounAdjective combos? Like giphy does? BoredGiraffe etc self.name = name + assert self.name, f"AgentState name must be a non-empty string" self.user_id = user_id self.preset = preset # The INITIAL values of the persona and human @@ -794,6 +802,15 @@ class AgentState: # state self.state = {} if not state else state + # tools + self.tools = tools + + # system + self.system = system + + # metadata + # self._metadata = _metadata + class Source: def __init__( diff --git a/memgpt/functions/functions.py b/memgpt/functions/functions.py index c75c74a9..3d96bcb1 100644 --- a/memgpt/functions/functions.py +++ b/memgpt/functions/functions.py @@ -28,6 +28,7 @@ def load_function_set(module: ModuleType) -> dict: generated_schema = generate_schema(attr) function_dict[attr_name] = { + "module": inspect.getsource(module), "python_function": attr, "json_schema": generated_schema, } diff --git a/memgpt/functions/schema_generator.py b/memgpt/functions/schema_generator.py index 71f017a6..11ada7ab 100644 --- a/memgpt/functions/schema_generator.py +++ b/memgpt/functions/schema_generator.py @@ -1,6 +1,6 @@ import inspect import typing -from typing import get_args, get_origin +from typing import Optional, get_args, get_origin from docstring_parser import parse from pydantic import BaseModel @@ -83,7 +83,7 @@ def pydantic_model_to_open_ai(model): } -def generate_schema(function): +def generate_schema(function, name: Optional[str] = None, description: Optional[str] = None): # Get the signature of the function sig = inspect.signature(function) @@ -92,11 +92,13 @@ def generate_schema(function): # Prepare the schema dictionary schema = { - "name": function.__name__, - "description": docstring.short_description, + "name": function.__name__ if name is None else name, + "description": docstring.short_description if description is None else description, "parameters": {"type": "object", "properties": {}, "required": []}, } + # TODO: ensure that 'agent' keyword is reserved for `Agent` class + for param in sig.parameters.values(): # Exclude 'self' parameter if param.name == "self": diff --git a/memgpt/metadata.py b/memgpt/metadata.py index ddf226a6..33cf9b44 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -1,6 +1,5 @@ """ Metadata store for user/agent/data_source information""" -import inspect as python_inspect import os import secrets import traceback @@ -35,7 +34,6 @@ from memgpt.data_types import ( Token, User, ) -from memgpt.functions.functions import load_all_function_sets from memgpt.models.pydantic_models import ( HumanModel, JobModel, @@ -179,6 +177,7 @@ class AgentModel(Base): name = Column(String, nullable=False) persona = Column(String) human = Column(String) + system = Column(String) preset = Column(String) created_at = Column(DateTime(timezone=True), server_default=func.now()) @@ -189,6 +188,9 @@ class AgentModel(Base): # state state = Column(JSON) + # tools + tools = Column(JSON) + def __repr__(self) -> str: return f"" @@ -204,6 +206,8 @@ class AgentModel(Base): llm_config=self.llm_config, embedding_config=self.embedding_config, state=self.state, + tools=self.tools, + system=self.system, ) @@ -545,6 +549,13 @@ class MetadataStore: session.commit() session.refresh(persona) + @enforce_types + def update_tool(self, tool: ToolModel): + with self.session_maker() as session: + session.add(tool) + session.commit() + session.refresh(tool) + @enforce_types def delete_agent(self, agent_id: uuid.UUID): with self.session_maker() as session: @@ -595,20 +606,8 @@ class MetadataStore: # def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]: # TODO: add when users can creat tools def list_tools(self) -> List[ToolModel]: with self.session_maker() as session: - available_functions = load_all_function_sets() - results = [ - ToolModel( - name=k, - json_schema=v["json_schema"], - tags=v["tags"], - source_type="python", - source_code=python_inspect.getsource(v["python_function"]), - ) - for k, v in available_functions.items() - ] + results = session.query(ToolModel).all() return results - # results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all() - # return [r.to_record() for r in results] @enforce_types def list_agents(self, user_id: uuid.UUID) -> List[AgentState]: diff --git a/memgpt/models/pydantic_models.py b/memgpt/models/pydantic_models.py index 217d6db7..4a813a55 100644 --- a/memgpt/models/pydantic_models.py +++ b/memgpt/models/pydantic_models.py @@ -1,3 +1,4 @@ +# tool imports import uuid from datetime import datetime from enum import Enum @@ -48,11 +49,12 @@ class PresetModel(BaseModel): class ToolModel(SQLModel, table=True): # TODO move into database - name: str = Field(..., description="The name of the function.") - id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the function.", primary_key=True) + name: str = Field(..., description="The name of the function.", primary_key=True) + id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the function.") tags: List[str] = Field(sa_column=Column(JSON), description="Metadata tags.") source_type: Optional[str] = Field(None, description="The type of the source code.") source_code: Optional[str] = Field(..., description="The source code of the function.") + module: Optional[str] = Field(None, description="The module of the function.") json_schema: Dict = Field(default_factory=dict, sa_column=Column(JSON), description="The JSON schema of the function.") @@ -89,7 +91,9 @@ class AgentStateModel(BaseModel): preset: str = Field(..., description="The preset used by the agent.") persona: str = Field(..., description="The persona used by the agent.") human: str = Field(..., description="The human used by the agent.") - functions_schema: List[Dict] = Field(..., description="The functions schema used by the agent.") + tools: List[str] = Field(..., description="The tools used by the agent.") + system: str = Field(..., description="The system prompt used by the agent.") + # functions_schema: List[Dict] = Field(..., description="The functions schema used by the agent.") # llm information llm_config: LLMConfigModel = Field(..., description="The LLM configuration used by the agent.") diff --git a/memgpt/presets/presets.py b/memgpt/presets/presets.py index 8c0e252e..114e0d8a 100644 --- a/memgpt/presets/presets.py +++ b/memgpt/presets/presets.py @@ -1,4 +1,5 @@ import importlib +import inspect import os import uuid from typing import List @@ -23,8 +24,8 @@ available_presets = load_all_presets() preset_options = list(available_presets.keys()) -def add_default_tools(user_id: uuid.UUID, ms: MetadataStore): - module_name = "base" +def load_module_tools(module_name="base"): + # return List[ToolModel] from base.py tools full_module_name = f"memgpt.functions.function_sets.{module_name}" try: module = importlib.import_module(full_module_name) @@ -42,8 +43,29 @@ def add_default_tools(user_id: uuid.UUID, ms: MetadataStore): printd(err) # create tool in db + tools = [] for name, schema in functions_to_schema.items(): - ms.add_tool(ToolModel(name=name, tags=["base"], source_type="python", json_schema=schema["json_schema"])) + # print([str(inspect.getsource(line)) for line in schema["imports"]]) + source_code = inspect.getsource(schema["python_function"]) + tools.append( + ToolModel( + name=name, + tags=["base"], + source_type="python", + module=schema["module"], + source_code=source_code, + json_schema=schema["json_schema"], + ) + ) + return tools + + +def add_default_tools(user_id: uuid.UUID, ms: MetadataStore): + module_name = "base" + for tool in load_module_tools(module_name=module_name): + existing_tool = ms.get_tool(tool.name) + if not existing_tool: + ms.add_tool(tool) def add_default_humans_and_personas(user_id: uuid.UUID, ms: MetadataStore): diff --git a/memgpt/server/rest_api/admin/tools.py b/memgpt/server/rest_api/admin/tools.py index 35fc4c5d..f3f5dd9f 100644 --- a/memgpt/server/rest_api/admin/tools.py +++ b/memgpt/server/rest_api/admin/tools.py @@ -15,7 +15,7 @@ class ListToolsResponse(BaseModel): class CreateToolRequest(BaseModel): - name: str = Field(..., description="The name of the function.") + json_schema: dict = Field(..., description="JSON schema of the tool.") source_code: str = Field(..., description="The source code of the function.") source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.") tags: Optional[List[str]] = Field(None, description="Metadata tags.") @@ -74,28 +74,13 @@ def setup_tools_index_router(server: SyncServer, interface: QueuingInterface): # user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific ): """ - Create a new tool (dummy route) + Create a new tool """ - from memgpt.functions.functions import load_function_file, write_function - - # check if function already exists - if server.ms.get_tool(request.name): - raise ValueError(f"Tool with name {request.name} already exists.") - - # write function to ~/.memgt/functions directory - file_path = write_function(request.name, request.name, request.source_code) - - # TODO: Use load_function_file to load function schema - schema = load_function_file(file_path) - assert len(list(schema.keys())) == 1, "Function schema must have exactly one key" - json_schema = list(schema.values())[0]["json_schema"] - - print("adding tool", request.name, request.tags, request.source_code) - tool = ToolModel(name=request.name, json_schema=json_schema, tags=request.tags, source_code=request.source_code) - tool.id - server.ms.add_tool(tool) - - # TODO: insert tool information into DB as ToolModel - return server.ms.get_tool(request.name) + try: + return server.create_tool( + json_schema=request.json_schema, source_code=request.source_code, source_type=request.source_type, tags=request.tags + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}") return router diff --git a/memgpt/server/rest_api/agents/config.py b/memgpt/server/rest_api/agents/config.py index 3f67195d..2050404b 100644 --- a/memgpt/server/rest_api/agents/config.py +++ b/memgpt/server/rest_api/agents/config.py @@ -86,7 +86,8 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, embedding_config=embedding_config, state=agent_state.state, created_at=int(agent_state.created_at.timestamp()), - functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead + tools=agent_state.tools, + system=agent_state.system, ), last_run_at=None, # TODO sources=attached_sources, @@ -131,7 +132,8 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, embedding_config=embedding_config, state=agent_state.state, created_at=int(agent_state.created_at.timestamp()), - functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead + tools=agent_state.tools, + system=agent_state.system, ), last_run_at=None, # TODO sources=attached_sources, diff --git a/memgpt/server/rest_api/agents/index.py b/memgpt/server/rest_api/agents/index.py index 62fd341b..03bed270 100644 --- a/memgpt/server/rest_api/agents/index.py +++ b/memgpt/server/rest_api/agents/index.py @@ -14,6 +14,7 @@ from memgpt.models.pydantic_models import ( from memgpt.server.rest_api.auth_token import get_current_user from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.server import SyncServer +from memgpt.settings import settings router = APIRouter() @@ -26,6 +27,7 @@ class ListAgentsResponse(BaseModel): class CreateAgentRequest(BaseModel): + # TODO: modify this (along with front end) config: dict = Field(..., description="The agent configuration object.") @@ -60,27 +62,42 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p """ interface.clear() + # Parse request + # TODO: don't just use JSON in the future + human_name = request.config["human_name"] if "human_name" in request.config else None + human = request.config["human"] if "human" in request.config else None + persona_name = request.config["persona_name"] if "persona_name" in request.config else None + persona = request.config["persona"] if "persona" in request.config else None + preset = request.config["preset"] if ("preset" in request.config and request.config["preset"]) else settings.default_preset + tool_names = request.config["function_names"] + + print("PRESET", preset) + try: agent_state = server.create_agent( user_id=user_id, # **request.config # TODO turn into a pydantic model name=request.config["name"], - preset=request.config["preset"] if "preset" in request.config else None, - persona_name=request.config["persona_name"] if "persona_name" in request.config else None, - human_name=request.config["human_name"] if "human_name" in request.config else None, - persona=request.config["persona"] if "persona" in request.config else None, - human=request.config["human"] if "human" in request.config else None, + preset=preset, + persona_name=persona_name, + human_name=human_name, + persona=persona, + human=human, # llm_config=LLMConfigModel( # model=request.config['model'], # ) - function_names=request.config["function_names"].split(",") if "function_names" in request.config else None, + # tools + tools=tool_names, + # function_names=request.config["function_names"].split(",") if "function_names" in request.config else None, ) llm_config = LLMConfigModel(**vars(agent_state.llm_config)) embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config)) # TODO when get_preset returns a PresetModel instead of Preset, we can remove this packing/unpacking line + # TODO: remove preset = server.ms.get_preset(name=agent_state.preset, user_id=user_id) + print("SYSTEM", agent_state.system) return CreateAgentResponse( agent_state=AgentStateModel( @@ -94,7 +111,8 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p embedding_config=embedding_config, state=agent_state.state, created_at=int(agent_state.created_at.timestamp()), - functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead + tools=tool_names, + system=agent_state.system, ), preset=PresetModel( name=preset.name, diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 40c0f39c..4ab05b0f 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -46,6 +46,7 @@ from memgpt.models.pydantic_models import ( SourceModel, ToolModel, ) +from memgpt.utils import create_random_username logger = get_logger(__name__) @@ -332,7 +333,8 @@ class SyncServer(LockingServer): # Instantiate an agent object using the state retrieved logger.info(f"Creating an agent object") - memgpt_agent = Agent(agent_state=agent_state, interface=interface) + tool_objs = [self.ms.get_tool(name) for name in agent_state.tools] # get tool objects + memgpt_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs) # Add the agent to the in-memory store and return its reference logger.info(f"Adding agent to the agent cache: user_id={user_id}, agent_id={agent_id}") @@ -645,17 +647,22 @@ class SyncServer(LockingServer): def create_agent( self, user_id: uuid.UUID, + tools: List[str], # list of tool names (handles) to include + # system: str, # system prompt + metadata: Optional[dict] = {}, # includes human/persona names name: Optional[str] = None, - preset: Optional[str] = None, - persona: Optional[str] = None, # NOTE: this is not the name, it's the memory init value - human: Optional[str] = None, # NOTE: this is not the name, it's the memory init value - persona_name: Optional[str] = None, - human_name: Optional[str] = None, + preset: Optional[str] = None, # TODO: remove eventually + # model config llm_config: Optional[LLMConfig] = None, embedding_config: Optional[EmbeddingConfig] = None, + # interface interface: Union[AgentInterface, None] = None, - # persistence_manager: Union[PersistenceManager, None] = None, - function_names: Optional[List[str]] = None, # TODO remove + # TODO: refactor this to be a more general memory configuration + system: Optional[str] = None, # prompt value + persona: Optional[str] = None, # NOTE: this is not the name, it's the memory init value + human: Optional[str] = None, # NOTE: this is not the name, it's the memory init value + persona_name: Optional[str] = None, # TODO: remove + human_name: Optional[str] = None, # TODO: remove ) -> AgentState: """Create a new agent using a config""" if self.ms.get_user(user_id=user_id) is None: @@ -668,6 +675,9 @@ class SyncServer(LockingServer): # if persistence_manager is None: # persistence_manager = self.default_persistence_manager_cls(agent_config=agent_config) + if name is None: + name = create_random_username() + logger.debug(f"Attempting to find user: {user_id}") user = self.ms.get_user(user_id=user_id) if not user: @@ -684,6 +694,14 @@ class SyncServer(LockingServer): assert preset_obj is not None, f"preset {preset if preset else self.config.preset} does not exist" logger.debug(f"Attempting to create agent from preset:\n{preset_obj}") + # system prompt + if system is None: + system = preset_obj.system + else: + preset_obj.system = system + preset_override = True + print("system", preset_obj.system, system) + # Overwrite fields in the preset if they were specified if human is not None and human != preset_obj.human: preset_override = True @@ -718,15 +736,8 @@ class SyncServer(LockingServer): llm_config = llm_config if llm_config else self.server_llm_config embedding_config = embedding_config if embedding_config else self.server_embedding_config - # TODO remove (https://github.com/cpacker/MemGPT/issues/1138) - if function_names is not None: - preset_override = True - # available_tools = self.ms.list_tools(user_id=user_id) # TODO: add back when user-specific - available_tools = self.ms.list_tools() - available_tools_names = [t.name for t in available_tools] - assert all([f_name in available_tools_names for f_name in function_names]) - preset_obj.functions_schema = [t.json_schema for t in available_tools if t.name in function_names] - print("overriding preset_obj tools with:", preset_obj.functions_schema) + # get tools + tool_objs = [self.ms.get_tool(name) for name in tools] # If the user overrode any parts of the preset, we need to create a new preset to refer back to if preset_override: @@ -735,13 +746,24 @@ class SyncServer(LockingServer): # Then write out to the database for storage self.ms.create_preset(preset=preset_obj) - agent = Agent( - interface=interface, - preset=preset_obj, + agent_state = AgentState( name=name, - created_by=user.id, + user_id=user_id, + persona=preset_obj.persona_name, # TODO: remove + human=preset_obj.human_name, # TODO: remove + tools=tools, # name=id for tools llm_config=llm_config, embedding_config=embedding_config, + system=system, + preset=preset, # TODO: remove + state={"persona": preset_obj.persona, "human": preset_obj.human, "system": system, "messages": None}, + ) + + agent = Agent( + interface=interface, + agent_state=agent_state, + tools=tool_objs, + # embedding_config=embedding_config, # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False, ) @@ -852,11 +874,11 @@ class SyncServer(LockingServer): # TODO remove this eventually when return type get pydanticfied # this is to add persona_name and human_name so that the columns in UI can populate - preset = self.ms.get_preset(name=agent_state.preset, user_id=user_id) # TODO hack for frontend, remove # (top level .persona is persona_name, and nested memory.persona is the state) - return_dict["persona"] = preset.persona_name - return_dict["human"] = preset.human_name + # TODO: eventually modify this to be contained in the metadata + return_dict["persona"] = agent_state.human + return_dict["human"] = agent_state.persona # Add information about tools # TODO memgpt_agent should really have a field of List[ToolModel] @@ -1437,8 +1459,36 @@ class SyncServer(LockingServer): return sources_with_metadata - def create_tool(self, name: str, user_id: uuid.UUID) -> ToolModel: # TODO: add other fields - """Create a new tool""" + def create_tool( + self, json_schema: dict, source_code: str, source_type: str, tags: Optional[List[str]] = None, exists_ok: Optional[bool] = True + ) -> ToolModel: # TODO: add other fields + """Create a new tool - def delete_tool(self, tool_id: uuid.UUID, user_id: uuid.UUID): + Args: + TODO + + Returns: + tool (ToolModel): Tool object + """ + name = json_schema["name"] + tool = self.ms.get_tool(name) + if tool: # check if function already exists + if exists_ok: + # update existing tool + tool.json_schema = json_schema + tool.tags = tags + tool.source_code = source_code + tool.source_type = source_type + self.ms.update_tool(tool) + else: + raise ValueError(f"Tool with name {name} already exists.") + else: + # create new tool + tool = ToolModel(name=name, json_schema=json_schema, tags=tags, source_code=source_code, source_type=source_type) + self.ms.add_tool(tool) + + return self.ms.get_tool(name) + + def delete_tool(self, name: str): """Delete a tool""" + self.ms.delete_tool(name) diff --git a/memgpt/settings.py b/memgpt/settings.py index bd7ce616..745bca0c 100644 --- a/memgpt/settings.py +++ b/memgpt/settings.py @@ -19,6 +19,9 @@ class Settings(BaseSettings): pg_uri: Optional[str] = None # option to specifiy full uri cors_origins: Optional[list] = ["http://memgpt.localhost", "http://localhost:8283", "http://localhost:8083"] + # agent configuration defaults + default_preset: Optional[str] = "memgpt_chat" + @property def memgpt_pg_uri(self) -> str: if self.pg_uri: diff --git a/poetry.lock b/poetry.lock index 3f6b50b8..965a790a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -2999,6 +2999,7 @@ description = "Nvidia JIT LTO Library" optional = true python-versions = ">=3" files = [ + {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_aarch64.whl", hash = "sha256:004186d5ea6a57758fd6d57052a123c73a4815adf365eb8dd6a85c9eaa7535ff"}, {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"}, {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"}, ] @@ -3874,13 +3875,13 @@ files = [ [[package]] name = "pydantic" -version = "2.7.3" +version = "2.7.4" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic-2.7.3-py3-none-any.whl", hash = "sha256:ea91b002777bf643bb20dd717c028ec43216b24a6001a280f83877fd2655d0b4"}, - {file = "pydantic-2.7.3.tar.gz", hash = "sha256:c46c76a40bb1296728d7a8b99aa73dd70a48c3510111ff290034f860c99c419e"}, + {file = "pydantic-2.7.4-py3-none-any.whl", hash = "sha256:ee8538d41ccb9c0a9ad3e0e5f07bf15ed8015b481ced539a1759d8cc89ae90d0"}, + {file = "pydantic-2.7.4.tar.gz", hash = "sha256:0c84efd9548d545f63ac0060c1e4d39bb9b14db8b3c0652338aecc07b5adec52"}, ] [package.dependencies] @@ -6340,4 +6341,4 @@ server = ["fastapi", "uvicorn", "websockets"] [metadata] lock-version = "2.0" python-versions = "<3.13,>=3.10" -content-hash = "bfa14c084ae06f7d5ceb561406794d93f90808c20b098af13110f4ebe38c7928" +content-hash = "904e243813980d61b67db2d1dc96cc782299bcdfb80bb122e4f2013f11f2a9c4" diff --git a/pyproject.toml b/pyproject.toml index f1603249..4b88db03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ chromadb = "^0.5.0" sqlalchemy-json = "^0.7.0" fastapi = {version = "^0.104.1", optional = true} uvicorn = {version = "^0.24.0.post1", optional = true} -pydantic = "^2.5.2" +pydantic = "^2.7.4" pyautogen = {version = "0.2.22", optional = true} html2text = "^2020.1.16" docx2txt = "^0.8" diff --git a/tests/test_agent_function_update.py b/tests/test_agent_function_update.py index a4bf1cd4..b9ab60eb 100644 --- a/tests/test_agent_function_update.py +++ b/tests/test_agent_function_update.py @@ -40,9 +40,7 @@ def agent(): if not client.server.get_user(user_id=user_id): client.server.create_user({"id": user_id}) - agent_state = client.create_agent( - preset=constants.DEFAULT_PRESET, - ) + agent_state = client.create_agent() return client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id) diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index b7c52727..e5c49f6b 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -4,7 +4,7 @@ import uuid import pytest import memgpt.functions.function_sets.base as base_functions -from memgpt import constants, create_client +from memgpt import create_client from tests import TEST_MEMGPT_CONFIG from .utils import create_config, wipe_config @@ -25,9 +25,7 @@ def agent_obj(): client = create_client() - agent_state = client.create_agent( - preset=constants.DEFAULT_PRESET, - ) + agent_state = client.create_agent() global agent_obj user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid) diff --git a/tests/test_concurrent_connections.py b/tests/test_concurrent_connections.py index 5aecf275..060acfc0 100644 --- a/tests/test_concurrent_connections.py +++ b/tests/test_concurrent_connections.py @@ -110,9 +110,7 @@ def test_concurrent_messages(admin_client): response = admin_client.create_user() token = response.api_key client = create_client(base_url=admin_client.base_url, token=token) - agent = client.create_agent( - name=test_agent_name, - ) + agent = client.create_agent() print("Agent created", agent.id) diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index b48d30df..d60ed69a 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -3,11 +3,11 @@ import os import uuid from memgpt.agent import Agent -from memgpt.data_types import Message +from memgpt.data_types import AgentState, Message from memgpt.embeddings import embedding_model from memgpt.llm_api.llm_api_tools import create from memgpt.models.pydantic_models import EmbeddingConfigModel, LLMConfigModel -from memgpt.presets.presets import load_preset +from memgpt.presets.presets import load_module_tools from memgpt.prompts import gpt_system messages = [Message(role="system", text=gpt_system.get_system_text("memgpt_chat")), Message(role="user", text="How are you?")] @@ -26,13 +26,22 @@ def run_llm_endpoint(filename): print(config_data) llm_config = LLMConfigModel(**config_data) embedding_config = EmbeddingConfigModel(**json.load(open(embedding_config_path))) + agent_state = AgentState( + name="test_agent", + tools=[tool.name for tool in load_module_tools()], + system="", + persona="", + human="", + preset="memgpt_chat", + embedding_config=embedding_config, + llm_config=llm_config, + user_id=uuid.UUID(int=1), + state={"persona": "", "human": "", "messages": None}, + ) agent = Agent( interface=None, - preset=load_preset("memgpt_chat", user_id=uuid.UUID(int=1)), - name="test_agent", - created_by=uuid.UUID(int=1), - llm_config=llm_config, - embedding_config=embedding_config, + tools=load_module_tools(), + agent_state=agent_state, # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now first_message_verify_mono=True, ) diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py index 85b866fa..08df688b 100644 --- a/tests/test_load_archival.py +++ b/tests/test_load_archival.py @@ -8,13 +8,12 @@ from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.cli.cli_load import load_directory from memgpt.config import MemGPTConfig from memgpt.credentials import MemGPTCredentials -from memgpt.data_types import AgentState, EmbeddingConfig, User +from memgpt.data_types import EmbeddingConfig, User from memgpt.metadata import MetadataStore # from memgpt.data_sources.connectors import DirectoryConnector, load_data # import memgpt from memgpt.settings import settings -from memgpt.utils import get_human_text, get_persona_text from tests import TEST_MEMGPT_CONFIG from .utils import create_config, wipe_config, with_qdrant_storage @@ -131,15 +130,17 @@ def test_load_directory( # config.save() # create user and agent - agent = AgentState( - user_id=user.id, - name="test_agent", - preset=TEST_MEMGPT_CONFIG.preset, - persona=get_persona_text(TEST_MEMGPT_CONFIG.persona), - human=get_human_text(TEST_MEMGPT_CONFIG.human), - llm_config=TEST_MEMGPT_CONFIG.default_llm_config, - embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config, - ) + # agent = AgentState( + # user_id=user.id, + # name="test_agent", + # preset=TEST_MEMGPT_CONFIG.preset, + # persona=get_persona_text(TEST_MEMGPT_CONFIG.persona), + # human=get_human_text(TEST_MEMGPT_CONFIG.human), + # llm_config=TEST_MEMGPT_CONFIG.default_llm_config, + # embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config, + # tools=[], + # system="", + # ) ms.delete_user(user.id) ms.create_user(user) # ms.create_agent(agent) @@ -223,7 +224,6 @@ def test_load_directory( # cleanup ms.delete_user(user.id) - ms.delete_agent(agent.id) ms.delete_source(sources[0].id) # revert to openai config diff --git a/tests/test_migrate.py b/tests/test_migrate.py deleted file mode 100644 index b7e6c042..00000000 --- a/tests/test_migrate.py +++ /dev/null @@ -1,62 +0,0 @@ -import os -import shutil -import uuid - -from memgpt.migrate import migrate_all_agents -from memgpt.server.server import SyncServer - -from .utils import create_config, wipe_config - - -def test_migrate_0211(): - wipe_config() - if os.getenv("OPENAI_API_KEY"): - create_config("openai") - else: - create_config("memgpt_hosted") - - data_dir = "tests/data/memgpt-0.2.11" - tmp_dir = f"tmp_{str(uuid.uuid4())}" - shutil.copytree(data_dir, tmp_dir) - print("temporary directory:", tmp_dir) - # os.environ["MEMGPT_CONFIG_PATH"] = os.path.join(data_dir, "config") - # print(f"MEMGPT_CONFIG_PATH={os.environ['MEMGPT_CONFIG_PATH']}") - try: - agent_res = migrate_all_agents(tmp_dir, debug=True) - assert len(agent_res["failed_migrations"]) == 0, f"Failed migrations: {agent_res}" - - # NOTE: source tests had to be removed since it is no longer possible to migrate llama index vector indices - # source_res = migrate_all_sources(tmp_dir) - # assert len(source_res["failed_migrations"]) == 0, f"Failed migrations: {source_res}" - - # TODO: assert everything is in the DB - - server = SyncServer() - for agent_name in agent_res["migration_candidates"]: - if agent_name not in agent_res["failed_migrations"]: - # assert agent data exists - agent_state = server.ms.get_agent(agent_name=agent_name, user_id=agent_res["user_id"]) - assert agent_state is not None, f"Missing agent {agent_name}" - - # assert in context messages exist - message_ids = server.get_in_context_message_ids(user_id=agent_res["user_id"], agent_id=agent_state.id) - assert len(message_ids) > 0 - - # assert recall memories exist - messages = server.get_agent_messages( - user_id=agent_state.user_id, - agent_id=agent_state.id, - start=0, - count=1000, - ) - assert len(messages) > 0 - - # for source_name in source_res["migration_candidates"]: - # if source_name not in source_res["failed_migrations"]: - # # assert source data exists - # source = server.ms.get_source(source_name=source_name, user_id=source_res["user_id"]) - # assert source is not None - except Exception as e: - raise e - finally: - shutil.rmtree(tmp_dir) diff --git a/tests/test_new_cli.py b/tests/test_new_cli.py index dafe062d..714dda99 100644 --- a/tests/test_new_cli.py +++ b/tests/test_new_cli.py @@ -6,6 +6,26 @@ import unittest.mock import pytest from memgpt.cli.cli_config import add, delete, list +from memgpt.config import MemGPTConfig +from memgpt.credentials import MemGPTCredentials +from tests.utils import create_config + + +def _reset_config(): + + if os.getenv("OPENAI_API_KEY"): + create_config("openai") + credentials = MemGPTCredentials( + openai_key=os.getenv("OPENAI_API_KEY"), + ) + else: # hosted + create_config("memgpt_hosted") + credentials = MemGPTCredentials() + + config = MemGPTConfig.load() + config.save() + credentials.save() + print("_reset_config :: ", config.config_path) @pytest.mark.skip(reason="This is a helper function.") @@ -31,6 +51,7 @@ def reset_env_variables(server_url, token): def test_crud_human(capsys): + _reset_config() server_url, token = unset_env_variables() diff --git a/tests/test_server.py b/tests/test_server.py index 35763692..a49a57f4 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -9,6 +9,7 @@ import memgpt.utils as utils utils.DEBUG = True from memgpt.config import MemGPTConfig from memgpt.credentials import MemGPTCredentials +from memgpt.presets.presets import load_module_tools from memgpt.server.server import SyncServer from memgpt.settings import settings @@ -73,6 +74,7 @@ def agent_id(server, user_id): user_id=user_id, name="test_agent", preset="memgpt_chat", + tools=[tool.name for tool in load_module_tools()], ) print(f"Created agent\n{agent_state}") yield agent_state.id diff --git a/tests/test_storage.py b/tests/test_storage.py index c5aa870d..fe411ff6 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -192,12 +192,12 @@ def test_storage( human=get_human_text(TEST_MEMGPT_CONFIG.human), llm_config=TEST_MEMGPT_CONFIG.default_llm_config, embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config, + system="", + tools=[], state={ "persona": "", "human": "", - "system": "", - "functions": [], - "messages": [], + "messages": None, }, ) ms.create_user(user)