feat: Store embeddings padded to size 4096 to allow DB storage of varying size embeddings (#852)

Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
Sarah Wooders
2024-01-19 16:03:13 -08:00
committed by GitHub
parent 492796ed5f
commit 4039763de5
19 changed files with 464 additions and 84 deletions

View File

@@ -10,7 +10,7 @@ from typing import List, Tuple
from box import Box
from memgpt.data_types import AgentState, Message
from memgpt.data_types import AgentState, Message, EmbeddingConfig
from memgpt.models import chat_completion_response
from memgpt.interface import AgentInterface
from memgpt.persistence_manager import PersistenceManager, LocalStateManager
@@ -897,3 +897,10 @@ class Agent(object):
state=updated_state,
)
return self.agent_state
def migrate_embedding(self, embedding_config: EmbeddingConfig):
"""Migrate the agent to a new embedding"""
# TODO: archival memory
# TODO: recall memory
raise NotImplementedError()

View File

@@ -1,5 +1,4 @@
import os
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, DateTime
from sqlalchemy import func, or_, and_
from sqlalchemy import desc, asc
@@ -22,6 +21,7 @@ from memgpt.config import MemGPTConfig
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.config import MemGPTConfig
from memgpt.utils import printd
from memgpt.constants import MAX_EMBEDDING_DIM
from memgpt.data_types import Record, Message, Passage, ToolCall
from memgpt.metadata import MetadataStore
@@ -109,22 +109,6 @@ def get_db_model(
agent_id: Optional[uuid.UUID] = None,
dialect="postgresql",
):
# get embedding dimention info
# TODO: Need to remove this and just pass in AgentState/User instead
ms = MetadataStore(config)
if agent_id and ms.get_agent(agent_id):
agent = ms.get_agent(agent_id)
embedding_dim = agent.embedding_config.embedding_dim
else:
user = ms.get_user(user_id)
if user is None:
raise ValueError(f"User {user_id} not found")
embedding_dim = config.default_embedding_config.embedding_dim
# this cannot be the case if we are making an agent-specific table
assert table_type != TableType.RECALL_MEMORY, f"Agent {agent_id} not found"
assert table_type != TableType.ARCHIVAL_MEMORY, f"Agent {agent_id} not found"
# Define a helper function to create or get the model class
def create_or_get_model(class_name, base_model, table_name):
if class_name in globals():
@@ -156,7 +140,9 @@ def get_db_model(
else:
from pgvector.sqlalchemy import Vector
embedding = mapped_column(Vector(embedding_dim))
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
embedding_dim = Column(BIGINT)
embedding_model = Column(String)
metadata_ = Column(MutableJson)
@@ -167,6 +153,8 @@ def get_db_model(
return Passage(
text=self.text,
embedding=self.embedding,
embedding_dim=self.embedding_dim,
embedding_model=self.embedding_model,
doc_id=self.doc_id,
user_id=self.user_id,
id=self.id,
@@ -216,7 +204,9 @@ def get_db_model(
else:
from pgvector.sqlalchemy import Vector
embedding = mapped_column(Vector(embedding_dim))
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
embedding_dim = Column(BIGINT)
embedding_model = Column(String)
# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True), server_default=func.now())
@@ -235,6 +225,8 @@ def get_db_model(
tool_calls=self.tool_calls,
tool_call_id=self.tool_call_id,
embedding=self.embedding,
embedding_dim=self.embedding_dim,
embedding_model=self.embedding_model,
created_at=self.created_at,
id=self.id,
)
@@ -440,6 +432,7 @@ class PostgresStorageConnector(SQLStorageConnector):
for c in self.db_model.__table__.columns:
if c.name == "embedding":
assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}"
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
session_maker = sessionmaker(bind=self.engine)

View File

@@ -68,6 +68,28 @@ def set_config_with_dict(new_config: dict) -> bool:
else:
printd(f"Skipping new config {k}: {v} == {new_config[k]}")
# update embedding config
for k, v in vars(old_config.default_embedding_config).items():
if k in new_config:
if v != new_config[k]:
printd(f"Replacing config {k}: {v} -> {new_config[k]}")
modified = True
# old_config[k] = new_config[k]
setattr(old_config.default_embedding_config, k, new_config[k])
else:
printd(f"Skipping new config {k}: {v} == {new_config[k]}")
# update llm config
for k, v in vars(old_config.default_llm_config).items():
if k in new_config:
if v != new_config[k]:
printd(f"Replacing config {k}: {v} -> {new_config[k]}")
modified = True
# old_config[k] = new_config[k]
setattr(old_config.default_llm_config, k, new_config[k])
else:
printd(f"Skipping new config {k}: {v} == {new_config[k]}")
if modified:
printd(f"Saving new config file.")
old_config.save()
@@ -416,11 +438,14 @@ def run(
user_id = uuid.UUID(config.anon_clientid)
user = ms.get_user(user_id=user_id)
if user is None:
print("Creating user", user_id)
ms.create_user(User(id=user_id))
user = ms.get_user(user_id=user_id)
if user is None:
typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED)
sys.exit(1)
else:
print("existing user", user, user_id)
# override with command line arguments
if debug:

View File

@@ -208,11 +208,11 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
other_option_str = "[enter model name manually]"
# Check if the model we have set already is even in the list (informs our default)
valid_model = config.model in hardcoded_model_options
valid_model = config.default_llm_config.model in hardcoded_model_options
model = questionary.select(
"Select default model (recommended: gpt-4):",
choices=hardcoded_model_options + [see_all_option_str, other_option_str],
default=config.model if valid_model else hardcoded_model_options[0],
default=config.default_llm_config.model if valid_model else hardcoded_model_options[0],
).ask()
if model is None:
raise KeyboardInterrupt
@@ -409,6 +409,7 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden
embedding_endpoint_type = "openai"
embedding_endpoint = "https://api.openai.com/v1"
embedding_dim = 1536
embedding_model = "text-embedding-ada-002"
elif embedding_provider == "azure":
# check for necessary vars
@@ -422,6 +423,7 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden
embedding_endpoint_type = "azure"
embedding_endpoint = azure_creds["azure_embedding_endpoint"]
embedding_dim = 1536
embedding_model = "text-embedding-ada-002"
elif embedding_provider == "hugging-face":
# configure hugging face embedding endpoint (https://github.com/huggingface/text-embeddings-inference)
@@ -676,7 +678,7 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
if arg == ListChoice.agents:
"""List all agents"""
table = PrettyTable()
table.field_names = ["Name", "Model", "Persona", "Human", "Data Source", "Create Time"]
table.field_names = ["Name", "LLM Model", "Embedding Model", "Embedding Dim", "Persona", "Human", "Data Source", "Create Time"]
for agent in tqdm(ms.list_agents(user_id=user_id)):
source_ids = ms.list_attached_sources(agent_id=agent.id)
source_names = [ms.get_source(source_id=source_id).name for source_id in source_ids]
@@ -684,6 +686,8 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
[
agent.name,
agent.llm_config.model,
agent.embedding_config.embedding_model,
agent.embedding_config.embedding_dim,
agent.persona,
agent.human,
",".join(source_names),
@@ -715,7 +719,7 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
# create table
table = PrettyTable()
table.field_names = ["Name", "Created At", "Agents"]
table.field_names = ["Name", "Embedding Model", "Embedding Dim", "Created At", "Agents"]
# TODO: eventually look accross all storage connections
# TODO: add data source stats
# TODO: connect to agents
@@ -726,7 +730,9 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
agent_ids = ms.list_attached_agents(source_id=source.id)
agent_names = [ms.get_agent(agent_id=agent_id).name for agent_id in agent_ids]
table.add_row([source.name, utils.format_datetime(source.created_at), ",".join(agent_names)])
table.add_row(
[source.name, source.embedding_model, source.embedding_dim, utils.format_datetime(source.created_at), ",".join(agent_names)]
)
print(table)
else:

View File

@@ -10,9 +10,10 @@ memgpt load <data-connector-type> --name <dataset-name> [ADDITIONAL ARGS]
from typing import List
from tqdm import tqdm
import numpy as np
import typer
import uuid
from memgpt.embeddings import embedding_model
from memgpt.embeddings import embedding_model, check_and_split_text
from memgpt.agent_store.storage import StorageConnector
from memgpt.config import MemGPTConfig
from memgpt.metadata import MetadataStore
@@ -91,16 +92,45 @@ def store_docs(name, docs, user_id=None, show_progress=True):
if user_id is None: # assume running local with single user
user_id = uuid.UUID(config.anon_clientid)
# ensure doc text is not too long
# TODO: replace this to instead split up docs that are too large
# (this is a temporary fix to avoid breaking the llama index)
for doc in docs:
doc.text = check_and_split_text(doc.text, config.default_embedding_config.embedding_model)[0]
# record data source metadata
ms = MetadataStore(config)
user = ms.get_user(user_id)
if user is None:
raise ValueError(f"Cannot find user {user_id} in metadata store. Please run 'memgpt configure'.")
data_source = Source(user_id=user.id, name=name, created_at=datetime.now())
if not ms.get_source(user_id=user.id, source_name=name):
# create data source record
data_source = Source(
user_id=user.id,
name=name,
created_at=datetime.now(),
embedding_model=config.default_embedding_config.embedding_model,
embedding_dim=config.default_embedding_config.embedding_dim,
)
existing_source = ms.get_source(user_id=user.id, source_name=name)
if not existing_source:
ms.create_source(data_source)
else:
print(f"Source {name} for user {user.id} already exists")
print(f"Source {name} for user {user.id} already exists.")
if existing_source.embedding_model != data_source.embedding_model:
print(
f"Warning: embedding model for existing source {existing_source.embedding_model} does not match default {data_source.embedding_model}"
)
print("Cannot import data into this source without a compatible embedding endpoint.")
print("Please run 'memgpt configure' to update the default embedding settings.")
return False
if existing_source.embedding_dim != data_source.embedding_dim:
print(
f"Warning: embedding dimension for existing source {existing_source.embedding_dim} does not match default {data_source.embedding_dim}"
)
print("Cannot import data into this source without a compatible embedding endpoint.")
print("Please run 'memgpt configure' to update the default embedding settings.")
return False
# compute and record passages
embed_model = embedding_model(config.default_embedding_config)
@@ -132,6 +162,8 @@ def store_docs(name, docs, user_id=None, show_progress=True):
data_source=name,
embedding=node.embedding,
metadata=None,
embedding_dim=config.default_embedding_config.embedding_dim,
embedding_model=config.default_embedding_config.embedding_model,
)
)
@@ -154,20 +186,31 @@ def load_index(
embed_dict = loaded_index._vector_store._data.embedding_dict
node_dict = loaded_index._docstore.docs
passages = []
for node_id, node in node_dict.items():
vector = embed_dict[node_id]
node.embedding = vector
passages.append(Passage(text=node.text, embedding=vector))
if len(passages) == 0:
raise ValueError(f"No passages found in index {dir}")
# create storage connector
config = MemGPTConfig.load()
if user_id is None:
user_id = uuid.UUID(config.anon_clientid)
passages = []
for node_id, node in node_dict.items():
vector = embed_dict[node_id]
node.embedding = vector
# assume embedding are the same as config
passages.append(
Passage(
text=node.text,
embedding=np.array(vector),
embedding_dim=config.default_embedding_config.embedding_dim,
embedding_model=config.default_embedding_config.embedding_model,
)
)
assert config.default_embedding_config.embedding_dim == len(
vector
), f"Expected embedding dimension {config.default_embedding_config.embedding_dim}, got {len(vector)}"
if len(passages) == 0:
raise ValueError(f"No passages found in index {dir}")
insert_passages_into_source(passages, name, user_id, config)
except ValueError as e:
typer.secho(f"Failed to load index from provided information.\n{e}", fg=typer.colors.RED)
@@ -309,7 +352,10 @@ def load_vector_database(
# Convert to a list of tuples (text, embedding)
passages = []
for text, embedding in result:
passages.append(Passage(text=text, embedding=embedding))
# assume that embeddings are the same model as in config
passages.append(
Passage(text=text, embedding=embedding, embedding_dim=config.embedding_dim, embedding_model=config.embedding_model)
)
assert config.embedding_dim == len(embedding), f"Expected embedding dimension {config.embedding_dim}, got {len(embedding)}"
# create storage connector

View File

@@ -134,8 +134,11 @@ class MemGPTConfig:
"embedding_model": get_field(config, "embedding", "embedding_model"),
"embedding_endpoint_type": get_field(config, "embedding", "embedding_endpoint_type"),
"embedding_dim": get_field(config, "embedding", "embedding_dim"),
"embedding_chunk_size": get_field(config, "embedding", "chunk_size"),
"embedding_chunk_size": get_field(config, "embedding", "embedding_chunk_size"),
}
# Remove null values
llm_config_dict = {k: v for k, v in llm_config_dict.items() if v is not None}
embedding_config_dict = {k: v for k, v in embedding_config_dict.items() if v is not None}
# Correct the types that aren't strings
if llm_config_dict["context_window"] is not None:
llm_config_dict["context_window"] = int(llm_config_dict["context_window"])

View File

@@ -3,6 +3,16 @@ from logging import CRITICAL, ERROR, WARN, WARNING, INFO, DEBUG, NOTSET
MEMGPT_DIR = os.path.join(os.path.expanduser("~"), ".memgpt")
# embeddings
MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset
# tokenizers
EMBEDDING_TO_TOKENIZER_MAP = {
"text-embedding-ada-002": "cl100k_base",
}
EMBEDDING_TO_TOKENIZER_DEFAULT = "cl100k_base"
DEFAULT_MEMGPT_MODEL = "gpt-4"
DEFAULT_PERSONA = "sam_pov"
DEFAULT_HUMAN = "basic"

View File

@@ -5,7 +5,7 @@ from abc import abstractmethod
from typing import Optional, List, Dict
import numpy as np
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS, MAX_EMBEDDING_DIM
from memgpt.utils import get_local_time, format_datetime
from memgpt.models import chat_completion_response
@@ -68,6 +68,8 @@ class Message(Record):
tool_calls: Optional[List[ToolCall]] = None, # list of tool calls requested
tool_call_id: Optional[str] = None,
embedding: Optional[np.ndarray] = None,
embedding_dim: Optional[int] = None,
embedding_model: Optional[str] = None,
id: Optional[uuid.UUID] = None,
):
super().__init__(id)
@@ -82,6 +84,20 @@ class Message(Record):
self.role = role # role (agent/user/function)
self.name = name
# pad and store embeddings
if isinstance(embedding, list):
embedding = np.array(embedding)
self.embedding = (
np.pad(embedding, (0, MAX_EMBEDDING_DIM - embedding.shape[0]), mode="constant").tolist() if embedding is not None else None
)
self.embedding_dim = embedding_dim
self.embedding_model = embedding_model
if self.embedding is not None:
assert self.embedding_dim, f"Must specify embedding_dim if providing an embedding"
assert self.embedding_model, f"Must specify embedding_model if providing an embedding"
assert len(self.embedding) == MAX_EMBEDDING_DIM, f"Embedding must be of length {MAX_EMBEDDING_DIM}"
# tool (i.e. function) call info (optional)
# if role == "assistant", this MAY be specified
@@ -97,9 +113,6 @@ class Message(Record):
assert tool_call_id is None
self.tool_call_id = tool_call_id
# embedding (optional)
self.embedding = embedding
# def __repr__(self):
# pass
@@ -172,7 +185,7 @@ class Message(Record):
if "tool_call_id" in openai_message_dict:
assert openai_message_dict["tool_call_id"] is None, openai_message_dict
if "tool_calls" in openai_message_dict and openai_message_dict["tool_calls"] is not None:
if "tool_calls" in openai_message_dict:
assert openai_message_dict["role"] == "assistant", openai_message_dict
tool_calls = [
@@ -273,6 +286,8 @@ class Passage(Record):
text: str,
agent_id: Optional[uuid.UUID] = None, # set if contained in agent memory
embedding: Optional[np.ndarray] = None,
embedding_dim: Optional[int] = None,
embedding_model: Optional[str] = None,
data_source: Optional[str] = None, # None if created by agent
doc_id: Optional[uuid.UUID] = None,
id: Optional[uuid.UUID] = None,
@@ -283,10 +298,23 @@ class Passage(Record):
self.agent_id = agent_id
self.text = text
self.data_source = data_source
self.embedding = embedding
self.doc_id = doc_id
self.metadata = metadata
# pad and store embeddings
if isinstance(embedding, list):
embedding = np.array(embedding)
self.embedding = (
np.pad(embedding, (0, MAX_EMBEDDING_DIM - embedding.shape[0]), mode="constant").tolist() if embedding is not None else None
)
self.embedding_dim = embedding_dim
self.embedding_model = embedding_model
if self.embedding is not None:
assert self.embedding_dim, f"Must specify embedding_dim if providing an embedding"
assert self.embedding_model, f"Must specify embedding_model if providing an embedding"
assert len(self.embedding) == MAX_EMBEDDING_DIM, f"Embedding must be of length {MAX_EMBEDDING_DIM}"
assert isinstance(self.user_id, uuid.UUID), f"UUID {self.user_id} must be a UUID type"
assert not agent_id or isinstance(self.agent_id, uuid.UUID), f"UUID {self.agent_id} must be a UUID type"
assert not doc_id or isinstance(self.doc_id, uuid.UUID), f"UUID {self.doc_id} must be a UUID type"
@@ -328,6 +356,13 @@ class EmbeddingConfig:
self.embedding_dim = embedding_dim
self.embedding_chunk_size = embedding_chunk_size
# fields cannot be set to None
assert self.embedding_endpoint_type
assert self.embedding_endpoint
assert self.embedding_model
assert self.embedding_dim
assert self.embedding_chunk_size
class OpenAIEmbeddingConfig(EmbeddingConfig):
def __init__(self, openai_key: Optional[str] = None, **kwargs):
@@ -434,6 +469,9 @@ class Source:
name: str,
created_at: Optional[str] = None,
id: Optional[uuid.UUID] = None,
# embedding info
embedding_model: Optional[str] = None,
embedding_dim: Optional[int] = None,
):
if id is None:
self.id = uuid.uuid4()
@@ -445,3 +483,7 @@ class Source:
self.name = name
self.user_id = user_id
self.created_at = created_at
# embedding info (optional)
self.embedding_dim = embedding_dim
self.embedding_model = embedding_model

View File

@@ -2,10 +2,12 @@ import typer
import uuid
from typing import Optional, List
import os
import numpy as np
from memgpt.utils import is_valid_url
from memgpt.utils import is_valid_url, printd
from memgpt.data_types import EmbeddingConfig
from memgpt.credentials import MemGPTCredentials
from memgpt.constants import MAX_EMBEDDING_DIM, EMBEDDING_TO_TOKENIZER_MAP, EMBEDDING_TO_TOKENIZER_DEFAULT
from llama_index.embeddings import OpenAIEmbedding, AzureOpenAIEmbedding
from llama_index.bridge.pydantic import PrivateAttr
@@ -14,6 +16,34 @@ from llama_index.embeddings.huggingface_utils import format_text
import tiktoken
def check_and_split_text(text: str, embedding_model: str) -> List[str]:
"""Split text into chunks of max_length tokens or less"""
if embedding_model in EMBEDDING_TO_TOKENIZER_MAP:
encoding = tiktoken.get_encoding(EMBEDDING_TO_TOKENIZER_MAP[embedding_model])
else:
print(f"Warning: couldn't find tokenizer for model {embedding_model}, using default tokenizer {EMBEDDING_TO_TOKENIZER_DEFAULT}")
encoding = tiktoken.get_encoding(EMBEDDING_TO_TOKENIZER_DEFAULT)
num_tokens = len(encoding.encode(text))
# determine max length
if hasattr(encoding, "max_length"):
max_length = encoding.max_length
else:
# TODO: figure out the real number
printd(f"Warning: couldn't find max_length for tokenizer {embedding_model}, using default max_length 8191")
max_length = 8191
# truncate text if too long
if num_tokens > max_length:
# TODO: split this into two pieces of text instead of truncating
print(f"Warning: text is too long ({num_tokens} tokens), truncating to {max_length} tokens.")
text = format_text(text, embedding_model, max_length=max_length)
return [text]
class EmbeddingEndpoint(BaseEmbedding):
"""Implementation for OpenAI compatible endpoint"""
@@ -38,7 +68,6 @@ class EmbeddingEndpoint(BaseEmbedding):
self._user = user
self._base_url = base_url
self._timeout = timeout
self._encoding = tiktoken.get_encoding(model)
super().__init__(
model_name=model,
)
@@ -47,10 +76,6 @@ class EmbeddingEndpoint(BaseEmbedding):
def class_name(cls) -> str:
return "EmbeddingEndpoint"
def count_tokens(self, text: str) -> int:
"""Count tokens using the embedding model's tokenizer"""
return len(self._encoding.encode(text))
def _call_api(self, text: str) -> List[float]:
if not is_valid_url(self._base_url):
raise ValueError(
@@ -58,12 +83,6 @@ class EmbeddingEndpoint(BaseEmbedding):
)
import httpx
# If necessary, truncate text to fit in the embedding model's max sequence length (usually 512)
num_tokens = self.count_tokens(text)
max_length = self._encoding.max_length
if num_tokens > max_length:
text = format_text(text, self.model_name, max_length=max_length)
headers = {"Content-Type": "application/json"}
json_data = {"input": text, "model": self.model_name, "user": self._user}
@@ -157,6 +176,14 @@ def default_embedding_model():
return HuggingFaceEmbedding(model_name=model)
def query_embedding(embedding_model, query_text: str):
"""Generate padded embedding for querying database"""
query_vec = embedding_model.get_text_embedding(query_text)
query_vec = np.array(query_vec)
query_vec = np.pad(query_vec, (0, MAX_EMBEDDING_DIM - query_vec.shape[0]), mode="constant").tolist()
return query_vec
def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None):
"""Return LlamaIndex embedding model to use for embeddings"""
@@ -181,11 +208,6 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
api_version=credentials.azure_version,
)
elif endpoint_type == "hugging-face":
try:
return EmbeddingEndpoint(model=config.embedding_model, base_url=config.embedding_endpoint, user=user_id)
except Exception as e:
# TODO: remove, this is just to get passing tests
print(e)
return default_embedding_model()
return EmbeddingEndpoint(model=config.embedding_model, base_url=config.embedding_endpoint, user=user_id)
else:
return default_embedding_model()

