fix: write temporary file for REST upload file endpoint + return number added passages/documents (#1169)

This commit is contained in:
Sarah Wooders
2024-03-20 15:39:08 -07:00
committed by GitHub
parent 844b153ea3
commit 094b2177de
2 changed files with 18 additions and 8 deletions

View File

@@ -1,4 +1,7 @@
import uuid
import tempfile
import os
import hashlib
from functools import partial
from typing import List, Optional
@@ -150,14 +153,20 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
interface.clear()
source = server.ms.get_source(source_id=source_id, user_id=user_id)
# create a directory connector that reads the in-memory file
connector = DirectoryConnector(input_files=[file.filename])
# write the file to a temporary directory (deleted after the context manager exits)
with tempfile.TemporaryDirectory() as tmpdirname:
file_path = os.path.join(tmpdirname, file.filename)
with open(file_path, "wb") as buffer:
buffer.write(file.file.read())
# load the data into the source via the connector
server.load_data(user_id=user_id, source_name=source.name, connector=connector)
# read the file
connector = DirectoryConnector(input_files=[file_path])
# load the data into the source via the connector
passage_count, document_count = server.load_data(user_id=user_id, source_name=source.name, connector=connector)
# TODO: actually return added passages/documents
return UploadFileToSourceResponse(source=source, added_passages=0, added_documents=0)
return UploadFileToSourceResponse(source=source, added_passages=passage_count, added_documents=document_count)
@router.get("/sources/passages ", tags=["sources"], response_model=GetSourcePassagesResponse)
async def list_passages(

View File

@@ -6,7 +6,7 @@ from abc import abstractmethod
from datetime import datetime
from functools import wraps
from threading import Lock
from typing import Union, Callable, Optional, List
from typing import Union, Callable, Optional, List, Tuple
import warnings
from fastapi import HTTPException
@@ -1254,7 +1254,7 @@ class SyncServer(LockingServer):
user_id: uuid.UUID,
connector: DataConnector,
source_name: str,
):
) -> Tuple[int, int]:
"""Load data from a DataConnector into a source for a specified user_id"""
# TODO: this should be implemented as a batch job or at least async, since it may take a long time
@@ -1269,7 +1269,8 @@ class SyncServer(LockingServer):
document_store = None # StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id)
# load data into the document store
load_data(connector, source, self.config.default_embedding_config, passage_store, document_store)
passage_count, document_count = load_data(connector, source, self.config.default_embedding_config, passage_store, document_store)
return passage_count, document_count
def attach_source_to_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str):
# attach a data source to an agent