feat: Store embeddings padded to size 4096 to allow DB storage of varying size embeddings (#852)
Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
@@ -10,7 +10,7 @@ from typing import List, Tuple
|
||||
|
||||
from box import Box
|
||||
|
||||
from memgpt.data_types import AgentState, Message
|
||||
from memgpt.data_types import AgentState, Message, EmbeddingConfig
|
||||
from memgpt.models import chat_completion_response
|
||||
from memgpt.interface import AgentInterface
|
||||
from memgpt.persistence_manager import PersistenceManager, LocalStateManager
|
||||
@@ -897,3 +897,10 @@ class Agent(object):
|
||||
state=updated_state,
|
||||
)
|
||||
return self.agent_state
|
||||
|
||||
def migrate_embedding(self, embedding_config: EmbeddingConfig):
|
||||
"""Migrate the agent to a new embedding"""
|
||||
# TODO: archival memory
|
||||
|
||||
# TODO: recall memory
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
|
||||
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, DateTime
|
||||
from sqlalchemy import func, or_, and_
|
||||
from sqlalchemy import desc, asc
|
||||
@@ -22,6 +21,7 @@ from memgpt.config import MemGPTConfig
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.utils import printd
|
||||
from memgpt.constants import MAX_EMBEDDING_DIM
|
||||
from memgpt.data_types import Record, Message, Passage, ToolCall
|
||||
from memgpt.metadata import MetadataStore
|
||||
|
||||
@@ -109,22 +109,6 @@ def get_db_model(
|
||||
agent_id: Optional[uuid.UUID] = None,
|
||||
dialect="postgresql",
|
||||
):
|
||||
# get embedding dimention info
|
||||
# TODO: Need to remove this and just pass in AgentState/User instead
|
||||
ms = MetadataStore(config)
|
||||
if agent_id and ms.get_agent(agent_id):
|
||||
agent = ms.get_agent(agent_id)
|
||||
embedding_dim = agent.embedding_config.embedding_dim
|
||||
else:
|
||||
user = ms.get_user(user_id)
|
||||
if user is None:
|
||||
raise ValueError(f"User {user_id} not found")
|
||||
embedding_dim = config.default_embedding_config.embedding_dim
|
||||
|
||||
# this cannot be the case if we are making an agent-specific table
|
||||
assert table_type != TableType.RECALL_MEMORY, f"Agent {agent_id} not found"
|
||||
assert table_type != TableType.ARCHIVAL_MEMORY, f"Agent {agent_id} not found"
|
||||
|
||||
# Define a helper function to create or get the model class
|
||||
def create_or_get_model(class_name, base_model, table_name):
|
||||
if class_name in globals():
|
||||
@@ -156,7 +140,9 @@ def get_db_model(
|
||||
else:
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
embedding = mapped_column(Vector(embedding_dim))
|
||||
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
|
||||
embedding_dim = Column(BIGINT)
|
||||
embedding_model = Column(String)
|
||||
|
||||
metadata_ = Column(MutableJson)
|
||||
|
||||
@@ -167,6 +153,8 @@ def get_db_model(
|
||||
return Passage(
|
||||
text=self.text,
|
||||
embedding=self.embedding,
|
||||
embedding_dim=self.embedding_dim,
|
||||
embedding_model=self.embedding_model,
|
||||
doc_id=self.doc_id,
|
||||
user_id=self.user_id,
|
||||
id=self.id,
|
||||
@@ -216,7 +204,9 @@ def get_db_model(
|
||||
else:
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
embedding = mapped_column(Vector(embedding_dim))
|
||||
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
|
||||
embedding_dim = Column(BIGINT)
|
||||
embedding_model = Column(String)
|
||||
|
||||
# Add a datetime column, with default value as the current time
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
@@ -235,6 +225,8 @@ def get_db_model(
|
||||
tool_calls=self.tool_calls,
|
||||
tool_call_id=self.tool_call_id,
|
||||
embedding=self.embedding,
|
||||
embedding_dim=self.embedding_dim,
|
||||
embedding_model=self.embedding_model,
|
||||
created_at=self.created_at,
|
||||
id=self.id,
|
||||
)
|
||||
@@ -440,6 +432,7 @@ class PostgresStorageConnector(SQLStorageConnector):
|
||||
for c in self.db_model.__table__.columns:
|
||||
if c.name == "embedding":
|
||||
assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}"
|
||||
|
||||
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
|
||||
|
||||
session_maker = sessionmaker(bind=self.engine)
|
||||
|
||||
@@ -68,6 +68,28 @@ def set_config_with_dict(new_config: dict) -> bool:
|
||||
else:
|
||||
printd(f"Skipping new config {k}: {v} == {new_config[k]}")
|
||||
|
||||
# update embedding config
|
||||
for k, v in vars(old_config.default_embedding_config).items():
|
||||
if k in new_config:
|
||||
if v != new_config[k]:
|
||||
printd(f"Replacing config {k}: {v} -> {new_config[k]}")
|
||||
modified = True
|
||||
# old_config[k] = new_config[k]
|
||||
setattr(old_config.default_embedding_config, k, new_config[k])
|
||||
else:
|
||||
printd(f"Skipping new config {k}: {v} == {new_config[k]}")
|
||||
|
||||
# update llm config
|
||||
for k, v in vars(old_config.default_llm_config).items():
|
||||
if k in new_config:
|
||||
if v != new_config[k]:
|
||||
printd(f"Replacing config {k}: {v} -> {new_config[k]}")
|
||||
modified = True
|
||||
# old_config[k] = new_config[k]
|
||||
setattr(old_config.default_llm_config, k, new_config[k])
|
||||
else:
|
||||
printd(f"Skipping new config {k}: {v} == {new_config[k]}")
|
||||
|
||||
if modified:
|
||||
printd(f"Saving new config file.")
|
||||
old_config.save()
|
||||
@@ -416,11 +438,14 @@ def run(
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
user = ms.get_user(user_id=user_id)
|
||||
if user is None:
|
||||
print("Creating user", user_id)
|
||||
ms.create_user(User(id=user_id))
|
||||
user = ms.get_user(user_id=user_id)
|
||||
if user is None:
|
||||
typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED)
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("existing user", user, user_id)
|
||||
|
||||
# override with command line arguments
|
||||
if debug:
|
||||
|
||||
@@ -208,11 +208,11 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
|
||||
other_option_str = "[enter model name manually]"
|
||||
|
||||
# Check if the model we have set already is even in the list (informs our default)
|
||||
valid_model = config.model in hardcoded_model_options
|
||||
valid_model = config.default_llm_config.model in hardcoded_model_options
|
||||
model = questionary.select(
|
||||
"Select default model (recommended: gpt-4):",
|
||||
choices=hardcoded_model_options + [see_all_option_str, other_option_str],
|
||||
default=config.model if valid_model else hardcoded_model_options[0],
|
||||
default=config.default_llm_config.model if valid_model else hardcoded_model_options[0],
|
||||
).ask()
|
||||
if model is None:
|
||||
raise KeyboardInterrupt
|
||||
@@ -409,6 +409,7 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden
|
||||
embedding_endpoint_type = "openai"
|
||||
embedding_endpoint = "https://api.openai.com/v1"
|
||||
embedding_dim = 1536
|
||||
embedding_model = "text-embedding-ada-002"
|
||||
|
||||
elif embedding_provider == "azure":
|
||||
# check for necessary vars
|
||||
@@ -422,6 +423,7 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden
|
||||
embedding_endpoint_type = "azure"
|
||||
embedding_endpoint = azure_creds["azure_embedding_endpoint"]
|
||||
embedding_dim = 1536
|
||||
embedding_model = "text-embedding-ada-002"
|
||||
|
||||
elif embedding_provider == "hugging-face":
|
||||
# configure hugging face embedding endpoint (https://github.com/huggingface/text-embeddings-inference)
|
||||
@@ -676,7 +678,7 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
if arg == ListChoice.agents:
|
||||
"""List all agents"""
|
||||
table = PrettyTable()
|
||||
table.field_names = ["Name", "Model", "Persona", "Human", "Data Source", "Create Time"]
|
||||
table.field_names = ["Name", "LLM Model", "Embedding Model", "Embedding Dim", "Persona", "Human", "Data Source", "Create Time"]
|
||||
for agent in tqdm(ms.list_agents(user_id=user_id)):
|
||||
source_ids = ms.list_attached_sources(agent_id=agent.id)
|
||||
source_names = [ms.get_source(source_id=source_id).name for source_id in source_ids]
|
||||
@@ -684,6 +686,8 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
[
|
||||
agent.name,
|
||||
agent.llm_config.model,
|
||||
agent.embedding_config.embedding_model,
|
||||
agent.embedding_config.embedding_dim,
|
||||
agent.persona,
|
||||
agent.human,
|
||||
",".join(source_names),
|
||||
@@ -715,7 +719,7 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
|
||||
# create table
|
||||
table = PrettyTable()
|
||||
table.field_names = ["Name", "Created At", "Agents"]
|
||||
table.field_names = ["Name", "Embedding Model", "Embedding Dim", "Created At", "Agents"]
|
||||
# TODO: eventually look accross all storage connections
|
||||
# TODO: add data source stats
|
||||
# TODO: connect to agents
|
||||
@@ -726,7 +730,9 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
agent_ids = ms.list_attached_agents(source_id=source.id)
|
||||
agent_names = [ms.get_agent(agent_id=agent_id).name for agent_id in agent_ids]
|
||||
|
||||
table.add_row([source.name, utils.format_datetime(source.created_at), ",".join(agent_names)])
|
||||
table.add_row(
|
||||
[source.name, source.embedding_model, source.embedding_dim, utils.format_datetime(source.created_at), ",".join(agent_names)]
|
||||
)
|
||||
|
||||
print(table)
|
||||
else:
|
||||
|
||||
@@ -10,9 +10,10 @@ memgpt load <data-connector-type> --name <dataset-name> [ADDITIONAL ARGS]
|
||||
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import typer
|
||||
import uuid
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.embeddings import embedding_model, check_and_split_text
|
||||
from memgpt.agent_store.storage import StorageConnector
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.metadata import MetadataStore
|
||||
@@ -91,16 +92,45 @@ def store_docs(name, docs, user_id=None, show_progress=True):
|
||||
if user_id is None: # assume running local with single user
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
|
||||
# ensure doc text is not too long
|
||||
# TODO: replace this to instead split up docs that are too large
|
||||
# (this is a temporary fix to avoid breaking the llama index)
|
||||
for doc in docs:
|
||||
doc.text = check_and_split_text(doc.text, config.default_embedding_config.embedding_model)[0]
|
||||
|
||||
# record data source metadata
|
||||
ms = MetadataStore(config)
|
||||
user = ms.get_user(user_id)
|
||||
if user is None:
|
||||
raise ValueError(f"Cannot find user {user_id} in metadata store. Please run 'memgpt configure'.")
|
||||
data_source = Source(user_id=user.id, name=name, created_at=datetime.now())
|
||||
if not ms.get_source(user_id=user.id, source_name=name):
|
||||
|
||||
# create data source record
|
||||
data_source = Source(
|
||||
user_id=user.id,
|
||||
name=name,
|
||||
created_at=datetime.now(),
|
||||
embedding_model=config.default_embedding_config.embedding_model,
|
||||
embedding_dim=config.default_embedding_config.embedding_dim,
|
||||
)
|
||||
existing_source = ms.get_source(user_id=user.id, source_name=name)
|
||||
if not existing_source:
|
||||
ms.create_source(data_source)
|
||||
else:
|
||||
print(f"Source {name} for user {user.id} already exists")
|
||||
print(f"Source {name} for user {user.id} already exists.")
|
||||
if existing_source.embedding_model != data_source.embedding_model:
|
||||
print(
|
||||
f"Warning: embedding model for existing source {existing_source.embedding_model} does not match default {data_source.embedding_model}"
|
||||
)
|
||||
print("Cannot import data into this source without a compatible embedding endpoint.")
|
||||
print("Please run 'memgpt configure' to update the default embedding settings.")
|
||||
return False
|
||||
if existing_source.embedding_dim != data_source.embedding_dim:
|
||||
print(
|
||||
f"Warning: embedding dimension for existing source {existing_source.embedding_dim} does not match default {data_source.embedding_dim}"
|
||||
)
|
||||
print("Cannot import data into this source without a compatible embedding endpoint.")
|
||||
print("Please run 'memgpt configure' to update the default embedding settings.")
|
||||
return False
|
||||
|
||||
# compute and record passages
|
||||
embed_model = embedding_model(config.default_embedding_config)
|
||||
@@ -132,6 +162,8 @@ def store_docs(name, docs, user_id=None, show_progress=True):
|
||||
data_source=name,
|
||||
embedding=node.embedding,
|
||||
metadata=None,
|
||||
embedding_dim=config.default_embedding_config.embedding_dim,
|
||||
embedding_model=config.default_embedding_config.embedding_model,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -154,20 +186,31 @@ def load_index(
|
||||
embed_dict = loaded_index._vector_store._data.embedding_dict
|
||||
node_dict = loaded_index._docstore.docs
|
||||
|
||||
passages = []
|
||||
for node_id, node in node_dict.items():
|
||||
vector = embed_dict[node_id]
|
||||
node.embedding = vector
|
||||
passages.append(Passage(text=node.text, embedding=vector))
|
||||
|
||||
if len(passages) == 0:
|
||||
raise ValueError(f"No passages found in index {dir}")
|
||||
|
||||
# create storage connector
|
||||
config = MemGPTConfig.load()
|
||||
if user_id is None:
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
|
||||
passages = []
|
||||
for node_id, node in node_dict.items():
|
||||
vector = embed_dict[node_id]
|
||||
node.embedding = vector
|
||||
# assume embedding are the same as config
|
||||
passages.append(
|
||||
Passage(
|
||||
text=node.text,
|
||||
embedding=np.array(vector),
|
||||
embedding_dim=config.default_embedding_config.embedding_dim,
|
||||
embedding_model=config.default_embedding_config.embedding_model,
|
||||
)
|
||||
)
|
||||
assert config.default_embedding_config.embedding_dim == len(
|
||||
vector
|
||||
), f"Expected embedding dimension {config.default_embedding_config.embedding_dim}, got {len(vector)}"
|
||||
|
||||
if len(passages) == 0:
|
||||
raise ValueError(f"No passages found in index {dir}")
|
||||
|
||||
insert_passages_into_source(passages, name, user_id, config)
|
||||
except ValueError as e:
|
||||
typer.secho(f"Failed to load index from provided information.\n{e}", fg=typer.colors.RED)
|
||||
@@ -309,7 +352,10 @@ def load_vector_database(
|
||||
# Convert to a list of tuples (text, embedding)
|
||||
passages = []
|
||||
for text, embedding in result:
|
||||
passages.append(Passage(text=text, embedding=embedding))
|
||||
# assume that embeddings are the same model as in config
|
||||
passages.append(
|
||||
Passage(text=text, embedding=embedding, embedding_dim=config.embedding_dim, embedding_model=config.embedding_model)
|
||||
)
|
||||
assert config.embedding_dim == len(embedding), f"Expected embedding dimension {config.embedding_dim}, got {len(embedding)}"
|
||||
|
||||
# create storage connector
|
||||
|
||||
@@ -134,8 +134,11 @@ class MemGPTConfig:
|
||||
"embedding_model": get_field(config, "embedding", "embedding_model"),
|
||||
"embedding_endpoint_type": get_field(config, "embedding", "embedding_endpoint_type"),
|
||||
"embedding_dim": get_field(config, "embedding", "embedding_dim"),
|
||||
"embedding_chunk_size": get_field(config, "embedding", "chunk_size"),
|
||||
"embedding_chunk_size": get_field(config, "embedding", "embedding_chunk_size"),
|
||||
}
|
||||
# Remove null values
|
||||
llm_config_dict = {k: v for k, v in llm_config_dict.items() if v is not None}
|
||||
embedding_config_dict = {k: v for k, v in embedding_config_dict.items() if v is not None}
|
||||
# Correct the types that aren't strings
|
||||
if llm_config_dict["context_window"] is not None:
|
||||
llm_config_dict["context_window"] = int(llm_config_dict["context_window"])
|
||||
|
||||
@@ -3,6 +3,16 @@ from logging import CRITICAL, ERROR, WARN, WARNING, INFO, DEBUG, NOTSET
|
||||
|
||||
MEMGPT_DIR = os.path.join(os.path.expanduser("~"), ".memgpt")
|
||||
|
||||
# embeddings
|
||||
MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset
|
||||
|
||||
# tokenizers
|
||||
EMBEDDING_TO_TOKENIZER_MAP = {
|
||||
"text-embedding-ada-002": "cl100k_base",
|
||||
}
|
||||
EMBEDDING_TO_TOKENIZER_DEFAULT = "cl100k_base"
|
||||
|
||||
|
||||
DEFAULT_MEMGPT_MODEL = "gpt-4"
|
||||
DEFAULT_PERSONA = "sam_pov"
|
||||
DEFAULT_HUMAN = "basic"
|
||||
|
||||
@@ -5,7 +5,7 @@ from abc import abstractmethod
|
||||
from typing import Optional, List, Dict
|
||||
import numpy as np
|
||||
|
||||
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS
|
||||
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS, MAX_EMBEDDING_DIM
|
||||
from memgpt.utils import get_local_time, format_datetime
|
||||
from memgpt.models import chat_completion_response
|
||||
|
||||
@@ -68,6 +68,8 @@ class Message(Record):
|
||||
tool_calls: Optional[List[ToolCall]] = None, # list of tool calls requested
|
||||
tool_call_id: Optional[str] = None,
|
||||
embedding: Optional[np.ndarray] = None,
|
||||
embedding_dim: Optional[int] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
id: Optional[uuid.UUID] = None,
|
||||
):
|
||||
super().__init__(id)
|
||||
@@ -82,6 +84,20 @@ class Message(Record):
|
||||
self.role = role # role (agent/user/function)
|
||||
self.name = name
|
||||
|
||||
# pad and store embeddings
|
||||
if isinstance(embedding, list):
|
||||
embedding = np.array(embedding)
|
||||
self.embedding = (
|
||||
np.pad(embedding, (0, MAX_EMBEDDING_DIM - embedding.shape[0]), mode="constant").tolist() if embedding is not None else None
|
||||
)
|
||||
self.embedding_dim = embedding_dim
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
if self.embedding is not None:
|
||||
assert self.embedding_dim, f"Must specify embedding_dim if providing an embedding"
|
||||
assert self.embedding_model, f"Must specify embedding_model if providing an embedding"
|
||||
assert len(self.embedding) == MAX_EMBEDDING_DIM, f"Embedding must be of length {MAX_EMBEDDING_DIM}"
|
||||
|
||||
# tool (i.e. function) call info (optional)
|
||||
|
||||
# if role == "assistant", this MAY be specified
|
||||
@@ -97,9 +113,6 @@ class Message(Record):
|
||||
assert tool_call_id is None
|
||||
self.tool_call_id = tool_call_id
|
||||
|
||||
# embedding (optional)
|
||||
self.embedding = embedding
|
||||
|
||||
# def __repr__(self):
|
||||
# pass
|
||||
|
||||
@@ -172,7 +185,7 @@ class Message(Record):
|
||||
if "tool_call_id" in openai_message_dict:
|
||||
assert openai_message_dict["tool_call_id"] is None, openai_message_dict
|
||||
|
||||
if "tool_calls" in openai_message_dict and openai_message_dict["tool_calls"] is not None:
|
||||
if "tool_calls" in openai_message_dict:
|
||||
assert openai_message_dict["role"] == "assistant", openai_message_dict
|
||||
|
||||
tool_calls = [
|
||||
@@ -273,6 +286,8 @@ class Passage(Record):
|
||||
text: str,
|
||||
agent_id: Optional[uuid.UUID] = None, # set if contained in agent memory
|
||||
embedding: Optional[np.ndarray] = None,
|
||||
embedding_dim: Optional[int] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
data_source: Optional[str] = None, # None if created by agent
|
||||
doc_id: Optional[uuid.UUID] = None,
|
||||
id: Optional[uuid.UUID] = None,
|
||||
@@ -283,10 +298,23 @@ class Passage(Record):
|
||||
self.agent_id = agent_id
|
||||
self.text = text
|
||||
self.data_source = data_source
|
||||
self.embedding = embedding
|
||||
self.doc_id = doc_id
|
||||
self.metadata = metadata
|
||||
|
||||
# pad and store embeddings
|
||||
if isinstance(embedding, list):
|
||||
embedding = np.array(embedding)
|
||||
self.embedding = (
|
||||
np.pad(embedding, (0, MAX_EMBEDDING_DIM - embedding.shape[0]), mode="constant").tolist() if embedding is not None else None
|
||||
)
|
||||
self.embedding_dim = embedding_dim
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
if self.embedding is not None:
|
||||
assert self.embedding_dim, f"Must specify embedding_dim if providing an embedding"
|
||||
assert self.embedding_model, f"Must specify embedding_model if providing an embedding"
|
||||
assert len(self.embedding) == MAX_EMBEDDING_DIM, f"Embedding must be of length {MAX_EMBEDDING_DIM}"
|
||||
|
||||
assert isinstance(self.user_id, uuid.UUID), f"UUID {self.user_id} must be a UUID type"
|
||||
assert not agent_id or isinstance(self.agent_id, uuid.UUID), f"UUID {self.agent_id} must be a UUID type"
|
||||
assert not doc_id or isinstance(self.doc_id, uuid.UUID), f"UUID {self.doc_id} must be a UUID type"
|
||||
@@ -328,6 +356,13 @@ class EmbeddingConfig:
|
||||
self.embedding_dim = embedding_dim
|
||||
self.embedding_chunk_size = embedding_chunk_size
|
||||
|
||||
# fields cannot be set to None
|
||||
assert self.embedding_endpoint_type
|
||||
assert self.embedding_endpoint
|
||||
assert self.embedding_model
|
||||
assert self.embedding_dim
|
||||
assert self.embedding_chunk_size
|
||||
|
||||
|
||||
class OpenAIEmbeddingConfig(EmbeddingConfig):
|
||||
def __init__(self, openai_key: Optional[str] = None, **kwargs):
|
||||
@@ -434,6 +469,9 @@ class Source:
|
||||
name: str,
|
||||
created_at: Optional[str] = None,
|
||||
id: Optional[uuid.UUID] = None,
|
||||
# embedding info
|
||||
embedding_model: Optional[str] = None,
|
||||
embedding_dim: Optional[int] = None,
|
||||
):
|
||||
if id is None:
|
||||
self.id = uuid.uuid4()
|
||||
@@ -445,3 +483,7 @@ class Source:
|
||||
self.name = name
|
||||
self.user_id = user_id
|
||||
self.created_at = created_at
|
||||
|
||||
# embedding info (optional)
|
||||
self.embedding_dim = embedding_dim
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
@@ -2,10 +2,12 @@ import typer
|
||||
import uuid
|
||||
from typing import Optional, List
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from memgpt.utils import is_valid_url
|
||||
from memgpt.utils import is_valid_url, printd
|
||||
from memgpt.data_types import EmbeddingConfig
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.constants import MAX_EMBEDDING_DIM, EMBEDDING_TO_TOKENIZER_MAP, EMBEDDING_TO_TOKENIZER_DEFAULT
|
||||
|
||||
from llama_index.embeddings import OpenAIEmbedding, AzureOpenAIEmbedding
|
||||
from llama_index.bridge.pydantic import PrivateAttr
|
||||
@@ -14,6 +16,34 @@ from llama_index.embeddings.huggingface_utils import format_text
|
||||
import tiktoken
|
||||
|
||||
|
||||
def check_and_split_text(text: str, embedding_model: str) -> List[str]:
|
||||
"""Split text into chunks of max_length tokens or less"""
|
||||
|
||||
if embedding_model in EMBEDDING_TO_TOKENIZER_MAP:
|
||||
encoding = tiktoken.get_encoding(EMBEDDING_TO_TOKENIZER_MAP[embedding_model])
|
||||
else:
|
||||
print(f"Warning: couldn't find tokenizer for model {embedding_model}, using default tokenizer {EMBEDDING_TO_TOKENIZER_DEFAULT}")
|
||||
encoding = tiktoken.get_encoding(EMBEDDING_TO_TOKENIZER_DEFAULT)
|
||||
|
||||
num_tokens = len(encoding.encode(text))
|
||||
|
||||
# determine max length
|
||||
if hasattr(encoding, "max_length"):
|
||||
max_length = encoding.max_length
|
||||
else:
|
||||
# TODO: figure out the real number
|
||||
printd(f"Warning: couldn't find max_length for tokenizer {embedding_model}, using default max_length 8191")
|
||||
max_length = 8191
|
||||
|
||||
# truncate text if too long
|
||||
if num_tokens > max_length:
|
||||
# TODO: split this into two pieces of text instead of truncating
|
||||
print(f"Warning: text is too long ({num_tokens} tokens), truncating to {max_length} tokens.")
|
||||
text = format_text(text, embedding_model, max_length=max_length)
|
||||
|
||||
return [text]
|
||||
|
||||
|
||||
class EmbeddingEndpoint(BaseEmbedding):
|
||||
|
||||
"""Implementation for OpenAI compatible endpoint"""
|
||||
@@ -38,7 +68,6 @@ class EmbeddingEndpoint(BaseEmbedding):
|
||||
self._user = user
|
||||
self._base_url = base_url
|
||||
self._timeout = timeout
|
||||
self._encoding = tiktoken.get_encoding(model)
|
||||
super().__init__(
|
||||
model_name=model,
|
||||
)
|
||||
@@ -47,10 +76,6 @@ class EmbeddingEndpoint(BaseEmbedding):
|
||||
def class_name(cls) -> str:
|
||||
return "EmbeddingEndpoint"
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""Count tokens using the embedding model's tokenizer"""
|
||||
return len(self._encoding.encode(text))
|
||||
|
||||
def _call_api(self, text: str) -> List[float]:
|
||||
if not is_valid_url(self._base_url):
|
||||
raise ValueError(
|
||||
@@ -58,12 +83,6 @@ class EmbeddingEndpoint(BaseEmbedding):
|
||||
)
|
||||
import httpx
|
||||
|
||||
# If necessary, truncate text to fit in the embedding model's max sequence length (usually 512)
|
||||
num_tokens = self.count_tokens(text)
|
||||
max_length = self._encoding.max_length
|
||||
if num_tokens > max_length:
|
||||
text = format_text(text, self.model_name, max_length=max_length)
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
json_data = {"input": text, "model": self.model_name, "user": self._user}
|
||||
|
||||
@@ -157,6 +176,14 @@ def default_embedding_model():
|
||||
return HuggingFaceEmbedding(model_name=model)
|
||||
|
||||
|
||||
def query_embedding(embedding_model, query_text: str):
|
||||
"""Generate padded embedding for querying database"""
|
||||
query_vec = embedding_model.get_text_embedding(query_text)
|
||||
query_vec = np.array(query_vec)
|
||||
query_vec = np.pad(query_vec, (0, MAX_EMBEDDING_DIM - query_vec.shape[0]), mode="constant").tolist()
|
||||
return query_vec
|
||||
|
||||
|
||||
def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None):
|
||||
"""Return LlamaIndex embedding model to use for embeddings"""
|
||||
|
||||
@@ -181,11 +208,6 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
|
||||
api_version=credentials.azure_version,
|
||||
)
|
||||
elif endpoint_type == "hugging-face":
|
||||
try:
|
||||
return EmbeddingEndpoint(model=config.embedding_model, base_url=config.embedding_endpoint, user=user_id)
|
||||
except Exception as e:
|
||||
# TODO: remove, this is just to get passing tests
|
||||
print(e)
|
||||
return default_embedding_model()
|
||||
return EmbeddingEndpoint(model=config.embedding_model, base_url=config.embedding_endpoint, user=user_id)
|
||||
else:
|
||||
return default_embedding_model()
|
||||
|
||||
@@ -111,6 +111,9 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
|
||||
elif user_input.lower() == "/attach":
|
||||
# TODO: check if agent already has it
|
||||
|
||||
# TODO: check to ensure source embedding dimentions/model match agents, and disallow attachment if not
|
||||
# TODO: alternatively, only list sources with compatible embeddings, and print warning about non-compatible sources
|
||||
|
||||
data_source_options = ms.list_sources(user_id=memgpt_agent.agent_state.user_id)
|
||||
data_source_options = [s.name for s in data_source_options]
|
||||
if len(data_source_options) == 0:
|
||||
@@ -120,7 +123,24 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
|
||||
bold=True,
|
||||
)
|
||||
continue
|
||||
data_source = questionary.select("Select data source", choices=data_source_options).ask()
|
||||
|
||||
# determine what sources are valid to be attached to this agent
|
||||
valid_options = []
|
||||
invalid_options = []
|
||||
for source in data_source_options:
|
||||
if source.embedding_model == memgpt_agent.embedding_model and source.embedding_dim == memgpt_agent.embedding_dim:
|
||||
valid_options.append(source.name)
|
||||
else:
|
||||
invalid_options.append(source.name)
|
||||
|
||||
# print warning about invalid sources
|
||||
typer.secho(
|
||||
f"Warning: the following sources are not compatible with this agent's embedding model and dimension: {invalid_options}",
|
||||
fg=typer.colors.YELLOW,
|
||||
)
|
||||
|
||||
# prompt user for data source selection
|
||||
data_source = questionary.select("Select data source", choices=valid_options).ask()
|
||||
|
||||
# attach new data
|
||||
attach(memgpt_agent.config.name, data_source)
|
||||
|
||||
@@ -7,7 +7,7 @@ from memgpt.utils import get_local_time, printd, count_tokens, validate_date_for
|
||||
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
|
||||
from memgpt.llm_api_tools import create
|
||||
from memgpt.data_types import Message, Passage, AgentState
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.embeddings import embedding_model, query_embedding
|
||||
from llama_index import Document
|
||||
from llama_index.node_parser import SimpleNodeParser
|
||||
|
||||
@@ -372,6 +372,7 @@ class EmbeddingArchivalMemory(ArchivalMemory):
|
||||
# create embedding model
|
||||
self.embed_model = embedding_model(agent_state.embedding_config)
|
||||
self.embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
|
||||
assert self.embedding_chunk_size, f"Must set {agent_state.embedding_config.embedding_chunk_size}"
|
||||
|
||||
# create storage backend
|
||||
self.storage = StorageConnector.get_archival_storage_connector(user_id=agent_state.user_id, agent_id=agent_state.id)
|
||||
@@ -384,6 +385,8 @@ class EmbeddingArchivalMemory(ArchivalMemory):
|
||||
agent_id=self.agent_state.id,
|
||||
text=text,
|
||||
embedding=embedding,
|
||||
embedding_dim=self.agent_state.embedding_config.embedding_dim,
|
||||
embedding_model=self.agent_state.embedding_config.embedding_model,
|
||||
)
|
||||
|
||||
def save(self):
|
||||
@@ -432,7 +435,7 @@ class EmbeddingArchivalMemory(ArchivalMemory):
|
||||
try:
|
||||
if query_string not in self.cache:
|
||||
# self.cache[query_string] = self.retriever.retrieve(query_string)
|
||||
query_vec = self.embed_model.get_text_embedding(query_string)
|
||||
query_vec = query_embedding(self.embed_model, query_string)
|
||||
self.cache[query_string] = self.storage.query(query_string, query_vec, top_k=self.top_k)
|
||||
|
||||
start = int(start if start else 0)
|
||||
|
||||
@@ -64,6 +64,7 @@ class LLMConfigColumn(TypeDecorator):
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
print("GET VALUE", value)
|
||||
if value:
|
||||
return LLMConfig(**value)
|
||||
return value
|
||||
@@ -168,6 +169,8 @@ class SourceModel(Base):
|
||||
user_id = Column(CommonUUID, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
embedding_dim = Column(BIGINT)
|
||||
embedding_model = Column(String)
|
||||
|
||||
# TODO: add num passages
|
||||
|
||||
@@ -175,7 +178,14 @@ class SourceModel(Base):
|
||||
return f"<Source(passage_id='{self.id}', name='{self.name}')>"
|
||||
|
||||
def to_record(self) -> Source:
|
||||
return Source(id=self.id, user_id=self.user_id, name=self.name, created_at=self.created_at)
|
||||
return Source(
|
||||
id=self.id,
|
||||
user_id=self.user_id,
|
||||
name=self.name,
|
||||
created_at=self.created_at,
|
||||
embedding_dim=self.embedding_dim,
|
||||
embedding_model=self.embedding_model,
|
||||
)
|
||||
|
||||
|
||||
class AgentSourceMappingModel(Base):
|
||||
@@ -288,6 +298,7 @@ class MetadataStore:
|
||||
|
||||
@enforce_types
|
||||
def list_agents(self, user_id: uuid.UUID) -> List[AgentState]:
|
||||
print("query agents", user_id)
|
||||
results = self.session.query(AgentModel).filter(AgentModel.user_id == user_id).all()
|
||||
return [r.to_record() for r in results]
|
||||
|
||||
|
||||
@@ -150,7 +150,20 @@ def migrate_source(source_name: str):
|
||||
for node in nodes:
|
||||
# print(len(node.embedding))
|
||||
# TODO: make sure embedding config matches embedding size?
|
||||
passages.append(Passage(user_id=user.id, data_source=source_name, text=node.text, embedding=node.embedding))
|
||||
if len(node.embedding) != config.default_embedding_config.embedding_dim:
|
||||
raise ValueError(
|
||||
f"Cannot migrate source {source_name} due to incompatible embedding dimentions. Please re-load this source with `memgpt load`."
|
||||
)
|
||||
passages.append(
|
||||
Passage(
|
||||
user_id=user.id,
|
||||
data_source=source_name,
|
||||
text=node.text,
|
||||
embedding=node.embedding,
|
||||
embedding_dim=config.default_embedding_config.embedding_dim,
|
||||
embedding_model=config.default_embedding_config.embedding_model,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(passages) > 0, f"Source {source_name} has no passages"
|
||||
conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config=config, user_id=user_id)
|
||||
@@ -313,9 +326,18 @@ def migrate_agent(agent_name: str):
|
||||
nodes = pickle.load(open(archival_filename, "rb"))
|
||||
passages = []
|
||||
for node in nodes:
|
||||
# print(len(node.embedding))
|
||||
# TODO: make sure embeding size matches embedding config?
|
||||
passages.append(Passage(user_id=user.id, agent_id=agent_state.id, text=node.text, embedding=node.embedding))
|
||||
if len(node.embedding) != config.default_embedding_config.embedding_dim:
|
||||
raise ValueError(f"Cannot migrate agent {agent_state.name} due to incompatible embedding dimentions.")
|
||||
passages.append(
|
||||
Passage(
|
||||
user_id=user.id,
|
||||
agent_id=agent_state.id,
|
||||
text=node.text,
|
||||
embedding=node.embedding,
|
||||
embedding_dim=agent_state.embedding_config.embedding_dim,
|
||||
embedding_model=agent_state.embedding_config.embedding_model,
|
||||
)
|
||||
)
|
||||
if len(passages) > 0:
|
||||
agent.persistence_manager.archival_memory.storage.insert_many(passages)
|
||||
print(f"Inserted {len(passages)} passages into archival memory")
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import openai
|
||||
# import openai
|
||||
|
||||
from memgpt.constants import JSON_ENSURE_ASCII
|
||||
|
||||
@@ -68,7 +68,8 @@ class ApiType(Enum):
|
||||
elif label.lower() in ("open_ai", "openai"):
|
||||
return ApiType.OPEN_AI
|
||||
else:
|
||||
raise openai.error.InvalidAPIType(
|
||||
# raise openai.error.InvalidAPIType(
|
||||
raise Exception(
|
||||
"The API type provided in invalid. Please select one of the supported API types: 'azure', 'azure_ad', 'open_ai'"
|
||||
)
|
||||
|
||||
@@ -361,7 +362,8 @@ class OpenAIObject(dict):
|
||||
|
||||
@property
|
||||
def typed_api_type(self):
|
||||
return ApiType.from_str(self.api_type) if self.api_type else ApiType.from_str(openai.api_type)
|
||||
# return ApiType.from_str(self.api_type) if self.api_type else ApiType.from_str(openai.api_type)
|
||||
return ApiType.from_str(self.api_type) if self.api_type else ApiType.from_str(ApiType.OPEN_AI)
|
||||
|
||||
# This class overrides __setitem__ to throw exceptions on inputs that it
|
||||
# doesn't like. This can cause problems when we try to copy an object
|
||||
|
||||
123
tests/test_different_embedding_size.py
Normal file
123
tests/test_different_embedding_size.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import uuid
|
||||
import os
|
||||
|
||||
from memgpt import MemGPT
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt import constants
|
||||
from memgpt.data_types import LLMConfig, EmbeddingConfig, AgentState, Passage
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from .utils import wipe_config
|
||||
import uuid
|
||||
|
||||
|
||||
test_agent_name = f"test_client_{str(uuid.uuid4())}"
|
||||
test_agent_state = None
|
||||
client = None
|
||||
|
||||
test_agent_state_post_message = None
|
||||
test_user_id = uuid.uuid4()
|
||||
|
||||
|
||||
def generate_passages(user, agent):
|
||||
# Note: the database will filter out rows that do not correspond to agent1 and test_user by default.
|
||||
texts = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
embed_model = embedding_model(agent.embedding_config)
|
||||
orig_embeddings = []
|
||||
passages = []
|
||||
for text in texts:
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
orig_embeddings.append(list(embedding))
|
||||
passages.append(
|
||||
Passage(
|
||||
user_id=user.id,
|
||||
agent_id=agent.id,
|
||||
text=text,
|
||||
embedding=embedding,
|
||||
embedding_dim=agent.embedding_config.embedding_dim,
|
||||
embedding_model=agent.embedding_config.embedding_model,
|
||||
)
|
||||
)
|
||||
return passages, orig_embeddings
|
||||
|
||||
|
||||
def test_create_user():
|
||||
wipe_config()
|
||||
|
||||
# create client
|
||||
client = MemGPT(quickstart="openai", user_id=test_user_id)
|
||||
|
||||
# create user
|
||||
user = client.server.create_user({"id": test_user_id})
|
||||
|
||||
# openai: create agent
|
||||
openai_agent = client.create_agent(
|
||||
{
|
||||
"user_id": test_user_id,
|
||||
"name": "openai_agent",
|
||||
}
|
||||
)
|
||||
assert (
|
||||
openai_agent.embedding_config.embedding_endpoint_type == "openai"
|
||||
), f"openai_agent.embedding_config.embedding_endpoint_type={openai_agent.embedding_config.embedding_endpoint_type}"
|
||||
|
||||
# openai: add passages
|
||||
passages, openai_embeddings = generate_passages(user, openai_agent)
|
||||
openai_agent_run = client.server._get_or_load_agent(user_id=user.id, agent_id=openai_agent.id)
|
||||
openai_agent_run.persistence_manager.archival_memory.storage.insert_many(passages)
|
||||
|
||||
# hosted: create agent
|
||||
hosted_agent = client.create_agent(
|
||||
{
|
||||
"user_id": test_user_id,
|
||||
"name": "hosted_agent",
|
||||
"embedding_config": EmbeddingConfig(
|
||||
embedding_endpoint_type="hugging-face",
|
||||
embedding_model="BAAI/bge-large-en-v1.5",
|
||||
embedding_endpoint="https://embeddings.memgpt.ai",
|
||||
embedding_dim=1024,
|
||||
),
|
||||
}
|
||||
)
|
||||
# check to make sure endpoint overriden
|
||||
assert (
|
||||
hosted_agent.embedding_config.embedding_endpoint_type == "hugging-face"
|
||||
), f"hosted_agent.embedding_config.embedding_endpoint_type={hosted_agent.embedding_config.embedding_endpoint_type}"
|
||||
|
||||
# hosted: add passages
|
||||
passages, hosted_embeddings = generate_passages(user, hosted_agent)
|
||||
hosted_agent_run = client.server._get_or_load_agent(user_id=user.id, agent_id=hosted_agent.id)
|
||||
hosted_agent_run.persistence_manager.archival_memory.storage.insert_many(passages)
|
||||
|
||||
# test passage dimentionality
|
||||
config = MemGPTConfig.load()
|
||||
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user.id)
|
||||
storage.filters = {} # clear filters to be able to get all passages
|
||||
passages = storage.get_all()
|
||||
for passage in passages:
|
||||
if passage.agent_id == hosted_agent.id:
|
||||
assert (
|
||||
passage.embedding_dim == hosted_agent.embedding_config.embedding_dim
|
||||
), f"passage.embedding_dim={passage.embedding_dim} != hosted_agent.embedding_config.embedding_dim={hosted_agent.embedding_config.embedding_dim}"
|
||||
|
||||
# ensure was in original embeddings
|
||||
embedding = passage.embedding[: passage.embedding_dim]
|
||||
assert embedding in hosted_embeddings, f"embedding={embedding} not in hosted_embeddings={hosted_embeddings}"
|
||||
|
||||
# make sure all zeros
|
||||
assert not any(
|
||||
passage.embedding[passage.embedding_dim :]
|
||||
), f"passage.embedding[passage.embedding_dim:]={passage.embedding[passage.embedding_dim:]}"
|
||||
elif passage.agent_id == openai_agent.id:
|
||||
assert (
|
||||
passage.embedding_dim == openai_agent.embedding_config.embedding_dim
|
||||
), f"passage.embedding_dim={passage.embedding_dim} != openai_agent.embedding_config.embedding_dim={openai_agent.embedding_config.embedding_dim}"
|
||||
|
||||
# ensure was in original embeddings
|
||||
embedding = passage.embedding[: passage.embedding_dim]
|
||||
assert embedding in openai_embeddings, f"embedding={embedding} not in openai_embeddings={openai_embeddings}"
|
||||
|
||||
# make sure all zeros
|
||||
assert not any(
|
||||
passage.embedding[passage.embedding_dim :]
|
||||
), f"passage.embedding[passage.embedding_dim:]={passage.embedding[passage.embedding_dim:]}"
|
||||
@@ -88,13 +88,13 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
|
||||
persona=user.default_persona,
|
||||
human=user.default_human,
|
||||
llm_config=config.default_llm_config,
|
||||
embedding_config=config.default_embedding_config,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
ms.delete_user(user.id)
|
||||
ms.create_user(user)
|
||||
ms.create_agent(agent)
|
||||
user = ms.get_user(user.id)
|
||||
print("Got user:", user, config.default_embedding_config)
|
||||
print("Got user:", user, embedding_config)
|
||||
|
||||
# setup storage connectors
|
||||
print("Creating storage connectors...")
|
||||
|
||||
@@ -104,7 +104,14 @@ def test_server():
|
||||
for text in archival_memories:
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
agent.persistence_manager.archival_memory.storage.insert(
|
||||
Passage(user_id=user.id, agent_id=agent_state.id, text=text, embedding=embedding)
|
||||
Passage(
|
||||
user_id=user.id,
|
||||
agent_id=agent_state.id,
|
||||
text=text,
|
||||
embedding=embedding,
|
||||
embedding_dim=agent.agent_state.embedding_config.embedding_dim,
|
||||
embedding_model=agent.agent_state.embedding_config.embedding_model,
|
||||
)
|
||||
)
|
||||
|
||||
# add data into recall memory
|
||||
|
||||
@@ -4,13 +4,14 @@ import uuid
|
||||
import pytest
|
||||
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.embeddings import embedding_model, query_embedding
|
||||
from memgpt.data_types import Message, Passage, EmbeddingConfig, AgentState, OpenAIEmbeddingConfig
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.data_types import User
|
||||
from memgpt.constants import MAX_EMBEDDING_DIM
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
@@ -33,10 +34,23 @@ def generate_passages(embed_model):
|
||||
# embeddings: use openai if env is set, otherwise local
|
||||
passages = []
|
||||
for text, _, _, agent_id, id in zip(texts, dates, roles, agent_ids, ids):
|
||||
embedding = None
|
||||
embedding, embedding_model, embedding_dim = None, None, None
|
||||
if embed_model:
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
passages.append(Passage(user_id=user_id, text=text, agent_id=agent_id, embedding=embedding, data_source="test_source", id=id))
|
||||
embedding_model = "gpt-4"
|
||||
embedding_dim = len(embedding)
|
||||
passages.append(
|
||||
Passage(
|
||||
user_id=user_id,
|
||||
text=text,
|
||||
agent_id=agent_id,
|
||||
embedding=embedding,
|
||||
data_source="test_source",
|
||||
id=id,
|
||||
embedding_dim=embedding_dim,
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
)
|
||||
return passages
|
||||
|
||||
|
||||
@@ -45,11 +59,24 @@ def generate_messages(embed_model):
|
||||
"""Generate list of 3 Message objects"""
|
||||
messages = []
|
||||
for text, date, role, agent_id, id in zip(texts, dates, roles, agent_ids, ids):
|
||||
embedding = None
|
||||
embedding, embedding_model, embedding_dim = None, None, None
|
||||
if embed_model:
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
embedding_model = "gpt-4"
|
||||
embedding_dim = len(embedding)
|
||||
messages.append(
|
||||
Message(user_id=user_id, text=text, agent_id=agent_id, role=role, created_at=date, id=id, model="gpt-4", embedding=embedding)
|
||||
Message(
|
||||
user_id=user_id,
|
||||
text=text,
|
||||
agent_id=agent_id,
|
||||
role=role,
|
||||
created_at=date,
|
||||
id=id,
|
||||
model="gpt-4",
|
||||
embedding=embedding,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dim=embedding_dim,
|
||||
)
|
||||
)
|
||||
print(messages[-1].text)
|
||||
return messages
|
||||
@@ -165,6 +192,14 @@ def test_storage(storage_connector, table_type, clear_dynamically_created_models
|
||||
else:
|
||||
raise NotImplementedError(f"Table type {table_type} not implemented")
|
||||
|
||||
# check record dimentions
|
||||
print("TABLE TYPE", table_type, type(records[0]), len(records[0].embedding))
|
||||
if embed_model:
|
||||
assert len(records[0].embedding) == MAX_EMBEDDING_DIM, f"Expected {MAX_EMBEDDING_DIM}, got {len(records[0].embedding)}"
|
||||
assert (
|
||||
records[0].embedding_dim == embedding_config.embedding_dim
|
||||
), f"Expected {embedding_config.embedding_dim}, got {records[0].embedding_dim}"
|
||||
|
||||
# test: insert
|
||||
conn.insert(records[0])
|
||||
assert conn.size() == 1, f"Expected 1 record, got {conn.size()}: {conn.get_all()}"
|
||||
@@ -208,7 +243,7 @@ def test_storage(storage_connector, table_type, clear_dynamically_created_models
|
||||
# test: query (vector)
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
query = "why was she crying"
|
||||
query_vec = embed_model.get_text_embedding(query)
|
||||
query_vec = query_embedding(embed_model, query)
|
||||
res = conn.query(None, query_vec, top_k=2)
|
||||
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||
print("Archival memory results", res)
|
||||
|
||||
@@ -15,6 +15,9 @@ def wipe_config():
|
||||
config_path = MemGPTConfig.config_path
|
||||
# TODO delete file config_path
|
||||
os.remove(config_path)
|
||||
assert not MemGPTConfig.exists(), "Config should not exist after deletion"
|
||||
else:
|
||||
print("No config to wipe", MemGPTConfig.config_path)
|
||||
|
||||
|
||||
def wipe_memgpt_home():
|
||||
|
||||
Reference in New Issue
Block a user