diff --git a/memgpt/client/admin.py b/memgpt/client/admin.py index d57cd4e0..2199da8a 100644 --- a/memgpt/client/admin.py +++ b/memgpt/client/admin.py @@ -1,9 +1,10 @@ import uuid -from typing import Optional +from typing import List, Optional import requests from requests import HTTPError +from memgpt.models.pydantic_models import ToolModel from memgpt.server.rest_api.admin.users import ( CreateAPIKeyResponse, CreateUserResponse, @@ -12,6 +13,7 @@ from memgpt.server.rest_api.admin.users import ( GetAllUsersResponse, GetAPIKeysResponse, ) +from memgpt.server.rest_api.tools.index import ListToolsResponse class Admin: @@ -46,7 +48,7 @@ class Admin: if response.status_code != 200: raise HTTPError(response.json()) print(response.text, response.status_code) - return GetAPIKeysResponse(**response.json()) + return GetAPIKeysResponse(**response.json()).api_key_list def delete_key(self, api_key: str): params = {"api_key": api_key} @@ -77,7 +79,35 @@ class Admin: # TODO: clear out all agents, presets, etc. users = self.get_users().user_list for user in users: - keys = self.get_keys(user["user_id"]).api_key_list + keys = self.get_keys(user["user_id"]) for key in keys: self.delete_key(key) self.delete_user(user["user_id"]) + + # tools (currently only available for admin) + def create_tool(self, name: str, file_path: str, source_type: Optional[str] = "python", tags: Optional[List[str]] = None) -> ToolModel: + """Add a tool implemented in a file path""" + source_code = open(file_path, "r").read() + data = {"name": name, "source_code": source_code, "source_type": source_type, "tags": tags} + response = requests.post(f"{self.base_url}/api/tools", json=data, headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to create tool: {response.text}") + return ToolModel(**response.json()) + + def list_tools(self) -> ListToolsResponse: + response = requests.get(f"{self.base_url}/api/tools", headers=self.headers) + return ListToolsResponse(**response.json()) + + def delete_tool(self, name: str): + response = requests.delete(f"{self.base_url}/api/tools/{name}", headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to delete tool: {response.text}") + return response.json() + + def get_tool(self, name: str): + response = requests.get(f"{self.base_url}/api/tools/{name}", headers=self.headers) + if response.status_code == 404: + return None + elif response.status_code != 200: + raise ValueError(f"Failed to get tool: {response.text}") + return ToolModel(**response.json()) diff --git a/memgpt/client/client.py b/memgpt/client/client.py index f3e74bbb..7ef05f95 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Tuple, Union import requests from memgpt.config import MemGPTConfig +from memgpt.constants import DEFAULT_PRESET from memgpt.data_sources.connectors import DataConnector from memgpt.data_types import ( AgentState, @@ -23,6 +24,7 @@ from memgpt.models.pydantic_models import ( PersonaModel, PresetModel, SourceModel, + ToolModel, ) # import pydantic response objects from memgpt.server.rest_api @@ -50,7 +52,7 @@ from memgpt.server.rest_api.presets.index import ( ListPresetsResponse, ) from memgpt.server.rest_api.sources.index import ListSourcesResponse -from memgpt.server.rest_api.tools.index import CreateToolResponse, ListToolsResponse +from memgpt.server.rest_api.tools.index import CreateToolResponse from memgpt.server.server import SyncServer @@ -259,7 +261,7 @@ class RESTClient(AbstractClient): def create_agent( self, name: Optional[str] = None, - preset: Optional[str] = None, + preset: Optional[str] = None, # TODO: this should actually be re-named preset_name persona: Optional[str] = None, human: Optional[str] = None, embedding_config: Optional[EmbeddingConfig] = None, @@ -267,6 +269,7 @@ class RESTClient(AbstractClient): ) -> AgentState: if embedding_config or llm_config: raise ValueError("Cannot override embedding_config or llm_config when creating agent via REST API") + # TODO: distinguish between name and objects payload = { "config": { "name": name, @@ -329,23 +332,87 @@ class RESTClient(AbstractClient): response_obj = GetAgentResponse(**response.json()) return self.get_agent_response_to_state(response_obj) - # presets - def create_preset(self, preset: Preset) -> CreatePresetResponse: - # TODO should the arg type here be PresetModel, not Preset? + ## presets + # def create_preset(self, preset: Preset) -> CreatePresetResponse: + # # TODO should the arg type here be PresetModel, not Preset? + # payload = CreatePresetsRequest( + # id=str(preset.id), + # name=preset.name, + # description=preset.description, + # system=preset.system, + # persona=preset.persona, + # human=preset.human, + # persona_name=preset.persona_name, + # human_name=preset.human_name, + # functions_schema=preset.functions_schema, + # ) + # response = requests.post(f"{self.base_url}/api/presets", json=payload.model_dump(), headers=self.headers) + # assert response.status_code == 200, f"Failed to create preset: {response.text}" + # return CreatePresetResponse(**response.json()) + + def get_preset(self, name: str) -> PresetModel: + response = requests.get(f"{self.base_url}/api/presets/{name}", headers=self.headers) + assert response.status_code == 200, f"Failed to get preset: {response.text}" + return PresetModel(**response.json()) + + def create_preset( + self, + name: str, + description: Optional[str] = None, + system_name: Optional[str] = None, + persona_name: Optional[str] = None, + human_name: Optional[str] = None, + tools: Optional[List[ToolModel]] = None, + default_tools: bool = True, + ) -> PresetModel: + """Create an agent preset + + :param name: Name of the preset + :type name: str + :param system: System prompt (text) + :type system: str + :param persona: Persona prompt (text) + :type persona: Optional[str] + :param human: Human prompt (text) + :type human: Optional[str] + :param tools: List of tools to connect, defaults to None + :type tools: Optional[List[Tool]], optional + :param default_tools: Whether to automatically include default tools, defaults to True + :type default_tools: bool, optional + :return: Preset object + :rtype: PresetModel + """ + # provided tools + schema = [] + if tools: + for tool in tools: + print("CUSOTM TOOL", tool.json_schema) + schema.append(tool.json_schema) + + # include default tools + default_preset = self.get_preset(name=DEFAULT_PRESET) + if default_tools: + # TODO + # from memgpt.functions.functions import load_function_set + # load_function_set() + # return + for function in default_preset.functions_schema: + schema.append(function) + payload = CreatePresetsRequest( - id=str(preset.id), - name=preset.name, - description=preset.description, - system=preset.system, - persona=preset.persona, - human=preset.human, - persona_name=preset.persona_name, - human_name=preset.human_name, - functions_schema=preset.functions_schema, + name=name, + description=description, + system_name=system_name, + persona_name=persona_name, + human_name=human_name, + functions_schema=schema, ) + print(schema) + print(human_name, persona_name, system_name, name) + print(payload.model_dump()) response = requests.post(f"{self.base_url}/api/presets", json=payload.model_dump(), headers=self.headers) assert response.status_code == 200, f"Failed to create preset: {response.text}" - return CreatePresetResponse(**response.json()) + return CreatePresetResponse(**response.json()).preset def delete_preset(self, preset_id: uuid.UUID): response = requests.delete(f"{self.base_url}/api/presets/{str(preset_id)}", headers=self.headers) @@ -518,23 +585,6 @@ class RESTClient(AbstractClient): response = requests.get(f"{self.base_url}/api/config", headers=self.headers) return ConfigResponse(**response.json()) - # tools - - def create_tool( - self, name: str, file_path: str, source_type: Optional[str] = "python", tags: Optional[List[str]] = None - ) -> CreateToolResponse: - """Add a tool implemented in a file path""" - source_code = open(file_path, "r").read() - data = {"name": name, "source_code": source_code, "source_type": source_type, "tags": tags} - response = requests.post(f"{self.base_url}/api/tools", json=data, headers=self.headers) - if response.status_code != 200: - raise ValueError(f"Failed to create tool: {response.text}") - return CreateToolResponse(**response.json()) - - def list_tools(self) -> ListToolsResponse: - response = requests.get(f"{self.base_url}/api/tools", headers=self.headers) - return ListToolsResponse(**response.json()) - class LocalClient(AbstractClient): def __init__( diff --git a/memgpt/config.py b/memgpt/config.py index 30e4e8df..4d0bb971 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -271,6 +271,7 @@ class MemGPTConfig: with open(self.config_path, "w", encoding="utf-8") as f: config.write(f) logger.debug(f"Saved Config: {self.config_path}") + print(f"Saved Config: {self.config_path}") @staticmethod def exists(): diff --git a/memgpt/data_types.py b/memgpt/data_types.py index 682270e3..4d60e3ba 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -11,11 +11,13 @@ from pydantic import BaseModel, Field from memgpt.constants import ( DEFAULT_HUMAN, DEFAULT_PERSONA, + DEFAULT_PRESET, LLM_MAX_TOKENS, MAX_EMBEDDING_DIM, TOOL_CALL_ID_MAX_LEN, ) from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG +from memgpt.prompts import gpt_system from memgpt.utils import ( create_uuid_from_string, get_human_text, @@ -848,7 +850,10 @@ class Preset(BaseModel): user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.") description: Optional[str] = Field(None, description="The description of the preset.") created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.") - system: str = Field(..., description="The system prompt of the preset.") + system: str = Field( + gpt_system.get_system_text(DEFAULT_PRESET), description="The system prompt of the preset." + ) # default system prompt is same as default preset name + # system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.") persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.") persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.") human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.") diff --git a/memgpt/functions/functions.py b/memgpt/functions/functions.py index 78656d18..233ed26e 100644 --- a/memgpt/functions/functions.py +++ b/memgpt/functions/functions.py @@ -76,6 +76,26 @@ def write_function(module_name: str, function_name: str, function_code: str): # raise error if function cannot be loaded if not succ: raise ValueError(error) + return file_path + + +def load_function_file(filepath: str) -> dict: + file = os.path.basename(filepath) + module_name = file[:-3] # Remove '.py' from filename + try: + spec = importlib.util.spec_from_file_location(module_name, filepath) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + except ModuleNotFoundError as e: + # Handle missing module imports + missing_package = str(e).split("'")[1] # Extract the name of the missing package + print(f"{CLI_WARNING_PREFIX}skipped loading python file '{filepath}'!") + print( + f"'{file}' imports '{missing_package}', but '{missing_package}' is not installed locally - install python package '{missing_package}' to link functions from '{file}' to MemGPT." + ) + # load all functions in the module + function_dict = load_function_set(module) + return function_dict def load_all_function_sets(merge: bool = True, ignore_duplicates: bool = True) -> dict: diff --git a/memgpt/metadata.py b/memgpt/metadata.py index 37820bcd..411e7d1f 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -400,7 +400,10 @@ class MetadataStore: @enforce_types def get_api_key(self, api_key: str) -> Optional[Token]: with self.session_maker() as session: + print("getting api key", api_key) + print([r.token for r in self.get_all_api_keys_for_user(user_id=uuid.UUID(int=0))]) results = session.query(TokenModel).filter(TokenModel.token == api_key).all() + print("results", [r.token for r in results]) if len(results) == 0: return None assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result @@ -410,15 +413,19 @@ class MetadataStore: def get_all_api_keys_for_user(self, user_id: uuid.UUID) -> List[Token]: with self.session_maker() as session: results = session.query(TokenModel).filter(TokenModel.user_id == user_id).all() - return [r.to_record() for r in results] + tokens = [r.to_record() for r in results] + print([t.token for t in tokens]) + return tokens @enforce_types def get_user_from_api_key(self, api_key: str) -> Optional[User]: """Get the user associated with a given API key""" token = self.get_api_key(api_key=api_key) + print("got api key", token.token, token is None) if token is None: raise ValueError(f"Token {api_key} does not exist") else: + print(isinstance(token.user_id, uuid.UUID), self.get_user(user_id=token.user_id)) return self.get_user(user_id=token.user_id) @enforce_types @@ -803,6 +810,12 @@ class MetadataStore: session.query(PresetModel).filter(PresetModel.name == name).filter(PresetModel.user_id == user_id).delete() session.commit() + @enforce_types + def delete_tool(self, name: str): + with self.session_maker() as session: + session.query(ToolModel).filter(ToolModel.name == name).delete() + session.commit() + # job related functions def create_job(self, job: JobModel): with self.session_maker() as session: diff --git a/memgpt/models/pydantic_models.py b/memgpt/models/pydantic_models.py index 05592404..217d6db7 100644 --- a/memgpt/models/pydantic_models.py +++ b/memgpt/models/pydantic_models.py @@ -38,6 +38,7 @@ class PresetModel(BaseModel): description: Optional[str] = Field(None, description="The description of the preset.") created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.") system: str = Field(..., description="The system prompt of the preset.") + system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.") persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.") persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.") human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.") diff --git a/memgpt/presets/presets.py b/memgpt/presets/presets.py index 41f6889a..14df189f 100644 --- a/memgpt/presets/presets.py +++ b/memgpt/presets/presets.py @@ -1,13 +1,14 @@ +import importlib import os import uuid from typing import List from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA from memgpt.data_types import AgentState, Preset -from memgpt.functions.functions import load_all_function_sets +from memgpt.functions.functions import load_all_function_sets, load_function_set from memgpt.interface import AgentInterface from memgpt.metadata import MetadataStore -from memgpt.models.pydantic_models import HumanModel, PersonaModel +from memgpt.models.pydantic_models import HumanModel, PersonaModel, ToolModel from memgpt.presets.utils import load_all_presets, load_yaml_file from memgpt.prompts import gpt_system from memgpt.utils import ( @@ -22,6 +23,34 @@ available_presets = load_all_presets() preset_options = list(available_presets.keys()) +def add_default_tools(user_id: uuid.UUID, ms: MetadataStore): + module_name = "base" + full_module_name = f"memgpt.functions.function_sets.{module_name}" + try: + module = importlib.import_module(full_module_name) + except Exception as e: + # Handle other general exceptions + raise e + + # function tags + + try: + # Load the function set + functions_to_schema = load_function_set(module) + except ValueError as e: + err = f"Error loading function set '{module_name}': {e}" + printd(err) + + from pprint import pprint + + print("BASE FUNCTIONS", functions_to_schema.keys()) + pprint(functions_to_schema) + + # create tool in db + for name, schema in functions_to_schema.items(): + ms.add_tool(ToolModel(name=name, tags=["base"], source_type="python", json_schema=schema["json_schema"])) + + def add_default_humans_and_personas(user_id: uuid.UUID, ms: MetadataStore): for persona_file in list_persona_files(): text = open(persona_file, "r").read() @@ -89,6 +118,10 @@ def add_default_presets(user_id: uuid.UUID, ms: MetadataStore): # make sure humans/personas added add_default_humans_and_personas(user_id=user_id, ms=ms) + # make sure base functions added + # TODO: pull from functions instead + add_default_tools(user_id=user_id, ms=ms) + # add default presets for preset_name in preset_options: if ms.get_preset(user_id=user_id, name=preset_name) is not None: diff --git a/memgpt/server/rest_api/admin/users.py b/memgpt/server/rest_api/admin/users.py index b4c485f8..97ff9de5 100644 --- a/memgpt/server/rest_api/admin/users.py +++ b/memgpt/server/rest_api/admin/users.py @@ -137,6 +137,8 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface): Get a list of all API keys for a user """ try: + if server.ms.get_user(user_id=user_id) is None: + raise HTTPException(status_code=404, detail=f"User does not exist") tokens = server.ms.get_all_api_keys_for_user(user_id=user_id) processed_tokens = [t.token for t in tokens] except HTTPException: diff --git a/memgpt/server/rest_api/agents/message.py b/memgpt/server/rest_api/agents/message.py index 07ab789c..c1faff4e 100644 --- a/memgpt/server/rest_api/agents/message.py +++ b/memgpt/server/rest_api/agents/message.py @@ -26,6 +26,7 @@ class MessageRoleType(str, Enum): class UserMessageRequest(BaseModel): message: str = Field(..., description="The message content to be processed by the agent.") + name: str = Field(default="user", description="Name of the message request sender") stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.") role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')") timestamp: Optional[datetime] = Field( diff --git a/memgpt/server/rest_api/humans/index.py b/memgpt/server/rest_api/humans/index.py index ea5f3ad1..a7f012a9 100644 --- a/memgpt/server/rest_api/humans/index.py +++ b/memgpt/server/rest_api/humans/index.py @@ -39,6 +39,7 @@ def setup_humans_index_router(server: SyncServer, interface: QueuingInterface, p request: CreateHumanRequest = Body(...), user_id: uuid.UUID = Depends(get_current_user_with_server), ): + # TODO: disallow duplicate names for humans interface.clear() new_human = HumanModel(text=request.text, name=request.name, user_id=user_id) human_id = new_human.id diff --git a/memgpt/server/rest_api/personas/index.py b/memgpt/server/rest_api/personas/index.py index ea46ee9e..b8f2503c 100644 --- a/memgpt/server/rest_api/personas/index.py +++ b/memgpt/server/rest_api/personas/index.py @@ -40,6 +40,7 @@ def setup_personas_index_router(server: SyncServer, interface: QueuingInterface, request: CreatePersonaRequest = Body(...), user_id: uuid.UUID = Depends(get_current_user_with_server), ): + # TODO: disallow duplicate names for personas interface.clear() new_persona = PersonaModel(text=request.text, name=request.name, user_id=user_id) persona_id = new_persona.id diff --git a/memgpt/server/rest_api/presets/index.py b/memgpt/server/rest_api/presets/index.py index 5b60d6bc..4702371f 100644 --- a/memgpt/server/rest_api/presets/index.py +++ b/memgpt/server/rest_api/presets/index.py @@ -1,14 +1,15 @@ import uuid from functools import partial -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional from fastapi import APIRouter, Body, Depends, HTTPException, status from fastapi.responses import JSONResponse from pydantic import BaseModel, Field -from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA +from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET from memgpt.data_types import Preset # TODO remove -from memgpt.models.pydantic_models import PresetModel +from memgpt.models.pydantic_models import HumanModel, PersonaModel, PresetModel +from memgpt.prompts import gpt_system from memgpt.server.rest_api.auth_token import get_current_user from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.server import SyncServer @@ -32,17 +33,18 @@ class ListPresetsResponse(BaseModel): class CreatePresetsRequest(BaseModel): # TODO is there a cleaner way to create the request from the PresetModel (need to drop fields though)? name: str = Field(..., description="The name of the preset.") - id: Optional[Union[uuid.UUID, str]] = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.") + id: Optional[str] = Field(None, description="The unique identifier of the preset.") # user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.") description: Optional[str] = Field(None, description="The description of the preset.") # created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.") - system: str = Field(..., description="The system prompt of the preset.") - persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.") - human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.") + system: Optional[str] = Field(None, description="The system prompt of the preset.") # TODO: make optional and allow defaults + persona: Optional[str] = Field(default=None, description="The persona of the preset.") + human: Optional[str] = Field(default=None, description="The human of the preset.") functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.") # TODO persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.") human_name: Optional[str] = Field(None, description="The name of the human of the preset.") + system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.") class CreatePresetResponse(BaseModel): @@ -52,6 +54,20 @@ class CreatePresetResponse(BaseModel): def setup_presets_index_router(server: SyncServer, interface: QueuingInterface, password: str): get_current_user_with_server = partial(partial(get_current_user, server), password) + @router.get("/presets/{preset_name}", tags=["presets"], response_model=PresetModel) + async def get_preset( + preset_name: str, + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + """Get a preset.""" + try: + preset = server.get_preset(user_id=user_id, preset_name=preset_name) + return preset + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") + @router.get("/presets", tags=["presets"], response_model=ListPresetsResponse) async def list_presets( user_id: uuid.UUID = Depends(get_current_user_with_server), @@ -77,18 +93,52 @@ def setup_presets_index_router(server: SyncServer, interface: QueuingInterface, try: if isinstance(request.id, str): request.id = uuid.UUID(request.id) - # new_preset = PresetModel( + + # check if preset already exists + # TODO: move this into a server function to create a preset + if server.ms.get_preset(name=request.name, user_id=user_id): + raise HTTPException(status_code=400, detail=f"Preset with name {request.name} already exists.") + + # For system/human/persona - if {system/human-personal}_name is None but the text is provied, then create a new data entry + if not request.system_name and request.system: + # new system provided without name identity + system_name = f"system_{request.name}_{str(uuid.uuid4())}" + system = request.system + # TODO: insert into system table + else: + system_name = request.system_name if request.system_name else DEFAULT_PRESET + system = request.system if request.system else gpt_system.get_system_text(system_name) + + if not request.human_name and request.human: + # new human provided without name identity + human_name = f"human_{request.name}_{str(uuid.uuid4())}" + human = request.human + server.ms.add_human(HumanModel(text=human, name=human_name, user_id=user_id)) + else: + human_name = request.human_name if request.human_name else DEFAULT_HUMAN + human = request.human if request.human else get_human_text(human_name) + + if not request.persona_name and request.persona: + # new persona provided without name identity + persona_name = f"persona_{request.name}_{str(uuid.uuid4())}" + persona = request.persona + server.ms.add_persona(PersonaModel(text=persona, name=persona_name, user_id=user_id)) + else: + persona_name = request.persona_name if request.persona_name else DEFAULT_PERSONA + persona = request.persona if request.persona else get_persona_text(persona_name) + + # create preset new_preset = Preset( user_id=user_id, - id=request.id, + id=request.id if request.id else uuid.uuid4(), name=request.name, description=request.description, - system=request.system, - persona=request.persona, - human=request.human, + system=system, + persona=persona, + persona_name=persona_name, + human=human, + human_name=human_name, functions_schema=request.functions_schema, - persona_name=request.persona_name, - human_name=request.human_name, ) preset = server.create_preset(preset=new_preset) diff --git a/memgpt/server/rest_api/tools/index.py b/memgpt/server/rest_api/tools/index.py index 0f3ae2b7..2c8fc736 100644 --- a/memgpt/server/rest_api/tools/index.py +++ b/memgpt/server/rest_api/tools/index.py @@ -1,6 +1,6 @@ from typing import List, Literal, Optional -from fastapi import APIRouter, Body +from fastapi import APIRouter, Body, HTTPException from pydantic import BaseModel, Field from memgpt.models.pydantic_models import ToolModel @@ -28,6 +28,33 @@ class CreateToolResponse(BaseModel): def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, password: str): # get_current_user_with_server = partial(partial(get_current_user, server), password) + @router.delete("/tools/{tool_name}", tags=["tools"]) + async def delete_tool( + tool_name: str, + # user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific + ): + """ + Delete a tool by name + """ + # Clear the interface + interface.clear() + # tool = server.ms.delete_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific + server.ms.delete_tool(name=tool_name) + + @router.get("/tools/{tool_name}", tags=["tools"], response_model=ToolModel) + async def get_tool(tool_name: str): + """ + Get a tool by name + """ + # Clear the interface + interface.clear() + # tool = server.ms.get_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific + tool = server.ms.get_tool(tool_name=tool_name) + if tool is None: + # return 404 error + raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.") + return tool + @router.get("/tools", tags=["tools"], response_model=ListToolsResponse) async def list_all_tools( # user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific @@ -41,7 +68,7 @@ def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, pa tools = server.ms.list_tools() return ListToolsResponse(tools=tools) - @router.post("/tools", tags=["tools"], response_model=CreateToolResponse) + @router.post("/tools", tags=["tools"], response_model=ToolModel) async def create_tool( request: CreateToolRequest = Body(...), # user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific @@ -49,20 +76,26 @@ def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, pa """ Create a new tool (dummy route) """ - from memgpt.functions.functions import write_function + from memgpt.functions.functions import load_function_file, write_function # check if function already exists if server.ms.get_tool(request.name): raise ValueError(f"Tool with name {request.name} already exists.") # write function to ~/.memgt/functions directory - write_function(request.name, request.name, request.source_code) + file_path = write_function(request.name, request.name, request.source_code) + + # TODO: Use load_function_file to load function schema + schema = load_function_file(file_path) + assert len(list(schema.keys())) == 1, "Function schema must have exactly one key" + json_schema = list(schema.values())[0]["json_schema"] print("adding tool", request.name, request.tags, request.source_code) - tool = ToolModel(name=request.name, json_schema={}, tags=request.tags, source_code=request.source_code) + tool = ToolModel(name=request.name, json_schema=json_schema, tags=request.tags, source_code=request.source_code) + tool.id server.ms.add_tool(tool) # TODO: insert tool information into DB as ToolModel - return CreateToolResponse(tool=ToolModel(name=request.name, json_schema={}, tags=[], source_code=request.source_code)) + return server.ms.get_tool(request.name) return router diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 00c776d4..b557625c 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -45,7 +45,6 @@ from memgpt.models.pydantic_models import ( SourceModel, ToolModel, ) -from memgpt.settings import settings logger = logging.getLogger(__name__) @@ -214,12 +213,12 @@ class SyncServer(LockingServer): assert self.config.human is not None, "Human must be set in the config" # Update storage URI to match passed in settings - # TODO: very hack, fix in the future - for memory_type in ("archival", "recall", "metadata"): - if settings.memgpt_pg_uri_no_default: - # override with env - setattr(self.config, f"{memory_type}_storage_uri", settings.memgpt_pg_uri) - self.config.save() + # (NOTE: no longer needed since envs being used, I think) + # for memory_type in ("archival", "recall", "metadata"): + # if settings.memgpt_pg_uri: + # # override with env + # setattr(self.config, f"{memory_type}_storage_uri", settings.memgpt_pg_uri) + # self.config.save() # TODO figure out how to handle credentials for the server self.credentials = MemGPTCredentials.load() @@ -1278,6 +1277,7 @@ class SyncServer(LockingServer): def api_key_to_user(self, api_key: str) -> uuid.UUID: """Decode an API key to a user""" user = self.ms.get_user_from_api_key(api_key=api_key) + print("got user", api_key, user.id) if user is None: raise HTTPException(status_code=403, detail="Invalid credentials") else: diff --git a/tests/test_admin_client.py b/tests/test_admin_client.py index 6b1970ad..80573da6 100644 --- a/tests/test_admin_client.py +++ b/tests/test_admin_client.py @@ -58,19 +58,17 @@ def test_admin_client(admin_client): # list keys user1_keys = admin_client.get_keys(user_id) - assert len(user1_keys.api_key_list) == 2, f"Expected 2 keys, got {user1_keys}" - assert create_key1_response.api_key in user1_keys.api_key_list, f"Expected {create_key1_response.api_key} in {user1_keys.api_key_list}" - assert ( - create_user1_response.api_key in user1_keys.api_key_list - ), f"Expected {create_user1_response.api_key} in {user1_keys.api_key_list}" + assert len(user1_keys) == 2, f"Expected 2 keys, got {user1_keys}" + assert create_key1_response.api_key in user1_keys, f"Expected {create_key1_response.api_key} in {user1_keys}" + assert create_user1_response.api_key in user1_keys, f"Expected {create_user1_response.api_key} in {user1_keys}" # delete key delete_key1_response = admin_client.delete_key(create_key1_response.api_key) assert delete_key1_response.api_key_deleted == create_key1_response.api_key - assert len(admin_client.get_keys(user_id).api_key_list) == 1 + assert len(admin_client.get_keys(user_id)) == 1 delete_key2_response = admin_client.delete_key(create_key2_response.api_key) assert delete_key2_response.api_key_deleted == create_key2_response.api_key - assert len(admin_client.get_keys(create_user_2_response.user_id).api_key_list) == 1 + assert len(admin_client.get_keys(create_user_2_response.user_id)) == 1 # delete users delete_user1_response = admin_client.delete_user(user_id) diff --git a/tests/test_client.py b/tests/test_client.py index 798862b7..5b37f3dd 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -320,41 +320,41 @@ def test_sources(client, agent): def test_presets(client, agent): _reset_config() - new_preset = Preset( - # user_id=client.user_id, - name="pytest_test_preset", - description="DUMMY_DESCRIPTION", - system="DUMMY_SYSTEM", - persona="DUMMY_PERSONA", - persona_name="DUMMY_PERSONA_NAME", - human="DUMMY_HUMAN", - human_name="DUMMY_HUMAN_NAME", - functions_schema=[ - { - "name": "send_message", - "json_schema": { - "name": "send_message", - "description": "Sends a message to the human user.", - "parameters": { - "type": "object", - "properties": { - "message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."} - }, - "required": ["message"], - }, - }, - "tags": ["memgpt-base"], - "source_type": "python", - "source_code": 'def send_message(self, message: str) -> Optional[str]:\n """\n Sends a message to the human user.\n\n Args:\n message (str): Message contents. All unicode (including emojis) are supported.\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.\n """\n self.interface.assistant_message(message)\n return None\n', - } - ], - ) + # new_preset = Preset( + # # user_id=client.user_id, + # name="pytest_test_preset", + # description="DUMMY_DESCRIPTION", + # system="DUMMY_SYSTEM", + # persona="DUMMY_PERSONA", + # persona_name="DUMMY_PERSONA_NAME", + # human="DUMMY_HUMAN", + # human_name="DUMMY_HUMAN_NAME", + # functions_schema=[ + # { + # "name": "send_message", + # "json_schema": { + # "name": "send_message", + # "description": "Sends a message to the human user.", + # "parameters": { + # "type": "object", + # "properties": { + # "message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."} + # }, + # "required": ["message"], + # }, + # }, + # "tags": ["memgpt-base"], + # "source_type": "python", + # "source_code": 'def send_message(self, message: str) -> Optional[str]:\n """\n Sends a message to the human user.\n\n Args:\n message (str): Message contents. All unicode (including emojis) are supported.\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.\n """\n self.interface.assistant_message(message)\n return None\n', + # } + # ], + # ) - # List all presets and make sure the preset is NOT in the list - all_presets = client.list_presets() - assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets) + ## List all presets and make sure the preset is NOT in the list + # all_presets = client.list_presets() + # assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets) # Create a preset - client.create_preset(preset=new_preset) + new_preset = client.create_preset(name="pytest_test_preset") # List all presets and make sure the preset is in the list all_presets = client.list_presets() diff --git a/tests/test_server.py b/tests/test_server.py index ee73de34..9fee78d0 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -22,7 +22,7 @@ def server(): wipe_config() wipe_memgpt_home() - db_url = settings.pg_db + db_url = settings.memgpt_pg_uri if os.getenv("OPENAI_API_KEY"): config = TestMGPTConfig(