Cli bug fixes (loading human/persona text, azure setup, local setup) (#222)

* mark depricated API section

* add readme

* add readme

* add readme

* add readme

* add readme

* add readme

* add readme

* add readme

* add readme

* CLI bug fixes for azure

* check azure before running

* Update README.md

* Update README.md

* bug fix with persona loading

* revert readme

* remove print
This commit is contained in:
Sarah Wooders
2023-10-31 13:51:20 -07:00
committed by GitHub
parent 4fc88f95f1
commit c9225d329e
4 changed files with 54 additions and 18 deletions

View File

@@ -26,6 +26,10 @@ from memgpt.config import MemGPTConfig, AgentConfig
from memgpt.constants import MEMGPT_DIR
from memgpt.agent import AgentAsync
from memgpt.embeddings import embedding_model
from memgpt.openai_tools import (
configure_azure_support,
check_azure_embeddings,
)
def run(
@@ -135,8 +139,8 @@ def run(
agent_config.preset,
agent_config,
agent_config.model,
agent_config.persona,
agent_config.human,
utils.get_persona_text(agent_config.persona),
utils.get_human_text(agent_config.human),
memgpt.interface,
persistence_manager,
)
@@ -144,5 +148,10 @@ def run(
# start event loop
from memgpt.main import run_agent_loop
# setup azure if using
# TODO: cleanup this code
if config.model_endpoint == "azure":
configure_azure_support()
loop = asyncio.get_event_loop()
loop.run_until_complete(run_agent_loop(memgpt_agent, first, no_verify, config)) # TODO: add back no_verify

View File

@@ -37,10 +37,10 @@ def configure():
use_azure_deployment_ids = False
if use_azure:
# search for key in enviornment
azure_key = os.getenv("AZURE_API_KEY")
azure_endpoint = (os.getenv("AZURE_ENDPOINT"),)
azure_version = (os.getenv("AZURE_VERSION"),)
azure_deployment = (os.getenv("AZURE_OPENAI_DEPLOYMENT"),)
azure_key = os.getenv("AZURE_OPENAI_KEY")
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
azure_version = os.getenv("AZURE_OPENAI_VERSION")
azure_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
azure_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
if all([azure_key, azure_endpoint, azure_version]):
@@ -66,7 +66,7 @@ def configure():
endpoint_options = []
if os.getenv("OPENAI_API_BASE") is not None:
endpoint_options.append(os.getenv("OPENAI_API_BASE"))
if os.getenv("AZURE_ENDPOINT") is not None:
if use_azure:
endpoint_options += ["azure"]
if use_openai:
endpoint_options += ["openai"]

View File

@@ -110,8 +110,10 @@ class MemGPTConfig:
azure_key = config.get("azure", "key")
azure_endpoint = config.get("azure", "endpoint")
azure_version = config.get("azure", "version")
azure_deployment = config.get("azure", "deployment")
azure_embedding_deployment = config.get("azure", "embedding_deployment")
azure_deployment = config.get("azure", "deployment") if config.has_option("azure", "deployment") else None
azure_embedding_deployment = (
config.get("azure", "embedding_deployment") if config.has_option("azure", "embedding_deployment") else None
)
embedding_model = config.get("embedding", "model")
embedding_dim = config.getint("embedding", "dim")
@@ -167,8 +169,9 @@ class MemGPTConfig:
config.set("azure", "key", self.azure_key)
config.set("azure", "endpoint", self.azure_endpoint)
config.set("azure", "version", self.azure_version)
config.set("azure", "deployment", self.azure_deployment)
config.set("azure", "embedding_deployment", self.azure_embedding_deployment)
if self.azure_deployment:
config.set("azure", "deployment", self.azure_deployment)
config.set("azure", "embedding_deployment", self.azure_embedding_deployment)
# embeddings
config.add_section("embedding")

View File

@@ -20,6 +20,8 @@ from memgpt.constants import MEMGPT_DIR
from llama_index import set_global_service_context, ServiceContext, VectorStoreIndex, load_index_from_storage, StorageContext
from llama_index.embeddings import OpenAIEmbedding
from memgpt.embeddings import embedding_model
def count_tokens(s: str, model: str = "gpt-4") -> int:
encoding = tiktoken.encoding_for_model(model)
@@ -398,13 +400,11 @@ def get_index(name, docs):
# read embedding confirguration
# TODO: in the future, make an IngestData class that loads the config once
# config = MemGPTConfig.load()
# chunk_size = config.embedding_chunk_size
# model = config.embedding_model # TODO: actually use this
# dim = config.embedding_dim # TODO: actually use this
# embed_model = OpenAIEmbedding()
# service_context = ServiceContext.from_defaults(embed_model=embed_model, chunk_size=chunk_size)
# set_global_service_context(service_context)
config = MemGPTConfig.load()
embed_model = embedding_model(config)
chunk_size = config.embedding_chunk_size
service_context = ServiceContext.from_defaults(embed_model=embed_model, chunk_size=chunk_size)
set_global_service_context(service_context)
# index documents
index = VectorStoreIndex.from_documents(docs)
@@ -481,3 +481,27 @@ def list_persona_files():
user_added = os.listdir(user_dir)
user_added = [os.path.join(user_dir, f) for f in user_added]
return memgpt_defaults + user_added
def get_human_text(name: str):
for file_path in list_human_files():
file = os.path.basename(file_path)
if f"{name}.txt" == file or name == file:
return open(file_path, "r").read().strip()
raise ValueError(f"Human {name} not found")
def get_persona_text(name: str):
for file_path in list_persona_files():
file = os.path.basename(file_path)
if f"{name}.txt" == file or name == file:
return open(file_path, "r").read().strip()
raise ValueError(f"Persona {name} not found")
def get_human_text(name: str):
for file_path in list_human_files():
file = os.path.basename(file_path)
if f"{name}.txt" == file or name == file:
return open(file_path, "r").read().strip()