diff --git a/memgpt/server/rest_api/sources/index.py b/memgpt/server/rest_api/sources/index.py index 25334672..d8ff6911 100644 --- a/memgpt/server/rest_api/sources/index.py +++ b/memgpt/server/rest_api/sources/index.py @@ -1,4 +1,7 @@ import uuid +import tempfile +import os +import hashlib from functools import partial from typing import List, Optional @@ -150,14 +153,20 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface, interface.clear() source = server.ms.get_source(source_id=source_id, user_id=user_id) - # create a directory connector that reads the in-memory file - connector = DirectoryConnector(input_files=[file.filename]) + # write the file to a temporary directory (deleted after the context manager exits) + with tempfile.TemporaryDirectory() as tmpdirname: + file_path = os.path.join(tmpdirname, file.filename) + with open(file_path, "wb") as buffer: + buffer.write(file.file.read()) - # load the data into the source via the connector - server.load_data(user_id=user_id, source_name=source.name, connector=connector) + # read the file + connector = DirectoryConnector(input_files=[file_path]) + + # load the data into the source via the connector + passage_count, document_count = server.load_data(user_id=user_id, source_name=source.name, connector=connector) # TODO: actually return added passages/documents - return UploadFileToSourceResponse(source=source, added_passages=0, added_documents=0) + return UploadFileToSourceResponse(source=source, added_passages=passage_count, added_documents=document_count) @router.get("/sources/passages ", tags=["sources"], response_model=GetSourcePassagesResponse) async def list_passages( diff --git a/memgpt/server/server.py b/memgpt/server/server.py index ba371d21..148718a6 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -6,7 +6,7 @@ from abc import abstractmethod from datetime import datetime from functools import wraps from threading import Lock -from typing import Union, Callable, Optional, List +from typing import Union, Callable, Optional, List, Tuple import warnings from fastapi import HTTPException @@ -1254,7 +1254,7 @@ class SyncServer(LockingServer): user_id: uuid.UUID, connector: DataConnector, source_name: str, - ): + ) -> Tuple[int, int]: """Load data from a DataConnector into a source for a specified user_id""" # TODO: this should be implemented as a batch job or at least async, since it may take a long time @@ -1269,7 +1269,8 @@ class SyncServer(LockingServer): document_store = None # StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id) # load data into the document store - load_data(connector, source, self.config.default_embedding_config, passage_store, document_store) + passage_count, document_count = load_data(connector, source, self.config.default_embedding_config, passage_store, document_store) + return passage_count, document_count def attach_source_to_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str): # attach a data source to an agent