feat: Enable adding files (#1864)

Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
Matthew Zhou
2024-10-14 10:22:45 -07:00
committed by GitHub
parent 9a44cc3df7
commit 93aacc087e
26 changed files with 565 additions and 223 deletions

3
.gitignore vendored
View File

@@ -2,6 +2,9 @@
# Created by https://www.toptal.com/developers/gitignore/api/vim,linux,macos,pydev,python,eclipse,pycharm,windows,netbeans,pycharm+all,pycharm+iml,visualstudio,jupyternotebooks,visualstudiocode,xcode,xcodeinjection
# Edit at https://www.toptal.com/developers/gitignore?templates=vim,linux,macos,pydev,python,eclipse,pycharm,windows,netbeans,pycharm+all,pycharm+iml,visualstudio,jupyternotebooks,visualstudiocode,xcode,xcodeinjection
openapi_letta.json
openapi_openai.json
### Eclipse ###
.metadata
bin/

View File

@@ -70,7 +70,7 @@ schema_models = [
"Message",
"Passage",
"AgentState",
"Document",
"File",
"Source",
"LLMConfig",
"EmbeddingConfig",

View File

@@ -270,7 +270,7 @@
"outputs": [],
"source": [
"from letta.data_sources.connectors import DataConnector \n",
"from letta.schemas.document import Document\n",
"from letta.schemas.file import FileMetadata\n",
"from llama_index.core import Document as LlamaIndexDocument\n",
"from llama_index.core import SummaryIndex\n",
"from llama_index.readers.web import SimpleWebPageReader\n",

View File

@@ -7,9 +7,9 @@ from letta.client.client import LocalClient, RESTClient, create_client
# imports for easier access
from letta.schemas.agent import AgentState
from letta.schemas.block import Block
from letta.schemas.document import Document
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import JobStatus
from letta.schemas.file import FileMetadata
from letta.schemas.job import Job
from letta.schemas.letta_message import LettaMessage
from letta.schemas.llm_config import LLMConfig

View File

@@ -28,7 +28,7 @@ from letta.agent_store.storage import StorageConnector, TableType
from letta.base import Base
from letta.config import LettaConfig
from letta.constants import MAX_EMBEDDING_DIM
from letta.metadata import EmbeddingConfigColumn, ToolCallColumn
from letta.metadata import EmbeddingConfigColumn, FileMetadataModel, ToolCallColumn
# from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
from letta.schemas.message import Message
@@ -141,7 +141,7 @@ class PassageModel(Base):
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
text = Column(String)
doc_id = Column(String)
file_id = Column(String)
agent_id = Column(String)
source_id = Column(String)
@@ -160,7 +160,7 @@ class PassageModel(Base):
# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True))
Index("passage_idx_user", user_id, agent_id, doc_id),
Index("passage_idx_user", user_id, agent_id, file_id),
def __repr__(self):
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
@@ -170,7 +170,7 @@ class PassageModel(Base):
text=self.text,
embedding=self.embedding,
embedding_config=self.embedding_config,
doc_id=self.doc_id,
file_id=self.file_id,
user_id=self.user_id,
id=self.id,
source_id=self.source_id,
@@ -365,12 +365,17 @@ class PostgresStorageConnector(SQLStorageConnector):
self.uri = self.config.archival_storage_uri
self.db_model = PassageModel
if self.config.archival_storage_uri is None:
raise ValueError(f"Must specifiy archival_storage_uri in config {self.config.config_path}")
raise ValueError(f"Must specify archival_storage_uri in config {self.config.config_path}")
elif table_type == TableType.RECALL_MEMORY:
self.uri = self.config.recall_storage_uri
self.db_model = MessageModel
if self.config.recall_storage_uri is None:
raise ValueError(f"Must specifiy recall_storage_uri in config {self.config.config_path}")
raise ValueError(f"Must specify recall_storage_uri in config {self.config.config_path}")
elif table_type == TableType.FILES:
self.uri = self.config.metadata_storage_uri
self.db_model = FileMetadataModel
if self.config.metadata_storage_uri is None:
raise ValueError(f"Must specify metadata_storage_uri in config {self.config.config_path}")
else:
raise ValueError(f"Table type {table_type} not implemented")
@@ -487,8 +492,14 @@ class SQLLiteStorageConnector(SQLStorageConnector):
# TODO: eventually implement URI option
self.path = self.config.recall_storage_path
if self.path is None:
raise ValueError(f"Must specifiy recall_storage_path in config {self.config.recall_storage_path}")
raise ValueError(f"Must specify recall_storage_path in config.")
self.db_model = MessageModel
elif table_type == TableType.FILES:
self.path = self.config.metadata_storage_path
if self.path is None:
raise ValueError(f"Must specify metadata_storage_path in config.")
self.db_model = FileMetadataModel
else:
raise ValueError(f"Table type {table_type} not implemented")

View File

@@ -24,7 +24,7 @@ def get_db_model(table_name: str, table_type: TableType):
id: uuid.UUID
user_id: str
text: str
doc_id: str
file_id: str
agent_id: str
data_source: str
embedding: Vector(config.default_embedding_config.embedding_dim)
@@ -37,7 +37,7 @@ def get_db_model(table_name: str, table_type: TableType):
return Passage(
text=self.text,
embedding=self.embedding,
doc_id=self.doc_id,
file_id=self.file_id,
user_id=self.user_id,
id=self.id,
data_source=self.data_source,

View File

@@ -26,7 +26,7 @@ class MilvusStorageConnector(StorageConnector):
raise ValueError("Please set `archival_storage_uri` in the config file when using Milvus.")
# need to be converted to strings
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "file_id"]
def _create_collection(self):
schema = MilvusClient.create_schema(

View File

@@ -38,7 +38,7 @@ class QdrantStorageConnector(StorageConnector):
distance=models.Distance.COSINE,
),
)
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "file_id"]
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 10) -> Iterator[List[RecordType]]:
from qdrant_client import grpc

