371 lines
14 KiB
Python
371 lines
14 KiB
Python
""" Metadata store for user/agent/data_source information"""
|
|
import os
|
|
from typing import Optional, List, Dict
|
|
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS
|
|
from memgpt.utils import get_local_time, enforce_types
|
|
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig
|
|
from memgpt.config import MemGPTConfig
|
|
|
|
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, Boolean
|
|
from sqlalchemy import func
|
|
from sqlalchemy.orm import sessionmaker, mapped_column, declarative_base
|
|
from sqlalchemy.orm.session import close_all_sessions
|
|
from sqlalchemy.sql import func
|
|
from sqlalchemy import Column, BIGINT, String, DateTime
|
|
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
|
from sqlalchemy_json import mutable_json_type, MutableJson
|
|
from sqlalchemy import TypeDecorator, CHAR
|
|
import uuid
|
|
|
|
|
|
from sqlalchemy.orm import sessionmaker, mapped_column, declarative_base
|
|
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
# Custom UUID type
|
|
class CommonUUID(TypeDecorator):
|
|
impl = CHAR
|
|
cache_ok = True
|
|
|
|
def load_dialect_impl(self, dialect):
|
|
if dialect.name == "postgresql":
|
|
return dialect.type_descriptor(UUID(as_uuid=True))
|
|
else:
|
|
return dialect.type_descriptor(CHAR())
|
|
|
|
def process_bind_param(self, value, dialect):
|
|
if dialect.name == "postgresql" or value is None:
|
|
return value
|
|
else:
|
|
return str(value) # Convert UUID to string for SQLite
|
|
|
|
def process_result_value(self, value, dialect):
|
|
if dialect.name == "postgresql" or value is None:
|
|
return value
|
|
else:
|
|
return uuid.UUID(value)
|
|
|
|
|
|
class LLMConfigColumn(TypeDecorator):
|
|
"""Custom type for storing LLMConfig as JSON"""
|
|
|
|
impl = JSON
|
|
cache_ok = True
|
|
|
|
def load_dialect_impl(self, dialect):
|
|
return dialect.type_descriptor(JSON())
|
|
|
|
def process_bind_param(self, value, dialect):
|
|
if value:
|
|
return vars(value)
|
|
return value
|
|
|
|
def process_result_value(self, value, dialect):
|
|
if value:
|
|
return LLMConfig(**value)
|
|
return value
|
|
|
|
|
|
class EmbeddingConfigColumn(TypeDecorator):
|
|
"""Custom type for storing EmbeddingConfig as JSON"""
|
|
|
|
impl = JSON
|
|
cache_ok = True
|
|
|
|
def load_dialect_impl(self, dialect):
|
|
return dialect.type_descriptor(JSON())
|
|
|
|
def process_bind_param(self, value, dialect):
|
|
if value:
|
|
return vars(value)
|
|
return value
|
|
|
|
def process_result_value(self, value, dialect):
|
|
if value:
|
|
return EmbeddingConfig(**value)
|
|
return value
|
|
|
|
|
|
class UserModel(Base):
|
|
__tablename__ = "users"
|
|
__table_args__ = {"extend_existing": True}
|
|
|
|
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
|
default_preset = Column(String)
|
|
default_persona = Column(String)
|
|
default_human = Column(String)
|
|
default_agent = Column(String)
|
|
|
|
default_llm_config = Column(LLMConfigColumn)
|
|
default_embedding_config = Column(EmbeddingConfigColumn)
|
|
|
|
azure_key = Column(String, nullable=True)
|
|
azure_endpoint = Column(String, nullable=True)
|
|
azure_version = Column(String, nullable=True)
|
|
azure_deployment = Column(String, nullable=True)
|
|
|
|
openai_key = Column(String, nullable=True)
|
|
policies_accepted = Column(Boolean, nullable=False, default=False)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"<User(id='{self.id}')>"
|
|
|
|
def to_record(self) -> User:
|
|
return User(
|
|
id=self.id,
|
|
default_preset=self.default_preset,
|
|
default_persona=self.default_persona,
|
|
default_human=self.default_human,
|
|
default_agent=self.default_agent,
|
|
default_llm_config=self.default_llm_config,
|
|
default_embedding_config=self.default_embedding_config,
|
|
azure_key=self.azure_key,
|
|
azure_endpoint=self.azure_endpoint,
|
|
azure_version=self.azure_version,
|
|
azure_deployment=self.azure_deployment,
|
|
openai_key=self.openai_key,
|
|
policies_accepted=self.policies_accepted,
|
|
)
|
|
|
|
|
|
class AgentModel(Base):
|
|
"""Defines data model for storing Passages (consisting of text, embedding)"""
|
|
|
|
__tablename__ = "agents"
|
|
__table_args__ = {"extend_existing": True}
|
|
|
|
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
|
user_id = Column(CommonUUID, nullable=False)
|
|
name = Column(String, nullable=False)
|
|
persona = Column(String)
|
|
human = Column(String)
|
|
preset = Column(String)
|
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
|
|
# configs
|
|
llm_config = Column(LLMConfigColumn)
|
|
embedding_config = Column(EmbeddingConfigColumn)
|
|
|
|
# state
|
|
state = Column(JSON)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"<Agent(id='{self.id}', name='{self.name}')>"
|
|
|
|
def to_record(self) -> AgentState:
|
|
return AgentState(
|
|
id=self.id,
|
|
user_id=self.user_id,
|
|
name=self.name,
|
|
persona=self.persona,
|
|
human=self.human,
|
|
preset=self.preset,
|
|
created_at=self.created_at,
|
|
llm_config=self.llm_config,
|
|
embedding_config=self.embedding_config,
|
|
state=self.state,
|
|
)
|
|
|
|
|
|
class SourceModel(Base):
|
|
"""Defines data model for storing Passages (consisting of text, embedding)"""
|
|
|
|
__tablename__ = "sources"
|
|
__table_args__ = {"extend_existing": True}
|
|
|
|
# Assuming passage_id is the primary key
|
|
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
|
user_id = Column(CommonUUID, nullable=False)
|
|
name = Column(String, nullable=False)
|
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
|
|
# TODO: add num passages
|
|
|
|
def __repr__(self) -> str:
|
|
return f"<Source(passage_id='{self.id}', name='{self.name}')>"
|
|
|
|
def to_record(self) -> Source:
|
|
return Source(id=self.id, user_id=self.user_id, name=self.name, created_at=self.created_at)
|
|
|
|
|
|
class AgentSourceMappingModel(Base):
|
|
|
|
"""Stores mapping between agent -> source"""
|
|
|
|
__tablename__ = "agent_source_mapping"
|
|
|
|
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
|
user_id = Column(CommonUUID, nullable=False)
|
|
agent_id = Column(CommonUUID, nullable=False)
|
|
source_id = Column(CommonUUID, nullable=False)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"<AgentSourceMapping(user_id='{self.user_id}', agent_id='{self.agent_id}', source_id='{self.source_id}')>"
|
|
|
|
|
|
class MetadataStore:
|
|
def __init__(self, config: MemGPTConfig):
|
|
# TODO: get DB URI or path
|
|
if config.metadata_storage_type == "postgres":
|
|
self.uri = config.metadata_storage_uri
|
|
elif config.metadata_storage_type == "sqlite":
|
|
path = os.path.join(config.metadata_storage_path, "sqlite.db")
|
|
self.uri = f"sqlite:///{path}"
|
|
else:
|
|
raise ValueError(f"Invalid metadata storage type: {config.metadata_storage_type}")
|
|
|
|
# TODO: check to see if table(s) need to be greated or not
|
|
|
|
self.engine = create_engine(self.uri)
|
|
Base.metadata.create_all(
|
|
self.engine, tables=[UserModel.__table__, AgentModel.__table__, SourceModel.__table__, AgentSourceMappingModel.__table__]
|
|
)
|
|
session_maker = sessionmaker(bind=self.engine)
|
|
self.session = session_maker()
|
|
|
|
@enforce_types
|
|
def create_agent(self, agent: AgentState):
|
|
# insert into agent table
|
|
# make sure agent.name does not already exist for user user_id
|
|
if self.session.query(AgentModel).filter(AgentModel.name == agent.name).filter(AgentModel.user_id == agent.user_id).count() > 0:
|
|
raise ValueError(f"Agent with name {agent.name} already exists")
|
|
self.session.add(AgentModel(**vars(agent)))
|
|
self.session.commit()
|
|
|
|
@enforce_types
|
|
def create_source(self, source: Source):
|
|
# make sure source.name does not already exist for user
|
|
if (
|
|
self.session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count()
|
|
> 0
|
|
):
|
|
raise ValueError(f"Source with name {source.name} already exists")
|
|
self.session.add(SourceModel(**vars(source)))
|
|
self.session.commit()
|
|
|
|
@enforce_types
|
|
def create_user(self, user: User):
|
|
if self.session.query(UserModel).filter(UserModel.id == user.id).count() > 0:
|
|
raise ValueError(f"User with id {user.id} already exists")
|
|
self.session.add(UserModel(**vars(user)))
|
|
self.session.commit()
|
|
|
|
@enforce_types
|
|
def update_agent(self, agent: AgentState):
|
|
self.session.query(AgentModel).filter(AgentModel.id == agent.id).update(vars(agent))
|
|
self.session.commit()
|
|
|
|
@enforce_types
|
|
def update_user(self, user: User):
|
|
self.session.query(UserModel).filter(UserModel.id == user.id).update(vars(user))
|
|
self.session.commit()
|
|
|
|
@enforce_types
|
|
def update_source(self, source: Source):
|
|
self.session.query(SourceModel).filter(SourceModel.id == source.id).update(vars(source))
|
|
self.session.commit()
|
|
|
|
@enforce_types
|
|
def delete_agent(self, agent_id: uuid.UUID):
|
|
self.session.query(AgentModel).filter(AgentModel.id == agent_id).delete()
|
|
self.session.commit()
|
|
|
|
@enforce_types
|
|
def delete_source(self, source_id: uuid.UUID):
|
|
# delete from sources table
|
|
self.session.query(SourceModel).filter(SourceModel.id == source_id).delete()
|
|
|
|
# delete any mappings
|
|
self.session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).delete()
|
|
|
|
self.session.commit()
|
|
|
|
@enforce_types
|
|
def delete_user(self, user_id: uuid.UUID):
|
|
# delete from users table
|
|
self.session.query(UserModel).filter(UserModel.id == user_id).delete()
|
|
|
|
# delete associated agents
|
|
self.session.query(AgentModel).filter(AgentModel.user_id == user_id).delete()
|
|
|
|
# delete associated sources
|
|
self.session.query(SourceModel).filter(SourceModel.user_id == user_id).delete()
|
|
|
|
# delete associated mappings
|
|
self.session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete()
|
|
|
|
self.session.commit()
|
|
|
|
@enforce_types
|
|
def list_agents(self, user_id: uuid.UUID) -> List[AgentState]:
|
|
results = self.session.query(AgentModel).filter(AgentModel.user_id == user_id).all()
|
|
return [r.to_record() for r in results]
|
|
|
|
@enforce_types
|
|
def list_sources(self, user_id: uuid.UUID) -> List[Source]:
|
|
results = self.session.query(SourceModel).filter(SourceModel.user_id == user_id).all()
|
|
return [r.to_record() for r in results]
|
|
|
|
@enforce_types
|
|
def get_agent(
|
|
self, agent_id: Optional[uuid.UUID] = None, agent_name: Optional[str] = None, user_id: Optional[uuid.UUID] = None
|
|
) -> Optional[AgentState]:
|
|
if agent_id:
|
|
results = self.session.query(AgentModel).filter(AgentModel.id == agent_id).all()
|
|
else:
|
|
assert agent_name is not None and user_id is not None, "Must provide either agent_id or agent_name"
|
|
results = self.session.query(AgentModel).filter(AgentModel.name == agent_name).filter(AgentModel.user_id == user_id).all()
|
|
|
|
if len(results) == 0:
|
|
return None
|
|
assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result
|
|
return results[0].to_record()
|
|
|
|
@enforce_types
|
|
def get_user(self, user_id: uuid.UUID) -> Optional[User]:
|
|
results = self.session.query(UserModel).filter(UserModel.id == user_id).all()
|
|
if len(results) == 0:
|
|
return None
|
|
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
|
return results[0].to_record()
|
|
|
|
@enforce_types
|
|
def get_source(
|
|
self, source_id: Optional[uuid.UUID] = None, user_id: Optional[uuid.UUID] = None, source_name: Optional[str] = None
|
|
) -> Optional[Source]:
|
|
if source_id:
|
|
results = self.session.query(SourceModel).filter(SourceModel.id == source_id).all()
|
|
else:
|
|
assert user_id is not None and source_name is not None
|
|
results = self.session.query(SourceModel).filter(SourceModel.name == source_name).filter(SourceModel.user_id == user_id).all()
|
|
if len(results) == 0:
|
|
return None
|
|
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
|
return results[0].to_record()
|
|
|
|
# agent source metadata
|
|
@enforce_types
|
|
def attach_source(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_id: uuid.UUID):
|
|
self.session.add(AgentSourceMappingModel(user_id=user_id, agent_id=agent_id, source_id=source_id))
|
|
self.session.commit()
|
|
|
|
@enforce_types
|
|
def list_attached_sources(self, agent_id: uuid.UUID) -> List[Column]:
|
|
results = self.session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all()
|
|
return [r.source_id for r in results]
|
|
|
|
@enforce_types
|
|
def list_attached_agents(self, source_id: uuid.UUID):
|
|
results = self.session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).all()
|
|
return [r.agent_id for r in results]
|
|
|
|
@enforce_types
|
|
def detach_source(self, agent_id: uuid.UUID, source_id: uuid.UUID):
|
|
self.session.query(AgentSourceMappingModel).filter(
|
|
AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id
|
|
).delete()
|
|
self.session.commit()
|