fix: misc bugs (#1276)

Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
Charles Packer
2024-04-20 10:24:51 -07:00
committed by GitHub
parent 922fd5b513
commit a3c19a70f5
3 changed files with 22 additions and 13 deletions

View File

@@ -1,6 +1,7 @@
import importlib
import inspect
import os
import warnings
import sys
from types import ModuleType
@@ -78,7 +79,9 @@ def write_function(module_name: str, function_name: str, function_code: str):
raise ValueError(error)
def load_all_function_sets(merge: bool = True) -> dict:
def load_all_function_sets(merge: bool = True, ignore_duplicates: bool = True) -> dict:
from memgpt.utils import printd
# functions/examples/*.py
scripts_dir = os.path.dirname(os.path.abspath(__file__)) # Get the directory of the current script
function_sets_dir = os.path.join(scripts_dir, "function_sets") # Path to the function_sets directory
@@ -106,7 +109,7 @@ def load_all_function_sets(merge: bool = True) -> dict:
if dir_path == USER_FUNCTIONS_DIR:
# For user scripts, adjust the module name appropriately
module_full_path = os.path.join(dir_path, file)
print(f"Loading user function set from '{module_full_path}'")
printd(f"Loading user function set from '{module_full_path}'")
try:
spec = importlib.util.spec_from_file_location(module_name, module_full_path)
module = importlib.util.module_from_spec(spec)
@@ -114,18 +117,18 @@ def load_all_function_sets(merge: bool = True) -> dict:
except ModuleNotFoundError as e:
# Handle missing module imports
missing_package = str(e).split("'")[1] # Extract the name of the missing package
print(f"{CLI_WARNING_PREFIX}skipped loading python file '{module_full_path}'!")
print(
printd(f"{CLI_WARNING_PREFIX}skipped loading python file '{module_full_path}'!")
printd(
f"'{file}' imports '{missing_package}', but '{missing_package}' is not installed locally - install python package '{missing_package}' to link functions from '{file}' to MemGPT."
)
continue
except SyntaxError as e:
# Handle syntax errors in the module
print(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}' due to a syntax error: {e}")
printd(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}' due to a syntax error: {e}")
continue
except Exception as e:
# Handle other general exceptions
print(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}': {e}")
printd(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}': {e}")
continue
else:
# For built-in scripts, use the existing method
@@ -135,7 +138,7 @@ def load_all_function_sets(merge: bool = True) -> dict:
module = importlib.import_module(full_module_name)
except Exception as e:
# Handle other general exceptions
print(f"{CLI_WARNING_PREFIX}skipped loading python module '{full_module_name}': {e}")
printd(f"{CLI_WARNING_PREFIX}skipped loading python module '{full_module_name}': {e}")
continue
try:
@@ -147,7 +150,7 @@ def load_all_function_sets(merge: bool = True) -> dict:
v["tags"] = tags
schemas_and_functions[module_name] = function_set
except ValueError as e:
print(f"Error loading function set '{module_name}': {e}")
printd(f"Error loading function set '{module_name}': {e}")
if merge:
# Put all functions from all sets into the same level dict
@@ -155,8 +158,13 @@ def load_all_function_sets(merge: bool = True) -> dict:
for set_name, function_set in schemas_and_functions.items():
for function_name, function_info in function_set.items():
if function_name in merged_functions:
raise ValueError(f"Duplicate function name '{function_name}' found in function set '{set_name}'")
merged_functions[function_name] = function_info
err_msg = f"Duplicate function name '{function_name}' found in function set '{set_name}'"
if ignore_duplicates:
warnings.warn(err_msg, category=UserWarning, stacklevel=2)
else:
raise ValueError(err_msg)
else:
merged_functions[function_name] = function_info
return merged_functions
else:
# Nested dict where the top level is organized by the function set name

View File

@@ -557,7 +557,6 @@ class MetadataStore:
def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]:
with self.session_maker() as session:
available_functions = load_all_function_sets()
print(available_functions)
results = [
ToolModel(
name=k,

View File

@@ -1,6 +1,8 @@
from typing import List, Optional, Dict, Literal, Type
from pydantic import BaseModel, Field, Json, ConfigDict
from enum import StrEnum
from enum import Enum
import uuid
import base64
import numpy as np
@@ -134,7 +136,7 @@ class SourceModel(SQLModel, table=True):
metadata_: Optional[dict] = Field(None, sa_column=Column(JSON), description="Metadata associated with the source.")
class JobStatus(StrEnum):
class JobStatus(str, Enum):
created = "created"
running = "running"
completed = "completed"