diff --git a/memgpt/agent.py b/memgpt/agent.py index ca5d379a..59bf9cbe 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -10,7 +10,7 @@ from typing import List, Tuple from box import Box -from memgpt.data_types import AgentState, Message +from memgpt.data_types import AgentState, Message, EmbeddingConfig from memgpt.models import chat_completion_response from memgpt.interface import AgentInterface from memgpt.persistence_manager import PersistenceManager, LocalStateManager @@ -897,3 +897,10 @@ class Agent(object): state=updated_state, ) return self.agent_state + + def migrate_embedding(self, embedding_config: EmbeddingConfig): + """Migrate the agent to a new embedding""" + # TODO: archival memory + + # TODO: recall memory + raise NotImplementedError() diff --git a/memgpt/agent_store/db.py b/memgpt/agent_store/db.py index e5a2795d..ba60e2c2 100644 --- a/memgpt/agent_store/db.py +++ b/memgpt/agent_store/db.py @@ -1,5 +1,4 @@ import os - from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, DateTime from sqlalchemy import func, or_, and_ from sqlalchemy import desc, asc @@ -22,6 +21,7 @@ from memgpt.config import MemGPTConfig from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.config import MemGPTConfig from memgpt.utils import printd +from memgpt.constants import MAX_EMBEDDING_DIM from memgpt.data_types import Record, Message, Passage, ToolCall from memgpt.metadata import MetadataStore @@ -109,22 +109,6 @@ def get_db_model( agent_id: Optional[uuid.UUID] = None, dialect="postgresql", ): - # get embedding dimention info - # TODO: Need to remove this and just pass in AgentState/User instead - ms = MetadataStore(config) - if agent_id and ms.get_agent(agent_id): - agent = ms.get_agent(agent_id) - embedding_dim = agent.embedding_config.embedding_dim - else: - user = ms.get_user(user_id) - if user is None: - raise ValueError(f"User {user_id} not found") - embedding_dim = config.default_embedding_config.embedding_dim - - # this cannot be the case if we are making an agent-specific table - assert table_type != TableType.RECALL_MEMORY, f"Agent {agent_id} not found" - assert table_type != TableType.ARCHIVAL_MEMORY, f"Agent {agent_id} not found" - # Define a helper function to create or get the model class def create_or_get_model(class_name, base_model, table_name): if class_name in globals(): @@ -156,7 +140,9 @@ def get_db_model( else: from pgvector.sqlalchemy import Vector - embedding = mapped_column(Vector(embedding_dim)) + embedding = mapped_column(Vector(MAX_EMBEDDING_DIM)) + embedding_dim = Column(BIGINT) + embedding_model = Column(String) metadata_ = Column(MutableJson) @@ -167,6 +153,8 @@ def get_db_model( return Passage( text=self.text, embedding=self.embedding, + embedding_dim=self.embedding_dim, + embedding_model=self.embedding_model, doc_id=self.doc_id, user_id=self.user_id, id=self.id, @@ -216,7 +204,9 @@ def get_db_model( else: from pgvector.sqlalchemy import Vector - embedding = mapped_column(Vector(embedding_dim)) + embedding = mapped_column(Vector(MAX_EMBEDDING_DIM)) + embedding_dim = Column(BIGINT) + embedding_model = Column(String) # Add a datetime column, with default value as the current time created_at = Column(DateTime(timezone=True), server_default=func.now()) @@ -235,6 +225,8 @@ def get_db_model( tool_calls=self.tool_calls, tool_call_id=self.tool_call_id, embedding=self.embedding, + embedding_dim=self.embedding_dim, + embedding_model=self.embedding_model, created_at=self.created_at, id=self.id, ) @@ -440,6 +432,7 @@ class PostgresStorageConnector(SQLStorageConnector): for c in self.db_model.__table__.columns: if c.name == "embedding": assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}" + Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist session_maker = sessionmaker(bind=self.engine) diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index 0045cd43..4e4a06ce 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -68,6 +68,28 @@ def set_config_with_dict(new_config: dict) -> bool: else: printd(f"Skipping new config {k}: {v} == {new_config[k]}") + # update embedding config + for k, v in vars(old_config.default_embedding_config).items(): + if k in new_config: + if v != new_config[k]: + printd(f"Replacing config {k}: {v} -> {new_config[k]}") + modified = True + # old_config[k] = new_config[k] + setattr(old_config.default_embedding_config, k, new_config[k]) + else: + printd(f"Skipping new config {k}: {v} == {new_config[k]}") + + # update llm config + for k, v in vars(old_config.default_llm_config).items(): + if k in new_config: + if v != new_config[k]: + printd(f"Replacing config {k}: {v} -> {new_config[k]}") + modified = True + # old_config[k] = new_config[k] + setattr(old_config.default_llm_config, k, new_config[k]) + else: + printd(f"Skipping new config {k}: {v} == {new_config[k]}") + if modified: printd(f"Saving new config file.") old_config.save() @@ -416,11 +438,14 @@ def run( user_id = uuid.UUID(config.anon_clientid) user = ms.get_user(user_id=user_id) if user is None: + print("Creating user", user_id) ms.create_user(User(id=user_id)) user = ms.get_user(user_id=user_id) if user is None: typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED) sys.exit(1) + else: + print("existing user", user, user_id) # override with command line arguments if debug: diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index 85a7f7f3..b68e10c9 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -208,11 +208,11 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_ other_option_str = "[enter model name manually]" # Check if the model we have set already is even in the list (informs our default) - valid_model = config.model in hardcoded_model_options + valid_model = config.default_llm_config.model in hardcoded_model_options model = questionary.select( "Select default model (recommended: gpt-4):", choices=hardcoded_model_options + [see_all_option_str, other_option_str], - default=config.model if valid_model else hardcoded_model_options[0], + default=config.default_llm_config.model if valid_model else hardcoded_model_options[0], ).ask() if model is None: raise KeyboardInterrupt @@ -409,6 +409,7 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden embedding_endpoint_type = "openai" embedding_endpoint = "https://api.openai.com/v1" embedding_dim = 1536 + embedding_model = "text-embedding-ada-002" elif embedding_provider == "azure": # check for necessary vars @@ -422,6 +423,7 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden embedding_endpoint_type = "azure" embedding_endpoint = azure_creds["azure_embedding_endpoint"] embedding_dim = 1536 + embedding_model = "text-embedding-ada-002" elif embedding_provider == "hugging-face": # configure hugging face embedding endpoint (https://github.com/huggingface/text-embeddings-inference) @@ -676,7 +678,7 @@ def list(arg: Annotated[ListChoice, typer.Argument]): if arg == ListChoice.agents: """List all agents""" table = PrettyTable() - table.field_names = ["Name", "Model", "Persona", "Human", "Data Source", "Create Time"] + table.field_names = ["Name", "LLM Model", "Embedding Model", "Embedding Dim", "Persona", "Human", "Data Source", "Create Time"] for agent in tqdm(ms.list_agents(user_id=user_id)): source_ids = ms.list_attached_sources(agent_id=agent.id) source_names = [ms.get_source(source_id=source_id).name for source_id in source_ids] @@ -684,6 +686,8 @@ def list(arg: Annotated[ListChoice, typer.Argument]): [ agent.name, agent.llm_config.model, + agent.embedding_config.embedding_model, + agent.embedding_config.embedding_dim, agent.persona, agent.human, ",".join(source_names), @@ -715,7 +719,7 @@ def list(arg: Annotated[ListChoice, typer.Argument]): # create table table = PrettyTable() - table.field_names = ["Name", "Created At", "Agents"] + table.field_names = ["Name", "Embedding Model", "Embedding Dim", "Created At", "Agents"] # TODO: eventually look accross all storage connections # TODO: add data source stats # TODO: connect to agents @@ -726,7 +730,9 @@ def list(arg: Annotated[ListChoice, typer.Argument]): agent_ids = ms.list_attached_agents(source_id=source.id) agent_names = [ms.get_agent(agent_id=agent_id).name for agent_id in agent_ids] - table.add_row([source.name, utils.format_datetime(source.created_at), ",".join(agent_names)]) + table.add_row( + [source.name, source.embedding_model, source.embedding_dim, utils.format_datetime(source.created_at), ",".join(agent_names)] + ) print(table) else: diff --git a/memgpt/cli/cli_load.py b/memgpt/cli/cli_load.py index 302ba9ee..a473422e 100644 --- a/memgpt/cli/cli_load.py +++ b/memgpt/cli/cli_load.py @@ -10,9 +10,10 @@ memgpt load --name [ADDITIONAL ARGS] from typing import List from tqdm import tqdm +import numpy as np import typer import uuid -from memgpt.embeddings import embedding_model +from memgpt.embeddings import embedding_model, check_and_split_text from memgpt.agent_store.storage import StorageConnector from memgpt.config import MemGPTConfig from memgpt.metadata import MetadataStore @@ -91,16 +92,45 @@ def store_docs(name, docs, user_id=None, show_progress=True): if user_id is None: # assume running local with single user user_id = uuid.UUID(config.anon_clientid) + # ensure doc text is not too long + # TODO: replace this to instead split up docs that are too large + # (this is a temporary fix to avoid breaking the llama index) + for doc in docs: + doc.text = check_and_split_text(doc.text, config.default_embedding_config.embedding_model)[0] + # record data source metadata ms = MetadataStore(config) user = ms.get_user(user_id) if user is None: raise ValueError(f"Cannot find user {user_id} in metadata store. Please run 'memgpt configure'.") - data_source = Source(user_id=user.id, name=name, created_at=datetime.now()) - if not ms.get_source(user_id=user.id, source_name=name): + + # create data source record + data_source = Source( + user_id=user.id, + name=name, + created_at=datetime.now(), + embedding_model=config.default_embedding_config.embedding_model, + embedding_dim=config.default_embedding_config.embedding_dim, + ) + existing_source = ms.get_source(user_id=user.id, source_name=name) + if not existing_source: ms.create_source(data_source) else: - print(f"Source {name} for user {user.id} already exists") + print(f"Source {name} for user {user.id} already exists.") + if existing_source.embedding_model != data_source.embedding_model: + print( + f"Warning: embedding model for existing source {existing_source.embedding_model} does not match default {data_source.embedding_model}" + ) + print("Cannot import data into this source without a compatible embedding endpoint.") + print("Please run 'memgpt configure' to update the default embedding settings.") + return False + if existing_source.embedding_dim != data_source.embedding_dim: + print( + f"Warning: embedding dimension for existing source {existing_source.embedding_dim} does not match default {data_source.embedding_dim}" + ) + print("Cannot import data into this source without a compatible embedding endpoint.") + print("Please run 'memgpt configure' to update the default embedding settings.") + return False # compute and record passages embed_model = embedding_model(config.default_embedding_config) @@ -132,6 +162,8 @@ def store_docs(name, docs, user_id=None, show_progress=True): data_source=name, embedding=node.embedding, metadata=None, + embedding_dim=config.default_embedding_config.embedding_dim, + embedding_model=config.default_embedding_config.embedding_model, ) ) @@ -154,20 +186,31 @@ def load_index( embed_dict = loaded_index._vector_store._data.embedding_dict node_dict = loaded_index._docstore.docs - passages = [] - for node_id, node in node_dict.items(): - vector = embed_dict[node_id] - node.embedding = vector - passages.append(Passage(text=node.text, embedding=vector)) - - if len(passages) == 0: - raise ValueError(f"No passages found in index {dir}") - # create storage connector config = MemGPTConfig.load() if user_id is None: user_id = uuid.UUID(config.anon_clientid) + passages = [] + for node_id, node in node_dict.items(): + vector = embed_dict[node_id] + node.embedding = vector + # assume embedding are the same as config + passages.append( + Passage( + text=node.text, + embedding=np.array(vector), + embedding_dim=config.default_embedding_config.embedding_dim, + embedding_model=config.default_embedding_config.embedding_model, + ) + ) + assert config.default_embedding_config.embedding_dim == len( + vector + ), f"Expected embedding dimension {config.default_embedding_config.embedding_dim}, got {len(vector)}" + + if len(passages) == 0: + raise ValueError(f"No passages found in index {dir}") + insert_passages_into_source(passages, name, user_id, config) except ValueError as e: typer.secho(f"Failed to load index from provided information.\n{e}", fg=typer.colors.RED) @@ -309,7 +352,10 @@ def load_vector_database( # Convert to a list of tuples (text, embedding) passages = [] for text, embedding in result: - passages.append(Passage(text=text, embedding=embedding)) + # assume that embeddings are the same model as in config + passages.append( + Passage(text=text, embedding=embedding, embedding_dim=config.embedding_dim, embedding_model=config.embedding_model) + ) assert config.embedding_dim == len(embedding), f"Expected embedding dimension {config.embedding_dim}, got {len(embedding)}" # create storage connector diff --git a/memgpt/config.py b/memgpt/config.py index a4c24c44..6972c609 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -134,8 +134,11 @@ class MemGPTConfig: "embedding_model": get_field(config, "embedding", "embedding_model"), "embedding_endpoint_type": get_field(config, "embedding", "embedding_endpoint_type"), "embedding_dim": get_field(config, "embedding", "embedding_dim"), - "embedding_chunk_size": get_field(config, "embedding", "chunk_size"), + "embedding_chunk_size": get_field(config, "embedding", "embedding_chunk_size"), } + # Remove null values + llm_config_dict = {k: v for k, v in llm_config_dict.items() if v is not None} + embedding_config_dict = {k: v for k, v in embedding_config_dict.items() if v is not None} # Correct the types that aren't strings if llm_config_dict["context_window"] is not None: llm_config_dict["context_window"] = int(llm_config_dict["context_window"]) diff --git a/memgpt/constants.py b/memgpt/constants.py index f313b105..b16bc40f 100644 --- a/memgpt/constants.py +++ b/memgpt/constants.py @@ -3,6 +3,16 @@ from logging import CRITICAL, ERROR, WARN, WARNING, INFO, DEBUG, NOTSET MEMGPT_DIR = os.path.join(os.path.expanduser("~"), ".memgpt") +# embeddings +MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset + +# tokenizers +EMBEDDING_TO_TOKENIZER_MAP = { + "text-embedding-ada-002": "cl100k_base", +} +EMBEDDING_TO_TOKENIZER_DEFAULT = "cl100k_base" + + DEFAULT_MEMGPT_MODEL = "gpt-4" DEFAULT_PERSONA = "sam_pov" DEFAULT_HUMAN = "basic" diff --git a/memgpt/data_types.py b/memgpt/data_types.py index 9fd6c6bd..762944bd 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -5,7 +5,7 @@ from abc import abstractmethod from typing import Optional, List, Dict import numpy as np -from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS +from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS, MAX_EMBEDDING_DIM from memgpt.utils import get_local_time, format_datetime from memgpt.models import chat_completion_response @@ -68,6 +68,8 @@ class Message(Record): tool_calls: Optional[List[ToolCall]] = None, # list of tool calls requested tool_call_id: Optional[str] = None, embedding: Optional[np.ndarray] = None, + embedding_dim: Optional[int] = None, + embedding_model: Optional[str] = None, id: Optional[uuid.UUID] = None, ): super().__init__(id) @@ -82,6 +84,20 @@ class Message(Record): self.role = role # role (agent/user/function) self.name = name + # pad and store embeddings + if isinstance(embedding, list): + embedding = np.array(embedding) + self.embedding = ( + np.pad(embedding, (0, MAX_EMBEDDING_DIM - embedding.shape[0]), mode="constant").tolist() if embedding is not None else None + ) + self.embedding_dim = embedding_dim + self.embedding_model = embedding_model + + if self.embedding is not None: + assert self.embedding_dim, f"Must specify embedding_dim if providing an embedding" + assert self.embedding_model, f"Must specify embedding_model if providing an embedding" + assert len(self.embedding) == MAX_EMBEDDING_DIM, f"Embedding must be of length {MAX_EMBEDDING_DIM}" + # tool (i.e. function) call info (optional) # if role == "assistant", this MAY be specified @@ -97,9 +113,6 @@ class Message(Record): assert tool_call_id is None self.tool_call_id = tool_call_id - # embedding (optional) - self.embedding = embedding - # def __repr__(self): # pass @@ -172,7 +185,7 @@ class Message(Record): if "tool_call_id" in openai_message_dict: assert openai_message_dict["tool_call_id"] is None, openai_message_dict - if "tool_calls" in openai_message_dict and openai_message_dict["tool_calls"] is not None: + if "tool_calls" in openai_message_dict: assert openai_message_dict["role"] == "assistant", openai_message_dict tool_calls = [ @@ -273,6 +286,8 @@ class Passage(Record): text: str, agent_id: Optional[uuid.UUID] = None, # set if contained in agent memory embedding: Optional[np.ndarray] = None, + embedding_dim: Optional[int] = None, + embedding_model: Optional[str] = None, data_source: Optional[str] = None, # None if created by agent doc_id: Optional[uuid.UUID] = None, id: Optional[uuid.UUID] = None, @@ -283,10 +298,23 @@ class Passage(Record): self.agent_id = agent_id self.text = text self.data_source = data_source - self.embedding = embedding self.doc_id = doc_id self.metadata = metadata + # pad and store embeddings + if isinstance(embedding, list): + embedding = np.array(embedding) + self.embedding = ( + np.pad(embedding, (0, MAX_EMBEDDING_DIM - embedding.shape[0]), mode="constant").tolist() if embedding is not None else None + ) + self.embedding_dim = embedding_dim + self.embedding_model = embedding_model + + if self.embedding is not None: + assert self.embedding_dim, f"Must specify embedding_dim if providing an embedding" + assert self.embedding_model, f"Must specify embedding_model if providing an embedding" + assert len(self.embedding) == MAX_EMBEDDING_DIM, f"Embedding must be of length {MAX_EMBEDDING_DIM}" + assert isinstance(self.user_id, uuid.UUID), f"UUID {self.user_id} must be a UUID type" assert not agent_id or isinstance(self.agent_id, uuid.UUID), f"UUID {self.agent_id} must be a UUID type" assert not doc_id or isinstance(self.doc_id, uuid.UUID), f"UUID {self.doc_id} must be a UUID type" @@ -328,6 +356,13 @@ class EmbeddingConfig: self.embedding_dim = embedding_dim self.embedding_chunk_size = embedding_chunk_size + # fields cannot be set to None + assert self.embedding_endpoint_type + assert self.embedding_endpoint + assert self.embedding_model + assert self.embedding_dim + assert self.embedding_chunk_size + class OpenAIEmbeddingConfig(EmbeddingConfig): def __init__(self, openai_key: Optional[str] = None, **kwargs): @@ -434,6 +469,9 @@ class Source: name: str, created_at: Optional[str] = None, id: Optional[uuid.UUID] = None, + # embedding info + embedding_model: Optional[str] = None, + embedding_dim: Optional[int] = None, ): if id is None: self.id = uuid.uuid4() @@ -445,3 +483,7 @@ class Source: self.name = name self.user_id = user_id self.created_at = created_at + + # embedding info (optional) + self.embedding_dim = embedding_dim + self.embedding_model = embedding_model diff --git a/memgpt/embeddings.py b/memgpt/embeddings.py index e717e063..a9c89e9c 100644 --- a/memgpt/embeddings.py +++ b/memgpt/embeddings.py @@ -2,10 +2,12 @@ import typer import uuid from typing import Optional, List import os +import numpy as np -from memgpt.utils import is_valid_url +from memgpt.utils import is_valid_url, printd from memgpt.data_types import EmbeddingConfig from memgpt.credentials import MemGPTCredentials +from memgpt.constants import MAX_EMBEDDING_DIM, EMBEDDING_TO_TOKENIZER_MAP, EMBEDDING_TO_TOKENIZER_DEFAULT from llama_index.embeddings import OpenAIEmbedding, AzureOpenAIEmbedding from llama_index.bridge.pydantic import PrivateAttr @@ -14,6 +16,34 @@ from llama_index.embeddings.huggingface_utils import format_text import tiktoken +def check_and_split_text(text: str, embedding_model: str) -> List[str]: + """Split text into chunks of max_length tokens or less""" + + if embedding_model in EMBEDDING_TO_TOKENIZER_MAP: + encoding = tiktoken.get_encoding(EMBEDDING_TO_TOKENIZER_MAP[embedding_model]) + else: + print(f"Warning: couldn't find tokenizer for model {embedding_model}, using default tokenizer {EMBEDDING_TO_TOKENIZER_DEFAULT}") + encoding = tiktoken.get_encoding(EMBEDDING_TO_TOKENIZER_DEFAULT) + + num_tokens = len(encoding.encode(text)) + + # determine max length + if hasattr(encoding, "max_length"): + max_length = encoding.max_length + else: + # TODO: figure out the real number + printd(f"Warning: couldn't find max_length for tokenizer {embedding_model}, using default max_length 8191") + max_length = 8191 + + # truncate text if too long + if num_tokens > max_length: + # TODO: split this into two pieces of text instead of truncating + print(f"Warning: text is too long ({num_tokens} tokens), truncating to {max_length} tokens.") + text = format_text(text, embedding_model, max_length=max_length) + + return [text] + + class EmbeddingEndpoint(BaseEmbedding): """Implementation for OpenAI compatible endpoint""" @@ -38,7 +68,6 @@ class EmbeddingEndpoint(BaseEmbedding): self._user = user self._base_url = base_url self._timeout = timeout - self._encoding = tiktoken.get_encoding(model) super().__init__( model_name=model, ) @@ -47,10 +76,6 @@ class EmbeddingEndpoint(BaseEmbedding): def class_name(cls) -> str: return "EmbeddingEndpoint" - def count_tokens(self, text: str) -> int: - """Count tokens using the embedding model's tokenizer""" - return len(self._encoding.encode(text)) - def _call_api(self, text: str) -> List[float]: if not is_valid_url(self._base_url): raise ValueError( @@ -58,12 +83,6 @@ class EmbeddingEndpoint(BaseEmbedding): ) import httpx - # If necessary, truncate text to fit in the embedding model's max sequence length (usually 512) - num_tokens = self.count_tokens(text) - max_length = self._encoding.max_length - if num_tokens > max_length: - text = format_text(text, self.model_name, max_length=max_length) - headers = {"Content-Type": "application/json"} json_data = {"input": text, "model": self.model_name, "user": self._user} @@ -157,6 +176,14 @@ def default_embedding_model(): return HuggingFaceEmbedding(model_name=model) +def query_embedding(embedding_model, query_text: str): + """Generate padded embedding for querying database""" + query_vec = embedding_model.get_text_embedding(query_text) + query_vec = np.array(query_vec) + query_vec = np.pad(query_vec, (0, MAX_EMBEDDING_DIM - query_vec.shape[0]), mode="constant").tolist() + return query_vec + + def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None): """Return LlamaIndex embedding model to use for embeddings""" @@ -181,11 +208,6 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None api_version=credentials.azure_version, ) elif endpoint_type == "hugging-face": - try: - return EmbeddingEndpoint(model=config.embedding_model, base_url=config.embedding_endpoint, user=user_id) - except Exception as e: - # TODO: remove, this is just to get passing tests - print(e) - return default_embedding_model() + return EmbeddingEndpoint(model=config.embedding_model, base_url=config.embedding_endpoint, user=user_id) else: return default_embedding_model() diff --git a/memgpt/main.py b/memgpt/main.py index 9d9f41e5..aa6558e6 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -111,6 +111,9 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, elif user_input.lower() == "/attach": # TODO: check if agent already has it + # TODO: check to ensure source embedding dimentions/model match agents, and disallow attachment if not + # TODO: alternatively, only list sources with compatible embeddings, and print warning about non-compatible sources + data_source_options = ms.list_sources(user_id=memgpt_agent.agent_state.user_id) data_source_options = [s.name for s in data_source_options] if len(data_source_options) == 0: @@ -120,7 +123,24 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, bold=True, ) continue - data_source = questionary.select("Select data source", choices=data_source_options).ask() + + # determine what sources are valid to be attached to this agent + valid_options = [] + invalid_options = [] + for source in data_source_options: + if source.embedding_model == memgpt_agent.embedding_model and source.embedding_dim == memgpt_agent.embedding_dim: + valid_options.append(source.name) + else: + invalid_options.append(source.name) + + # print warning about invalid sources + typer.secho( + f"Warning: the following sources are not compatible with this agent's embedding model and dimension: {invalid_options}", + fg=typer.colors.YELLOW, + ) + + # prompt user for data source selection + data_source = questionary.select("Select data source", choices=valid_options).ask() # attach new data attach(memgpt_agent.config.name, data_source) diff --git a/memgpt/memory.py b/memgpt/memory.py index a17e9a8b..2dbdebd0 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -7,7 +7,7 @@ from memgpt.utils import get_local_time, printd, count_tokens, validate_date_for from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM from memgpt.llm_api_tools import create from memgpt.data_types import Message, Passage, AgentState -from memgpt.embeddings import embedding_model +from memgpt.embeddings import embedding_model, query_embedding from llama_index import Document from llama_index.node_parser import SimpleNodeParser @@ -372,6 +372,7 @@ class EmbeddingArchivalMemory(ArchivalMemory): # create embedding model self.embed_model = embedding_model(agent_state.embedding_config) self.embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size + assert self.embedding_chunk_size, f"Must set {agent_state.embedding_config.embedding_chunk_size}" # create storage backend self.storage = StorageConnector.get_archival_storage_connector(user_id=agent_state.user_id, agent_id=agent_state.id) @@ -384,6 +385,8 @@ class EmbeddingArchivalMemory(ArchivalMemory): agent_id=self.agent_state.id, text=text, embedding=embedding, + embedding_dim=self.agent_state.embedding_config.embedding_dim, + embedding_model=self.agent_state.embedding_config.embedding_model, ) def save(self): @@ -432,7 +435,7 @@ class EmbeddingArchivalMemory(ArchivalMemory): try: if query_string not in self.cache: # self.cache[query_string] = self.retriever.retrieve(query_string) - query_vec = self.embed_model.get_text_embedding(query_string) + query_vec = query_embedding(self.embed_model, query_string) self.cache[query_string] = self.storage.query(query_string, query_vec, top_k=self.top_k) start = int(start if start else 0) diff --git a/memgpt/metadata.py b/memgpt/metadata.py index 1af4ed70..3e420330 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -64,6 +64,7 @@ class LLMConfigColumn(TypeDecorator): return value def process_result_value(self, value, dialect): + print("GET VALUE", value) if value: return LLMConfig(**value) return value @@ -168,6 +169,8 @@ class SourceModel(Base): user_id = Column(CommonUUID, nullable=False) name = Column(String, nullable=False) created_at = Column(DateTime(timezone=True), server_default=func.now()) + embedding_dim = Column(BIGINT) + embedding_model = Column(String) # TODO: add num passages @@ -175,7 +178,14 @@ class SourceModel(Base): return f"" def to_record(self) -> Source: - return Source(id=self.id, user_id=self.user_id, name=self.name, created_at=self.created_at) + return Source( + id=self.id, + user_id=self.user_id, + name=self.name, + created_at=self.created_at, + embedding_dim=self.embedding_dim, + embedding_model=self.embedding_model, + ) class AgentSourceMappingModel(Base): @@ -288,6 +298,7 @@ class MetadataStore: @enforce_types def list_agents(self, user_id: uuid.UUID) -> List[AgentState]: + print("query agents", user_id) results = self.session.query(AgentModel).filter(AgentModel.user_id == user_id).all() return [r.to_record() for r in results] diff --git a/memgpt/migrate.py b/memgpt/migrate.py index 873ff8ee..da70830e 100644 --- a/memgpt/migrate.py +++ b/memgpt/migrate.py @@ -150,7 +150,20 @@ def migrate_source(source_name: str): for node in nodes: # print(len(node.embedding)) # TODO: make sure embedding config matches embedding size? - passages.append(Passage(user_id=user.id, data_source=source_name, text=node.text, embedding=node.embedding)) + if len(node.embedding) != config.default_embedding_config.embedding_dim: + raise ValueError( + f"Cannot migrate source {source_name} due to incompatible embedding dimentions. Please re-load this source with `memgpt load`." + ) + passages.append( + Passage( + user_id=user.id, + data_source=source_name, + text=node.text, + embedding=node.embedding, + embedding_dim=config.default_embedding_config.embedding_dim, + embedding_model=config.default_embedding_config.embedding_model, + ) + ) assert len(passages) > 0, f"Source {source_name} has no passages" conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config=config, user_id=user_id) @@ -313,9 +326,18 @@ def migrate_agent(agent_name: str): nodes = pickle.load(open(archival_filename, "rb")) passages = [] for node in nodes: - # print(len(node.embedding)) - # TODO: make sure embeding size matches embedding config? - passages.append(Passage(user_id=user.id, agent_id=agent_state.id, text=node.text, embedding=node.embedding)) + if len(node.embedding) != config.default_embedding_config.embedding_dim: + raise ValueError(f"Cannot migrate agent {agent_state.name} due to incompatible embedding dimentions.") + passages.append( + Passage( + user_id=user.id, + agent_id=agent_state.id, + text=node.text, + embedding=node.embedding, + embedding_dim=agent_state.embedding_config.embedding_dim, + embedding_model=agent_state.embedding_config.embedding_model, + ) + ) if len(passages) > 0: agent.persistence_manager.archival_memory.storage.insert_many(passages) print(f"Inserted {len(passages)} passages into archival memory") diff --git a/memgpt/openai_backcompat/openai_object.py b/memgpt/openai_backcompat/openai_object.py index 38e24bcf..fcc63836 100644 --- a/memgpt/openai_backcompat/openai_object.py +++ b/memgpt/openai_backcompat/openai_object.py @@ -6,7 +6,7 @@ from typing import Optional, Tuple, Union from enum import Enum -import openai +# import openai from memgpt.constants import JSON_ENSURE_ASCII @@ -68,7 +68,8 @@ class ApiType(Enum): elif label.lower() in ("open_ai", "openai"): return ApiType.OPEN_AI else: - raise openai.error.InvalidAPIType( + # raise openai.error.InvalidAPIType( + raise Exception( "The API type provided in invalid. Please select one of the supported API types: 'azure', 'azure_ad', 'open_ai'" ) @@ -361,7 +362,8 @@ class OpenAIObject(dict): @property def typed_api_type(self): - return ApiType.from_str(self.api_type) if self.api_type else ApiType.from_str(openai.api_type) + # return ApiType.from_str(self.api_type) if self.api_type else ApiType.from_str(openai.api_type) + return ApiType.from_str(self.api_type) if self.api_type else ApiType.from_str(ApiType.OPEN_AI) # This class overrides __setitem__ to throw exceptions on inputs that it # doesn't like. This can cause problems when we try to copy an object diff --git a/tests/test_different_embedding_size.py b/tests/test_different_embedding_size.py new file mode 100644 index 00000000..ef0b1c52 --- /dev/null +++ b/tests/test_different_embedding_size.py @@ -0,0 +1,123 @@ +import uuid +import os + +from memgpt import MemGPT +from memgpt.config import MemGPTConfig +from memgpt import constants +from memgpt.data_types import LLMConfig, EmbeddingConfig, AgentState, Passage +from memgpt.embeddings import embedding_model +from memgpt.agent_store.storage import StorageConnector, TableType +from .utils import wipe_config +import uuid + + +test_agent_name = f"test_client_{str(uuid.uuid4())}" +test_agent_state = None +client = None + +test_agent_state_post_message = None +test_user_id = uuid.uuid4() + + +def generate_passages(user, agent): + # Note: the database will filter out rows that do not correspond to agent1 and test_user by default. + texts = ["This is a test passage", "This is another test passage", "Cinderella wept"] + embed_model = embedding_model(agent.embedding_config) + orig_embeddings = [] + passages = [] + for text in texts: + embedding = embed_model.get_text_embedding(text) + orig_embeddings.append(list(embedding)) + passages.append( + Passage( + user_id=user.id, + agent_id=agent.id, + text=text, + embedding=embedding, + embedding_dim=agent.embedding_config.embedding_dim, + embedding_model=agent.embedding_config.embedding_model, + ) + ) + return passages, orig_embeddings + + +def test_create_user(): + wipe_config() + + # create client + client = MemGPT(quickstart="openai", user_id=test_user_id) + + # create user + user = client.server.create_user({"id": test_user_id}) + + # openai: create agent + openai_agent = client.create_agent( + { + "user_id": test_user_id, + "name": "openai_agent", + } + ) + assert ( + openai_agent.embedding_config.embedding_endpoint_type == "openai" + ), f"openai_agent.embedding_config.embedding_endpoint_type={openai_agent.embedding_config.embedding_endpoint_type}" + + # openai: add passages + passages, openai_embeddings = generate_passages(user, openai_agent) + openai_agent_run = client.server._get_or_load_agent(user_id=user.id, agent_id=openai_agent.id) + openai_agent_run.persistence_manager.archival_memory.storage.insert_many(passages) + + # hosted: create agent + hosted_agent = client.create_agent( + { + "user_id": test_user_id, + "name": "hosted_agent", + "embedding_config": EmbeddingConfig( + embedding_endpoint_type="hugging-face", + embedding_model="BAAI/bge-large-en-v1.5", + embedding_endpoint="https://embeddings.memgpt.ai", + embedding_dim=1024, + ), + } + ) + # check to make sure endpoint overriden + assert ( + hosted_agent.embedding_config.embedding_endpoint_type == "hugging-face" + ), f"hosted_agent.embedding_config.embedding_endpoint_type={hosted_agent.embedding_config.embedding_endpoint_type}" + + # hosted: add passages + passages, hosted_embeddings = generate_passages(user, hosted_agent) + hosted_agent_run = client.server._get_or_load_agent(user_id=user.id, agent_id=hosted_agent.id) + hosted_agent_run.persistence_manager.archival_memory.storage.insert_many(passages) + + # test passage dimentionality + config = MemGPTConfig.load() + storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user.id) + storage.filters = {} # clear filters to be able to get all passages + passages = storage.get_all() + for passage in passages: + if passage.agent_id == hosted_agent.id: + assert ( + passage.embedding_dim == hosted_agent.embedding_config.embedding_dim + ), f"passage.embedding_dim={passage.embedding_dim} != hosted_agent.embedding_config.embedding_dim={hosted_agent.embedding_config.embedding_dim}" + + # ensure was in original embeddings + embedding = passage.embedding[: passage.embedding_dim] + assert embedding in hosted_embeddings, f"embedding={embedding} not in hosted_embeddings={hosted_embeddings}" + + # make sure all zeros + assert not any( + passage.embedding[passage.embedding_dim :] + ), f"passage.embedding[passage.embedding_dim:]={passage.embedding[passage.embedding_dim:]}" + elif passage.agent_id == openai_agent.id: + assert ( + passage.embedding_dim == openai_agent.embedding_config.embedding_dim + ), f"passage.embedding_dim={passage.embedding_dim} != openai_agent.embedding_config.embedding_dim={openai_agent.embedding_config.embedding_dim}" + + # ensure was in original embeddings + embedding = passage.embedding[: passage.embedding_dim] + assert embedding in openai_embeddings, f"embedding={embedding} not in openai_embeddings={openai_embeddings}" + + # make sure all zeros + assert not any( + passage.embedding[passage.embedding_dim :] + ), f"passage.embedding[passage.embedding_dim:]={passage.embedding[passage.embedding_dim:]}" diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py index 11578dba..33456bf0 100644 --- a/tests/test_load_archival.py +++ b/tests/test_load_archival.py @@ -88,13 +88,13 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c persona=user.default_persona, human=user.default_human, llm_config=config.default_llm_config, - embedding_config=config.default_embedding_config, + embedding_config=embedding_config, ) ms.delete_user(user.id) ms.create_user(user) ms.create_agent(agent) user = ms.get_user(user.id) - print("Got user:", user, config.default_embedding_config) + print("Got user:", user, embedding_config) # setup storage connectors print("Creating storage connectors...") diff --git a/tests/test_server.py b/tests/test_server.py index 963d53dc..e5071cf0 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -104,7 +104,14 @@ def test_server(): for text in archival_memories: embedding = embed_model.get_text_embedding(text) agent.persistence_manager.archival_memory.storage.insert( - Passage(user_id=user.id, agent_id=agent_state.id, text=text, embedding=embedding) + Passage( + user_id=user.id, + agent_id=agent_state.id, + text=text, + embedding=embedding, + embedding_dim=agent.agent_state.embedding_config.embedding_dim, + embedding_model=agent.agent_state.embedding_config.embedding_model, + ) ) # add data into recall memory diff --git a/tests/test_storage.py b/tests/test_storage.py index 21faf176..ae4acfc2 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -4,13 +4,14 @@ import uuid import pytest from memgpt.agent_store.storage import StorageConnector, TableType -from memgpt.embeddings import embedding_model +from memgpt.embeddings import embedding_model, query_embedding from memgpt.data_types import Message, Passage, EmbeddingConfig, AgentState, OpenAIEmbeddingConfig from memgpt.config import MemGPTConfig from memgpt.credentials import MemGPTCredentials from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.metadata import MetadataStore from memgpt.data_types import User +from memgpt.constants import MAX_EMBEDDING_DIM from datetime import datetime, timedelta @@ -33,10 +34,23 @@ def generate_passages(embed_model): # embeddings: use openai if env is set, otherwise local passages = [] for text, _, _, agent_id, id in zip(texts, dates, roles, agent_ids, ids): - embedding = None + embedding, embedding_model, embedding_dim = None, None, None if embed_model: embedding = embed_model.get_text_embedding(text) - passages.append(Passage(user_id=user_id, text=text, agent_id=agent_id, embedding=embedding, data_source="test_source", id=id)) + embedding_model = "gpt-4" + embedding_dim = len(embedding) + passages.append( + Passage( + user_id=user_id, + text=text, + agent_id=agent_id, + embedding=embedding, + data_source="test_source", + id=id, + embedding_dim=embedding_dim, + embedding_model=embedding_model, + ) + ) return passages @@ -45,11 +59,24 @@ def generate_messages(embed_model): """Generate list of 3 Message objects""" messages = [] for text, date, role, agent_id, id in zip(texts, dates, roles, agent_ids, ids): - embedding = None + embedding, embedding_model, embedding_dim = None, None, None if embed_model: embedding = embed_model.get_text_embedding(text) + embedding_model = "gpt-4" + embedding_dim = len(embedding) messages.append( - Message(user_id=user_id, text=text, agent_id=agent_id, role=role, created_at=date, id=id, model="gpt-4", embedding=embedding) + Message( + user_id=user_id, + text=text, + agent_id=agent_id, + role=role, + created_at=date, + id=id, + model="gpt-4", + embedding=embedding, + embedding_model=embedding_model, + embedding_dim=embedding_dim, + ) ) print(messages[-1].text) return messages @@ -165,6 +192,14 @@ def test_storage(storage_connector, table_type, clear_dynamically_created_models else: raise NotImplementedError(f"Table type {table_type} not implemented") + # check record dimentions + print("TABLE TYPE", table_type, type(records[0]), len(records[0].embedding)) + if embed_model: + assert len(records[0].embedding) == MAX_EMBEDDING_DIM, f"Expected {MAX_EMBEDDING_DIM}, got {len(records[0].embedding)}" + assert ( + records[0].embedding_dim == embedding_config.embedding_dim + ), f"Expected {embedding_config.embedding_dim}, got {records[0].embedding_dim}" + # test: insert conn.insert(records[0]) assert conn.size() == 1, f"Expected 1 record, got {conn.size()}: {conn.get_all()}" @@ -208,7 +243,7 @@ def test_storage(storage_connector, table_type, clear_dynamically_created_models # test: query (vector) if table_type == TableType.ARCHIVAL_MEMORY: query = "why was she crying" - query_vec = embed_model.get_text_embedding(query) + query_vec = query_embedding(embed_model, query) res = conn.query(None, query_vec, top_k=2) assert len(res) == 2, f"Expected 2 results, got {len(res)}" print("Archival memory results", res) diff --git a/tests/utils.py b/tests/utils.py index 37f8bb87..1de59779 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,6 +15,9 @@ def wipe_config(): config_path = MemGPTConfig.config_path # TODO delete file config_path os.remove(config_path) + assert not MemGPTConfig.exists(), "Config should not exist after deletion" + else: + print("No config to wipe", MemGPTConfig.config_path) def wipe_memgpt_home():