Files
letta-server/letta/data_sources/connectors.py
Kian Jones f5c4ab50f4 chore: add ty + pre-commit hook and repeal even more ruff rules (#9504)
* auto fixes

* auto fix pt2 and transitive deps and undefined var checking locals()

* manual fixes (ignored or letta-code fixed)

* fix circular import

* remove all ignores, add FastAPI rules and Ruff rules

* add ty and precommit

* ruff stuff

* ty check fixes

* ty check fixes pt 2

* error on invalid
2026-02-24 10:55:11 -08:00

214 lines
8.8 KiB
Python

from typing import TYPE_CHECKING, Dict, Iterator, List, Tuple
if TYPE_CHECKING:
from letta.schemas.user import User
import typer
from letta.constants import EMBEDDING_BATCH_SIZE
from letta.data_sources.connectors_helper import assert_all_files_exist_locally, extract_metadata_from_files, get_filenames_in_dir
from letta.schemas.file import FileMetadata
from letta.schemas.passage import Passage
from letta.schemas.source import Source
from letta.services.file_manager import FileManager
from letta.services.passage_manager import PassageManager
class DataConnector:
"""
Base class for data connectors that can be extended to generate files and passages from a custom data source.
"""
def find_files(self, source: Source) -> Iterator[FileMetadata]:
"""
Generate file metadata from a data source.
Returns:
files (Iterator[FileMetadata]): Generate file metadata for each file found.
"""
def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
"""
Generate passage text and metadata from a list of files.
Args:
file (FileMetadata): The document to generate passages from.
chunk_size (int, optional): Chunk size for splitting passages. Defaults to 1024.
Returns:
passages (Iterator[Tuple[str, Dict]]): Generate a tuple of string text and metadata dictionary for each passage.
"""
async def load_data(connector: DataConnector, source: Source, passage_manager: PassageManager, file_manager: FileManager, actor: "User"):
from letta.llm_api.llm_client import LLMClient
"""Load data from a connector (generates file and passages) into a specified source_id, associated with a user_id."""
embedding_config = source.embedding_config
# insert passages/file
embedding_to_document_name = {}
passage_count = 0
file_count = 0
# Use the new LLMClient for all embedding requests
client = LLMClient.create(
provider_type=embedding_config.embedding_endpoint_type,
actor=actor,
)
for file_metadata in connector.find_files(source):
file_count += 1
await file_manager.create_file(file_metadata, actor)
# generate passages for this file
texts = []
metadatas = []
for passage_text, passage_metadata in connector.generate_passages(file_metadata, chunk_size=embedding_config.embedding_chunk_size):
# for some reason, llama index parsers sometimes return empty strings
if len(passage_text) == 0:
typer.secho(
f"Warning: Llama index parser returned empty string, skipping insert of passage with metadata '{passage_metadata}' into VectorDB. You can usually ignore this warning.",
fg=typer.colors.YELLOW,
)
continue
texts.append(passage_text)
metadatas.append(passage_metadata)
if len(texts) >= EMBEDDING_BATCH_SIZE:
# Process the batch
embeddings = await client.request_embeddings(texts, embedding_config)
passages = []
for text, embedding, passage_metadata in zip(texts, embeddings, metadatas):
passage = Passage(
text=text,
file_id=file_metadata.id,
source_id=source.id,
metadata=passage_metadata,
organization_id=source.organization_id,
embedding_config=source.embedding_config,
embedding=embedding,
)
hashable_embedding = tuple(passage.embedding)
file_name = file_metadata.file_name
if hashable_embedding in embedding_to_document_name:
typer.secho(
f"Warning: Duplicate embedding found for passage in {file_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.",
fg=typer.colors.YELLOW,
)
continue
passages.append(passage)
embedding_to_document_name[hashable_embedding] = file_name
# insert passages into passage store
await passage_manager.create_many_passages_async(passages, actor)
passage_count += len(passages)
# Reset for next batch
texts = []
metadatas = []
# Process final remaining texts for this file
if len(texts) > 0:
embeddings = await client.request_embeddings(texts, embedding_config)
passages = []
for text, embedding, passage_metadata in zip(texts, embeddings, metadatas):
passage = Passage(
text=text,
file_id=file_metadata.id,
source_id=source.id,
metadata=passage_metadata,
organization_id=source.organization_id,
embedding_config=source.embedding_config,
embedding=embedding,
)
hashable_embedding = tuple(passage.embedding)
file_name = file_metadata.file_name
if hashable_embedding in embedding_to_document_name:
typer.secho(
f"Warning: Duplicate embedding found for passage in {file_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.",
fg=typer.colors.YELLOW,
)
continue
passages.append(passage)
embedding_to_document_name[hashable_embedding] = file_name
await passage_manager.create_many_passages_async(passages, actor)
passage_count += len(passages)
return passage_count, file_count
class DirectoryConnector(DataConnector):
def __init__(
self,
input_files: List[str] | None = None,
input_directory: str | None = None,
recursive: bool = False,
extensions: List[str] | None = None,
):
"""
Connector for reading text data from a directory of files.
Args:
input_files (List[str], optional): List of file paths to read. Defaults to None.
input_directory (str, optional): Directory to read files from. Defaults to None.
recursive (bool, optional): Whether to read files recursively from the input directory. Defaults to False.
extensions (List[str], optional): List of file extensions to read. Defaults to None.
"""
self.connector_type = "directory"
self.input_files = input_files
self.input_directory = input_directory
self.recursive = recursive
self.extensions = extensions
if self.recursive:
assert self.input_directory is not None, "Must provide input directory if recursive is True."
def find_files(self, source: Source) -> Iterator[FileMetadata]:
if self.input_directory is not None:
files = get_filenames_in_dir(
input_dir=self.input_directory,
recursive=self.recursive,
required_exts=[ext.strip() for ext in str(self.extensions).split(",")],
exclude=["*png", "*jpg", "*jpeg"],
)
else:
files = self.input_files
# Check that file paths are valid
assert_all_files_exist_locally(files)
for metadata in extract_metadata_from_files(files):
yield FileMetadata(
source_id=source.id,
file_name=metadata.get("file_name"),
file_path=metadata.get("file_path"),
file_type=metadata.get("file_type"),
file_size=metadata.get("file_size"),
file_creation_date=metadata.get("file_creation_date"),
file_last_modified_date=metadata.get("file_last_modified_date"),
)
def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]:
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import TokenTextSplitter
parser = TokenTextSplitter(chunk_size=chunk_size)
if file.file_type == "application/pdf":
from llama_index.readers.file import PDFReader
reader = PDFReader()
documents = reader.load_data(file=file.file_path)
else:
documents = SimpleDirectoryReader(input_files=[file.file_path]).load_data()
nodes = parser.get_nodes_from_documents(documents)
for node in nodes:
yield node.text, None