fix: misc bugs (#1276)
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
import sys
|
||||
from types import ModuleType
|
||||
|
||||
@@ -78,7 +79,9 @@ def write_function(module_name: str, function_name: str, function_code: str):
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def load_all_function_sets(merge: bool = True) -> dict:
|
||||
def load_all_function_sets(merge: bool = True, ignore_duplicates: bool = True) -> dict:
|
||||
from memgpt.utils import printd
|
||||
|
||||
# functions/examples/*.py
|
||||
scripts_dir = os.path.dirname(os.path.abspath(__file__)) # Get the directory of the current script
|
||||
function_sets_dir = os.path.join(scripts_dir, "function_sets") # Path to the function_sets directory
|
||||
@@ -106,7 +109,7 @@ def load_all_function_sets(merge: bool = True) -> dict:
|
||||
if dir_path == USER_FUNCTIONS_DIR:
|
||||
# For user scripts, adjust the module name appropriately
|
||||
module_full_path = os.path.join(dir_path, file)
|
||||
print(f"Loading user function set from '{module_full_path}'")
|
||||
printd(f"Loading user function set from '{module_full_path}'")
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_full_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
@@ -114,18 +117,18 @@ def load_all_function_sets(merge: bool = True) -> dict:
|
||||
except ModuleNotFoundError as e:
|
||||
# Handle missing module imports
|
||||
missing_package = str(e).split("'")[1] # Extract the name of the missing package
|
||||
print(f"{CLI_WARNING_PREFIX}skipped loading python file '{module_full_path}'!")
|
||||
print(
|
||||
printd(f"{CLI_WARNING_PREFIX}skipped loading python file '{module_full_path}'!")
|
||||
printd(
|
||||
f"'{file}' imports '{missing_package}', but '{missing_package}' is not installed locally - install python package '{missing_package}' to link functions from '{file}' to MemGPT."
|
||||
)
|
||||
continue
|
||||
except SyntaxError as e:
|
||||
# Handle syntax errors in the module
|
||||
print(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}' due to a syntax error: {e}")
|
||||
printd(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}' due to a syntax error: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
# Handle other general exceptions
|
||||
print(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}': {e}")
|
||||
printd(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}': {e}")
|
||||
continue
|
||||
else:
|
||||
# For built-in scripts, use the existing method
|
||||
@@ -135,7 +138,7 @@ def load_all_function_sets(merge: bool = True) -> dict:
|
||||
module = importlib.import_module(full_module_name)
|
||||
except Exception as e:
|
||||
# Handle other general exceptions
|
||||
print(f"{CLI_WARNING_PREFIX}skipped loading python module '{full_module_name}': {e}")
|
||||
printd(f"{CLI_WARNING_PREFIX}skipped loading python module '{full_module_name}': {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
@@ -147,7 +150,7 @@ def load_all_function_sets(merge: bool = True) -> dict:
|
||||
v["tags"] = tags
|
||||
schemas_and_functions[module_name] = function_set
|
||||
except ValueError as e:
|
||||
print(f"Error loading function set '{module_name}': {e}")
|
||||
printd(f"Error loading function set '{module_name}': {e}")
|
||||
|
||||
if merge:
|
||||
# Put all functions from all sets into the same level dict
|
||||
@@ -155,8 +158,13 @@ def load_all_function_sets(merge: bool = True) -> dict:
|
||||
for set_name, function_set in schemas_and_functions.items():
|
||||
for function_name, function_info in function_set.items():
|
||||
if function_name in merged_functions:
|
||||
raise ValueError(f"Duplicate function name '{function_name}' found in function set '{set_name}'")
|
||||
merged_functions[function_name] = function_info
|
||||
err_msg = f"Duplicate function name '{function_name}' found in function set '{set_name}'"
|
||||
if ignore_duplicates:
|
||||
warnings.warn(err_msg, category=UserWarning, stacklevel=2)
|
||||
else:
|
||||
raise ValueError(err_msg)
|
||||
else:
|
||||
merged_functions[function_name] = function_info
|
||||
return merged_functions
|
||||
else:
|
||||
# Nested dict where the top level is organized by the function set name
|
||||
|
||||
@@ -557,7 +557,6 @@ class MetadataStore:
|
||||
def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]:
|
||||
with self.session_maker() as session:
|
||||
available_functions = load_all_function_sets()
|
||||
print(available_functions)
|
||||
results = [
|
||||
ToolModel(
|
||||
name=k,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from typing import List, Optional, Dict, Literal, Type
|
||||
from pydantic import BaseModel, Field, Json, ConfigDict
|
||||
from enum import StrEnum
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import uuid
|
||||
import base64
|
||||
import numpy as np
|
||||
@@ -134,7 +136,7 @@ class SourceModel(SQLModel, table=True):
|
||||
metadata_: Optional[dict] = Field(None, sa_column=Column(JSON), description="Metadata associated with the source.")
|
||||
|
||||
|
||||
class JobStatus(StrEnum):
|
||||
class JobStatus(str, Enum):
|
||||
created = "created"
|
||||
running = "running"
|
||||
completed = "completed"
|
||||
|
||||
Reference in New Issue
Block a user