View File

@@ -111,6 +111,9 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
elif user_input.lower() == "/attach":
# TODO: check if agent already has it
# TODO: check to ensure source embedding dimentions/model match agents, and disallow attachment if not
# TODO: alternatively, only list sources with compatible embeddings, and print warning about non-compatible sources
data_source_options = ms.list_sources(user_id=memgpt_agent.agent_state.user_id)
data_source_options = [s.name for s in data_source_options]
if len(data_source_options) == 0:
@@ -120,7 +123,24 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
bold=True,
)
continue
data_source = questionary.select("Select data source", choices=data_source_options).ask()
# determine what sources are valid to be attached to this agent
valid_options = []
invalid_options = []
for source in data_source_options:
if source.embedding_model == memgpt_agent.embedding_model and source.embedding_dim == memgpt_agent.embedding_dim:
valid_options.append(source.name)
else:
invalid_options.append(source.name)
# print warning about invalid sources
typer.secho(
f"Warning: the following sources are not compatible with this agent's embedding model and dimension: {invalid_options}",
fg=typer.colors.YELLOW,
)
# prompt user for data source selection
data_source = questionary.select("Select data source", choices=valid_options).ask()
# attach new data
attach(memgpt_agent.config.name, data_source)

View File

@@ -7,7 +7,7 @@ from memgpt.utils import get_local_time, printd, count_tokens, validate_date_for
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
from memgpt.llm_api_tools import create
from memgpt.data_types import Message, Passage, AgentState
from memgpt.embeddings import embedding_model
from memgpt.embeddings import embedding_model, query_embedding
from llama_index import Document
from llama_index.node_parser import SimpleNodeParser
@@ -372,6 +372,7 @@ class EmbeddingArchivalMemory(ArchivalMemory):
# create embedding model
self.embed_model = embedding_model(agent_state.embedding_config)
self.embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
assert self.embedding_chunk_size, f"Must set {agent_state.embedding_config.embedding_chunk_size}"
# create storage backend
self.storage = StorageConnector.get_archival_storage_connector(user_id=agent_state.user_id, agent_id=agent_state.id)
@@ -384,6 +385,8 @@ class EmbeddingArchivalMemory(ArchivalMemory):
agent_id=self.agent_state.id,
text=text,
embedding=embedding,
embedding_dim=self.agent_state.embedding_config.embedding_dim,
embedding_model=self.agent_state.embedding_config.embedding_model,
)
def save(self):
@@ -432,7 +435,7 @@ class EmbeddingArchivalMemory(ArchivalMemory):
try:
if query_string not in self.cache:
# self.cache[query_string] = self.retriever.retrieve(query_string)
query_vec = self.embed_model.get_text_embedding(query_string)
query_vec = query_embedding(self.embed_model, query_string)
self.cache[query_string] = self.storage.query(query_string, query_vec, top_k=self.top_k)
start = int(start if start else 0)