View File

@@ -10,7 +10,7 @@ from typing import Dict, List, Optional, Tuple, Type, Union
from pydantic import BaseModel
from letta.config import LettaConfig
from letta.schemas.document import Document
from letta.schemas.file import FileMetadata
from letta.schemas.message import Message
from letta.schemas.passage import Passage
from letta.utils import printd
@@ -22,7 +22,7 @@ class TableType:
ARCHIVAL_MEMORY = "archival_memory" # recall memory table: letta_agent_{agent_id}
RECALL_MEMORY = "recall_memory" # archival memory table: letta_agent_recall_{agent_id}
PASSAGES = "passages" # TODO
DOCUMENTS = "documents" # TODO
FILES = "files"
# table names used by Letta
@@ -33,17 +33,17 @@ ARCHIVAL_TABLE_NAME = "letta_archival_memory_agent" # agent memory
# external data source tables
PASSAGE_TABLE_NAME = "letta_passages" # chunked/embedded passages (from source)
DOCUMENT_TABLE_NAME = "letta_documents" # original documents (from source)
FILE_TABLE_NAME = "letta_files" # original files (from source)
class StorageConnector:
"""Defines a DB connection that is user-specific to access data: Documents, Passages, Archival/Recall Memory"""
"""Defines a DB connection that is user-specific to access data: files, Passages, Archival/Recall Memory"""
type: Type[BaseModel]
def __init__(
self,
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.DOCUMENTS],
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES],
config: LettaConfig,
user_id,
agent_id=None,
@@ -59,9 +59,9 @@ class StorageConnector:
elif table_type == TableType.RECALL_MEMORY:
self.type = Message
self.table_name = RECALL_TABLE_NAME
elif table_type == TableType.DOCUMENTS:
self.type = Document
self.table_name == DOCUMENT_TABLE_NAME
elif table_type == TableType.FILES:
self.type = FileMetadata
self.table_name = FILE_TABLE_NAME
elif table_type == TableType.PASSAGES:
self.type = Passage
self.table_name = PASSAGE_TABLE_NAME
@@ -74,7 +74,7 @@ class StorageConnector:
# agent-specific table
assert agent_id is not None, "Agent ID must be provided for agent-specific tables"
self.filters = {"user_id": self.user_id, "agent_id": self.agent_id}
elif self.table_type == TableType.PASSAGES or self.table_type == TableType.DOCUMENTS:
elif self.table_type == TableType.PASSAGES or self.table_type == TableType.FILES:
# setup base filters for user-specific tables
assert agent_id is None, "Agent ID must not be provided for user-specific tables"
self.filters = {"user_id": self.user_id}
@@ -83,7 +83,7 @@ class StorageConnector:
@staticmethod
def get_storage_connector(
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.DOCUMENTS],
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES],
config: LettaConfig,
user_id,
agent_id=None,
@@ -92,6 +92,8 @@ class StorageConnector:
storage_type = config.archival_storage_type
elif table_type == TableType.RECALL_MEMORY:
storage_type = config.recall_storage_type
elif table_type == TableType.FILES:
storage_type = config.metadata_storage_type
else:
raise ValueError(f"Table type {table_type} not implemented")

View File

