""" Metadata store for user/agent/data_source information""" import os from typing import Optional, List, Dict from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS from memgpt.utils import get_local_time, enforce_types from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig from memgpt.config import MemGPTConfig from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, Boolean from sqlalchemy import func from sqlalchemy.orm import sessionmaker, mapped_column, declarative_base from sqlalchemy.orm.session import close_all_sessions from sqlalchemy.sql import func from sqlalchemy import Column, BIGINT, String, DateTime from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy_json import mutable_json_type, MutableJson from sqlalchemy import TypeDecorator, CHAR import uuid from sqlalchemy.orm import sessionmaker, mapped_column, declarative_base Base = declarative_base() # Custom UUID type class CommonUUID(TypeDecorator): impl = CHAR cache_ok = True def load_dialect_impl(self, dialect): if dialect.name == "postgresql": return dialect.type_descriptor(UUID(as_uuid=True)) else: return dialect.type_descriptor(CHAR()) def process_bind_param(self, value, dialect): if dialect.name == "postgresql" or value is None: return value else: return str(value) # Convert UUID to string for SQLite def process_result_value(self, value, dialect): if dialect.name == "postgresql" or value is None: return value else: return uuid.UUID(value) class LLMConfigColumn(TypeDecorator): """Custom type for storing LLMConfig as JSON""" impl = JSON cache_ok = True def load_dialect_impl(self, dialect): return dialect.type_descriptor(JSON()) def process_bind_param(self, value, dialect): if value: return vars(value) return value def process_result_value(self, value, dialect): if value: return LLMConfig(**value) return value class EmbeddingConfigColumn(TypeDecorator): """Custom type for storing EmbeddingConfig as JSON""" impl = JSON cache_ok = True def load_dialect_impl(self, dialect): return dialect.type_descriptor(JSON()) def process_bind_param(self, value, dialect): if value: return vars(value) return value def process_result_value(self, value, dialect): if value: return EmbeddingConfig(**value) return value class UserModel(Base): __tablename__ = "users" __table_args__ = {"extend_existing": True} id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) default_preset = Column(String) default_persona = Column(String) default_human = Column(String) default_agent = Column(String) default_llm_config = Column(LLMConfigColumn) default_embedding_config = Column(EmbeddingConfigColumn) azure_key = Column(String, nullable=True) azure_endpoint = Column(String, nullable=True) azure_version = Column(String, nullable=True) azure_deployment = Column(String, nullable=True) openai_key = Column(String, nullable=True) policies_accepted = Column(Boolean, nullable=False, default=False) def __repr__(self) -> str: return f"" def to_record(self) -> User: return User( id=self.id, default_preset=self.default_preset, default_persona=self.default_persona, default_human=self.default_human, default_agent=self.default_agent, default_llm_config=self.default_llm_config, default_embedding_config=self.default_embedding_config, azure_key=self.azure_key, azure_endpoint=self.azure_endpoint, azure_version=self.azure_version, azure_deployment=self.azure_deployment, openai_key=self.openai_key, policies_accepted=self.policies_accepted, ) class AgentModel(Base): """Defines data model for storing Passages (consisting of text, embedding)""" __tablename__ = "agents" __table_args__ = {"extend_existing": True} id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) user_id = Column(CommonUUID, nullable=False) name = Column(String, nullable=False) persona = Column(String) human = Column(String) preset = Column(String) created_at = Column(DateTime(timezone=True), server_default=func.now()) # configs llm_config = Column(LLMConfigColumn) embedding_config = Column(EmbeddingConfigColumn) # state state = Column(JSON) def __repr__(self) -> str: return f"" def to_record(self) -> AgentState: return AgentState( id=self.id, user_id=self.user_id, name=self.name, persona=self.persona, human=self.human, preset=self.preset, created_at=self.created_at, llm_config=self.llm_config, embedding_config=self.embedding_config, state=self.state, ) class SourceModel(Base): """Defines data model for storing Passages (consisting of text, embedding)""" __tablename__ = "sources" __table_args__ = {"extend_existing": True} # Assuming passage_id is the primary key # id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) user_id = Column(CommonUUID, nullable=False) name = Column(String, nullable=False) created_at = Column(DateTime(timezone=True), server_default=func.now()) # TODO: add num passages def __repr__(self) -> str: return f"" def to_record(self) -> Source: return Source(id=self.id, user_id=self.user_id, name=self.name, created_at=self.created_at) class AgentSourceMappingModel(Base): """Stores mapping between agent -> source""" __tablename__ = "agent_source_mapping" id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) user_id = Column(CommonUUID, nullable=False) agent_id = Column(CommonUUID, nullable=False) source_id = Column(CommonUUID, nullable=False) def __repr__(self) -> str: return f"" class MetadataStore: def __init__(self, config: MemGPTConfig): # TODO: get DB URI or path if config.metadata_storage_type == "postgres": self.uri = config.metadata_storage_uri elif config.metadata_storage_type == "sqlite": path = os.path.join(config.metadata_storage_path, "sqlite.db") self.uri = f"sqlite:///{path}" else: raise ValueError(f"Invalid metadata storage type: {config.metadata_storage_type}") # TODO: check to see if table(s) need to be greated or not self.engine = create_engine(self.uri) Base.metadata.create_all( self.engine, tables=[UserModel.__table__, AgentModel.__table__, SourceModel.__table__, AgentSourceMappingModel.__table__] ) session_maker = sessionmaker(bind=self.engine) self.session = session_maker() @enforce_types def create_agent(self, agent: AgentState): # insert into agent table # make sure agent.name does not already exist for user user_id if self.session.query(AgentModel).filter(AgentModel.name == agent.name).filter(AgentModel.user_id == agent.user_id).count() > 0: raise ValueError(f"Agent with name {agent.name} already exists") self.session.add(AgentModel(**vars(agent))) self.session.commit() @enforce_types def create_source(self, source: Source): # make sure source.name does not already exist for user if ( self.session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0 ): raise ValueError(f"Source with name {source.name} already exists") self.session.add(SourceModel(**vars(source))) self.session.commit() @enforce_types def create_user(self, user: User): if self.session.query(UserModel).filter(UserModel.id == user.id).count() > 0: raise ValueError(f"User with id {user.id} already exists") self.session.add(UserModel(**vars(user))) self.session.commit() @enforce_types def update_agent(self, agent: AgentState): self.session.query(AgentModel).filter(AgentModel.id == agent.id).update(vars(agent)) self.session.commit() @enforce_types def update_user(self, user: User): self.session.query(UserModel).filter(UserModel.id == user.id).update(vars(user)) self.session.commit() @enforce_types def update_source(self, source: Source): self.session.query(SourceModel).filter(SourceModel.id == source.id).update(vars(source)) self.session.commit() @enforce_types def delete_agent(self, agent_id: uuid.UUID): self.session.query(AgentModel).filter(AgentModel.id == agent_id).delete() self.session.commit() @enforce_types def delete_source(self, source_id: uuid.UUID): # delete from sources table self.session.query(SourceModel).filter(SourceModel.id == source_id).delete() # delete any mappings self.session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).delete() self.session.commit() @enforce_types def delete_user(self, user_id: uuid.UUID): # delete from users table self.session.query(UserModel).filter(UserModel.id == user_id).delete() # delete associated agents self.session.query(AgentModel).filter(AgentModel.user_id == user_id).delete() # delete associated sources self.session.query(SourceModel).filter(SourceModel.user_id == user_id).delete() # delete associated mappings self.session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete() self.session.commit() @enforce_types def list_agents(self, user_id: uuid.UUID) -> List[AgentState]: results = self.session.query(AgentModel).filter(AgentModel.user_id == user_id).all() return [r.to_record() for r in results] @enforce_types def list_sources(self, user_id: uuid.UUID) -> List[Source]: results = self.session.query(SourceModel).filter(SourceModel.user_id == user_id).all() return [r.to_record() for r in results] @enforce_types def get_agent( self, agent_id: Optional[uuid.UUID] = None, agent_name: Optional[str] = None, user_id: Optional[uuid.UUID] = None ) -> Optional[AgentState]: if agent_id: results = self.session.query(AgentModel).filter(AgentModel.id == agent_id).all() else: assert agent_name is not None and user_id is not None, "Must provide either agent_id or agent_name" results = self.session.query(AgentModel).filter(AgentModel.name == agent_name).filter(AgentModel.user_id == user_id).all() if len(results) == 0: return None assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result return results[0].to_record() @enforce_types def get_user(self, user_id: uuid.UUID) -> Optional[User]: results = self.session.query(UserModel).filter(UserModel.id == user_id).all() if len(results) == 0: return None assert len(results) == 1, f"Expected 1 result, got {len(results)}" return results[0].to_record() @enforce_types def get_source( self, source_id: Optional[uuid.UUID] = None, user_id: Optional[uuid.UUID] = None, source_name: Optional[str] = None ) -> Optional[Source]: if source_id: results = self.session.query(SourceModel).filter(SourceModel.id == source_id).all() else: assert user_id is not None and source_name is not None results = self.session.query(SourceModel).filter(SourceModel.name == source_name).filter(SourceModel.user_id == user_id).all() if len(results) == 0: return None assert len(results) == 1, f"Expected 1 result, got {len(results)}" return results[0].to_record() # agent source metadata @enforce_types def attach_source(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_id: uuid.UUID): self.session.add(AgentSourceMappingModel(user_id=user_id, agent_id=agent_id, source_id=source_id)) self.session.commit() @enforce_types def list_attached_sources(self, agent_id: uuid.UUID) -> List[Column]: results = self.session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all() return [r.source_id for r in results] @enforce_types def list_attached_agents(self, source_id: uuid.UUID): results = self.session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).all() return [r.agent_id for r in results] @enforce_types def detach_source(self, agent_id: uuid.UUID, source_id: uuid.UUID): self.session.query(AgentSourceMappingModel).filter( AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id ).delete() self.session.commit()