feat: add more tool functionality for python client (#1361)
Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import requests
|
||||
from requests import HTTPError
|
||||
|
||||
from memgpt.models.pydantic_models import ToolModel
|
||||
from memgpt.server.rest_api.admin.users import (
|
||||
CreateAPIKeyResponse,
|
||||
CreateUserResponse,
|
||||
@@ -12,6 +13,7 @@ from memgpt.server.rest_api.admin.users import (
|
||||
GetAllUsersResponse,
|
||||
GetAPIKeysResponse,
|
||||
)
|
||||
from memgpt.server.rest_api.tools.index import ListToolsResponse
|
||||
|
||||
|
||||
class Admin:
|
||||
@@ -46,7 +48,7 @@ class Admin:
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
print(response.text, response.status_code)
|
||||
return GetAPIKeysResponse(**response.json())
|
||||
return GetAPIKeysResponse(**response.json()).api_key_list
|
||||
|
||||
def delete_key(self, api_key: str):
|
||||
params = {"api_key": api_key}
|
||||
@@ -77,7 +79,35 @@ class Admin:
|
||||
# TODO: clear out all agents, presets, etc.
|
||||
users = self.get_users().user_list
|
||||
for user in users:
|
||||
keys = self.get_keys(user["user_id"]).api_key_list
|
||||
keys = self.get_keys(user["user_id"])
|
||||
for key in keys:
|
||||
self.delete_key(key)
|
||||
self.delete_user(user["user_id"])
|
||||
|
||||
# tools (currently only available for admin)
|
||||
def create_tool(self, name: str, file_path: str, source_type: Optional[str] = "python", tags: Optional[List[str]] = None) -> ToolModel:
|
||||
"""Add a tool implemented in a file path"""
|
||||
source_code = open(file_path, "r").read()
|
||||
data = {"name": name, "source_code": source_code, "source_type": source_type, "tags": tags}
|
||||
response = requests.post(f"{self.base_url}/api/tools", json=data, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to create tool: {response.text}")
|
||||
return ToolModel(**response.json())
|
||||
|
||||
def list_tools(self) -> ListToolsResponse:
|
||||
response = requests.get(f"{self.base_url}/api/tools", headers=self.headers)
|
||||
return ListToolsResponse(**response.json())
|
||||
|
||||
def delete_tool(self, name: str):
|
||||
response = requests.delete(f"{self.base_url}/api/tools/{name}", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to delete tool: {response.text}")
|
||||
return response.json()
|
||||
|
||||
def get_tool(self, name: str):
|
||||
response = requests.get(f"{self.base_url}/api/tools/{name}", headers=self.headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
elif response.status_code != 200:
|
||||
raise ValueError(f"Failed to get tool: {response.text}")
|
||||
return ToolModel(**response.json())
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
import requests
|
||||
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
from memgpt.data_sources.connectors import DataConnector
|
||||
from memgpt.data_types import (
|
||||
AgentState,
|
||||
@@ -23,6 +24,7 @@ from memgpt.models.pydantic_models import (
|
||||
PersonaModel,
|
||||
PresetModel,
|
||||
SourceModel,
|
||||
ToolModel,
|
||||
)
|
||||
|
||||
# import pydantic response objects from memgpt.server.rest_api
|
||||
@@ -50,7 +52,7 @@ from memgpt.server.rest_api.presets.index import (
|
||||
ListPresetsResponse,
|
||||
)
|
||||
from memgpt.server.rest_api.sources.index import ListSourcesResponse
|
||||
from memgpt.server.rest_api.tools.index import CreateToolResponse, ListToolsResponse
|
||||
from memgpt.server.rest_api.tools.index import CreateToolResponse
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
|
||||
@@ -259,7 +261,7 @@ class RESTClient(AbstractClient):
|
||||
def create_agent(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
preset: Optional[str] = None,
|
||||
preset: Optional[str] = None, # TODO: this should actually be re-named preset_name
|
||||
persona: Optional[str] = None,
|
||||
human: Optional[str] = None,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
@@ -267,6 +269,7 @@ class RESTClient(AbstractClient):
|
||||
) -> AgentState:
|
||||
if embedding_config or llm_config:
|
||||
raise ValueError("Cannot override embedding_config or llm_config when creating agent via REST API")
|
||||
# TODO: distinguish between name and objects
|
||||
payload = {
|
||||
"config": {
|
||||
"name": name,
|
||||
@@ -329,23 +332,87 @@ class RESTClient(AbstractClient):
|
||||
response_obj = GetAgentResponse(**response.json())
|
||||
return self.get_agent_response_to_state(response_obj)
|
||||
|
||||
# presets
|
||||
def create_preset(self, preset: Preset) -> CreatePresetResponse:
|
||||
# TODO should the arg type here be PresetModel, not Preset?
|
||||
## presets
|
||||
# def create_preset(self, preset: Preset) -> CreatePresetResponse:
|
||||
# # TODO should the arg type here be PresetModel, not Preset?
|
||||
# payload = CreatePresetsRequest(
|
||||
# id=str(preset.id),
|
||||
# name=preset.name,
|
||||
# description=preset.description,
|
||||
# system=preset.system,
|
||||
# persona=preset.persona,
|
||||
# human=preset.human,
|
||||
# persona_name=preset.persona_name,
|
||||
# human_name=preset.human_name,
|
||||
# functions_schema=preset.functions_schema,
|
||||
# )
|
||||
# response = requests.post(f"{self.base_url}/api/presets", json=payload.model_dump(), headers=self.headers)
|
||||
# assert response.status_code == 200, f"Failed to create preset: {response.text}"
|
||||
# return CreatePresetResponse(**response.json())
|
||||
|
||||
def get_preset(self, name: str) -> PresetModel:
|
||||
response = requests.get(f"{self.base_url}/api/presets/{name}", headers=self.headers)
|
||||
assert response.status_code == 200, f"Failed to get preset: {response.text}"
|
||||
return PresetModel(**response.json())
|
||||
|
||||
def create_preset(
|
||||
self,
|
||||
name: str,
|
||||
description: Optional[str] = None,
|
||||
system_name: Optional[str] = None,
|
||||
persona_name: Optional[str] = None,
|
||||
human_name: Optional[str] = None,
|
||||
tools: Optional[List[ToolModel]] = None,
|
||||
default_tools: bool = True,
|
||||
) -> PresetModel:
|
||||
"""Create an agent preset
|
||||
|
||||
:param name: Name of the preset
|
||||
:type name: str
|
||||
:param system: System prompt (text)
|
||||
:type system: str
|
||||
:param persona: Persona prompt (text)
|
||||
:type persona: Optional[str]
|
||||
:param human: Human prompt (text)
|
||||
:type human: Optional[str]
|
||||
:param tools: List of tools to connect, defaults to None
|
||||
:type tools: Optional[List[Tool]], optional
|
||||
:param default_tools: Whether to automatically include default tools, defaults to True
|
||||
:type default_tools: bool, optional
|
||||
:return: Preset object
|
||||
:rtype: PresetModel
|
||||
"""
|
||||
# provided tools
|
||||
schema = []
|
||||
if tools:
|
||||
for tool in tools:
|
||||
print("CUSOTM TOOL", tool.json_schema)
|
||||
schema.append(tool.json_schema)
|
||||
|
||||
# include default tools
|
||||
default_preset = self.get_preset(name=DEFAULT_PRESET)
|
||||
if default_tools:
|
||||
# TODO
|
||||
# from memgpt.functions.functions import load_function_set
|
||||
# load_function_set()
|
||||
# return
|
||||
for function in default_preset.functions_schema:
|
||||
schema.append(function)
|
||||
|
||||
payload = CreatePresetsRequest(
|
||||
id=str(preset.id),
|
||||
name=preset.name,
|
||||
description=preset.description,
|
||||
system=preset.system,
|
||||
persona=preset.persona,
|
||||
human=preset.human,
|
||||
persona_name=preset.persona_name,
|
||||
human_name=preset.human_name,
|
||||
functions_schema=preset.functions_schema,
|
||||
name=name,
|
||||
description=description,
|
||||
system_name=system_name,
|
||||
persona_name=persona_name,
|
||||
human_name=human_name,
|
||||
functions_schema=schema,
|
||||
)
|
||||
print(schema)
|
||||
print(human_name, persona_name, system_name, name)
|
||||
print(payload.model_dump())
|
||||
response = requests.post(f"{self.base_url}/api/presets", json=payload.model_dump(), headers=self.headers)
|
||||
assert response.status_code == 200, f"Failed to create preset: {response.text}"
|
||||
return CreatePresetResponse(**response.json())
|
||||
return CreatePresetResponse(**response.json()).preset
|
||||
|
||||
def delete_preset(self, preset_id: uuid.UUID):
|
||||
response = requests.delete(f"{self.base_url}/api/presets/{str(preset_id)}", headers=self.headers)
|
||||
@@ -518,23 +585,6 @@ class RESTClient(AbstractClient):
|
||||
response = requests.get(f"{self.base_url}/api/config", headers=self.headers)
|
||||
return ConfigResponse(**response.json())
|
||||
|
||||
# tools
|
||||
|
||||
def create_tool(
|
||||
self, name: str, file_path: str, source_type: Optional[str] = "python", tags: Optional[List[str]] = None
|
||||
) -> CreateToolResponse:
|
||||
"""Add a tool implemented in a file path"""
|
||||
source_code = open(file_path, "r").read()
|
||||
data = {"name": name, "source_code": source_code, "source_type": source_type, "tags": tags}
|
||||
response = requests.post(f"{self.base_url}/api/tools", json=data, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to create tool: {response.text}")
|
||||
return CreateToolResponse(**response.json())
|
||||
|
||||
def list_tools(self) -> ListToolsResponse:
|
||||
response = requests.get(f"{self.base_url}/api/tools", headers=self.headers)
|
||||
return ListToolsResponse(**response.json())
|
||||
|
||||
|
||||
class LocalClient(AbstractClient):
|
||||
def __init__(
|
||||
|
||||
@@ -271,6 +271,7 @@ class MemGPTConfig:
|
||||
with open(self.config_path, "w", encoding="utf-8") as f:
|
||||
config.write(f)
|
||||
logger.debug(f"Saved Config: {self.config_path}")
|
||||
print(f"Saved Config: {self.config_path}")
|
||||
|
||||
@staticmethod
|
||||
def exists():
|
||||
|
||||
@@ -11,11 +11,13 @@ from pydantic import BaseModel, Field
|
||||
from memgpt.constants import (
|
||||
DEFAULT_HUMAN,
|
||||
DEFAULT_PERSONA,
|
||||
DEFAULT_PRESET,
|
||||
LLM_MAX_TOKENS,
|
||||
MAX_EMBEDDING_DIM,
|
||||
TOOL_CALL_ID_MAX_LEN,
|
||||
)
|
||||
from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from memgpt.prompts import gpt_system
|
||||
from memgpt.utils import (
|
||||
create_uuid_from_string,
|
||||
get_human_text,
|
||||
@@ -848,7 +850,10 @@ class Preset(BaseModel):
|
||||
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.")
|
||||
description: Optional[str] = Field(None, description="The description of the preset.")
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.")
|
||||
system: str = Field(..., description="The system prompt of the preset.")
|
||||
system: str = Field(
|
||||
gpt_system.get_system_text(DEFAULT_PRESET), description="The system prompt of the preset."
|
||||
) # default system prompt is same as default preset name
|
||||
# system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.")
|
||||
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
|
||||
persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.")
|
||||
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
|
||||
|
||||
@@ -76,6 +76,26 @@ def write_function(module_name: str, function_name: str, function_code: str):
|
||||
# raise error if function cannot be loaded
|
||||
if not succ:
|
||||
raise ValueError(error)
|
||||
return file_path
|
||||
|
||||
|
||||
def load_function_file(filepath: str) -> dict:
|
||||
file = os.path.basename(filepath)
|
||||
module_name = file[:-3] # Remove '.py' from filename
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(module_name, filepath)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
except ModuleNotFoundError as e:
|
||||
# Handle missing module imports
|
||||
missing_package = str(e).split("'")[1] # Extract the name of the missing package
|
||||
print(f"{CLI_WARNING_PREFIX}skipped loading python file '{filepath}'!")
|
||||
print(
|
||||
f"'{file}' imports '{missing_package}', but '{missing_package}' is not installed locally - install python package '{missing_package}' to link functions from '{file}' to MemGPT."
|
||||
)
|
||||
# load all functions in the module
|
||||
function_dict = load_function_set(module)
|
||||
return function_dict
|
||||
|
||||
|
||||
def load_all_function_sets(merge: bool = True, ignore_duplicates: bool = True) -> dict:
|
||||
|
||||
@@ -400,7 +400,10 @@ class MetadataStore:
|
||||
@enforce_types
|
||||
def get_api_key(self, api_key: str) -> Optional[Token]:
|
||||
with self.session_maker() as session:
|
||||
print("getting api key", api_key)
|
||||
print([r.token for r in self.get_all_api_keys_for_user(user_id=uuid.UUID(int=0))])
|
||||
results = session.query(TokenModel).filter(TokenModel.token == api_key).all()
|
||||
print("results", [r.token for r in results])
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result
|
||||
@@ -410,15 +413,19 @@ class MetadataStore:
|
||||
def get_all_api_keys_for_user(self, user_id: uuid.UUID) -> List[Token]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(TokenModel).filter(TokenModel.user_id == user_id).all()
|
||||
return [r.to_record() for r in results]
|
||||
tokens = [r.to_record() for r in results]
|
||||
print([t.token for t in tokens])
|
||||
return tokens
|
||||
|
||||
@enforce_types
|
||||
def get_user_from_api_key(self, api_key: str) -> Optional[User]:
|
||||
"""Get the user associated with a given API key"""
|
||||
token = self.get_api_key(api_key=api_key)
|
||||
print("got api key", token.token, token is None)
|
||||
if token is None:
|
||||
raise ValueError(f"Token {api_key} does not exist")
|
||||
else:
|
||||
print(isinstance(token.user_id, uuid.UUID), self.get_user(user_id=token.user_id))
|
||||
return self.get_user(user_id=token.user_id)
|
||||
|
||||
@enforce_types
|
||||
@@ -803,6 +810,12 @@ class MetadataStore:
|
||||
session.query(PresetModel).filter(PresetModel.name == name).filter(PresetModel.user_id == user_id).delete()
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def delete_tool(self, name: str):
|
||||
with self.session_maker() as session:
|
||||
session.query(ToolModel).filter(ToolModel.name == name).delete()
|
||||
session.commit()
|
||||
|
||||
# job related functions
|
||||
def create_job(self, job: JobModel):
|
||||
with self.session_maker() as session:
|
||||
|
||||
@@ -38,6 +38,7 @@ class PresetModel(BaseModel):
|
||||
description: Optional[str] = Field(None, description="The description of the preset.")
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.")
|
||||
system: str = Field(..., description="The system prompt of the preset.")
|
||||
system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.")
|
||||
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
|
||||
persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.")
|
||||
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import importlib
|
||||
import os
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
|
||||
from memgpt.data_types import AgentState, Preset
|
||||
from memgpt.functions.functions import load_all_function_sets
|
||||
from memgpt.functions.functions import load_all_function_sets, load_function_set
|
||||
from memgpt.interface import AgentInterface
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.models.pydantic_models import HumanModel, PersonaModel
|
||||
from memgpt.models.pydantic_models import HumanModel, PersonaModel, ToolModel
|
||||
from memgpt.presets.utils import load_all_presets, load_yaml_file
|
||||
from memgpt.prompts import gpt_system
|
||||
from memgpt.utils import (
|
||||
@@ -22,6 +23,34 @@ available_presets = load_all_presets()
|
||||
preset_options = list(available_presets.keys())
|
||||
|
||||
|
||||
def add_default_tools(user_id: uuid.UUID, ms: MetadataStore):
|
||||
module_name = "base"
|
||||
full_module_name = f"memgpt.functions.function_sets.{module_name}"
|
||||
try:
|
||||
module = importlib.import_module(full_module_name)
|
||||
except Exception as e:
|
||||
# Handle other general exceptions
|
||||
raise e
|
||||
|
||||
# function tags
|
||||
|
||||
try:
|
||||
# Load the function set
|
||||
functions_to_schema = load_function_set(module)
|
||||
except ValueError as e:
|
||||
err = f"Error loading function set '{module_name}': {e}"
|
||||
printd(err)
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
print("BASE FUNCTIONS", functions_to_schema.keys())
|
||||
pprint(functions_to_schema)
|
||||
|
||||
# create tool in db
|
||||
for name, schema in functions_to_schema.items():
|
||||
ms.add_tool(ToolModel(name=name, tags=["base"], source_type="python", json_schema=schema["json_schema"]))
|
||||
|
||||
|
||||
def add_default_humans_and_personas(user_id: uuid.UUID, ms: MetadataStore):
|
||||
for persona_file in list_persona_files():
|
||||
text = open(persona_file, "r").read()
|
||||
@@ -89,6 +118,10 @@ def add_default_presets(user_id: uuid.UUID, ms: MetadataStore):
|
||||
# make sure humans/personas added
|
||||
add_default_humans_and_personas(user_id=user_id, ms=ms)
|
||||
|
||||
# make sure base functions added
|
||||
# TODO: pull from functions instead
|
||||
add_default_tools(user_id=user_id, ms=ms)
|
||||
|
||||
# add default presets
|
||||
for preset_name in preset_options:
|
||||
if ms.get_preset(user_id=user_id, name=preset_name) is not None:
|
||||
|
||||
@@ -137,6 +137,8 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
|
||||
Get a list of all API keys for a user
|
||||
"""
|
||||
try:
|
||||
if server.ms.get_user(user_id=user_id) is None:
|
||||
raise HTTPException(status_code=404, detail=f"User does not exist")
|
||||
tokens = server.ms.get_all_api_keys_for_user(user_id=user_id)
|
||||
processed_tokens = [t.token for t in tokens]
|
||||
except HTTPException:
|
||||
|
||||
@@ -26,6 +26,7 @@ class MessageRoleType(str, Enum):
|
||||
|
||||
class UserMessageRequest(BaseModel):
|
||||
message: str = Field(..., description="The message content to be processed by the agent.")
|
||||
name: str = Field(default="user", description="Name of the message request sender")
|
||||
stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.")
|
||||
role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')")
|
||||
timestamp: Optional[datetime] = Field(
|
||||
|
||||
@@ -39,6 +39,7 @@ def setup_humans_index_router(server: SyncServer, interface: QueuingInterface, p
|
||||
request: CreateHumanRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
# TODO: disallow duplicate names for humans
|
||||
interface.clear()
|
||||
new_human = HumanModel(text=request.text, name=request.name, user_id=user_id)
|
||||
human_id = new_human.id
|
||||
|
||||
@@ -40,6 +40,7 @@ def setup_personas_index_router(server: SyncServer, interface: QueuingInterface,
|
||||
request: CreatePersonaRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
# TODO: disallow duplicate names for personas
|
||||
interface.clear()
|
||||
new_persona = PersonaModel(text=request.text, name=request.name, user_id=user_id)
|
||||
persona_id = new_persona.id
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
|
||||
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
|
||||
from memgpt.data_types import Preset # TODO remove
|
||||
from memgpt.models.pydantic_models import PresetModel
|
||||
from memgpt.models.pydantic_models import HumanModel, PersonaModel, PresetModel
|
||||
from memgpt.prompts import gpt_system
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
@@ -32,17 +33,18 @@ class ListPresetsResponse(BaseModel):
|
||||
class CreatePresetsRequest(BaseModel):
|
||||
# TODO is there a cleaner way to create the request from the PresetModel (need to drop fields though)?
|
||||
name: str = Field(..., description="The name of the preset.")
|
||||
id: Optional[Union[uuid.UUID, str]] = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
|
||||
id: Optional[str] = Field(None, description="The unique identifier of the preset.")
|
||||
# user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.")
|
||||
description: Optional[str] = Field(None, description="The description of the preset.")
|
||||
# created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.")
|
||||
system: str = Field(..., description="The system prompt of the preset.")
|
||||
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
|
||||
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
|
||||
system: Optional[str] = Field(None, description="The system prompt of the preset.") # TODO: make optional and allow defaults
|
||||
persona: Optional[str] = Field(default=None, description="The persona of the preset.")
|
||||
human: Optional[str] = Field(default=None, description="The human of the preset.")
|
||||
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
|
||||
# TODO
|
||||
persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.")
|
||||
human_name: Optional[str] = Field(None, description="The name of the human of the preset.")
|
||||
system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.")
|
||||
|
||||
|
||||
class CreatePresetResponse(BaseModel):
|
||||
@@ -52,6 +54,20 @@ class CreatePresetResponse(BaseModel):
|
||||
def setup_presets_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/presets/{preset_name}", tags=["presets"], response_model=PresetModel)
|
||||
async def get_preset(
|
||||
preset_name: str,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""Get a preset."""
|
||||
try:
|
||||
preset = server.get_preset(user_id=user_id, preset_name=preset_name)
|
||||
return preset
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
@router.get("/presets", tags=["presets"], response_model=ListPresetsResponse)
|
||||
async def list_presets(
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
@@ -77,18 +93,52 @@ def setup_presets_index_router(server: SyncServer, interface: QueuingInterface,
|
||||
try:
|
||||
if isinstance(request.id, str):
|
||||
request.id = uuid.UUID(request.id)
|
||||
# new_preset = PresetModel(
|
||||
|
||||
# check if preset already exists
|
||||
# TODO: move this into a server function to create a preset
|
||||
if server.ms.get_preset(name=request.name, user_id=user_id):
|
||||
raise HTTPException(status_code=400, detail=f"Preset with name {request.name} already exists.")
|
||||
|
||||
# For system/human/persona - if {system/human-personal}_name is None but the text is provied, then create a new data entry
|
||||
if not request.system_name and request.system:
|
||||
# new system provided without name identity
|
||||
system_name = f"system_{request.name}_{str(uuid.uuid4())}"
|
||||
system = request.system
|
||||
# TODO: insert into system table
|
||||
else:
|
||||
system_name = request.system_name if request.system_name else DEFAULT_PRESET
|
||||
system = request.system if request.system else gpt_system.get_system_text(system_name)
|
||||
|
||||
if not request.human_name and request.human:
|
||||
# new human provided without name identity
|
||||
human_name = f"human_{request.name}_{str(uuid.uuid4())}"
|
||||
human = request.human
|
||||
server.ms.add_human(HumanModel(text=human, name=human_name, user_id=user_id))
|
||||
else:
|
||||
human_name = request.human_name if request.human_name else DEFAULT_HUMAN
|
||||
human = request.human if request.human else get_human_text(human_name)
|
||||
|
||||
if not request.persona_name and request.persona:
|
||||
# new persona provided without name identity
|
||||
persona_name = f"persona_{request.name}_{str(uuid.uuid4())}"
|
||||
persona = request.persona
|
||||
server.ms.add_persona(PersonaModel(text=persona, name=persona_name, user_id=user_id))
|
||||
else:
|
||||
persona_name = request.persona_name if request.persona_name else DEFAULT_PERSONA
|
||||
persona = request.persona if request.persona else get_persona_text(persona_name)
|
||||
|
||||
# create preset
|
||||
new_preset = Preset(
|
||||
user_id=user_id,
|
||||
id=request.id,
|
||||
id=request.id if request.id else uuid.uuid4(),
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
system=request.system,
|
||||
persona=request.persona,
|
||||
human=request.human,
|
||||
system=system,
|
||||
persona=persona,
|
||||
persona_name=persona_name,
|
||||
human=human,
|
||||
human_name=human_name,
|
||||
functions_schema=request.functions_schema,
|
||||
persona_name=request.persona_name,
|
||||
human_name=request.human_name,
|
||||
)
|
||||
preset = server.create_preset(preset=new_preset)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Body
|
||||
from fastapi import APIRouter, Body, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.models.pydantic_models import ToolModel
|
||||
@@ -28,6 +28,33 @@ class CreateToolResponse(BaseModel):
|
||||
def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
# get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.delete("/tools/{tool_name}", tags=["tools"])
|
||||
async def delete_tool(
|
||||
tool_name: str,
|
||||
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
|
||||
):
|
||||
"""
|
||||
Delete a tool by name
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
# tool = server.ms.delete_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
|
||||
server.ms.delete_tool(name=tool_name)
|
||||
|
||||
@router.get("/tools/{tool_name}", tags=["tools"], response_model=ToolModel)
|
||||
async def get_tool(tool_name: str):
|
||||
"""
|
||||
Get a tool by name
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
# tool = server.ms.get_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
|
||||
tool = server.ms.get_tool(tool_name=tool_name)
|
||||
if tool is None:
|
||||
# return 404 error
|
||||
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.")
|
||||
return tool
|
||||
|
||||
@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
|
||||
async def list_all_tools(
|
||||
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
|
||||
@@ -41,7 +68,7 @@ def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, pa
|
||||
tools = server.ms.list_tools()
|
||||
return ListToolsResponse(tools=tools)
|
||||
|
||||
@router.post("/tools", tags=["tools"], response_model=CreateToolResponse)
|
||||
@router.post("/tools", tags=["tools"], response_model=ToolModel)
|
||||
async def create_tool(
|
||||
request: CreateToolRequest = Body(...),
|
||||
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
|
||||
@@ -49,20 +76,26 @@ def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, pa
|
||||
"""
|
||||
Create a new tool (dummy route)
|
||||
"""
|
||||
from memgpt.functions.functions import write_function
|
||||
from memgpt.functions.functions import load_function_file, write_function
|
||||
|
||||
# check if function already exists
|
||||
if server.ms.get_tool(request.name):
|
||||
raise ValueError(f"Tool with name {request.name} already exists.")
|
||||
|
||||
# write function to ~/.memgt/functions directory
|
||||
write_function(request.name, request.name, request.source_code)
|
||||
file_path = write_function(request.name, request.name, request.source_code)
|
||||
|
||||
# TODO: Use load_function_file to load function schema
|
||||
schema = load_function_file(file_path)
|
||||
assert len(list(schema.keys())) == 1, "Function schema must have exactly one key"
|
||||
json_schema = list(schema.values())[0]["json_schema"]
|
||||
|
||||
print("adding tool", request.name, request.tags, request.source_code)
|
||||
tool = ToolModel(name=request.name, json_schema={}, tags=request.tags, source_code=request.source_code)
|
||||
tool = ToolModel(name=request.name, json_schema=json_schema, tags=request.tags, source_code=request.source_code)
|
||||
tool.id
|
||||
server.ms.add_tool(tool)
|
||||
|
||||
# TODO: insert tool information into DB as ToolModel
|
||||
return CreateToolResponse(tool=ToolModel(name=request.name, json_schema={}, tags=[], source_code=request.source_code))
|
||||
return server.ms.get_tool(request.name)
|
||||
|
||||
return router
|
||||
|
||||
@@ -45,7 +45,6 @@ from memgpt.models.pydantic_models import (
|
||||
SourceModel,
|
||||
ToolModel,
|
||||
)
|
||||
from memgpt.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -214,12 +213,12 @@ class SyncServer(LockingServer):
|
||||
assert self.config.human is not None, "Human must be set in the config"
|
||||
|
||||
# Update storage URI to match passed in settings
|
||||
# TODO: very hack, fix in the future
|
||||
for memory_type in ("archival", "recall", "metadata"):
|
||||
if settings.memgpt_pg_uri_no_default:
|
||||
# override with env
|
||||
setattr(self.config, f"{memory_type}_storage_uri", settings.memgpt_pg_uri)
|
||||
self.config.save()
|
||||
# (NOTE: no longer needed since envs being used, I think)
|
||||
# for memory_type in ("archival", "recall", "metadata"):
|
||||
# if settings.memgpt_pg_uri:
|
||||
# # override with env
|
||||
# setattr(self.config, f"{memory_type}_storage_uri", settings.memgpt_pg_uri)
|
||||
# self.config.save()
|
||||
|
||||
# TODO figure out how to handle credentials for the server
|
||||
self.credentials = MemGPTCredentials.load()
|
||||
@@ -1278,6 +1277,7 @@ class SyncServer(LockingServer):
|
||||
def api_key_to_user(self, api_key: str) -> uuid.UUID:
|
||||
"""Decode an API key to a user"""
|
||||
user = self.ms.get_user_from_api_key(api_key=api_key)
|
||||
print("got user", api_key, user.id)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=403, detail="Invalid credentials")
|
||||
else:
|
||||
|
||||
@@ -58,19 +58,17 @@ def test_admin_client(admin_client):
|
||||
|
||||
# list keys
|
||||
user1_keys = admin_client.get_keys(user_id)
|
||||
assert len(user1_keys.api_key_list) == 2, f"Expected 2 keys, got {user1_keys}"
|
||||
assert create_key1_response.api_key in user1_keys.api_key_list, f"Expected {create_key1_response.api_key} in {user1_keys.api_key_list}"
|
||||
assert (
|
||||
create_user1_response.api_key in user1_keys.api_key_list
|
||||
), f"Expected {create_user1_response.api_key} in {user1_keys.api_key_list}"
|
||||
assert len(user1_keys) == 2, f"Expected 2 keys, got {user1_keys}"
|
||||
assert create_key1_response.api_key in user1_keys, f"Expected {create_key1_response.api_key} in {user1_keys}"
|
||||
assert create_user1_response.api_key in user1_keys, f"Expected {create_user1_response.api_key} in {user1_keys}"
|
||||
|
||||
# delete key
|
||||
delete_key1_response = admin_client.delete_key(create_key1_response.api_key)
|
||||
assert delete_key1_response.api_key_deleted == create_key1_response.api_key
|
||||
assert len(admin_client.get_keys(user_id).api_key_list) == 1
|
||||
assert len(admin_client.get_keys(user_id)) == 1
|
||||
delete_key2_response = admin_client.delete_key(create_key2_response.api_key)
|
||||
assert delete_key2_response.api_key_deleted == create_key2_response.api_key
|
||||
assert len(admin_client.get_keys(create_user_2_response.user_id).api_key_list) == 1
|
||||
assert len(admin_client.get_keys(create_user_2_response.user_id)) == 1
|
||||
|
||||
# delete users
|
||||
delete_user1_response = admin_client.delete_user(user_id)
|
||||
|
||||
@@ -320,41 +320,41 @@ def test_sources(client, agent):
|
||||
def test_presets(client, agent):
|
||||
_reset_config()
|
||||
|
||||
new_preset = Preset(
|
||||
# user_id=client.user_id,
|
||||
name="pytest_test_preset",
|
||||
description="DUMMY_DESCRIPTION",
|
||||
system="DUMMY_SYSTEM",
|
||||
persona="DUMMY_PERSONA",
|
||||
persona_name="DUMMY_PERSONA_NAME",
|
||||
human="DUMMY_HUMAN",
|
||||
human_name="DUMMY_HUMAN_NAME",
|
||||
functions_schema=[
|
||||
{
|
||||
"name": "send_message",
|
||||
"json_schema": {
|
||||
"name": "send_message",
|
||||
"description": "Sends a message to the human user.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."}
|
||||
},
|
||||
"required": ["message"],
|
||||
},
|
||||
},
|
||||
"tags": ["memgpt-base"],
|
||||
"source_type": "python",
|
||||
"source_code": 'def send_message(self, message: str) -> Optional[str]:\n """\n Sends a message to the human user.\n\n Args:\n message (str): Message contents. All unicode (including emojis) are supported.\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.\n """\n self.interface.assistant_message(message)\n return None\n',
|
||||
}
|
||||
],
|
||||
)
|
||||
# new_preset = Preset(
|
||||
# # user_id=client.user_id,
|
||||
# name="pytest_test_preset",
|
||||
# description="DUMMY_DESCRIPTION",
|
||||
# system="DUMMY_SYSTEM",
|
||||
# persona="DUMMY_PERSONA",
|
||||
# persona_name="DUMMY_PERSONA_NAME",
|
||||
# human="DUMMY_HUMAN",
|
||||
# human_name="DUMMY_HUMAN_NAME",
|
||||
# functions_schema=[
|
||||
# {
|
||||
# "name": "send_message",
|
||||
# "json_schema": {
|
||||
# "name": "send_message",
|
||||
# "description": "Sends a message to the human user.",
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."}
|
||||
# },
|
||||
# "required": ["message"],
|
||||
# },
|
||||
# },
|
||||
# "tags": ["memgpt-base"],
|
||||
# "source_type": "python",
|
||||
# "source_code": 'def send_message(self, message: str) -> Optional[str]:\n """\n Sends a message to the human user.\n\n Args:\n message (str): Message contents. All unicode (including emojis) are supported.\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.\n """\n self.interface.assistant_message(message)\n return None\n',
|
||||
# }
|
||||
# ],
|
||||
# )
|
||||
|
||||
# List all presets and make sure the preset is NOT in the list
|
||||
all_presets = client.list_presets()
|
||||
assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)
|
||||
## List all presets and make sure the preset is NOT in the list
|
||||
# all_presets = client.list_presets()
|
||||
# assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)
|
||||
# Create a preset
|
||||
client.create_preset(preset=new_preset)
|
||||
new_preset = client.create_preset(name="pytest_test_preset")
|
||||
|
||||
# List all presets and make sure the preset is in the list
|
||||
all_presets = client.list_presets()
|
||||
|
||||
@@ -22,7 +22,7 @@ def server():
|
||||
wipe_config()
|
||||
wipe_memgpt_home()
|
||||
|
||||
db_url = settings.pg_db
|
||||
db_url = settings.memgpt_pg_uri
|
||||
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
config = TestMGPTConfig(
|
||||
|
||||
Reference in New Issue
Block a user