View File

@@ -64,6 +64,7 @@ class LLMConfigColumn(TypeDecorator):
return value
def process_result_value(self, value, dialect):
print("GET VALUE", value)
if value:
return LLMConfig(**value)
return value
@@ -168,6 +169,8 @@ class SourceModel(Base):
user_id = Column(CommonUUID, nullable=False)
name = Column(String, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
embedding_dim = Column(BIGINT)
embedding_model = Column(String)
# TODO: add num passages
@@ -175,7 +178,14 @@ class SourceModel(Base):
return f"<Source(passage_id='{self.id}', name='{self.name}')>"
def to_record(self) -> Source:
return Source(id=self.id, user_id=self.user_id, name=self.name, created_at=self.created_at)
return Source(
id=self.id,
user_id=self.user_id,
name=self.name,
created_at=self.created_at,
embedding_dim=self.embedding_dim,
embedding_model=self.embedding_model,
)
class AgentSourceMappingModel(Base):
@@ -288,6 +298,7 @@ class MetadataStore:
@enforce_types
def list_agents(self, user_id: uuid.UUID) -> List[AgentState]:
print("query agents", user_id)
results = self.session.query(AgentModel).filter(AgentModel.user_id == user_id).all()
return [r.to_record() for r in results]

View File

@@ -150,7 +150,20 @@ def migrate_source(source_name: str):
for node in nodes:
# print(len(node.embedding))
# TODO: make sure embedding config matches embedding size?
passages.append(Passage(user_id=user.id, data_source=source_name, text=node.text, embedding=node.embedding))
if len(node.embedding) != config.default_embedding_config.embedding_dim:
raise ValueError(
f"Cannot migrate source {source_name} due to incompatible embedding dimentions. Please re-load this source with `memgpt load`."
)
passages.append(
Passage(
user_id=user.id,
data_source=source_name,
text=node.text,
embedding=node.embedding,
embedding_dim=config.default_embedding_config.embedding_dim,
embedding_model=config.default_embedding_config.embedding_model,
)
)
assert len(passages) > 0, f"Source {source_name} has no passages"
conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config=config, user_id=user_id)
@@ -313,9 +326,18 @@ def migrate_agent(agent_name: str):
nodes = pickle.load(open(archival_filename, "rb"))
passages = []
for node in nodes:
# print(len(node.embedding))
# TODO: make sure embeding size matches embedding config?
passages.append(Passage(user_id=user.id, agent_id=agent_state.id, text=node.text, embedding=node.embedding))
if len(node.embedding) != config.default_embedding_config.embedding_dim:
raise ValueError(f"Cannot migrate agent {agent_state.name} due to incompatible embedding dimentions.")
passages.append(
Passage(
user_id=user.id,
agent_id=agent_state.id,
text=node.text,
embedding=node.embedding,
embedding_dim=agent_state.embedding_config.embedding_dim,
embedding_model=agent_state.embedding_config.embedding_model,
)
)
if len(passages) > 0:
agent.persistence_manager.archival_memory.storage.insert_many(passages)
print(f"Inserted {len(passages)} passages into archival memory")

