feat: add more tool functionality for python client (#1361)

Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
Sarah Wooders
2024-05-13 12:05:32 -07:00
committed by GitHub
parent c5ac6850c7
commit 5d7fb14530
18 changed files with 346 additions and 107 deletions

View File

@@ -1,9 +1,10 @@
import uuid
from typing import Optional
from typing import List, Optional
import requests
from requests import HTTPError
from memgpt.models.pydantic_models import ToolModel
from memgpt.server.rest_api.admin.users import (
CreateAPIKeyResponse,
CreateUserResponse,
@@ -12,6 +13,7 @@ from memgpt.server.rest_api.admin.users import (
GetAllUsersResponse,
GetAPIKeysResponse,
)
from memgpt.server.rest_api.tools.index import ListToolsResponse
class Admin:
@@ -46,7 +48,7 @@ class Admin:
if response.status_code != 200:
raise HTTPError(response.json())
print(response.text, response.status_code)
return GetAPIKeysResponse(**response.json())
return GetAPIKeysResponse(**response.json()).api_key_list
def delete_key(self, api_key: str):
params = {"api_key": api_key}
@@ -77,7 +79,35 @@ class Admin:
# TODO: clear out all agents, presets, etc.
users = self.get_users().user_list
for user in users:
keys = self.get_keys(user["user_id"]).api_key_list
keys = self.get_keys(user["user_id"])
for key in keys:
self.delete_key(key)
self.delete_user(user["user_id"])
# tools (currently only available for admin)
def create_tool(self, name: str, file_path: str, source_type: Optional[str] = "python", tags: Optional[List[str]] = None) -> ToolModel:
"""Add a tool implemented in a file path"""
source_code = open(file_path, "r").read()
data = {"name": name, "source_code": source_code, "source_type": source_type, "tags": tags}
response = requests.post(f"{self.base_url}/api/tools", json=data, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to create tool: {response.text}")
return ToolModel(**response.json())
def list_tools(self) -> ListToolsResponse:
response = requests.get(f"{self.base_url}/api/tools", headers=self.headers)
return ListToolsResponse(**response.json())
def delete_tool(self, name: str):
response = requests.delete(f"{self.base_url}/api/tools/{name}", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to delete tool: {response.text}")
return response.json()
def get_tool(self, name: str):
response = requests.get(f"{self.base_url}/api/tools/{name}", headers=self.headers)
if response.status_code == 404:
return None
elif response.status_code != 200:
raise ValueError(f"Failed to get tool: {response.text}")
return ToolModel(**response.json())

View File

@@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Tuple, Union
import requests
from memgpt.config import MemGPTConfig
from memgpt.constants import DEFAULT_PRESET
from memgpt.data_sources.connectors import DataConnector
from memgpt.data_types import (
AgentState,
@@ -23,6 +24,7 @@ from memgpt.models.pydantic_models import (
PersonaModel,
PresetModel,
SourceModel,
ToolModel,
)
# import pydantic response objects from memgpt.server.rest_api
@@ -50,7 +52,7 @@ from memgpt.server.rest_api.presets.index import (
ListPresetsResponse,
)
from memgpt.server.rest_api.sources.index import ListSourcesResponse
from memgpt.server.rest_api.tools.index import CreateToolResponse, ListToolsResponse
from memgpt.server.rest_api.tools.index import CreateToolResponse
from memgpt.server.server import SyncServer
@@ -259,7 +261,7 @@ class RESTClient(AbstractClient):
def create_agent(
self,
name: Optional[str] = None,
preset: Optional[str] = None,
preset: Optional[str] = None, # TODO: this should actually be re-named preset_name
persona: Optional[str] = None,
human: Optional[str] = None,
embedding_config: Optional[EmbeddingConfig] = None,
@@ -267,6 +269,7 @@ class RESTClient(AbstractClient):
) -> AgentState:
if embedding_config or llm_config:
raise ValueError("Cannot override embedding_config or llm_config when creating agent via REST API")
# TODO: distinguish between name and objects
payload = {
"config": {
"name": name,
@@ -329,23 +332,87 @@ class RESTClient(AbstractClient):
response_obj = GetAgentResponse(**response.json())
return self.get_agent_response_to_state(response_obj)
# presets
def create_preset(self, preset: Preset) -> CreatePresetResponse:
# TODO should the arg type here be PresetModel, not Preset?
## presets
# def create_preset(self, preset: Preset) -> CreatePresetResponse:
# # TODO should the arg type here be PresetModel, not Preset?
# payload = CreatePresetsRequest(
# id=str(preset.id),
# name=preset.name,
# description=preset.description,
# system=preset.system,
# persona=preset.persona,
# human=preset.human,
# persona_name=preset.persona_name,
# human_name=preset.human_name,
# functions_schema=preset.functions_schema,
# )
# response = requests.post(f"{self.base_url}/api/presets", json=payload.model_dump(), headers=self.headers)
# assert response.status_code == 200, f"Failed to create preset: {response.text}"
# return CreatePresetResponse(**response.json())
def get_preset(self, name: str) -> PresetModel:
response = requests.get(f"{self.base_url}/api/presets/{name}", headers=self.headers)
assert response.status_code == 200, f"Failed to get preset: {response.text}"
return PresetModel(**response.json())
def create_preset(
self,
name: str,
description: Optional[str] = None,
system_name: Optional[str] = None,
persona_name: Optional[str] = None,
human_name: Optional[str] = None,
tools: Optional[List[ToolModel]] = None,
default_tools: bool = True,
) -> PresetModel:
"""Create an agent preset
:param name: Name of the preset
:type name: str
:param system: System prompt (text)
:type system: str
:param persona: Persona prompt (text)
:type persona: Optional[str]
:param human: Human prompt (text)
:type human: Optional[str]
:param tools: List of tools to connect, defaults to None
:type tools: Optional[List[Tool]], optional
:param default_tools: Whether to automatically include default tools, defaults to True
:type default_tools: bool, optional
:return: Preset object
:rtype: PresetModel
"""
# provided tools
schema = []
if tools:
for tool in tools:
print("CUSOTM TOOL", tool.json_schema)
schema.append(tool.json_schema)
# include default tools
default_preset = self.get_preset(name=DEFAULT_PRESET)
if default_tools:
# TODO
# from memgpt.functions.functions import load_function_set
# load_function_set()
# return
for function in default_preset.functions_schema:
schema.append(function)
payload = CreatePresetsRequest(
id=str(preset.id),
name=preset.name,
description=preset.description,
system=preset.system,
persona=preset.persona,
human=preset.human,
persona_name=preset.persona_name,
human_name=preset.human_name,
functions_schema=preset.functions_schema,
name=name,
description=description,
system_name=system_name,
persona_name=persona_name,
human_name=human_name,
functions_schema=schema,
)
print(schema)
print(human_name, persona_name, system_name, name)
print(payload.model_dump())
response = requests.post(f"{self.base_url}/api/presets", json=payload.model_dump(), headers=self.headers)
assert response.status_code == 200, f"Failed to create preset: {response.text}"
return CreatePresetResponse(**response.json())
return CreatePresetResponse(**response.json()).preset
def delete_preset(self, preset_id: uuid.UUID):
response = requests.delete(f"{self.base_url}/api/presets/{str(preset_id)}", headers=self.headers)
@@ -518,23 +585,6 @@ class RESTClient(AbstractClient):
response = requests.get(f"{self.base_url}/api/config", headers=self.headers)
return ConfigResponse(**response.json())
# tools
def create_tool(
self, name: str, file_path: str, source_type: Optional[str] = "python", tags: Optional[List[str]] = None
) -> CreateToolResponse:
"""Add a tool implemented in a file path"""
source_code = open(file_path, "r").read()
data = {"name": name, "source_code": source_code, "source_type": source_type, "tags": tags}
response = requests.post(f"{self.base_url}/api/tools", json=data, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to create tool: {response.text}")
return CreateToolResponse(**response.json())
def list_tools(self) -> ListToolsResponse:
response = requests.get(f"{self.base_url}/api/tools", headers=self.headers)
return ListToolsResponse(**response.json())
class LocalClient(AbstractClient):
def __init__(

View File

@@ -271,6 +271,7 @@ class MemGPTConfig:
with open(self.config_path, "w", encoding="utf-8") as f:
config.write(f)
logger.debug(f"Saved Config: {self.config_path}")
print(f"Saved Config: {self.config_path}")
@staticmethod
def exists():

View File

@@ -11,11 +11,13 @@ from pydantic import BaseModel, Field
from memgpt.constants import (
DEFAULT_HUMAN,
DEFAULT_PERSONA,
DEFAULT_PRESET,
LLM_MAX_TOKENS,
MAX_EMBEDDING_DIM,
TOOL_CALL_ID_MAX_LEN,
)
from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG
from memgpt.prompts import gpt_system
from memgpt.utils import (
create_uuid_from_string,
get_human_text,
@@ -848,7 +850,10 @@ class Preset(BaseModel):
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.")
description: Optional[str] = Field(None, description="The description of the preset.")
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.")
system: str = Field(..., description="The system prompt of the preset.")
system: str = Field(
gpt_system.get_system_text(DEFAULT_PRESET), description="The system prompt of the preset."
) # default system prompt is same as default preset name
# system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.")
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.")
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")

View File

@@ -76,6 +76,26 @@ def write_function(module_name: str, function_name: str, function_code: str):
# raise error if function cannot be loaded
if not succ:
raise ValueError(error)
return file_path
def load_function_file(filepath: str) -> dict:
file = os.path.basename(filepath)
module_name = file[:-3] # Remove '.py' from filename
try:
spec = importlib.util.spec_from_file_location(module_name, filepath)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
except ModuleNotFoundError as e:
# Handle missing module imports
missing_package = str(e).split("'")[1] # Extract the name of the missing package
print(f"{CLI_WARNING_PREFIX}skipped loading python file '{filepath}'!")
print(
f"'{file}' imports '{missing_package}', but '{missing_package}' is not installed locally - install python package '{missing_package}' to link functions from '{file}' to MemGPT."
)
# load all functions in the module
function_dict = load_function_set(module)
return function_dict
def load_all_function_sets(merge: bool = True, ignore_duplicates: bool = True) -> dict:

View File

@@ -400,7 +400,10 @@ class MetadataStore:
@enforce_types
def get_api_key(self, api_key: str) -> Optional[Token]:
with self.session_maker() as session:
print("getting api key", api_key)
print([r.token for r in self.get_all_api_keys_for_user(user_id=uuid.UUID(int=0))])
results = session.query(TokenModel).filter(TokenModel.token == api_key).all()
print("results", [r.token for r in results])
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result
@@ -410,15 +413,19 @@ class MetadataStore:
def get_all_api_keys_for_user(self, user_id: uuid.UUID) -> List[Token]:
with self.session_maker() as session:
results = session.query(TokenModel).filter(TokenModel.user_id == user_id).all()
return [r.to_record() for r in results]
tokens = [r.to_record() for r in results]
print([t.token for t in tokens])
return tokens
@enforce_types
def get_user_from_api_key(self, api_key: str) -> Optional[User]:
"""Get the user associated with a given API key"""
token = self.get_api_key(api_key=api_key)
print("got api key", token.token, token is None)
if token is None:
raise ValueError(f"Token {api_key} does not exist")
else:
print(isinstance(token.user_id, uuid.UUID), self.get_user(user_id=token.user_id))
return self.get_user(user_id=token.user_id)
@enforce_types
@@ -803,6 +810,12 @@ class MetadataStore:
session.query(PresetModel).filter(PresetModel.name == name).filter(PresetModel.user_id == user_id).delete()
session.commit()
@enforce_types
def delete_tool(self, name: str):
with self.session_maker() as session:
session.query(ToolModel).filter(ToolModel.name == name).delete()
session.commit()
# job related functions
def create_job(self, job: JobModel):
with self.session_maker() as session:

View File

@@ -38,6 +38,7 @@ class PresetModel(BaseModel):
description: Optional[str] = Field(None, description="The description of the preset.")
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.")
system: str = Field(..., description="The system prompt of the preset.")
system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.")
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.")
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")

View File

@@ -1,13 +1,14 @@
import importlib
import os
import uuid
from typing import List
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
from memgpt.data_types import AgentState, Preset
from memgpt.functions.functions import load_all_function_sets
from memgpt.functions.functions import load_all_function_sets, load_function_set
from memgpt.interface import AgentInterface
from memgpt.metadata import MetadataStore
from memgpt.models.pydantic_models import HumanModel, PersonaModel
from memgpt.models.pydantic_models import HumanModel, PersonaModel, ToolModel
from memgpt.presets.utils import load_all_presets, load_yaml_file
from memgpt.prompts import gpt_system
from memgpt.utils import (
@@ -22,6 +23,34 @@ available_presets = load_all_presets()
preset_options = list(available_presets.keys())
def add_default_tools(user_id: uuid.UUID, ms: MetadataStore):
module_name = "base"
full_module_name = f"memgpt.functions.function_sets.{module_name}"
try:
module = importlib.import_module(full_module_name)
except Exception as e:
# Handle other general exceptions
raise e
# function tags
try:
# Load the function set
functions_to_schema = load_function_set(module)
except ValueError as e:
err = f"Error loading function set '{module_name}': {e}"
printd(err)
from pprint import pprint
print("BASE FUNCTIONS", functions_to_schema.keys())
pprint(functions_to_schema)
# create tool in db
for name, schema in functions_to_schema.items():
ms.add_tool(ToolModel(name=name, tags=["base"], source_type="python", json_schema=schema["json_schema"]))
def add_default_humans_and_personas(user_id: uuid.UUID, ms: MetadataStore):
for persona_file in list_persona_files():
text = open(persona_file, "r").read()
@@ -89,6 +118,10 @@ def add_default_presets(user_id: uuid.UUID, ms: MetadataStore):
# make sure humans/personas added
add_default_humans_and_personas(user_id=user_id, ms=ms)
# make sure base functions added
# TODO: pull from functions instead
add_default_tools(user_id=user_id, ms=ms)
# add default presets
for preset_name in preset_options:
if ms.get_preset(user_id=user_id, name=preset_name) is not None:

View File

@@ -137,6 +137,8 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
Get a list of all API keys for a user
"""
try:
if server.ms.get_user(user_id=user_id) is None:
raise HTTPException(status_code=404, detail=f"User does not exist")
tokens = server.ms.get_all_api_keys_for_user(user_id=user_id)
processed_tokens = [t.token for t in tokens]
except HTTPException:

View File

@@ -26,6 +26,7 @@ class MessageRoleType(str, Enum):
class UserMessageRequest(BaseModel):
message: str = Field(..., description="The message content to be processed by the agent.")
name: str = Field(default="user", description="Name of the message request sender")
stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.")
role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')")
timestamp: Optional[datetime] = Field(

View File

@@ -39,6 +39,7 @@ def setup_humans_index_router(server: SyncServer, interface: QueuingInterface, p
request: CreateHumanRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
# TODO: disallow duplicate names for humans
interface.clear()
new_human = HumanModel(text=request.text, name=request.name, user_id=user_id)
human_id = new_human.id

View File

@@ -40,6 +40,7 @@ def setup_personas_index_router(server: SyncServer, interface: QueuingInterface,
request: CreatePersonaRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
# TODO: disallow duplicate names for personas
interface.clear()
new_persona = PersonaModel(text=request.text, name=request.name, user_id=user_id)
persona_id = new_persona.id

View File

@@ -1,14 +1,15 @@
import uuid
from functools import partial
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
from memgpt.data_types import Preset # TODO remove
from memgpt.models.pydantic_models import PresetModel
from memgpt.models.pydantic_models import HumanModel, PersonaModel, PresetModel
from memgpt.prompts import gpt_system
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
@@ -32,17 +33,18 @@ class ListPresetsResponse(BaseModel):
class CreatePresetsRequest(BaseModel):
# TODO is there a cleaner way to create the request from the PresetModel (need to drop fields though)?
name: str = Field(..., description="The name of the preset.")
id: Optional[Union[uuid.UUID, str]] = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
id: Optional[str] = Field(None, description="The unique identifier of the preset.")
# user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.")
description: Optional[str] = Field(None, description="The description of the preset.")
# created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.")
system: str = Field(..., description="The system prompt of the preset.")
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
system: Optional[str] = Field(None, description="The system prompt of the preset.") # TODO: make optional and allow defaults
persona: Optional[str] = Field(default=None, description="The persona of the preset.")
human: Optional[str] = Field(default=None, description="The human of the preset.")
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
# TODO
persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.")
human_name: Optional[str] = Field(None, description="The name of the human of the preset.")
system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.")
class CreatePresetResponse(BaseModel):
@@ -52,6 +54,20 @@ class CreatePresetResponse(BaseModel):
def setup_presets_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/presets/{preset_name}", tags=["presets"], response_model=PresetModel)
async def get_preset(
preset_name: str,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Get a preset."""
try:
preset = server.get_preset(user_id=user_id, preset_name=preset_name)
return preset
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.get("/presets", tags=["presets"], response_model=ListPresetsResponse)
async def list_presets(
user_id: uuid.UUID = Depends(get_current_user_with_server),
@@ -77,18 +93,52 @@ def setup_presets_index_router(server: SyncServer, interface: QueuingInterface,
try:
if isinstance(request.id, str):
request.id = uuid.UUID(request.id)
# new_preset = PresetModel(
# check if preset already exists
# TODO: move this into a server function to create a preset
if server.ms.get_preset(name=request.name, user_id=user_id):
raise HTTPException(status_code=400, detail=f"Preset with name {request.name} already exists.")
# For system/human/persona - if {system/human-personal}_name is None but the text is provied, then create a new data entry
if not request.system_name and request.system:
# new system provided without name identity
system_name = f"system_{request.name}_{str(uuid.uuid4())}"
system = request.system
# TODO: insert into system table
else:
system_name = request.system_name if request.system_name else DEFAULT_PRESET
system = request.system if request.system else gpt_system.get_system_text(system_name)
if not request.human_name and request.human:
# new human provided without name identity
human_name = f"human_{request.name}_{str(uuid.uuid4())}"
human = request.human
server.ms.add_human(HumanModel(text=human, name=human_name, user_id=user_id))
else:
human_name = request.human_name if request.human_name else DEFAULT_HUMAN
human = request.human if request.human else get_human_text(human_name)
if not request.persona_name and request.persona:
# new persona provided without name identity
persona_name = f"persona_{request.name}_{str(uuid.uuid4())}"
persona = request.persona
server.ms.add_persona(PersonaModel(text=persona, name=persona_name, user_id=user_id))
else:
persona_name = request.persona_name if request.persona_name else DEFAULT_PERSONA
persona = request.persona if request.persona else get_persona_text(persona_name)
# create preset
new_preset = Preset(
user_id=user_id,
id=request.id,
id=request.id if request.id else uuid.uuid4(),
name=request.name,
description=request.description,
system=request.system,
persona=request.persona,
human=request.human,
system=system,
persona=persona,
persona_name=persona_name,
human=human,
human_name=human_name,
functions_schema=request.functions_schema,
persona_name=request.persona_name,
human_name=request.human_name,
)
preset = server.create_preset(preset=new_preset)

View File

@@ -1,6 +1,6 @@
from typing import List, Literal, Optional
from fastapi import APIRouter, Body
from fastapi import APIRouter, Body, HTTPException
from pydantic import BaseModel, Field
from memgpt.models.pydantic_models import ToolModel
@@ -28,6 +28,33 @@ class CreateToolResponse(BaseModel):
def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, password: str):
# get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.delete("/tools/{tool_name}", tags=["tools"])
async def delete_tool(
tool_name: str,
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
):
"""
Delete a tool by name
"""
# Clear the interface
interface.clear()
# tool = server.ms.delete_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
server.ms.delete_tool(name=tool_name)
@router.get("/tools/{tool_name}", tags=["tools"], response_model=ToolModel)
async def get_tool(tool_name: str):
"""
Get a tool by name
"""
# Clear the interface
interface.clear()
# tool = server.ms.get_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
tool = server.ms.get_tool(tool_name=tool_name)
if tool is None:
# return 404 error
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.")
return tool
@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
async def list_all_tools(
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
@@ -41,7 +68,7 @@ def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, pa
tools = server.ms.list_tools()
return ListToolsResponse(tools=tools)
@router.post("/tools", tags=["tools"], response_model=CreateToolResponse)
@router.post("/tools", tags=["tools"], response_model=ToolModel)
async def create_tool(
request: CreateToolRequest = Body(...),
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
@@ -49,20 +76,26 @@ def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, pa
"""
Create a new tool (dummy route)
"""
from memgpt.functions.functions import write_function
from memgpt.functions.functions import load_function_file, write_function
# check if function already exists
if server.ms.get_tool(request.name):
raise ValueError(f"Tool with name {request.name} already exists.")
# write function to ~/.memgt/functions directory
write_function(request.name, request.name, request.source_code)
file_path = write_function(request.name, request.name, request.source_code)
# TODO: Use load_function_file to load function schema
schema = load_function_file(file_path)
assert len(list(schema.keys())) == 1, "Function schema must have exactly one key"
json_schema = list(schema.values())[0]["json_schema"]
print("adding tool", request.name, request.tags, request.source_code)
tool = ToolModel(name=request.name, json_schema={}, tags=request.tags, source_code=request.source_code)
tool = ToolModel(name=request.name, json_schema=json_schema, tags=request.tags, source_code=request.source_code)
tool.id
server.ms.add_tool(tool)
# TODO: insert tool information into DB as ToolModel
return CreateToolResponse(tool=ToolModel(name=request.name, json_schema={}, tags=[], source_code=request.source_code))
return server.ms.get_tool(request.name)
return router

View File

@@ -45,7 +45,6 @@ from memgpt.models.pydantic_models import (
SourceModel,
ToolModel,
)
from memgpt.settings import settings
logger = logging.getLogger(__name__)
@@ -214,12 +213,12 @@ class SyncServer(LockingServer):
assert self.config.human is not None, "Human must be set in the config"
# Update storage URI to match passed in settings
# TODO: very hack, fix in the future
for memory_type in ("archival", "recall", "metadata"):
if settings.memgpt_pg_uri_no_default:
# override with env
setattr(self.config, f"{memory_type}_storage_uri", settings.memgpt_pg_uri)
self.config.save()
# (NOTE: no longer needed since envs being used, I think)
# for memory_type in ("archival", "recall", "metadata"):
# if settings.memgpt_pg_uri:
# # override with env
# setattr(self.config, f"{memory_type}_storage_uri", settings.memgpt_pg_uri)
# self.config.save()
# TODO figure out how to handle credentials for the server
self.credentials = MemGPTCredentials.load()
@@ -1278,6 +1277,7 @@ class SyncServer(LockingServer):
def api_key_to_user(self, api_key: str) -> uuid.UUID:
"""Decode an API key to a user"""
user = self.ms.get_user_from_api_key(api_key=api_key)
print("got user", api_key, user.id)
if user is None:
raise HTTPException(status_code=403, detail="Invalid credentials")
else:

View File

@@ -58,19 +58,17 @@ def test_admin_client(admin_client):
# list keys
user1_keys = admin_client.get_keys(user_id)
assert len(user1_keys.api_key_list) == 2, f"Expected 2 keys, got {user1_keys}"
assert create_key1_response.api_key in user1_keys.api_key_list, f"Expected {create_key1_response.api_key} in {user1_keys.api_key_list}"
assert (
create_user1_response.api_key in user1_keys.api_key_list
), f"Expected {create_user1_response.api_key} in {user1_keys.api_key_list}"
assert len(user1_keys) == 2, f"Expected 2 keys, got {user1_keys}"
assert create_key1_response.api_key in user1_keys, f"Expected {create_key1_response.api_key} in {user1_keys}"
assert create_user1_response.api_key in user1_keys, f"Expected {create_user1_response.api_key} in {user1_keys}"
# delete key
delete_key1_response = admin_client.delete_key(create_key1_response.api_key)
assert delete_key1_response.api_key_deleted == create_key1_response.api_key
assert len(admin_client.get_keys(user_id).api_key_list) == 1
assert len(admin_client.get_keys(user_id)) == 1
delete_key2_response = admin_client.delete_key(create_key2_response.api_key)
assert delete_key2_response.api_key_deleted == create_key2_response.api_key
assert len(admin_client.get_keys(create_user_2_response.user_id).api_key_list) == 1
assert len(admin_client.get_keys(create_user_2_response.user_id)) == 1
# delete users
delete_user1_response = admin_client.delete_user(user_id)

View File

@@ -320,41 +320,41 @@ def test_sources(client, agent):
def test_presets(client, agent):
_reset_config()
new_preset = Preset(
# user_id=client.user_id,
name="pytest_test_preset",
description="DUMMY_DESCRIPTION",
system="DUMMY_SYSTEM",
persona="DUMMY_PERSONA",
persona_name="DUMMY_PERSONA_NAME",
human="DUMMY_HUMAN",
human_name="DUMMY_HUMAN_NAME",
functions_schema=[
{
"name": "send_message",
"json_schema": {
"name": "send_message",
"description": "Sends a message to the human user.",
"parameters": {
"type": "object",
"properties": {
"message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."}
},
"required": ["message"],
},
},
"tags": ["memgpt-base"],
"source_type": "python",
"source_code": 'def send_message(self, message: str) -> Optional[str]:\n """\n Sends a message to the human user.\n\n Args:\n message (str): Message contents. All unicode (including emojis) are supported.\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.\n """\n self.interface.assistant_message(message)\n return None\n',
}
],
)
# new_preset = Preset(
# # user_id=client.user_id,
# name="pytest_test_preset",
# description="DUMMY_DESCRIPTION",
# system="DUMMY_SYSTEM",
# persona="DUMMY_PERSONA",
# persona_name="DUMMY_PERSONA_NAME",
# human="DUMMY_HUMAN",
# human_name="DUMMY_HUMAN_NAME",
# functions_schema=[
# {
# "name": "send_message",
# "json_schema": {
# "name": "send_message",
# "description": "Sends a message to the human user.",
# "parameters": {
# "type": "object",
# "properties": {
# "message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."}
# },
# "required": ["message"],
# },
# },
# "tags": ["memgpt-base"],
# "source_type": "python",
# "source_code": 'def send_message(self, message: str) -> Optional[str]:\n """\n Sends a message to the human user.\n\n Args:\n message (str): Message contents. All unicode (including emojis) are supported.\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.\n """\n self.interface.assistant_message(message)\n return None\n',
# }
# ],
# )
# List all presets and make sure the preset is NOT in the list
all_presets = client.list_presets()
assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)
## List all presets and make sure the preset is NOT in the list
# all_presets = client.list_presets()
# assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)
# Create a preset
client.create_preset(preset=new_preset)
new_preset = client.create_preset(name="pytest_test_preset")
# List all presets and make sure the preset is in the list
all_presets = client.list_presets()

View File

@@ -22,7 +22,7 @@ def server():
wipe_config()
wipe_memgpt_home()
db_url = settings.pg_db
db_url = settings.memgpt_pg_uri
if os.getenv("OPENAI_API_KEY"):
config = TestMGPTConfig(