feat: add DEFAULT_USER_ID and DEFAULT_ORG_ID for local usage (#1768)
This commit is contained in:
@@ -5,7 +5,6 @@ from typing import Callable, Dict, Generator, List, Optional, Union
|
||||
import requests
|
||||
|
||||
import memgpt.utils
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
|
||||
from memgpt.data_sources.connectors import DataConnector
|
||||
from memgpt.functions.functions import parse_source_code
|
||||
@@ -42,7 +41,6 @@ from memgpt.schemas.openai.chat_completions import ToolCall
|
||||
from memgpt.schemas.passage import Passage
|
||||
from memgpt.schemas.source import Source, SourceCreate, SourceUpdate
|
||||
from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate
|
||||
from memgpt.schemas.user import UserCreate
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.utils import get_human_text, get_persona_text
|
||||
@@ -1353,14 +1351,6 @@ class LocalClient(AbstractClient):
|
||||
"""
|
||||
self.auto_save = auto_save
|
||||
|
||||
# determine user_id (pulled from local config)
|
||||
config = MemGPTConfig.load()
|
||||
if user_id:
|
||||
self.user_id = user_id
|
||||
else:
|
||||
# TODO: find a neater way to do this
|
||||
self.user_id = config.anon_clientid
|
||||
|
||||
# set logging levels
|
||||
memgpt.utils.DEBUG = debug
|
||||
logging.getLogger().setLevel(logging.CRITICAL)
|
||||
@@ -1368,19 +1358,14 @@ class LocalClient(AbstractClient):
|
||||
self.interface = QueuingInterface(debug=debug)
|
||||
self.server = SyncServer(default_interface_factory=lambda: self.interface)
|
||||
|
||||
# set logging levels
|
||||
memgpt.utils.DEBUG = debug
|
||||
logging.getLogger().setLevel(logging.CRITICAL)
|
||||
# save user_id that `LocalClient` is associated with
|
||||
if user_id:
|
||||
self.user_id = user_id
|
||||
else:
|
||||
# get default user
|
||||
self.user_id = self.server.get_default_user().id
|
||||
|
||||
# create user if does not exist
|
||||
existing_user = self.server.get_user(self.user_id)
|
||||
if not existing_user:
|
||||
self.user = self.server.create_user(UserCreate())
|
||||
self.user_id = self.user.id
|
||||
|
||||
# update config
|
||||
config.anon_clientid = str(self.user_id)
|
||||
config.save()
|
||||
print("USER", self.user_id)
|
||||
|
||||
# agents
|
||||
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
import os
|
||||
from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING
|
||||
|
||||
# Defaults
|
||||
DEFAULT_USER_ID = "user-00000000"
|
||||
DEFAULT_ORG_ID = "org-00000000"
|
||||
DEFAULT_USER_NAME = "default"
|
||||
DEFAULT_ORG_NAME = "default"
|
||||
|
||||
# Default directory
|
||||
MEMGPT_DIR = os.path.join(os.path.expanduser("~"), ".memgpt")
|
||||
|
||||
# String in the error message for when the context window is too large
|
||||
|
||||
@@ -1905,6 +1905,34 @@ class SyncServer(Server):
|
||||
|
||||
self._current_user = user_id
|
||||
|
||||
def get_default_user(self) -> User:
|
||||
|
||||
from memgpt.constants import (
|
||||
DEFAULT_ORG_ID,
|
||||
DEFAULT_ORG_NAME,
|
||||
DEFAULT_USER_ID,
|
||||
DEFAULT_USER_NAME,
|
||||
)
|
||||
|
||||
# check if default org exists
|
||||
default_org = self.ms.get_organization(DEFAULT_ORG_ID)
|
||||
if not default_org:
|
||||
org = Organization(name=DEFAULT_ORG_NAME, id=DEFAULT_ORG_ID)
|
||||
self.ms.create_organization(org)
|
||||
|
||||
# check if default user exists
|
||||
default_user = self.get_user(DEFAULT_USER_ID)
|
||||
if not default_user:
|
||||
user = User(name=DEFAULT_USER_NAME, org_id=DEFAULT_ORG_ID, id=DEFAULT_USER_ID)
|
||||
self.ms.create_user(user)
|
||||
|
||||
# add default data (TODO: move to org)
|
||||
self.add_default_blocks(user.id)
|
||||
self.add_default_tools(module_name="base", user_id=user.id)
|
||||
|
||||
# check if default org exists
|
||||
return self.get_user(DEFAULT_USER_ID)
|
||||
|
||||
# TODO(ethan) wire back to real method in future ORM PR
|
||||
def get_current_user(self) -> User:
|
||||
"""Returns the currently authed user.
|
||||
@@ -1918,6 +1946,7 @@ class SyncServer(Server):
|
||||
current_user = self.get_user(self._current_user)
|
||||
if not current_user:
|
||||
warnings.warn(f"Provided user '{self._current_user}' not found, using default user")
|
||||
return self.get_default_user()
|
||||
else:
|
||||
return current_user
|
||||
|
||||
|
||||
Reference in New Issue
Block a user