View File

@@ -6,7 +6,7 @@ from typing import Optional, Tuple, Union
from enum import Enum
import openai
# import openai
from memgpt.constants import JSON_ENSURE_ASCII
@@ -68,7 +68,8 @@ class ApiType(Enum):
elif label.lower() in ("open_ai", "openai"):
return ApiType.OPEN_AI
else:
raise openai.error.InvalidAPIType(
# raise openai.error.InvalidAPIType(
raise Exception(
"The API type provided in invalid. Please select one of the supported API types: 'azure', 'azure_ad', 'open_ai'"
)
@@ -361,7 +362,8 @@ class OpenAIObject(dict):
@property
def typed_api_type(self):
return ApiType.from_str(self.api_type) if self.api_type else ApiType.from_str(openai.api_type)
# return ApiType.from_str(self.api_type) if self.api_type else ApiType.from_str(openai.api_type)
return ApiType.from_str(self.api_type) if self.api_type else ApiType.from_str(ApiType.OPEN_AI)
# This class overrides __setitem__ to throw exceptions on inputs that it
# doesn't like. This can cause problems when we try to copy an object

View File

@@ -0,0 +1,123 @@
import uuid
import os
from memgpt import MemGPT
from memgpt.config import MemGPTConfig
from memgpt import constants
from memgpt.data_types import LLMConfig, EmbeddingConfig, AgentState, Passage
from memgpt.embeddings import embedding_model
from memgpt.agent_store.storage import StorageConnector, TableType
from .utils import wipe_config
import uuid
test_agent_name = f"test_client_{str(uuid.uuid4())}"
test_agent_state = None
client = None
test_agent_state_post_message = None
test_user_id = uuid.uuid4()
def generate_passages(user, agent):
# Note: the database will filter out rows that do not correspond to agent1 and test_user by default.
texts = ["This is a test passage", "This is another test passage", "Cinderella wept"]
embed_model = embedding_model(agent.embedding_config)
orig_embeddings = []
passages = []
for text in texts:
embedding = embed_model.get_text_embedding(text)
orig_embeddings.append(list(embedding))
passages.append(
Passage(
user_id=user.id,
agent_id=agent.id,
text=text,
embedding=embedding,
embedding_dim=agent.embedding_config.embedding_dim,
embedding_model=agent.embedding_config.embedding_model,
)
)
return passages, orig_embeddings
def test_create_user():
wipe_config()
# create client
client = MemGPT(quickstart="openai", user_id=test_user_id)
# create user
user = client.server.create_user({"id": test_user_id})
# openai: create agent
openai_agent = client.create_agent(
{
"user_id": test_user_id,
"name": "openai_agent",
}
)
assert (
openai_agent.embedding_config.embedding_endpoint_type == "openai"
), f"openai_agent.embedding_config.embedding_endpoint_type={openai_agent.embedding_config.embedding_endpoint_type}"
# openai: add passages
passages, openai_embeddings = generate_passages(user, openai_agent)
openai_agent_run = client.server._get_or_load_agent(user_id=user.id, agent_id=openai_agent.id)
openai_agent_run.persistence_manager.archival_memory.storage.insert_many(passages)
# hosted: create agent
hosted_agent = client.create_agent(
{
"user_id": test_user_id,
"name": "hosted_agent",
"embedding_config": EmbeddingConfig(
embedding_endpoint_type="hugging-face",
embedding_model="BAAI/bge-large-en-v1.5",
embedding_endpoint="https://embeddings.memgpt.ai",
embedding_dim=1024,
),
}
)
# check to make sure endpoint overriden
assert (
hosted_agent.embedding_config.embedding_endpoint_type == "hugging-face"
), f"hosted_agent.embedding_config.embedding_endpoint_type={hosted_agent.embedding_config.embedding_endpoint_type}"
# hosted: add passages
passages, hosted_embeddings = generate_passages(user, hosted_agent)
hosted_agent_run = client.server._get_or_load_agent(user_id=user.id, agent_id=hosted_agent.id)
hosted_agent_run.persistence_manager.archival_memory.storage.insert_many(passages)
# test passage dimentionality
config = MemGPTConfig.load()
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user.id)
storage.filters = {} # clear filters to be able to get all passages
passages = storage.get_all()
for passage in passages:
if passage.agent_id == hosted_agent.id:
assert (
passage.embedding_dim == hosted_agent.embedding_config.embedding_dim
), f"passage.embedding_dim={passage.embedding_dim} != hosted_agent.embedding_config.embedding_dim={hosted_agent.embedding_config.embedding_dim}"
# ensure was in original embeddings
embedding = passage.embedding[: passage.embedding_dim]
assert embedding in hosted_embeddings, f"embedding={embedding} not in hosted_embeddings={hosted_embeddings}"
# make sure all zeros
assert not any(
passage.embedding[passage.embedding_dim :]
), f"passage.embedding[passage.embedding_dim:]={passage.embedding[passage.embedding_dim:]}"
elif passage.agent_id == openai_agent.id:
assert (
passage.embedding_dim == openai_agent.embedding_config.embedding_dim
), f"passage.embedding_dim={passage.embedding_dim} != openai_agent.embedding_config.embedding_dim={openai_agent.embedding_config.embedding_dim}"
# ensure was in original embeddings
embedding = passage.embedding[: passage.embedding_dim]
assert embedding in openai_embeddings, f"embedding={embedding} not in openai_embeddings={openai_embeddings}"
# make sure all zeros
assert not any(
passage.embedding[passage.embedding_dim :]
), f"passage.embedding[passage.embedding_dim:]={passage.embedding[passage.embedding_dim:]}"

