enable source desc and allowing editing source name and desc

This commit is contained in:
Jonathan Ward
2024-07-31 12:23:06 -07:00
parent 2b569b3480
commit 6cdc838e2b
2 changed files with 44 additions and 4 deletions

View File

@@ -141,10 +141,36 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
interface.clear()
try:
# TODO: don't use Source and just use SourceModel once pydantic migration is complete
source = server.create_source(name=request.name, user_id=user_id)
source = server.create_source(name=request.name, user_id=user_id, description=request.description)
return SourceModel(
name=source.name,
description=None, # TODO: actually store descriptions
description=source.description,
user_id=source.user_id,
id=source.id,
embedding_config=server.server_embedding_config,
created_at=source.created_at.timestamp(),
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.post("/sources/{source_id}", tags=["sources"], response_model=SourceModel)
async def update_source(
source_id: uuid.UUID,
request: CreateSourceRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Update the name or documentation of an existing data source.
"""
interface.clear()
try:
# TODO: don't use Source and just use SourceModel once pydantic migration is complete
source = server.update_source(source_id=source_id, name=request.name, user_id=user_id, description=request.description)
return SourceModel(
name=source.name,
description=source.description,
user_id=source.user_id,
id=source.id,
embedding_config=server.server_embedding_config,

View File

@@ -1375,11 +1375,12 @@ class SyncServer(LockingServer):
token = self.ms.create_api_key(user_id=user_id)
return token
def create_source(self, name: str, user_id: uuid.UUID) -> Source: # TODO: add other fields
def create_source(self, name: str, user_id: uuid.UUID, description: str = None) -> Source: # TODO: add other fields
"""Create a new data source"""
source = Source(
name=name,
user_id=user_id,
description=description,
embedding_model=self.config.default_embedding_config.embedding_model,
embedding_dim=self.config.default_embedding_config.embedding_dim,
)
@@ -1387,6 +1388,19 @@ class SyncServer(LockingServer):
assert self.ms.get_source(source_name=name, user_id=user_id) is not None, f"Failed to create source {name}"
return source
def update_source(self, source_id: uuid.UUID, name: str, user_id: uuid.UUID, description: str = None) -> Source:
"""Updates a data source"""
source = Source(
id=source_id,
name=name,
user_id=user_id,
description=description,
embedding_model=self.config.default_embedding_config.embedding_model,
embedding_dim=self.config.default_embedding_config.embedding_dim
)
self.ms.update_source(source)
return source
def delete_source(self, source_id: uuid.UUID, user_id: uuid.UUID):
"""Delete a data source"""
source = self.ms.get_source(source_id=source_id, user_id=user_id)
@@ -1475,7 +1489,7 @@ class SyncServer(LockingServer):
sources = [
SourceModel(
name=source.name,
description=None, # TODO: actually store descriptions
description=source.description,
user_id=source.user_id,
id=source.id,
embedding_config=self.server_embedding_config,