feat: Enable adding files (#1864)
Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -2,6 +2,9 @@
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/vim,linux,macos,pydev,python,eclipse,pycharm,windows,netbeans,pycharm+all,pycharm+iml,visualstudio,jupyternotebooks,visualstudiocode,xcode,xcodeinjection
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=vim,linux,macos,pydev,python,eclipse,pycharm,windows,netbeans,pycharm+all,pycharm+iml,visualstudio,jupyternotebooks,visualstudiocode,xcode,xcodeinjection
|
||||
|
||||
openapi_letta.json
|
||||
openapi_openai.json
|
||||
|
||||
### Eclipse ###
|
||||
.metadata
|
||||
bin/
|
||||
|
||||
@@ -70,7 +70,7 @@ schema_models = [
|
||||
"Message",
|
||||
"Passage",
|
||||
"AgentState",
|
||||
"Document",
|
||||
"File",
|
||||
"Source",
|
||||
"LLMConfig",
|
||||
"EmbeddingConfig",
|
||||
|
||||
@@ -270,7 +270,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from letta.data_sources.connectors import DataConnector \n",
|
||||
"from letta.schemas.document import Document\n",
|
||||
"from letta.schemas.file import FileMetadata\n",
|
||||
"from llama_index.core import Document as LlamaIndexDocument\n",
|
||||
"from llama_index.core import SummaryIndex\n",
|
||||
"from llama_index.readers.web import SimpleWebPageReader\n",
|
||||
|
||||
@@ -7,9 +7,9 @@ from letta.client.client import LocalClient, RESTClient, create_client
|
||||
# imports for easier access
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.document import Document
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.job import Job
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
@@ -28,7 +28,7 @@ from letta.agent_store.storage import StorageConnector, TableType
|
||||
from letta.base import Base
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.metadata import EmbeddingConfigColumn, ToolCallColumn
|
||||
from letta.metadata import EmbeddingConfigColumn, FileMetadataModel, ToolCallColumn
|
||||
|
||||
# from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
|
||||
from letta.schemas.message import Message
|
||||
@@ -141,7 +141,7 @@ class PassageModel(Base):
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
text = Column(String)
|
||||
doc_id = Column(String)
|
||||
file_id = Column(String)
|
||||
agent_id = Column(String)
|
||||
source_id = Column(String)
|
||||
|
||||
@@ -160,7 +160,7 @@ class PassageModel(Base):
|
||||
# Add a datetime column, with default value as the current time
|
||||
created_at = Column(DateTime(timezone=True))
|
||||
|
||||
Index("passage_idx_user", user_id, agent_id, doc_id),
|
||||
Index("passage_idx_user", user_id, agent_id, file_id),
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
||||
@@ -170,7 +170,7 @@ class PassageModel(Base):
|
||||
text=self.text,
|
||||
embedding=self.embedding,
|
||||
embedding_config=self.embedding_config,
|
||||
doc_id=self.doc_id,
|
||||
file_id=self.file_id,
|
||||
user_id=self.user_id,
|
||||
id=self.id,
|
||||
source_id=self.source_id,
|
||||
@@ -365,12 +365,17 @@ class PostgresStorageConnector(SQLStorageConnector):
|
||||
self.uri = self.config.archival_storage_uri
|
||||
self.db_model = PassageModel
|
||||
if self.config.archival_storage_uri is None:
|
||||
raise ValueError(f"Must specifiy archival_storage_uri in config {self.config.config_path}")
|
||||
raise ValueError(f"Must specify archival_storage_uri in config {self.config.config_path}")
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
self.uri = self.config.recall_storage_uri
|
||||
self.db_model = MessageModel
|
||||
if self.config.recall_storage_uri is None:
|
||||
raise ValueError(f"Must specifiy recall_storage_uri in config {self.config.config_path}")
|
||||
raise ValueError(f"Must specify recall_storage_uri in config {self.config.config_path}")
|
||||
elif table_type == TableType.FILES:
|
||||
self.uri = self.config.metadata_storage_uri
|
||||
self.db_model = FileMetadataModel
|
||||
if self.config.metadata_storage_uri is None:
|
||||
raise ValueError(f"Must specify metadata_storage_uri in config {self.config.config_path}")
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
|
||||
@@ -487,8 +492,14 @@ class SQLLiteStorageConnector(SQLStorageConnector):
|
||||
# TODO: eventually implement URI option
|
||||
self.path = self.config.recall_storage_path
|
||||
if self.path is None:
|
||||
raise ValueError(f"Must specifiy recall_storage_path in config {self.config.recall_storage_path}")
|
||||
raise ValueError(f"Must specify recall_storage_path in config.")
|
||||
self.db_model = MessageModel
|
||||
elif table_type == TableType.FILES:
|
||||
self.path = self.config.metadata_storage_path
|
||||
if self.path is None:
|
||||
raise ValueError(f"Must specify metadata_storage_path in config.")
|
||||
self.db_model = FileMetadataModel
|
||||
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ def get_db_model(table_name: str, table_type: TableType):
|
||||
id: uuid.UUID
|
||||
user_id: str
|
||||
text: str
|
||||
doc_id: str
|
||||
file_id: str
|
||||
agent_id: str
|
||||
data_source: str
|
||||
embedding: Vector(config.default_embedding_config.embedding_dim)
|
||||
@@ -37,7 +37,7 @@ def get_db_model(table_name: str, table_type: TableType):
|
||||
return Passage(
|
||||
text=self.text,
|
||||
embedding=self.embedding,
|
||||
doc_id=self.doc_id,
|
||||
file_id=self.file_id,
|
||||
user_id=self.user_id,
|
||||
id=self.id,
|
||||
data_source=self.data_source,
|
||||
|
||||
@@ -26,7 +26,7 @@ class MilvusStorageConnector(StorageConnector):
|
||||
raise ValueError("Please set `archival_storage_uri` in the config file when using Milvus.")
|
||||
|
||||
# need to be converted to strings
|
||||
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
|
||||
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "file_id"]
|
||||
|
||||
def _create_collection(self):
|
||||
schema = MilvusClient.create_schema(
|
||||
|
||||
@@ -38,7 +38,7 @@ class QdrantStorageConnector(StorageConnector):
|
||||
distance=models.Distance.COSINE,
|
||||
),
|
||||
)
|
||||
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
|
||||
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "file_id"]
|
||||
|
||||
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 10) -> Iterator[List[RecordType]]:
|
||||
from qdrant_client import grpc
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.schemas.document import Document
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.utils import printd
|
||||
@@ -22,7 +22,7 @@ class TableType:
|
||||
ARCHIVAL_MEMORY = "archival_memory" # recall memory table: letta_agent_{agent_id}
|
||||
RECALL_MEMORY = "recall_memory" # archival memory table: letta_agent_recall_{agent_id}
|
||||
PASSAGES = "passages" # TODO
|
||||
DOCUMENTS = "documents" # TODO
|
||||
FILES = "files"
|
||||
|
||||
|
||||
# table names used by Letta
|
||||
@@ -33,17 +33,17 @@ ARCHIVAL_TABLE_NAME = "letta_archival_memory_agent" # agent memory
|
||||
|
||||
# external data source tables
|
||||
PASSAGE_TABLE_NAME = "letta_passages" # chunked/embedded passages (from source)
|
||||
DOCUMENT_TABLE_NAME = "letta_documents" # original documents (from source)
|
||||
FILE_TABLE_NAME = "letta_files" # original files (from source)
|
||||
|
||||
|
||||
class StorageConnector:
|
||||
"""Defines a DB connection that is user-specific to access data: Documents, Passages, Archival/Recall Memory"""
|
||||
"""Defines a DB connection that is user-specific to access data: files, Passages, Archival/Recall Memory"""
|
||||
|
||||
type: Type[BaseModel]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.DOCUMENTS],
|
||||
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES],
|
||||
config: LettaConfig,
|
||||
user_id,
|
||||
agent_id=None,
|
||||
@@ -59,9 +59,9 @@ class StorageConnector:
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
self.type = Message
|
||||
self.table_name = RECALL_TABLE_NAME
|
||||
elif table_type == TableType.DOCUMENTS:
|
||||
self.type = Document
|
||||
self.table_name == DOCUMENT_TABLE_NAME
|
||||
elif table_type == TableType.FILES:
|
||||
self.type = FileMetadata
|
||||
self.table_name = FILE_TABLE_NAME
|
||||
elif table_type == TableType.PASSAGES:
|
||||
self.type = Passage
|
||||
self.table_name = PASSAGE_TABLE_NAME
|
||||
@@ -74,7 +74,7 @@ class StorageConnector:
|
||||
# agent-specific table
|
||||
assert agent_id is not None, "Agent ID must be provided for agent-specific tables"
|
||||
self.filters = {"user_id": self.user_id, "agent_id": self.agent_id}
|
||||
elif self.table_type == TableType.PASSAGES or self.table_type == TableType.DOCUMENTS:
|
||||
elif self.table_type == TableType.PASSAGES or self.table_type == TableType.FILES:
|
||||
# setup base filters for user-specific tables
|
||||
assert agent_id is None, "Agent ID must not be provided for user-specific tables"
|
||||
self.filters = {"user_id": self.user_id}
|
||||
@@ -83,7 +83,7 @@ class StorageConnector:
|
||||
|
||||
@staticmethod
|
||||
def get_storage_connector(
|
||||
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.DOCUMENTS],
|
||||
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES],
|
||||
config: LettaConfig,
|
||||
user_id,
|
||||
agent_id=None,
|
||||
@@ -92,6 +92,8 @@ class StorageConnector:
|
||||
storage_type = config.archival_storage_type
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
storage_type = config.recall_storage_type
|
||||
elif table_type == TableType.FILES:
|
||||
storage_type = config.metadata_storage_type
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
|
||||
|
||||
@@ -106,7 +106,7 @@ def load_vector_database(
|
||||
# document_store=None,
|
||||
# passage_store=passage_storage,
|
||||
# )
|
||||
# print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
|
||||
# print(f"Loaded {num_passages} passages and {num_documents} files from {name}")
|
||||
# except Exception as e:
|
||||
# typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
|
||||
# ms.delete_source(source_id=source.id)
|
||||
|
||||
@@ -25,6 +25,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
||||
|
||||
# new schemas
|
||||
from letta.schemas.enums import JobStatus, MessageRole
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.job import Job
|
||||
from letta.schemas.letta_request import LettaRequest
|
||||
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
|
||||
@@ -232,6 +233,9 @@ class AbstractClient(object):
|
||||
def list_attached_sources(self, agent_id: str) -> List[Source]:
|
||||
raise NotImplementedError
|
||||
|
||||
def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
|
||||
raise NotImplementedError
|
||||
|
||||
def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1016,6 +1020,12 @@ class RESTClient(AbstractClient):
|
||||
raise ValueError(f"Failed to get job: {response.text}")
|
||||
return Job(**response.json())
|
||||
|
||||
def delete_job(self, job_id: str) -> Job:
|
||||
response = requests.delete(f"{self.base_url}/{self.api_prefix}/jobs/{job_id}", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to delete job: {response.text}")
|
||||
return Job(**response.json())
|
||||
|
||||
def list_jobs(self):
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/jobs", headers=self.headers)
|
||||
return [Job(**job) for job in response.json()]
|
||||
@@ -1088,6 +1098,30 @@ class RESTClient(AbstractClient):
|
||||
raise ValueError(f"Failed to list attached sources: {response.text}")
|
||||
return [Source(**source) for source in response.json()]
|
||||
|
||||
def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
|
||||
"""
|
||||
List files from source with pagination support.
|
||||
|
||||
Args:
|
||||
source_id (str): ID of the source
|
||||
limit (int): Number of files to return
|
||||
cursor (Optional[str]): Pagination cursor for fetching the next page
|
||||
|
||||
Returns:
|
||||
List[FileMetadata]: List of files
|
||||
"""
|
||||
# Prepare query parameters for pagination
|
||||
params = {"limit": limit, "cursor": cursor}
|
||||
|
||||
# Make the request to the FastAPI endpoint
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/files", headers=self.headers, params=params)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to list files with source id {source_id}: [{response.status_code}] {response.text}")
|
||||
|
||||
# Parse the JSON response
|
||||
return [FileMetadata(**metadata) for metadata in response.json()]
|
||||
|
||||
def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
|
||||
"""
|
||||
Update a source
|
||||
@@ -2162,6 +2196,9 @@ class LocalClient(AbstractClient):
|
||||
def get_job(self, job_id: str):
|
||||
return self.server.get_job(job_id=job_id)
|
||||
|
||||
def delete_job(self, job_id: str):
|
||||
return self.server.delete_job(job_id)
|
||||
|
||||
def list_jobs(self):
|
||||
return self.server.list_jobs(user_id=self.user_id)
|
||||
|
||||
@@ -2261,6 +2298,20 @@ class LocalClient(AbstractClient):
|
||||
"""
|
||||
return self.server.list_attached_sources(agent_id=agent_id)
|
||||
|
||||
def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
|
||||
"""
|
||||
List files from source.
|
||||
|
||||
Args:
|
||||
source_id (str): ID of the source
|
||||
limit (int): The # of items to return
|
||||
cursor (str): The cursor for fetching the next page
|
||||
|
||||
Returns:
|
||||
files (List[FileMetadata]): List of files
|
||||
"""
|
||||
return self.server.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor)
|
||||
|
||||
def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
|
||||
"""
|
||||
Update a source
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
from typing import Dict, Iterator, List, Tuple
|
||||
|
||||
import typer
|
||||
from llama_index.core import Document as LlamaIndexDocument
|
||||
|
||||
from letta.agent_store.storage import StorageConnector
|
||||
from letta.data_sources.connectors_helper import (
|
||||
assert_all_files_exist_locally,
|
||||
extract_metadata_from_files,
|
||||
get_filenames_in_dir,
|
||||
)
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.schemas.document import Document
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.source import Source
|
||||
from letta.utils import create_uuid_from_string
|
||||
@@ -13,23 +17,23 @@ from letta.utils import create_uuid_from_string
|
||||
|
||||
class DataConnector:
|
||||
"""
|
||||
Base class for data connectors that can be extended to generate documents and passages from a custom data source.
|
||||
Base class for data connectors that can be extended to generate files and passages from a custom data source.
|
||||
"""
|
||||
|
||||
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
|
||||
def find_files(self, source: Source) -> Iterator[FileMetadata]:
|
||||
"""
|
||||
Generate document text and metadata from a data source.
|
||||
Generate file metadata from a data source.
|
||||
|
||||
Returns:
|
||||
documents (Iterator[Tuple[str, Dict]]): Generate a tuple of string text and metadata dictionary for each document.
|
||||
files (Iterator[FileMetadata]): Generate file metadata for each file found.
|
||||
"""
|
||||
|
||||
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
|
||||
def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
|
||||
"""
|
||||
Generate passage text and metadata from a list of documents.
|
||||
Generate passage text and metadata from a list of files.
|
||||
|
||||
Args:
|
||||
documents (List[Document]): List of documents to generate passages from.
|
||||
file (FileMetadata): The document to generate passages from.
|
||||
chunk_size (int, optional): Chunk size for splitting passages. Defaults to 1024.
|
||||
|
||||
Returns:
|
||||
@@ -41,33 +45,25 @@ def load_data(
|
||||
connector: DataConnector,
|
||||
source: Source,
|
||||
passage_store: StorageConnector,
|
||||
document_store: Optional[StorageConnector] = None,
|
||||
file_metadata_store: StorageConnector,
|
||||
):
|
||||
"""Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id."""
|
||||
"""Load data from a connector (generates file and passages) into a specified source_id, associatedw with a user_id."""
|
||||
embedding_config = source.embedding_config
|
||||
|
||||
# embedding model
|
||||
embed_model = embedding_model(embedding_config)
|
||||
|
||||
# insert passages/documents
|
||||
# insert passages/file
|
||||
passages = []
|
||||
embedding_to_document_name = {}
|
||||
passage_count = 0
|
||||
document_count = 0
|
||||
for document_text, document_metadata in connector.generate_documents():
|
||||
# insert document into storage
|
||||
document = Document(
|
||||
text=document_text,
|
||||
metadata_=document_metadata,
|
||||
source_id=source.id,
|
||||
user_id=source.user_id,
|
||||
)
|
||||
document_count += 1
|
||||
if document_store:
|
||||
document_store.insert(document)
|
||||
file_count = 0
|
||||
for file_metadata in connector.find_files(source):
|
||||
file_count += 1
|
||||
file_metadata_store.insert(file_metadata)
|
||||
|
||||
# generate passages
|
||||
for passage_text, passage_metadata in connector.generate_passages([document], chunk_size=embedding_config.embedding_chunk_size):
|
||||
for passage_text, passage_metadata in connector.generate_passages(file_metadata, chunk_size=embedding_config.embedding_chunk_size):
|
||||
# for some reason, llama index parsers sometimes return empty strings
|
||||
if len(passage_text) == 0:
|
||||
typer.secho(
|
||||
@@ -89,7 +85,7 @@ def load_data(
|
||||
passage = Passage(
|
||||
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
|
||||
text=passage_text,
|
||||
doc_id=document.id,
|
||||
file_id=file_metadata.id,
|
||||
source_id=source.id,
|
||||
metadata_=passage_metadata,
|
||||
user_id=source.user_id,
|
||||
@@ -98,16 +94,16 @@ def load_data(
|
||||
)
|
||||
|
||||
hashable_embedding = tuple(passage.embedding)
|
||||
document_name = document.metadata_.get("file_path", document.id)
|
||||
file_name = file_metadata.file_name
|
||||
if hashable_embedding in embedding_to_document_name:
|
||||
typer.secho(
|
||||
f"Warning: Duplicate embedding found for passage in {document_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.",
|
||||
f"Warning: Duplicate embedding found for passage in {file_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.",
|
||||
fg=typer.colors.YELLOW,
|
||||
)
|
||||
continue
|
||||
|
||||
passages.append(passage)
|
||||
embedding_to_document_name[hashable_embedding] = document_name
|
||||
embedding_to_document_name[hashable_embedding] = file_name
|
||||
if len(passages) >= 100:
|
||||
# insert passages into passage store
|
||||
passage_store.insert_many(passages)
|
||||
@@ -120,7 +116,7 @@ def load_data(
|
||||
passage_store.insert_many(passages)
|
||||
passage_count += len(passages)
|
||||
|
||||
return passage_count, document_count
|
||||
return passage_count, file_count
|
||||
|
||||
|
||||
class DirectoryConnector(DataConnector):
|
||||
@@ -143,105 +139,109 @@ class DirectoryConnector(DataConnector):
|
||||
if self.recursive == True:
|
||||
assert self.input_directory is not None, "Must provide input directory if recursive is True."
|
||||
|
||||
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
def find_files(self, source: Source) -> Iterator[FileMetadata]:
|
||||
if self.input_directory is not None:
|
||||
reader = SimpleDirectoryReader(
|
||||
files = get_filenames_in_dir(
|
||||
input_dir=self.input_directory,
|
||||
recursive=self.recursive,
|
||||
required_exts=[ext.strip() for ext in str(self.extensions).split(",")],
|
||||
exclude=["*png", "*jpg", "*jpeg"],
|
||||
)
|
||||
else:
|
||||
assert self.input_files is not None, "Must provide input files if input_dir is None"
|
||||
reader = SimpleDirectoryReader(input_files=[str(f) for f in self.input_files])
|
||||
files = self.input_files
|
||||
|
||||
llama_index_docs = reader.load_data(show_progress=True)
|
||||
for llama_index_doc in llama_index_docs:
|
||||
# TODO: add additional metadata?
|
||||
# doc = Document(text=llama_index_doc.text, metadata=llama_index_doc.metadata)
|
||||
# docs.append(doc)
|
||||
yield llama_index_doc.text, llama_index_doc.metadata
|
||||
# Check that file paths are valid
|
||||
assert_all_files_exist_locally(files)
|
||||
|
||||
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
|
||||
# use llama index to run embeddings code
|
||||
# from llama_index.core.node_parser import SentenceSplitter
|
||||
for metadata in extract_metadata_from_files(files):
|
||||
yield FileMetadata(
|
||||
user_id=source.user_id,
|
||||
source_id=source.id,
|
||||
file_name=metadata.get("file_name"),
|
||||
file_path=metadata.get("file_path"),
|
||||
file_type=metadata.get("file_type"),
|
||||
file_size=metadata.get("file_size"),
|
||||
file_creation_date=metadata.get("file_creation_date"),
|
||||
file_last_modified_date=metadata.get("file_last_modified_date"),
|
||||
)
|
||||
|
||||
def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]:
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import TokenTextSplitter
|
||||
|
||||
parser = TokenTextSplitter(chunk_size=chunk_size)
|
||||
for document in documents:
|
||||
llama_index_docs = [LlamaIndexDocument(text=document.text, metadata=document.metadata_)]
|
||||
nodes = parser.get_nodes_from_documents(llama_index_docs)
|
||||
documents = SimpleDirectoryReader(input_files=[file.file_path]).load_data()
|
||||
nodes = parser.get_nodes_from_documents(documents)
|
||||
for node in nodes:
|
||||
# passage = Passage(
|
||||
# text=node.text,
|
||||
# doc_id=document.id,
|
||||
# )
|
||||
yield node.text, None
|
||||
|
||||
|
||||
class WebConnector(DirectoryConnector):
|
||||
def __init__(self, urls: List[str] = None, html_to_text: bool = True):
|
||||
self.urls = urls
|
||||
self.html_to_text = html_to_text
|
||||
|
||||
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
|
||||
from llama_index.readers.web import SimpleWebPageReader
|
||||
|
||||
documents = SimpleWebPageReader(html_to_text=self.html_to_text).load_data(self.urls)
|
||||
for document in documents:
|
||||
yield document.text, {"url": document.id_}
|
||||
|
||||
|
||||
class VectorDBConnector(DataConnector):
|
||||
# NOTE: this class has not been properly tested, so is unlikely to work
|
||||
# TODO: allow loading multiple tables (1:1 mapping between Document and Table)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
uri: str,
|
||||
table_name: str,
|
||||
text_column: str,
|
||||
embedding_column: str,
|
||||
embedding_dim: int,
|
||||
):
|
||||
self.name = name
|
||||
self.uri = uri
|
||||
self.table_name = table_name
|
||||
self.text_column = text_column
|
||||
self.embedding_column = embedding_column
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
# connect to db table
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
self.engine = create_engine(uri)
|
||||
|
||||
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
|
||||
yield self.table_name, None
|
||||
|
||||
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import Inspector, MetaData, Table, select
|
||||
|
||||
metadata = MetaData()
|
||||
# Create an inspector to inspect the database
|
||||
inspector = Inspector.from_engine(self.engine)
|
||||
table_names = inspector.get_table_names()
|
||||
assert self.table_name in table_names, f"Table {self.table_name} not found in database: tables that exist {table_names}."
|
||||
|
||||
table = Table(self.table_name, metadata, autoload_with=self.engine)
|
||||
|
||||
# Prepare a select statement
|
||||
select_statement = select(table.c[self.text_column], table.c[self.embedding_column].cast(Vector(self.embedding_dim)))
|
||||
|
||||
# Execute the query and fetch the results
|
||||
# TODO: paginate results
|
||||
with self.engine.connect() as connection:
|
||||
result = connection.execute(select_statement).fetchall()
|
||||
|
||||
for text, embedding in result:
|
||||
# assume that embeddings are the same model as in config
|
||||
# TODO: don't re-compute embedding
|
||||
yield text, {"embedding": embedding}
|
||||
"""
|
||||
The below isn't used anywhere, it isn't tested, and pretty much should be deleted.
|
||||
- Matt
|
||||
"""
|
||||
# class WebConnector(DirectoryConnector):
|
||||
# def __init__(self, urls: List[str] = None, html_to_text: bool = True):
|
||||
# self.urls = urls
|
||||
# self.html_to_text = html_to_text
|
||||
#
|
||||
# def generate_files(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
|
||||
# from llama_index.readers.web import SimpleWebPageReader
|
||||
#
|
||||
# files = SimpleWebPageReader(html_to_text=self.html_to_text).load_data(self.urls)
|
||||
# for document in files:
|
||||
# yield document.text, {"url": document.id_}
|
||||
#
|
||||
#
|
||||
# class VectorDBConnector(DataConnector):
|
||||
# # NOTE: this class has not been properly tested, so is unlikely to work
|
||||
# # TODO: allow loading multiple tables (1:1 mapping between FileMetadata and Table)
|
||||
#
|
||||
# def __init__(
|
||||
# self,
|
||||
# name: str,
|
||||
# uri: str,
|
||||
# table_name: str,
|
||||
# text_column: str,
|
||||
# embedding_column: str,
|
||||
# embedding_dim: int,
|
||||
# ):
|
||||
# self.name = name
|
||||
# self.uri = uri
|
||||
# self.table_name = table_name
|
||||
# self.text_column = text_column
|
||||
# self.embedding_column = embedding_column
|
||||
# self.embedding_dim = embedding_dim
|
||||
#
|
||||
# # connect to db table
|
||||
# from sqlalchemy import create_engine
|
||||
#
|
||||
# self.engine = create_engine(uri)
|
||||
#
|
||||
# def generate_files(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
|
||||
# yield self.table_name, None
|
||||
#
|
||||
# def generate_passages(self, file_text: str, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
|
||||
# from pgvector.sqlalchemy import Vector
|
||||
# from sqlalchemy import Inspector, MetaData, Table, select
|
||||
#
|
||||
# metadata = MetaData()
|
||||
# # Create an inspector to inspect the database
|
||||
# inspector = Inspector.from_engine(self.engine)
|
||||
# table_names = inspector.get_table_names()
|
||||
# assert self.table_name in table_names, f"Table {self.table_name} not found in database: tables that exist {table_names}."
|
||||
#
|
||||
# table = Table(self.table_name, metadata, autoload_with=self.engine)
|
||||
#
|
||||
# # Prepare a select statement
|
||||
# select_statement = select(table.c[self.text_column], table.c[self.embedding_column].cast(Vector(self.embedding_dim)))
|
||||
#
|
||||
# # Execute the query and fetch the results
|
||||
# # TODO: paginate results
|
||||
# with self.engine.connect() as connection:
|
||||
# result = connection.execute(select_statement).fetchall()
|
||||
#
|
||||
# for text, embedding in result:
|
||||
# # assume that embeddings are the same model as in config
|
||||
# # TODO: don't re-compute embedding
|
||||
# yield text, {"embedding": embedding}
|
||||
|
||||
97
letta/data_sources/connectors_helper.py
Normal file
97
letta/data_sources/connectors_helper.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import mimetypes
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
def extract_file_metadata(file_path) -> dict:
|
||||
"""Extracts metadata from a single file."""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(file_path)
|
||||
|
||||
file_metadata = {
|
||||
"file_name": os.path.basename(file_path),
|
||||
"file_path": file_path,
|
||||
"file_type": mimetypes.guess_type(file_path)[0] or "unknown",
|
||||
"file_size": os.path.getsize(file_path),
|
||||
"file_creation_date": datetime.fromtimestamp(os.path.getctime(file_path)).strftime("%Y-%m-%d"),
|
||||
"file_last_modified_date": datetime.fromtimestamp(os.path.getmtime(file_path)).strftime("%Y-%m-%d"),
|
||||
}
|
||||
return file_metadata
|
||||
|
||||
|
||||
def extract_metadata_from_files(file_list):
|
||||
"""Extracts metadata for a list of files."""
|
||||
metadata = []
|
||||
for file_path in file_list:
|
||||
file_metadata = extract_file_metadata(file_path)
|
||||
if file_metadata:
|
||||
metadata.append(file_metadata)
|
||||
return metadata
|
||||
|
||||
|
||||
def get_filenames_in_dir(
|
||||
input_dir: str, recursive: bool = True, required_exts: Optional[List[str]] = None, exclude: Optional[List[str]] = None
|
||||
):
|
||||
"""
|
||||
Recursively reads files from the directory, applying required_exts and exclude filters.
|
||||
Ensures that required_exts and exclude do not overlap.
|
||||
|
||||
Args:
|
||||
input_dir (str): The directory to scan for files.
|
||||
recursive (bool): Whether to scan directories recursively.
|
||||
required_exts (list): List of file extensions to include (e.g., ['pdf', 'txt']).
|
||||
If None or empty, matches any file extension.
|
||||
exclude (list): List of file patterns to exclude (e.g., ['*png', '*jpg']).
|
||||
|
||||
Returns:
|
||||
list: A list of matching file paths.
|
||||
"""
|
||||
required_exts = required_exts or []
|
||||
exclude = exclude or []
|
||||
|
||||
# Ensure required_exts and exclude do not overlap
|
||||
ext_set = set(required_exts)
|
||||
exclude_set = set(exclude)
|
||||
overlap = ext_set & exclude_set
|
||||
if overlap:
|
||||
raise ValueError(f"Extensions in required_exts and exclude overlap: {overlap}")
|
||||
|
||||
def is_excluded(file_name):
|
||||
"""Check if a file matches any pattern in the exclude list."""
|
||||
for pattern in exclude:
|
||||
if Path(file_name).match(pattern):
|
||||
return True
|
||||
return False
|
||||
|
||||
files = []
|
||||
search_pattern = "**/*" if recursive else "*"
|
||||
|
||||
for file_path in Path(input_dir).glob(search_pattern):
|
||||
if file_path.is_file() and not is_excluded(file_path.name):
|
||||
ext = file_path.suffix.lstrip(".")
|
||||
# If required_exts is empty, match any file
|
||||
if not required_exts or ext in required_exts:
|
||||
files.append(file_path)
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def assert_all_files_exist_locally(file_paths: List[str]) -> bool:
|
||||
"""
|
||||
Checks if all file paths in the provided list exist locally.
|
||||
Raises a FileNotFoundError with a list of missing files if any do not exist.
|
||||
|
||||
Args:
|
||||
file_paths (List[str]): List of file paths to check.
|
||||
|
||||
Returns:
|
||||
bool: True if all files exist, raises FileNotFoundError if any file is missing.
|
||||
"""
|
||||
missing_files = [file_path for file_path in file_paths if not Path(file_path).exists()]
|
||||
|
||||
if missing_files:
|
||||
raise FileNotFoundError(missing_files)
|
||||
|
||||
return True
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy import (
|
||||
Column,
|
||||
DateTime,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
TypeDecorator,
|
||||
desc,
|
||||
@@ -24,6 +25,7 @@ from letta.schemas.api_key import APIKey
|
||||
from letta.schemas.block import Block, Human, Persona
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.job import Job
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
@@ -38,6 +40,41 @@ from letta.settings import settings
|
||||
from letta.utils import enforce_types, get_utc_time, printd
|
||||
|
||||
|
||||
class FileMetadataModel(Base):
|
||||
__tablename__ = "files"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id = Column(String, primary_key=True, nullable=False)
|
||||
user_id = Column(String, nullable=False)
|
||||
# TODO: Investigate why this breaks during table creation due to FK
|
||||
# source_id = Column(String, ForeignKey("sources.id"), nullable=False)
|
||||
source_id = Column(String, nullable=False)
|
||||
file_name = Column(String, nullable=True)
|
||||
file_path = Column(String, nullable=True)
|
||||
file_type = Column(String, nullable=True)
|
||||
file_size = Column(Integer, nullable=True)
|
||||
file_creation_date = Column(String, nullable=True)
|
||||
file_last_modified_date = Column(String, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
def __repr__(self):
|
||||
return f"<FileMetadata(id='{self.id}', source_id='{self.source_id}', file_name='{self.file_name}')>"
|
||||
|
||||
def to_record(self):
|
||||
return FileMetadata(
|
||||
id=self.id,
|
||||
user_id=self.user_id,
|
||||
source_id=self.source_id,
|
||||
file_name=self.file_name,
|
||||
file_path=self.file_path,
|
||||
file_type=self.file_type,
|
||||
file_size=self.file_size,
|
||||
file_creation_date=self.file_creation_date,
|
||||
file_last_modified_date=self.file_last_modified_date,
|
||||
created_at=self.created_at,
|
||||
)
|
||||
|
||||
|
||||
class LLMConfigColumn(TypeDecorator):
|
||||
"""Custom type for storing LLMConfig as JSON"""
|
||||
|
||||
@@ -865,6 +902,27 @@ class MetadataStore:
|
||||
session.add(JobModel(**vars(job)))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def list_files_from_source(self, source_id: str, limit: int, cursor: Optional[str]):
|
||||
with self.session_maker() as session:
|
||||
# Start with the basic query filtered by source_id
|
||||
query = session.query(FileMetadataModel).filter(FileMetadataModel.source_id == source_id)
|
||||
|
||||
if cursor:
|
||||
# Assuming cursor is the ID of the last file in the previous page
|
||||
query = query.filter(FileMetadataModel.id > cursor)
|
||||
|
||||
# Order by ID or other ordering criteria to ensure correct pagination
|
||||
query = query.order_by(FileMetadataModel.id)
|
||||
|
||||
# Limit the number of results returned
|
||||
results = query.limit(limit).all()
|
||||
|
||||
# Convert the results to the required FileMetadata objects
|
||||
files = [r.to_record() for r in results]
|
||||
|
||||
return files
|
||||
|
||||
def delete_job(self, job_id: str):
|
||||
with self.session_maker() as session:
|
||||
session.query(JobModel).filter(JobModel.id == job_id).delete()
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
class DocumentBase(LettaBase):
|
||||
"""Base class for document schemas"""
|
||||
|
||||
__id_prefix__ = "doc"
|
||||
|
||||
|
||||
class Document(DocumentBase):
|
||||
"""Representation of a single document (broken up into `Passage` objects)"""
|
||||
|
||||
id: str = DocumentBase.generate_id_field()
|
||||
text: str = Field(..., description="The text of the document.")
|
||||
source_id: str = Field(..., description="The unique identifier of the source associated with the document.")
|
||||
user_id: str = Field(description="The unique identifier of the user associated with the document.")
|
||||
metadata_: Optional[Dict] = Field({}, description="The metadata of the document.")
|
||||
31
letta/schemas/file.py
Normal file
31
letta/schemas/file.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.utils import get_utc_time
|
||||
|
||||
|
||||
class FileMetadataBase(LettaBase):
|
||||
"""Base class for FileMetadata schemas"""
|
||||
|
||||
__id_prefix__ = "file"
|
||||
|
||||
|
||||
class FileMetadata(FileMetadataBase):
|
||||
"""Representation of a single FileMetadata"""
|
||||
|
||||
id: str = FileMetadataBase.generate_id_field()
|
||||
user_id: str = Field(description="The unique identifier of the user associated with the document.")
|
||||
source_id: str = Field(..., description="The unique identifier of the source associated with the document.")
|
||||
file_name: Optional[str] = Field(None, description="The name of the file.")
|
||||
file_path: Optional[str] = Field(None, description="The path to the file.")
|
||||
file_type: Optional[str] = Field(None, description="The type of the file (MIME type).")
|
||||
file_size: Optional[int] = Field(None, description="The size of the file in bytes.")
|
||||
file_creation_date: Optional[str] = Field(None, description="The creation date of the file.")
|
||||
file_last_modified_date: Optional[str] = Field(None, description="The last modified date of the file.")
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of this file metadata object.")
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
@@ -15,7 +15,7 @@ class JobBase(LettaBase):
|
||||
|
||||
class Job(JobBase):
|
||||
"""
|
||||
Representation of offline jobs, used for tracking status of data loading tasks (involving parsing and embedding documents).
|
||||
Representation of offline jobs, used for tracking status of data loading tasks (involving parsing and embedding files).
|
||||
|
||||
Parameters:
|
||||
id (str): The unique identifier of the job.
|
||||
|
||||
@@ -19,8 +19,8 @@ class PassageBase(LettaBase):
|
||||
# origin data source
|
||||
source_id: Optional[str] = Field(None, description="The data source of the passage.")
|
||||
|
||||
# document association
|
||||
doc_id: Optional[str] = Field(None, description="The unique identifier of the document associated with the passage.")
|
||||
# file association
|
||||
file_id: Optional[str] = Field(None, description="The unique identifier of the file associated with the passage.")
|
||||
metadata_: Optional[Dict] = Field({}, description="The metadata of the passage.")
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ class Passage(PassageBase):
|
||||
user_id (str): The unique identifier of the user associated with the passage.
|
||||
agent_id (str): The unique identifier of the agent associated with the passage.
|
||||
source_id (str): The data source of the passage.
|
||||
doc_id (str): The unique identifier of the document associated with the passage.
|
||||
file_id (str): The unique identifier of the file associated with the passage.
|
||||
"""
|
||||
|
||||
id: str = PassageBase.generate_id_field()
|
||||
|
||||
@@ -28,7 +28,7 @@ class SourceCreate(BaseSource):
|
||||
|
||||
class Source(BaseSource):
|
||||
"""
|
||||
Representation of a source, which is a collection of documents and passages.
|
||||
Representation of a source, which is a collection of files and passages.
|
||||
|
||||
Parameters:
|
||||
id (str): The ID of the source
|
||||
@@ -59,4 +59,4 @@ class UploadFileToSourceRequest(BaseModel):
|
||||
class UploadFileToSourceResponse(BaseModel):
|
||||
source: Source = Field(..., description="The source the file was uploaded to.")
|
||||
added_passages: int = Field(..., description="The number of passages added to the source.")
|
||||
added_documents: int = Field(..., description="The number of documents added to the source.")
|
||||
added_documents: int = Field(..., description="The number of files added to the source.")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Query
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query
|
||||
|
||||
from letta.schemas.job import Job
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
@@ -54,3 +54,19 @@ def get_job(
|
||||
"""
|
||||
|
||||
return server.get_job(job_id=job_id)
|
||||
|
||||
|
||||
@router.delete("/{job_id}", response_model=Job, operation_id="delete_job")
|
||||
def delete_job(
|
||||
job_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Delete a job by its job_id.
|
||||
"""
|
||||
job = server.get_job(job_id=job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
server.delete_job(job_id=job_id)
|
||||
return job
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Header, Query, UploadFile
|
||||
|
||||
from letta.schemas.document import Document
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.job import Job
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
||||
@@ -186,19 +186,17 @@ def list_passages(
|
||||
return passages
|
||||
|
||||
|
||||
@router.get("/{source_id}/documents", response_model=List[Document], operation_id="list_source_documents")
|
||||
def list_documents(
|
||||
@router.get("/{source_id}/files", response_model=List[FileMetadata], operation_id="list_files_from_source")
|
||||
def list_files_from_source(
|
||||
source_id: str,
|
||||
limit: int = Query(1000, description="Number of files to return"),
|
||||
cursor: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all documents associated with a data source.
|
||||
List paginated files associated with a data source.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
documents = server.list_data_source_documents(user_id=actor.id, source_id=source_id)
|
||||
return documents
|
||||
return server.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor)
|
||||
|
||||
|
||||
def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes):
|
||||
|
||||
@@ -63,11 +63,11 @@ from letta.schemas.block import (
|
||||
CreatePersona,
|
||||
UpdateBlock,
|
||||
)
|
||||
from letta.schemas.document import Document
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
|
||||
# openai schemas
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.job import Job
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
@@ -1596,7 +1596,7 @@ class SyncServer(Server):
|
||||
# job.status = JobStatus.failed
|
||||
# job.metadata_["error"] = error
|
||||
# self.ms.update_job(job)
|
||||
# # TODO: delete any associated passages/documents?
|
||||
# # TODO: delete any associated passages/files?
|
||||
|
||||
# # return failed job
|
||||
# return job
|
||||
@@ -1625,11 +1625,10 @@ class SyncServer(Server):
|
||||
|
||||
# get the data connectors
|
||||
passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
||||
# TODO: add document store support
|
||||
document_store = None # StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id)
|
||||
file_store = StorageConnector.get_storage_connector(TableType.FILES, self.config, user_id=user_id)
|
||||
|
||||
# load data into the document store
|
||||
passage_count, document_count = load_data(connector, source, passage_store, document_store)
|
||||
passage_count, document_count = load_data(connector, source, passage_store, file_store)
|
||||
return passage_count, document_count
|
||||
|
||||
def attach_source_to_agent(
|
||||
@@ -1686,14 +1685,14 @@ class SyncServer(Server):
|
||||
# list all attached sources to an agent
|
||||
return self.ms.list_attached_sources(agent_id)
|
||||
|
||||
def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
|
||||
# list all attached sources to an agent
|
||||
return self.ms.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor)
|
||||
|
||||
def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]:
|
||||
warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning)
|
||||
return []
|
||||
|
||||
def list_data_source_documents(self, user_id: str, source_id: str) -> List[Document]:
|
||||
warnings.warn("list_data_source_documents is not yet implemented, returning empty list.", category=UserWarning)
|
||||
return []
|
||||
|
||||
def list_all_sources(self, user_id: str) -> List[Source]:
|
||||
"""List all sources (w/ extra metadata) belonging to a user"""
|
||||
|
||||
@@ -1707,9 +1706,9 @@ class SyncServer(Server):
|
||||
passage_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
||||
num_passages = passage_conn.size({"source_id": source.id})
|
||||
|
||||
# TODO: add when documents table implemented
|
||||
## count number of documents
|
||||
# document_conn = StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id)
|
||||
# TODO: add when files table implemented
|
||||
## count number of files
|
||||
# document_conn = StorageConnector.get_storage_connector(TableType.FILES, self.config, user_id=user_id)
|
||||
# num_documents = document_conn.size({"data_source": source.name})
|
||||
num_documents = 0
|
||||
|
||||
|
||||
1
tests/data/test.txt
Normal file
1
tests/data/test.txt
Normal file
@@ -0,0 +1 @@
|
||||
test
|
||||
34
tests/helpers/client_helper.py
Normal file
34
tests/helpers/client_helper.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import time
|
||||
from typing import Union
|
||||
|
||||
from letta import LocalClient, RESTClient
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.job import Job
|
||||
from letta.schemas.source import Source
|
||||
|
||||
|
||||
def upload_file_using_client(client: Union[LocalClient, RESTClient], source: Source, filename: str) -> Job:
|
||||
# load a file into a source (non-blocking job)
|
||||
upload_job = client.load_file_into_source(filename=filename, source_id=source.id, blocking=False)
|
||||
print("Upload job", upload_job, upload_job.status, upload_job.metadata_)
|
||||
|
||||
# view active jobs
|
||||
active_jobs = client.list_active_jobs()
|
||||
jobs = client.list_jobs()
|
||||
assert upload_job.id in [j.id for j in jobs]
|
||||
assert len(active_jobs) == 1
|
||||
assert active_jobs[0].metadata_["source_id"] == source.id
|
||||
|
||||
# wait for job to finish (with timeout)
|
||||
timeout = 120
|
||||
start_time = time.time()
|
||||
while True:
|
||||
status = client.get_job(upload_job.id).status
|
||||
print(f"\r{status}", end="", flush=True)
|
||||
if status == JobStatus.completed:
|
||||
break
|
||||
time.sleep(1)
|
||||
if time.time() - start_time > timeout:
|
||||
raise ValueError("Job did not finish in time")
|
||||
|
||||
return upload_job
|
||||
@@ -12,12 +12,13 @@ from letta.client.client import LocalClient, RESTClient
|
||||
from letta.constants import DEFAULT_PRESET
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import JobStatus, MessageStreamStatus
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import FunctionCallMessage, InternalMonologue
|
||||
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from tests.helpers.client_helper import upload_file_using_client
|
||||
|
||||
# from tests.utils import create_config
|
||||
|
||||
@@ -298,6 +299,70 @@ def test_config(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# print("CONFIG", config_response)
|
||||
|
||||
|
||||
def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# clear sources
|
||||
for source in client.list_sources():
|
||||
client.delete_source(source.id)
|
||||
|
||||
# clear jobs
|
||||
for job in client.list_jobs():
|
||||
client.delete_job(job.id)
|
||||
|
||||
# create a source
|
||||
source = client.create_source(name="test_source")
|
||||
|
||||
# load files into sources
|
||||
file_a = "tests/data/memgpt_paper.pdf"
|
||||
file_b = "tests/data/test.txt"
|
||||
upload_file_using_client(client, source, file_a)
|
||||
upload_file_using_client(client, source, file_b)
|
||||
|
||||
# Get the first file
|
||||
files_a = client.list_files_from_source(source.id, limit=1)
|
||||
assert len(files_a) == 1
|
||||
assert files_a[0].source_id == source.id
|
||||
|
||||
# Use the cursor from response_a to get the remaining file
|
||||
files_b = client.list_files_from_source(source.id, limit=1, cursor=files_a[-1].id)
|
||||
assert len(files_b) == 1
|
||||
assert files_b[0].source_id == source.id
|
||||
|
||||
# Check files are different to ensure the cursor works
|
||||
assert files_a[0].file_name != files_b[0].file_name
|
||||
|
||||
# Use the cursor from response_b to list files, should be empty
|
||||
files = client.list_files_from_source(source.id, limit=1, cursor=files_b[-1].id)
|
||||
assert len(files) == 0 # Should be empty
|
||||
|
||||
|
||||
def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# _reset_config()
|
||||
|
||||
# clear sources
|
||||
for source in client.list_sources():
|
||||
client.delete_source(source.id)
|
||||
|
||||
# clear jobs
|
||||
for job in client.list_jobs():
|
||||
client.delete_job(job.id)
|
||||
|
||||
# create a source
|
||||
source = client.create_source(name="test_source")
|
||||
|
||||
# load a file into a source (non-blocking job)
|
||||
filename = "tests/data/memgpt_paper.pdf"
|
||||
upload_file_using_client(client, source, filename)
|
||||
|
||||
# Get the files
|
||||
files = client.list_files_from_source(source.id)
|
||||
assert len(files) == 1 # Should be condensed to one document
|
||||
|
||||
# Get the memgpt paper
|
||||
file = files[0]
|
||||
assert file.file_name == "memgpt_paper.pdf"
|
||||
assert file.source_id == source.id
|
||||
|
||||
|
||||
def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# _reset_config()
|
||||
|
||||
@@ -305,6 +370,10 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
for source in client.list_sources():
|
||||
client.delete_source(source.id)
|
||||
|
||||
# clear jobs
|
||||
for job in client.list_jobs():
|
||||
client.delete_job(job.id)
|
||||
|
||||
# list sources
|
||||
sources = client.list_sources()
|
||||
print("listed sources", sources)
|
||||
@@ -343,28 +412,7 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
|
||||
# load a file into a source (non-blocking job)
|
||||
filename = "tests/data/memgpt_paper.pdf"
|
||||
upload_job = client.load_file_into_source(filename=filename, source_id=source.id, blocking=False)
|
||||
print("Upload job", upload_job, upload_job.status, upload_job.metadata_)
|
||||
|
||||
# view active jobs
|
||||
active_jobs = client.list_active_jobs()
|
||||
jobs = client.list_jobs()
|
||||
print(jobs)
|
||||
assert upload_job.id in [j.id for j in jobs]
|
||||
assert len(active_jobs) == 1
|
||||
assert active_jobs[0].metadata_["source_id"] == source.id
|
||||
|
||||
# wait for job to finish (with timeout)
|
||||
timeout = 120
|
||||
start_time = time.time()
|
||||
while True:
|
||||
status = client.get_job(upload_job.id).status
|
||||
print(status)
|
||||
if status == JobStatus.completed:
|
||||
break
|
||||
time.sleep(1)
|
||||
if time.time() - start_time > timeout:
|
||||
raise ValueError("Job did not finish in time")
|
||||
upload_job = upload_file_using_client(client, source, filename)
|
||||
job = client.get_job(upload_job.id)
|
||||
created_passages = job.metadata_["num_passages"]
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import datetime
|
||||
import os
|
||||
from datetime import datetime
|
||||
from importlib import util
|
||||
from typing import Dict, Iterator, List, Tuple
|
||||
|
||||
@@ -7,7 +8,7 @@ import requests
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.data_sources.connectors import DataConnector
|
||||
from letta.schemas.document import Document
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.settings import TestSettings
|
||||
|
||||
from .constants import TIMEOUT
|
||||
@@ -18,14 +19,27 @@ class DummyDataConnector(DataConnector):
|
||||
|
||||
def __init__(self, texts: List[str]):
|
||||
self.texts = texts
|
||||
self.file_to_text = {}
|
||||
|
||||
def generate_documents(self) -> Iterator[Tuple[str, Dict]]:
|
||||
def find_files(self, source) -> Iterator[FileMetadata]:
|
||||
for text in self.texts:
|
||||
yield text, {"metadata": "dummy"}
|
||||
file_metadata = FileMetadata(
|
||||
user_id="",
|
||||
source_id="",
|
||||
file_name="",
|
||||
file_path="",
|
||||
file_type="",
|
||||
file_size=0, # Set to 0 as a placeholder
|
||||
file_creation_date="1970-01-01", # Placeholder date
|
||||
file_last_modified_date="1970-01-01", # Placeholder date
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
self.file_to_text[file_metadata.id] = text
|
||||
|
||||
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]:
|
||||
for doc in documents:
|
||||
yield doc.text, doc.metadata_
|
||||
yield file_metadata
|
||||
|
||||
def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]:
|
||||
yield self.file_to_text[file.id], {}
|
||||
|
||||
|
||||
def wipe_config():
|
||||
|
||||
Reference in New Issue
Block a user