feat: add DEFAULT_USER_ID and DEFAULT_ORG_ID for local usage (#1768)

This commit is contained in:
Sarah Wooders
2024-09-22 15:10:54 -07:00
committed by GitHub
parent 0b348f8bd9
commit 9ebbaacc1f
3 changed files with 43 additions and 22 deletions

View File

@@ -5,7 +5,6 @@ from typing import Callable, Dict, Generator, List, Optional, Union
import requests
import memgpt.utils
from memgpt.config import MemGPTConfig
from memgpt.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
from memgpt.data_sources.connectors import DataConnector
from memgpt.functions.functions import parse_source_code
@@ -42,7 +41,6 @@ from memgpt.schemas.openai.chat_completions import ToolCall
from memgpt.schemas.passage import Passage
from memgpt.schemas.source import Source, SourceCreate, SourceUpdate
from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate
from memgpt.schemas.user import UserCreate
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
from memgpt.utils import get_human_text, get_persona_text
@@ -1353,14 +1351,6 @@ class LocalClient(AbstractClient):
"""
self.auto_save = auto_save
# determine user_id (pulled from local config)
config = MemGPTConfig.load()
if user_id:
self.user_id = user_id
else:
# TODO: find a neater way to do this
self.user_id = config.anon_clientid
# set logging levels
memgpt.utils.DEBUG = debug
logging.getLogger().setLevel(logging.CRITICAL)
@@ -1368,19 +1358,14 @@ class LocalClient(AbstractClient):
self.interface = QueuingInterface(debug=debug)
self.server = SyncServer(default_interface_factory=lambda: self.interface)
# set logging levels
memgpt.utils.DEBUG = debug
logging.getLogger().setLevel(logging.CRITICAL)
# save user_id that `LocalClient` is associated with
if user_id:
self.user_id = user_id
else:
# get default user
self.user_id = self.server.get_default_user().id
# create user if does not exist
existing_user = self.server.get_user(self.user_id)
if not existing_user:
self.user = self.server.create_user(UserCreate())
self.user_id = self.user.id
# update config
config.anon_clientid = str(self.user_id)
config.save()
print("USER", self.user_id)
# agents

View File

@@ -1,6 +1,13 @@
import os
from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING
# Defaults
DEFAULT_USER_ID = "user-00000000"
DEFAULT_ORG_ID = "org-00000000"
DEFAULT_USER_NAME = "default"
DEFAULT_ORG_NAME = "default"
# Default directory
MEMGPT_DIR = os.path.join(os.path.expanduser("~"), ".memgpt")
# String in the error message for when the context window is too large

View File

@@ -1905,6 +1905,34 @@ class SyncServer(Server):
self._current_user = user_id
def get_default_user(self) -> User:
from memgpt.constants import (
DEFAULT_ORG_ID,
DEFAULT_ORG_NAME,
DEFAULT_USER_ID,
DEFAULT_USER_NAME,
)
# check if default org exists
default_org = self.ms.get_organization(DEFAULT_ORG_ID)
if not default_org:
org = Organization(name=DEFAULT_ORG_NAME, id=DEFAULT_ORG_ID)
self.ms.create_organization(org)
# check if default user exists
default_user = self.get_user(DEFAULT_USER_ID)
if not default_user:
user = User(name=DEFAULT_USER_NAME, org_id=DEFAULT_ORG_ID, id=DEFAULT_USER_ID)
self.ms.create_user(user)
# add default data (TODO: move to org)
self.add_default_blocks(user.id)
self.add_default_tools(module_name="base", user_id=user.id)
# check if default org exists
return self.get_user(DEFAULT_USER_ID)
# TODO(ethan) wire back to real method in future ORM PR
def get_current_user(self) -> User:
"""Returns the currently authed user.
@@ -1918,6 +1946,7 @@ class SyncServer(Server):
current_user = self.get_user(self._current_user)
if not current_user:
warnings.warn(f"Provided user '{self._current_user}' not found, using default user")
return self.get_default_user()
else:
return current_user