From a3c19a70f595c2d68c574053eabfed189108c9f6 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Sat, 20 Apr 2024 10:24:51 -0700 Subject: [PATCH] fix: misc bugs (#1276) Co-authored-by: Sarah Wooders --- memgpt/functions/functions.py | 28 ++++++++++++++++++---------- memgpt/metadata.py | 1 - memgpt/models/pydantic_models.py | 6 ++++-- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/memgpt/functions/functions.py b/memgpt/functions/functions.py index 5be3469d..5df29f64 100644 --- a/memgpt/functions/functions.py +++ b/memgpt/functions/functions.py @@ -1,6 +1,7 @@ import importlib import inspect import os +import warnings import sys from types import ModuleType @@ -78,7 +79,9 @@ def write_function(module_name: str, function_name: str, function_code: str): raise ValueError(error) -def load_all_function_sets(merge: bool = True) -> dict: +def load_all_function_sets(merge: bool = True, ignore_duplicates: bool = True) -> dict: + from memgpt.utils import printd + # functions/examples/*.py scripts_dir = os.path.dirname(os.path.abspath(__file__)) # Get the directory of the current script function_sets_dir = os.path.join(scripts_dir, "function_sets") # Path to the function_sets directory @@ -106,7 +109,7 @@ def load_all_function_sets(merge: bool = True) -> dict: if dir_path == USER_FUNCTIONS_DIR: # For user scripts, adjust the module name appropriately module_full_path = os.path.join(dir_path, file) - print(f"Loading user function set from '{module_full_path}'") + printd(f"Loading user function set from '{module_full_path}'") try: spec = importlib.util.spec_from_file_location(module_name, module_full_path) module = importlib.util.module_from_spec(spec) @@ -114,18 +117,18 @@ def load_all_function_sets(merge: bool = True) -> dict: except ModuleNotFoundError as e: # Handle missing module imports missing_package = str(e).split("'")[1] # Extract the name of the missing package - print(f"{CLI_WARNING_PREFIX}skipped loading python file '{module_full_path}'!") - print( + printd(f"{CLI_WARNING_PREFIX}skipped loading python file '{module_full_path}'!") + printd( f"'{file}' imports '{missing_package}', but '{missing_package}' is not installed locally - install python package '{missing_package}' to link functions from '{file}' to MemGPT." ) continue except SyntaxError as e: # Handle syntax errors in the module - print(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}' due to a syntax error: {e}") + printd(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}' due to a syntax error: {e}") continue except Exception as e: # Handle other general exceptions - print(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}': {e}") + printd(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}': {e}") continue else: # For built-in scripts, use the existing method @@ -135,7 +138,7 @@ def load_all_function_sets(merge: bool = True) -> dict: module = importlib.import_module(full_module_name) except Exception as e: # Handle other general exceptions - print(f"{CLI_WARNING_PREFIX}skipped loading python module '{full_module_name}': {e}") + printd(f"{CLI_WARNING_PREFIX}skipped loading python module '{full_module_name}': {e}") continue try: @@ -147,7 +150,7 @@ def load_all_function_sets(merge: bool = True) -> dict: v["tags"] = tags schemas_and_functions[module_name] = function_set except ValueError as e: - print(f"Error loading function set '{module_name}': {e}") + printd(f"Error loading function set '{module_name}': {e}") if merge: # Put all functions from all sets into the same level dict @@ -155,8 +158,13 @@ def load_all_function_sets(merge: bool = True) -> dict: for set_name, function_set in schemas_and_functions.items(): for function_name, function_info in function_set.items(): if function_name in merged_functions: - raise ValueError(f"Duplicate function name '{function_name}' found in function set '{set_name}'") - merged_functions[function_name] = function_info + err_msg = f"Duplicate function name '{function_name}' found in function set '{set_name}'" + if ignore_duplicates: + warnings.warn(err_msg, category=UserWarning, stacklevel=2) + else: + raise ValueError(err_msg) + else: + merged_functions[function_name] = function_info return merged_functions else: # Nested dict where the top level is organized by the function set name diff --git a/memgpt/metadata.py b/memgpt/metadata.py index 34f2c663..2450884d 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -557,7 +557,6 @@ class MetadataStore: def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]: with self.session_maker() as session: available_functions = load_all_function_sets() - print(available_functions) results = [ ToolModel( name=k, diff --git a/memgpt/models/pydantic_models.py b/memgpt/models/pydantic_models.py index 4d0ac71a..35635a30 100644 --- a/memgpt/models/pydantic_models.py +++ b/memgpt/models/pydantic_models.py @@ -1,6 +1,8 @@ from typing import List, Optional, Dict, Literal, Type from pydantic import BaseModel, Field, Json, ConfigDict -from enum import StrEnum + +from enum import Enum + import uuid import base64 import numpy as np @@ -134,7 +136,7 @@ class SourceModel(SQLModel, table=True): metadata_: Optional[dict] = Field(None, sa_column=Column(JSON), description="Metadata associated with the source.") -class JobStatus(StrEnum): +class JobStatus(str, Enum): created = "created" running = "running" completed = "completed"