fix: write temporary file for REST upload file endpoint + return number added passages/documents (#1169)
This commit is contained in:
@@ -1,4 +1,7 @@
|
||||
import uuid
|
||||
import tempfile
|
||||
import os
|
||||
import hashlib
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -150,14 +153,20 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
|
||||
interface.clear()
|
||||
source = server.ms.get_source(source_id=source_id, user_id=user_id)
|
||||
|
||||
# create a directory connector that reads the in-memory file
|
||||
connector = DirectoryConnector(input_files=[file.filename])
|
||||
# write the file to a temporary directory (deleted after the context manager exits)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
file_path = os.path.join(tmpdirname, file.filename)
|
||||
with open(file_path, "wb") as buffer:
|
||||
buffer.write(file.file.read())
|
||||
|
||||
# load the data into the source via the connector
|
||||
server.load_data(user_id=user_id, source_name=source.name, connector=connector)
|
||||
# read the file
|
||||
connector = DirectoryConnector(input_files=[file_path])
|
||||
|
||||
# load the data into the source via the connector
|
||||
passage_count, document_count = server.load_data(user_id=user_id, source_name=source.name, connector=connector)
|
||||
|
||||
# TODO: actually return added passages/documents
|
||||
return UploadFileToSourceResponse(source=source, added_passages=0, added_documents=0)
|
||||
return UploadFileToSourceResponse(source=source, added_passages=passage_count, added_documents=document_count)
|
||||
|
||||
@router.get("/sources/passages ", tags=["sources"], response_model=GetSourcePassagesResponse)
|
||||
async def list_passages(
|
||||
|
||||
@@ -6,7 +6,7 @@ from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from threading import Lock
|
||||
from typing import Union, Callable, Optional, List
|
||||
from typing import Union, Callable, Optional, List, Tuple
|
||||
import warnings
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -1254,7 +1254,7 @@ class SyncServer(LockingServer):
|
||||
user_id: uuid.UUID,
|
||||
connector: DataConnector,
|
||||
source_name: str,
|
||||
):
|
||||
) -> Tuple[int, int]:
|
||||
"""Load data from a DataConnector into a source for a specified user_id"""
|
||||
# TODO: this should be implemented as a batch job or at least async, since it may take a long time
|
||||
|
||||
@@ -1269,7 +1269,8 @@ class SyncServer(LockingServer):
|
||||
document_store = None # StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id)
|
||||
|
||||
# load data into the document store
|
||||
load_data(connector, source, self.config.default_embedding_config, passage_store, document_store)
|
||||
passage_count, document_count = load_data(connector, source, self.config.default_embedding_config, passage_store, document_store)
|
||||
return passage_count, document_count
|
||||
|
||||
def attach_source_to_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str):
|
||||
# attach a data source to an agent
|
||||
|
||||
Reference in New Issue
Block a user