@@ -106,7 +106,7 @@ def load_vector_database(
# document_store=None,
# passage_store=passage_storage,
# )
# print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
# print(f"Loaded {num_passages} passages and {num_documents} files from {name}")
# except Exception as e:
# typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
# ms.delete_source(source_id=source.id)

View File

@@ -25,6 +25,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
# new schemas
from letta.schemas.enums import JobStatus, MessageRole
from letta.schemas.file import FileMetadata
from letta.schemas.job import Job
from letta.schemas.letta_request import LettaRequest
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
@@ -232,6 +233,9 @@ class AbstractClient(object):
def list_attached_sources(self, agent_id: str) -> List[Source]:
raise NotImplementedError
def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
raise NotImplementedError
def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
raise NotImplementedError
@@ -1016,6 +1020,12 @@ class RESTClient(AbstractClient):
raise ValueError(f"Failed to get job: {response.text}")
return Job(**response.json())
def delete_job(self, job_id: str) -> Job:
response = requests.delete(f"{self.base_url}/{self.api_prefix}/jobs/{job_id}", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to delete job: {response.text}")
return Job(**response.json())
def list_jobs(self):
response = requests.get(f"{self.base_url}/{self.api_prefix}/jobs", headers=self.headers)
return [Job(**job) for job in response.json()]
@@ -1088,6 +1098,30 @@ class RESTClient(AbstractClient):
raise ValueError(f"Failed to list attached sources: {response.text}")
return [Source(**source) for source in response.json()]
def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
"""
List files from source with pagination support.
Args:
source_id (str): ID of the source
limit (int): Number of files to return
cursor (Optional[str]): Pagination cursor for fetching the next page
Returns:
List[FileMetadata]: List of files
"""
# Prepare query parameters for pagination
params = {"limit": limit, "cursor": cursor}
# Make the request to the FastAPI endpoint
response = requests.get(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/files", headers=self.headers, params=params)
if response.status_code != 200:
raise ValueError(f"Failed to list files with source id {source_id}: [{response.status_code}] {response.text}")
# Parse the JSON response
return [FileMetadata(**metadata) for metadata in response.json()]
def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
"""
Update a source
@@ -2162,6 +2196,9 @@ class LocalClient(AbstractClient):
def get_job(self, job_id: str):
return self.server.get_job(job_id=job_id)
def delete_job(self, job_id: str):
return self.server.delete_job(job_id)
def list_jobs(self):
return self.server.list_jobs(user_id=self.user_id)
@@ -2261,6 +2298,20 @@ class LocalClient(AbstractClient):
"""
return self.server.list_attached_sources(agent_id=agent_id)
def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
"""
List files from source.
Args:
source_id (str): ID of the source
limit (int): The # of items to return
cursor (str): The cursor for fetching the next page
Returns:
files (List[FileMetadata]): List of files
"""
return self.server.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor)
def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
"""
Update a source

View File

@@ -1,11 +1,15 @@
from typing import Dict, Iterator, List, Optional, Tuple
from typing import Dict, Iterator, List, Tuple
import typer
from llama_index.core import Document as LlamaIndexDocument
from letta.agent_store.storage import StorageConnector
from letta.data_sources.connectors_helper import (
assert_all_files_exist_locally,
extract_metadata_from_files,
get_filenames_in_dir,
)
from letta.embeddings import embedding_model
from letta.schemas.document import Document
from letta.schemas.file import FileMetadata
from letta.schemas.passage import Passage
from letta.schemas.source import Source
from letta.utils import create_uuid_from_string
@@ -13,23 +17,23 @@ from letta.utils import create_uuid_from_string
class DataConnector:
"""
Base class for data connectors that can be extended to generate documents and passages from a custom data source.
Base class for data connectors that can be extended to generate files and passages from a custom data source.
"""
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
def find_files(self, source: Source) -> Iterator[FileMetadata]:
"""
Generate document text and metadata from a data source.
Generate file metadata from a data source.
Returns:
documents (Iterator[Tuple[str, Dict]]): Generate a tuple of string text and metadata dictionary for each document.
files (Iterator[FileMetadata]): Generate file metadata for each file found.
"""
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
"""
Generate passage text and metadata from a list of documents.
Generate passage text and metadata from a list of files.
Args:
documents (List[Document]): List of documents to generate passages from.
file (FileMetadata): The document to generate passages from.
chunk_size (int, optional): Chunk size for splitting passages. Defaults to 1024.
Returns:
@@ -41,33 +45,25 @@ def load_data(
connector: DataConnector,
source: Source,
passage_store: StorageConnector,
document_store: Optional[StorageConnector] = None,
file_metadata_store: StorageConnector,
):
"""Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id."""
"""Load data from a connector (generates file and passages) into a specified source_id, associatedw with a user_id."""
embedding_config = source.embedding_config
# embedding model
embed_model = embedding_model(embedding_config)
# insert passages/documents
# insert passages/file
passages = []
embedding_to_document_name = {}
passage_count = 0
document_count = 0
for document_text, document_metadata in connector.generate_documents():
# insert document into storage
document = Document(
text=document_text,
metadata_=document_metadata,
source_id=source.id,
user_id=source.user_id,
)
document_count += 1
if document_store:
document_store.insert(document)
file_count = 0
for file_metadata in connector.find_files(source):
file_count += 1
file_metadata_store.insert(file_metadata)
# generate passages
for passage_text, passage_metadata in connector.generate_passages([document], chunk_size=embedding_config.embedding_chunk_size):
for passage_text, passage_metadata in connector.generate_passages(file_metadata, chunk_size=embedding_config.embedding_chunk_size):
# for some reason, llama index parsers sometimes return empty strings
if len(passage_text) == 0:
typer.secho(
@@ -89,7 +85,7 @@ def load_data(
passage = Passage(
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
text=passage_text,
doc_id=document.id,
file_id=file_metadata.id,
source_id=source.id,
metadata_=passage_metadata,
user_id=source.user_id,
@@ -98,16 +94,16 @@ def load_data(
)
hashable_embedding = tuple(passage.embedding)
document_name = document.metadata_.get("file_path", document.id)
file_name = file_metadata.file_name
if hashable_embedding in embedding_to_document_name:
typer.secho(
f"Warning: Duplicate embedding found for passage in {document_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.",
f"Warning: Duplicate embedding found for passage in {file_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.",
fg=typer.colors.YELLOW,
)
continue
passages.append(passage)
embedding_to_document_name[hashable_embedding] = document_name
embedding_to_document_name[hashable_embedding] = file_name
if len(passages) >= 100:
# insert passages into passage store
passage_store.insert_many(passages)
@@ -120,7 +116,7 @@ def load_data(
passage_store.insert_many(passages)
passage_count += len(passages)
return passage_count, document_count
return passage_count, file_count
class DirectoryConnector(DataConnector):
@@ -143,105 +139,109 @@ class DirectoryConnector(DataConnector):
if self.recursive == True:
assert self.input_directory is not None, "Must provide input directory if recursive is True."
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
from llama_index.core import SimpleDirectoryReader
def find_files(self, source: Source) -> Iterator[FileMetadata]:
if self.input_directory is not None:
reader = SimpleDirectoryReader(
files = get_filenames_in_dir(
input_dir=self.input_directory,
recursive=self.recursive,
required_exts=[ext.strip() for ext in str(self.extensions).split(",")],
exclude=["*png", "*jpg", "*jpeg"],
)
else:
assert self.input_files is not None, "Must provide input files if input_dir is None"
reader = SimpleDirectoryReader(input_files=[str(f) for f in self.input_files])
files = self.input_files
llama_index_docs = reader.load_data(show_progress=True)
for llama_index_doc in llama_index_docs:
# TODO: add additional metadata?
# doc = Document(text=llama_index_doc.text, metadata=llama_index_doc.metadata)
# docs.append(doc)
yield llama_index_doc.text, llama_index_doc.metadata
# Check that file paths are valid
assert_all_files_exist_locally(files)
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
# use llama index to run embeddings code
# from llama_index.core.node_parser import SentenceSplitter
for metadata in extract_metadata_from_files(files):
yield FileMetadata(
user_id=source.user_id,
source_id=source.id,
file_name=metadata.get("file_name"),
file_path=metadata.get("file_path"),
file_type=metadata.get("file_type"),
file_size=metadata.get("file_size"),
file_creation_date=metadata.get("file_creation_date"),
file_last_modified_date=metadata.get("file_last_modified_date"),
)
def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]:
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import TokenTextSplitter
parser = TokenTextSplitter(chunk_size=chunk_size)
for document in documents:
llama_index_docs = [LlamaIndexDocument(text=document.text, metadata=document.metadata_)]
nodes = parser.get_nodes_from_documents(llama_index_docs)
documents = SimpleDirectoryReader(input_files=[file.file_path]).load_data()
nodes = parser.get_nodes_from_documents(documents)
for node in nodes:
# passage = Passage(
# text=node.text,
# doc_id=document.id,
# )
yield node.text, None
class WebConnector(DirectoryConnector):
def __init__(self, urls: List[str] = None, html_to_text: bool = True):
self.urls = urls
self.html_to_text = html_to_text
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
from llama_index.readers.web import SimpleWebPageReader
documents = SimpleWebPageReader(html_to_text=self.html_to_text).load_data(self.urls)
for document in documents:
yield document.text, {"url": document.id_}
class VectorDBConnector(DataConnector):
# NOTE: this class has not been properly tested, so is unlikely to work
# TODO: allow loading multiple tables (1:1 mapping between Document and Table)
def __init__(
self,
name: str,
uri: str,
table_name: str,
text_column: str,
embedding_column: str,
embedding_dim: int,
):
self.name = name
self.uri = uri
self.table_name = table_name
self.text_column = text_column
self.embedding_column = embedding_column
self.embedding_dim = embedding_dim
# connect to db table
from sqlalchemy import create_engine
self.engine = create_engine(uri)
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
yield self.table_name, None
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
from pgvector.sqlalchemy import Vector
from sqlalchemy import Inspector, MetaData, Table, select
metadata = MetaData()
# Create an inspector to inspect the database
inspector = Inspector.from_engine(self.engine)
table_names = inspector.get_table_names()
assert self.table_name in table_names, f"Table {self.table_name} not found in database: tables that exist {table_names}."
table = Table(self.table_name, metadata, autoload_with=self.engine)
# Prepare a select statement
select_statement = select(table.c[self.text_column], table.c[self.embedding_column].cast(Vector(self.embedding_dim)))
# Execute the query and fetch the results
# TODO: paginate results
with self.engine.connect() as connection:
result = connection.execute(select_statement).fetchall()
for text, embedding in result:
# assume that embeddings are the same model as in config
# TODO: don't re-compute embedding
yield text, {"embedding": embedding}
"""
The below isn't used anywhere, it isn't tested, and pretty much should be deleted.
- Matt
"""
# class WebConnector(DirectoryConnector):
# def __init__(self, urls: List[str] = None, html_to_text: bool = True):
# self.urls = urls
# self.html_to_text = html_to_text
#
# def generate_files(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
# from llama_index.readers.web import SimpleWebPageReader
#
# files = SimpleWebPageReader(html_to_text=self.html_to_text).load_data(self.urls)
# for document in files:
# yield document.text, {"url": document.id_}
#
#
# class VectorDBConnector(DataConnector):
# # NOTE: this class has not been properly tested, so is unlikely to work
# # TODO: allow loading multiple tables (1:1 mapping between FileMetadata and Table)
#
# def __init__(
# self,
# name: str,
# uri: str,
# table_name: str,
# text_column: str,
# embedding_column: str,
# embedding_dim: int,
# ):
# self.name = name
# self.uri = uri
# self.table_name = table_name
# self.text_column = text_column
# self.embedding_column = embedding_column
# self.embedding_dim = embedding_dim
#
# # connect to db table
# from sqlalchemy import create_engine
#
# self.engine = create_engine(uri)
#
# def generate_files(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
# yield self.table_name, None
#
# def generate_passages(self, file_text: str, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
# from pgvector.sqlalchemy import Vector
# from sqlalchemy import Inspector, MetaData, Table, select
#
# metadata = MetaData()
# # Create an inspector to inspect the database
# inspector = Inspector.from_engine(self.engine)
# table_names = inspector.get_table_names()
# assert self.table_name in table_names, f"Table {self.table_name} not found in database: tables that exist {table_names}."
#
# table = Table(self.table_name, metadata, autoload_with=self.engine)
#
# # Prepare a select statement
# select_statement = select(table.c[self.text_column], table.c[self.embedding_column].cast(Vector(self.embedding_dim)))
#
# # Execute the query and fetch the results
# # TODO: paginate results
# with self.engine.connect() as connection:
# result = connection.execute(select_statement).fetchall()
#
# for text, embedding in result:
# # assume that embeddings are the same model as in config
# # TODO: don't re-compute embedding
# yield text, {"embedding": embedding}

View File

@@ -0,0 +1,97 @@
import mimetypes
import os
from datetime import datetime
from pathlib import Path
from typing import List, Optional
def extract_file_metadata(file_path) -> dict:
"""Extracts metadata from a single file."""
if not os.path.exists(file_path):
raise FileNotFoundError(file_path)
file_metadata = {
"file_name": os.path.basename(file_path),
"file_path": file_path,
"file_type": mimetypes.guess_type(file_path)[0] or "unknown",
"file_size": os.path.getsize(file_path),
"file_creation_date": datetime.fromtimestamp(os.path.getctime(file_path)).strftime("%Y-%m-%d"),
"file_last_modified_date": datetime.fromtimestamp(os.path.getmtime(file_path)).strftime("%Y-%m-%d"),
}
return file_metadata
def extract_metadata_from_files(file_list):
"""Extracts metadata for a list of files."""
metadata = []
for file_path in file_list:
file_metadata = extract_file_metadata(file_path)
if file_metadata:
metadata.append(file_metadata)
return metadata
def get_filenames_in_dir(
input_dir: str, recursive: bool = True, required_exts: Optional[List[str]] = None, exclude: Optional[List[str]] = None
):
"""
Recursively reads files from the directory, applying required_exts and exclude filters.
Ensures that required_exts and exclude do not overlap.
Args:
input_dir (str): The directory to scan for files.
recursive (bool): Whether to scan directories recursively.
required_exts (list): List of file extensions to include (e.g., ['pdf', 'txt']).
If None or empty, matches any file extension.
exclude (list): List of file patterns to exclude (e.g., ['*png', '*jpg']).
Returns:
list: A list of matching file paths.
"""
required_exts = required_exts or []
exclude = exclude or []
# Ensure required_exts and exclude do not overlap
ext_set = set(required_exts)
exclude_set = set(exclude)
overlap = ext_set & exclude_set
if overlap:
raise ValueError(f"Extensions in required_exts and exclude overlap: {overlap}")
def is_excluded(file_name):
"""Check if a file matches any pattern in the exclude list."""
for pattern in exclude:
if Path(file_name).match(pattern):
return True
return False
files = []
search_pattern = "**/*" if recursive else "*"
for file_path in Path(input_dir).glob(search_pattern):
if file_path.is_file() and not is_excluded(file_path.name):
ext = file_path.suffix.lstrip(".")
# If required_exts is empty, match any file
if not required_exts or ext in required_exts:
files.append(file_path)
return files
def assert_all_files_exist_locally(file_paths: List[str]) -> bool:
"""
Checks if all file paths in the provided list exist locally.
Raises a FileNotFoundError with a list of missing files if any do not exist.
Args:
file_paths (List[str]): List of file paths to check.
Returns:
bool: True if all files exist, raises FileNotFoundError if any file is missing.
"""
missing_files = [file_path for file_path in file_paths if not Path(file_path).exists()]
if missing_files:
raise FileNotFoundError(missing_files)
return True

View File

@@ -11,6 +11,7 @@ from sqlalchemy import (
Column,
DateTime,
Index,
Integer,
String,
TypeDecorator,
desc,
@@ -24,6 +25,7 @@ from letta.schemas.api_key import APIKey
from letta.schemas.block import Block, Human, Persona
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import JobStatus
from letta.schemas.file import FileMetadata
from letta.schemas.job import Job
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
@@ -38,6 +40,41 @@ from letta.settings import settings
from letta.utils import enforce_types, get_utc_time, printd
class FileMetadataModel(Base):
__tablename__ = "files"
__table_args__ = {"extend_existing": True}
id = Column(String, primary_key=True, nullable=False)
user_id = Column(String, nullable=False)
# TODO: Investigate why this breaks during table creation due to FK
# source_id = Column(String, ForeignKey("sources.id"), nullable=False)
source_id = Column(String, nullable=False)
file_name = Column(String, nullable=True)
file_path = Column(String, nullable=True)
file_type = Column(String, nullable=True)
file_size = Column(Integer, nullable=True)
file_creation_date = Column(String, nullable=True)
file_last_modified_date = Column(String, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
def __repr__(self):
return f"<FileMetadata(id='{self.id}', source_id='{self.source_id}', file_name='{self.file_name}')>"
def to_record(self):
return FileMetadata(
id=self.id,
user_id=self.user_id,
source_id=self.source_id,
file_name=self.file_name,
file_path=self.file_path,
file_type=self.file_type,
file_size=self.file_size,
file_creation_date=self.file_creation_date,
file_last_modified_date=self.file_last_modified_date,
created_at=self.created_at,
)
class LLMConfigColumn(TypeDecorator):
"""Custom type for storing LLMConfig as JSON"""
@@ -865,6 +902,27 @@ class MetadataStore:
session.add(JobModel(**vars(job)))
session.commit()
@enforce_types
def list_files_from_source(self, source_id: str, limit: int, cursor: Optional[str]):
with self.session_maker() as session:
# Start with the basic query filtered by source_id
query = session.query(FileMetadataModel).filter(FileMetadataModel.source_id == source_id)
if cursor:
# Assuming cursor is the ID of the last file in the previous page
query = query.filter(FileMetadataModel.id > cursor)
# Order by ID or other ordering criteria to ensure correct pagination
query = query.order_by(FileMetadataModel.id)
# Limit the number of results returned
results = query.limit(limit).all()
# Convert the results to the required FileMetadata objects
files = [r.to_record() for r in results]
return files
def delete_job(self, job_id: str):
with self.session_maker() as session:
session.query(JobModel).filter(JobModel.id == job_id).delete()

View File

@@ -1,21 +0,0 @@
from typing import Dict, Optional
from pydantic import Field
from letta.schemas.letta_base import LettaBase
class DocumentBase(LettaBase):
"""Base class for document schemas"""
__id_prefix__ = "doc"
class Document(DocumentBase):
"""Representation of a single document (broken up into `Passage` objects)"""
id: str = DocumentBase.generate_id_field()
text: str = Field(..., description="The text of the document.")
source_id: str = Field(..., description="The unique identifier of the source associated with the document.")
user_id: str = Field(description="The unique identifier of the user associated with the document.")
metadata_: Optional[Dict] = Field({}, description="The metadata of the document.")

31
letta/schemas/file.py Normal file
View File

@@ -0,0 +1,31 @@
from datetime import datetime
from typing import Optional
from pydantic import Field
from letta.schemas.letta_base import LettaBase
from letta.utils import get_utc_time
class FileMetadataBase(LettaBase):
"""Base class for FileMetadata schemas"""
__id_prefix__ = "file"
class FileMetadata(FileMetadataBase):
"""Representation of a single FileMetadata"""
id: str = FileMetadataBase.generate_id_field()
user_id: str = Field(description="The unique identifier of the user associated with the document.")
source_id: str = Field(..., description="The unique identifier of the source associated with the document.")
file_name: Optional[str] = Field(None, description="The name of the file.")
file_path: Optional[str] = Field(None, description="The path to the file.")
file_type: Optional[str] = Field(None, description="The type of the file (MIME type).")
file_size: Optional[int] = Field(None, description="The size of the file in bytes.")
file_creation_date: Optional[str] = Field(None, description="The creation date of the file.")
file_last_modified_date: Optional[str] = Field(None, description="The last modified date of the file.")
created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of this file metadata object.")
class Config:
extra = "allow"

View File

@@ -15,7 +15,7 @@ class JobBase(LettaBase):
class Job(JobBase):
"""
Representation of offline jobs, used for tracking status of data loading tasks (involving parsing and embedding documents).
Representation of offline jobs, used for tracking status of data loading tasks (involving parsing and embedding files).
Parameters:
id (str): The unique identifier of the job.

View File

@@ -19,8 +19,8 @@ class PassageBase(LettaBase):
# origin data source
source_id: Optional[str] = Field(None, description="The data source of the passage.")
# document association
doc_id: Optional[str] = Field(None, description="The unique identifier of the document associated with the passage.")
# file association
file_id: Optional[str] = Field(None, description="The unique identifier of the file associated with the passage.")
metadata_: Optional[Dict] = Field({}, description="The metadata of the passage.")
@@ -36,7 +36,7 @@ class Passage(PassageBase):
user_id (str): The unique identifier of the user associated with the passage.
agent_id (str): The unique identifier of the agent associated with the passage.
source_id (str): The data source of the passage.
doc_id (str): The unique identifier of the document associated with the passage.
file_id (str): The unique identifier of the file associated with the passage.
"""
id: str = PassageBase.generate_id_field()

View File

@@ -28,7 +28,7 @@ class SourceCreate(BaseSource):
class Source(BaseSource):
"""
Representation of a source, which is a collection of documents and passages.
Representation of a source, which is a collection of files and passages.
Parameters:
id (str): The ID of the source
@@ -59,4 +59,4 @@ class UploadFileToSourceRequest(BaseModel):
class UploadFileToSourceResponse(BaseModel):
source: Source = Field(..., description="The source the file was uploaded to.")
added_passages: int = Field(..., description="The number of passages added to the source.")
added_documents: int = Field(..., description="The number of documents added to the source.")
added_documents: int = Field(..., description="The number of files added to the source.")

View File

@@ -1,6 +1,6 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, Header, Query
from fastapi import APIRouter, Depends, Header, HTTPException, Query
from letta.schemas.job import Job
from letta.server.rest_api.utils import get_letta_server
@@ -54,3 +54,19 @@ def get_job(
"""
return server.get_job(job_id=job_id)
@router.delete("/{job_id}", response_model=Job, operation_id="delete_job")
def delete_job(
job_id: str,
server: "SyncServer" = Depends(get_letta_server),
):
"""
Delete a job by its job_id.
"""
job = server.get_job(job_id=job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
server.delete_job(job_id=job_id)
return job

View File

@@ -4,7 +4,7 @@ from typing import List, Optional
from fastapi import APIRouter, BackgroundTasks, Depends, Header, Query, UploadFile
from letta.schemas.document import Document
from letta.schemas.file import FileMetadata
from letta.schemas.job import Job
from letta.schemas.passage import Passage
from letta.schemas.source import Source, SourceCreate, SourceUpdate
@@ -186,19 +186,17 @@ def list_passages(
return passages
@router.get("/{source_id}/documents", response_model=List[Document], operation_id="list_source_documents")
def list_documents(
@router.get("/{source_id}/files", response_model=List[FileMetadata], operation_id="list_files_from_source")
def list_files_from_source(
source_id: str,
limit: int = Query(1000, description="Number of files to return"),
cursor: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all documents associated with a data source.
List paginated files associated with a data source.
"""
actor = server.get_user_or_default(user_id=user_id)
documents = server.list_data_source_documents(user_id=actor.id, source_id=source_id)
return documents
return server.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor)
def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes):

View File

@@ -63,11 +63,11 @@ from letta.schemas.block import (
CreatePersona,
UpdateBlock,
)
from letta.schemas.document import Document
from letta.schemas.embedding_config import EmbeddingConfig
# openai schemas
from letta.schemas.enums import JobStatus
from letta.schemas.file import FileMetadata
from letta.schemas.job import Job
from letta.schemas.letta_message import LettaMessage
from letta.schemas.llm_config import LLMConfig
@@ -1596,7 +1596,7 @@ class SyncServer(Server):
# job.status = JobStatus.failed
# job.metadata_["error"] = error
# self.ms.update_job(job)
# # TODO: delete any associated passages/documents?
# # TODO: delete any associated passages/files?
# # return failed job
# return job
@@ -1625,11 +1625,10 @@ class SyncServer(Server):
# get the data connectors
passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
# TODO: add document store support
document_store = None # StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id)
file_store = StorageConnector.get_storage_connector(TableType.FILES, self.config, user_id=user_id)
# load data into the document store
passage_count, document_count = load_data(connector, source, passage_store, document_store)
passage_count, document_count = load_data(connector, source, passage_store, file_store)
return passage_count, document_count
def attach_source_to_agent(
@@ -1686,14 +1685,14 @@ class SyncServer(Server):
# list all attached sources to an agent
return self.ms.list_attached_sources(agent_id)
def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
# list all attached sources to an agent
return self.ms.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor)
def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]:
warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning)
return []
def list_data_source_documents(self, user_id: str, source_id: str) -> List[Document]:
warnings.warn("list_data_source_documents is not yet implemented, returning empty list.", category=UserWarning)
return []
def list_all_sources(self, user_id: str) -> List[Source]:
"""List all sources (w/ extra metadata) belonging to a user"""
@@ -1707,9 +1706,9 @@ class SyncServer(Server):
passage_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
num_passages = passage_conn.size({"source_id": source.id})
# TODO: add when documents table implemented
## count number of documents
# document_conn = StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id)
# TODO: add when files table implemented
## count number of files
# document_conn = StorageConnector.get_storage_connector(TableType.FILES, self.config, user_id=user_id)
# num_documents = document_conn.size({"data_source": source.name})
num_documents = 0

1
tests/data/test.txt Normal file
View File

@@ -0,0 +1 @@
test

View File

@@ -0,0 +1,34 @@
import time
from typing import Union
from letta import LocalClient, RESTClient
from letta.schemas.enums import JobStatus
from letta.schemas.job import Job
from letta.schemas.source import Source
def upload_file_using_client(client: Union[LocalClient, RESTClient], source: Source, filename: str) -> Job:
# load a file into a source (non-blocking job)
upload_job = client.load_file_into_source(filename=filename, source_id=source.id, blocking=False)
print("Upload job", upload_job, upload_job.status, upload_job.metadata_)
# view active jobs
active_jobs = client.list_active_jobs()
jobs = client.list_jobs()
assert upload_job.id in [j.id for j in jobs]
assert len(active_jobs) == 1
assert active_jobs[0].metadata_["source_id"] == source.id
# wait for job to finish (with timeout)
timeout = 120
start_time = time.time()
while True:
status = client.get_job(upload_job.id).status
print(f"\r{status}", end="", flush=True)
if status == JobStatus.completed:
break
time.sleep(1)
if time.time() - start_time > timeout:
raise ValueError("Job did not finish in time")
return upload_job

View File

@@ -12,12 +12,13 @@ from letta.client.client import LocalClient, RESTClient
from letta.constants import DEFAULT_PRESET
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import JobStatus, MessageStreamStatus
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import FunctionCallMessage, InternalMonologue
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.usage import LettaUsageStatistics
from tests.helpers.client_helper import upload_file_using_client
# from tests.utils import create_config
@@ -298,6 +299,70 @@ def test_config(client: Union[LocalClient, RESTClient], agent: AgentState):
# print("CONFIG", config_response)
def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
# clear sources
for source in client.list_sources():
client.delete_source(source.id)
# clear jobs
for job in client.list_jobs():
client.delete_job(job.id)
# create a source
source = client.create_source(name="test_source")
# load files into sources
file_a = "tests/data/memgpt_paper.pdf"
file_b = "tests/data/test.txt"
upload_file_using_client(client, source, file_a)
upload_file_using_client(client, source, file_b)
# Get the first file
files_a = client.list_files_from_source(source.id, limit=1)
assert len(files_a) == 1
assert files_a[0].source_id == source.id
# Use the cursor from response_a to get the remaining file
files_b = client.list_files_from_source(source.id, limit=1, cursor=files_a[-1].id)
assert len(files_b) == 1
assert files_b[0].source_id == source.id
# Check files are different to ensure the cursor works
assert files_a[0].file_name != files_b[0].file_name
# Use the cursor from response_b to list files, should be empty
files = client.list_files_from_source(source.id, limit=1, cursor=files_b[-1].id)
assert len(files) == 0 # Should be empty
def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
# clear sources
for source in client.list_sources():
client.delete_source(source.id)
# clear jobs
for job in client.list_jobs():
client.delete_job(job.id)
# create a source
source = client.create_source(name="test_source")
# load a file into a source (non-blocking job)
filename = "tests/data/memgpt_paper.pdf"
upload_file_using_client(client, source, filename)
# Get the files
files = client.list_files_from_source(source.id)
assert len(files) == 1 # Should be condensed to one document
# Get the memgpt paper
file = files[0]
assert file.file_name == "memgpt_paper.pdf"
assert file.source_id == source.id
def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
@@ -305,6 +370,10 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
for source in client.list_sources():
client.delete_source(source.id)
# clear jobs
for job in client.list_jobs():
client.delete_job(job.id)
# list sources
sources = client.list_sources()
print("listed sources", sources)
@@ -343,28 +412,7 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
# load a file into a source (non-blocking job)
filename = "tests/data/memgpt_paper.pdf"
upload_job = client.load_file_into_source(filename=filename, source_id=source.id, blocking=False)
print("Upload job", upload_job, upload_job.status, upload_job.metadata_)
# view active jobs
active_jobs = client.list_active_jobs()
jobs = client.list_jobs()
print(jobs)
assert upload_job.id in [j.id for j in jobs]
assert len(active_jobs) == 1
assert active_jobs[0].metadata_["source_id"] == source.id
# wait for job to finish (with timeout)
timeout = 120
start_time = time.time()
while True:
status = client.get_job(upload_job.id).status
print(status)
if status == JobStatus.completed:
break
time.sleep(1)
if time.time() - start_time > timeout:
raise ValueError("Job did not finish in time")
upload_job = upload_file_using_client(client, source, filename)
job = client.get_job(upload_job.id)
created_passages = job.metadata_["num_passages"]

View File

@@ -1,5 +1,6 @@
import datetime
import os
from datetime import datetime
from importlib import util
from typing import Dict, Iterator, List, Tuple
@@ -7,7 +8,7 @@ import requests
from letta.config import LettaConfig
from letta.data_sources.connectors import DataConnector
from letta.schemas.document import Document
from letta.schemas.file import FileMetadata
from letta.settings import TestSettings
from .constants import TIMEOUT
@@ -18,14 +19,27 @@ class DummyDataConnector(DataConnector):
def __init__(self, texts: List[str]):
self.texts = texts
self.file_to_text = {}
def generate_documents(self) -> Iterator[Tuple[str, Dict]]:
def find_files(self, source) -> Iterator[FileMetadata]:
for text in self.texts:
yield text, {"metadata": "dummy"}
file_metadata = FileMetadata(
user_id="",
source_id="",
file_name="",
file_path="",
file_type="",
file_size=0, # Set to 0 as a placeholder
file_creation_date="1970-01-01", # Placeholder date
file_last_modified_date="1970-01-01", # Placeholder date
created_at=datetime.utcnow(),
)
self.file_to_text[file_metadata.id] = text
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]:
for doc in documents:
yield doc.text, doc.metadata_
yield file_metadata
def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]:
yield self.file_to_text[file.id], {}
def wipe_config():