View File

@@ -88,13 +88,13 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
persona=user.default_persona,
human=user.default_human,
llm_config=config.default_llm_config,
embedding_config=config.default_embedding_config,
embedding_config=embedding_config,
)
ms.delete_user(user.id)
ms.create_user(user)
ms.create_agent(agent)
user = ms.get_user(user.id)
print("Got user:", user, config.default_embedding_config)
print("Got user:", user, embedding_config)
# setup storage connectors
print("Creating storage connectors...")

View File

@@ -104,7 +104,14 @@ def test_server():
for text in archival_memories:
embedding = embed_model.get_text_embedding(text)
agent.persistence_manager.archival_memory.storage.insert(
Passage(user_id=user.id, agent_id=agent_state.id, text=text, embedding=embedding)
Passage(
user_id=user.id,
agent_id=agent_state.id,
text=text,
embedding=embedding,
embedding_dim=agent.agent_state.embedding_config.embedding_dim,
embedding_model=agent.agent_state.embedding_config.embedding_model,
)
)
# add data into recall memory

View File

@@ -4,13 +4,14 @@ import uuid
import pytest
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.embeddings import embedding_model
from memgpt.embeddings import embedding_model, query_embedding
from memgpt.data_types import Message, Passage, EmbeddingConfig, AgentState, OpenAIEmbeddingConfig
from memgpt.config import MemGPTConfig
from memgpt.credentials import MemGPTCredentials
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.metadata import MetadataStore
from memgpt.data_types import User
from memgpt.constants import MAX_EMBEDDING_DIM
from datetime import datetime, timedelta
@@ -33,10 +34,23 @@ def generate_passages(embed_model):
# embeddings: use openai if env is set, otherwise local
passages = []
for text, _, _, agent_id, id in zip(texts, dates, roles, agent_ids, ids):
embedding = None
embedding, embedding_model, embedding_dim = None, None, None
if embed_model:
embedding = embed_model.get_text_embedding(text)
passages.append(Passage(user_id=user_id, text=text, agent_id=agent_id, embedding=embedding, data_source="test_source", id=id))
embedding_model = "gpt-4"
embedding_dim = len(embedding)
passages.append(
Passage(
user_id=user_id,
text=text,
agent_id=agent_id,
embedding=embedding,
data_source="test_source",
id=id,
embedding_dim=embedding_dim,
embedding_model=embedding_model,
)
)
return passages
@@ -45,11 +59,24 @@ def generate_messages(embed_model):
"""Generate list of 3 Message objects"""
messages = []
for text, date, role, agent_id, id in zip(texts, dates, roles, agent_ids, ids):
embedding = None
embedding, embedding_model, embedding_dim = None, None, None
if embed_model:
embedding = embed_model.get_text_embedding(text)
embedding_model = "gpt-4"
embedding_dim = len(embedding)
messages.append(
Message(user_id=user_id, text=text, agent_id=agent_id, role=role, created_at=date, id=id, model="gpt-4", embedding=embedding)
Message(
user_id=user_id,
text=text,
agent_id=agent_id,
role=role,
created_at=date,
id=id,
model="gpt-4",
embedding=embedding,
embedding_model=embedding_model,
embedding_dim=embedding_dim,
)
)
print(messages[-1].text)
return messages
@@ -165,6 +192,14 @@ def test_storage(storage_connector, table_type, clear_dynamically_created_models
else:
raise NotImplementedError(f"Table type {table_type} not implemented")
# check record dimentions
print("TABLE TYPE", table_type, type(records[0]), len(records[0].embedding))
if embed_model:
assert len(records[0].embedding) == MAX_EMBEDDING_DIM, f"Expected {MAX_EMBEDDING_DIM}, got {len(records[0].embedding)}"
assert (
records[0].embedding_dim == embedding_config.embedding_dim
), f"Expected {embedding_config.embedding_dim}, got {records[0].embedding_dim}"
# test: insert
conn.insert(records[0])
assert conn.size() == 1, f"Expected 1 record, got {conn.size()}: {conn.get_all()}"
@@ -208,7 +243,7 @@ def test_storage(storage_connector, table_type, clear_dynamically_created_models
# test: query (vector)
if table_type == TableType.ARCHIVAL_MEMORY:
query = "why was she crying"
query_vec = embed_model.get_text_embedding(query)
query_vec = query_embedding(embed_model, query)
res = conn.query(None, query_vec, top_k=2)
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
print("Archival memory results", res)

View File

@@ -15,6 +15,9 @@ def wipe_config():
config_path = MemGPTConfig.config_path
# TODO delete file config_path
os.remove(config_path)
assert not MemGPTConfig.exists(), "Config should not exist after deletion"
else:
print("No config to wipe", MemGPTConfig.config_path)
def wipe_memgpt_home():