Cli bug fixes (loading human/persona text, azure setup, local setup) (#222)
* mark depricated API section * add readme * add readme * add readme * add readme * add readme * add readme * add readme * add readme * add readme * CLI bug fixes for azure * check azure before running * Update README.md * Update README.md * bug fix with persona loading * revert readme * remove print
This commit is contained in:
@@ -26,6 +26,10 @@ from memgpt.config import MemGPTConfig, AgentConfig
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
from memgpt.agent import AgentAsync
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.openai_tools import (
|
||||
configure_azure_support,
|
||||
check_azure_embeddings,
|
||||
)
|
||||
|
||||
|
||||
def run(
|
||||
@@ -135,8 +139,8 @@ def run(
|
||||
agent_config.preset,
|
||||
agent_config,
|
||||
agent_config.model,
|
||||
agent_config.persona,
|
||||
agent_config.human,
|
||||
utils.get_persona_text(agent_config.persona),
|
||||
utils.get_human_text(agent_config.human),
|
||||
memgpt.interface,
|
||||
persistence_manager,
|
||||
)
|
||||
@@ -144,5 +148,10 @@ def run(
|
||||
# start event loop
|
||||
from memgpt.main import run_agent_loop
|
||||
|
||||
# setup azure if using
|
||||
# TODO: cleanup this code
|
||||
if config.model_endpoint == "azure":
|
||||
configure_azure_support()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(run_agent_loop(memgpt_agent, first, no_verify, config)) # TODO: add back no_verify
|
||||
|
||||
@@ -37,10 +37,10 @@ def configure():
|
||||
use_azure_deployment_ids = False
|
||||
if use_azure:
|
||||
# search for key in enviornment
|
||||
azure_key = os.getenv("AZURE_API_KEY")
|
||||
azure_endpoint = (os.getenv("AZURE_ENDPOINT"),)
|
||||
azure_version = (os.getenv("AZURE_VERSION"),)
|
||||
azure_deployment = (os.getenv("AZURE_OPENAI_DEPLOYMENT"),)
|
||||
azure_key = os.getenv("AZURE_OPENAI_KEY")
|
||||
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
azure_version = os.getenv("AZURE_OPENAI_VERSION")
|
||||
azure_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
||||
azure_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
|
||||
|
||||
if all([azure_key, azure_endpoint, azure_version]):
|
||||
@@ -66,7 +66,7 @@ def configure():
|
||||
endpoint_options = []
|
||||
if os.getenv("OPENAI_API_BASE") is not None:
|
||||
endpoint_options.append(os.getenv("OPENAI_API_BASE"))
|
||||
if os.getenv("AZURE_ENDPOINT") is not None:
|
||||
if use_azure:
|
||||
endpoint_options += ["azure"]
|
||||
if use_openai:
|
||||
endpoint_options += ["openai"]
|
||||
|
||||
@@ -110,8 +110,10 @@ class MemGPTConfig:
|
||||
azure_key = config.get("azure", "key")
|
||||
azure_endpoint = config.get("azure", "endpoint")
|
||||
azure_version = config.get("azure", "version")
|
||||
azure_deployment = config.get("azure", "deployment")
|
||||
azure_embedding_deployment = config.get("azure", "embedding_deployment")
|
||||
azure_deployment = config.get("azure", "deployment") if config.has_option("azure", "deployment") else None
|
||||
azure_embedding_deployment = (
|
||||
config.get("azure", "embedding_deployment") if config.has_option("azure", "embedding_deployment") else None
|
||||
)
|
||||
|
||||
embedding_model = config.get("embedding", "model")
|
||||
embedding_dim = config.getint("embedding", "dim")
|
||||
@@ -167,8 +169,9 @@ class MemGPTConfig:
|
||||
config.set("azure", "key", self.azure_key)
|
||||
config.set("azure", "endpoint", self.azure_endpoint)
|
||||
config.set("azure", "version", self.azure_version)
|
||||
config.set("azure", "deployment", self.azure_deployment)
|
||||
config.set("azure", "embedding_deployment", self.azure_embedding_deployment)
|
||||
if self.azure_deployment:
|
||||
config.set("azure", "deployment", self.azure_deployment)
|
||||
config.set("azure", "embedding_deployment", self.azure_embedding_deployment)
|
||||
|
||||
# embeddings
|
||||
config.add_section("embedding")
|
||||
|
||||
@@ -20,6 +20,8 @@ from memgpt.constants import MEMGPT_DIR
|
||||
from llama_index import set_global_service_context, ServiceContext, VectorStoreIndex, load_index_from_storage, StorageContext
|
||||
from llama_index.embeddings import OpenAIEmbedding
|
||||
|
||||
from memgpt.embeddings import embedding_model
|
||||
|
||||
|
||||
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
@@ -398,13 +400,11 @@ def get_index(name, docs):
|
||||
|
||||
# read embedding confirguration
|
||||
# TODO: in the future, make an IngestData class that loads the config once
|
||||
# config = MemGPTConfig.load()
|
||||
# chunk_size = config.embedding_chunk_size
|
||||
# model = config.embedding_model # TODO: actually use this
|
||||
# dim = config.embedding_dim # TODO: actually use this
|
||||
# embed_model = OpenAIEmbedding()
|
||||
# service_context = ServiceContext.from_defaults(embed_model=embed_model, chunk_size=chunk_size)
|
||||
# set_global_service_context(service_context)
|
||||
config = MemGPTConfig.load()
|
||||
embed_model = embedding_model(config)
|
||||
chunk_size = config.embedding_chunk_size
|
||||
service_context = ServiceContext.from_defaults(embed_model=embed_model, chunk_size=chunk_size)
|
||||
set_global_service_context(service_context)
|
||||
|
||||
# index documents
|
||||
index = VectorStoreIndex.from_documents(docs)
|
||||
@@ -481,3 +481,27 @@ def list_persona_files():
|
||||
user_added = os.listdir(user_dir)
|
||||
user_added = [os.path.join(user_dir, f) for f in user_added]
|
||||
return memgpt_defaults + user_added
|
||||
|
||||
|
||||
def get_human_text(name: str):
|
||||
for file_path in list_human_files():
|
||||
file = os.path.basename(file_path)
|
||||
if f"{name}.txt" == file or name == file:
|
||||
return open(file_path, "r").read().strip()
|
||||
raise ValueError(f"Human {name} not found")
|
||||
|
||||
|
||||
def get_persona_text(name: str):
|
||||
for file_path in list_persona_files():
|
||||
file = os.path.basename(file_path)
|
||||
if f"{name}.txt" == file or name == file:
|
||||
return open(file_path, "r").read().strip()
|
||||
|
||||
raise ValueError(f"Persona {name} not found")
|
||||
|
||||
|
||||
def get_human_text(name: str):
|
||||
for file_path in list_human_files():
|
||||
file = os.path.basename(file_path)
|
||||
if f"{name}.txt" == file or name == file:
|
||||
return open(file_path, "r").read().strip()
|
||||
|
||||
Reference in New Issue
Block a user