feat: pass embedding handle for source create (#1307)

This commit is contained in:
Sarah Wooders
2025-03-16 15:42:11 -07:00
committed by GitHub
parent edab956d0b
commit 93a480b1f4
2 changed files with 21 additions and 3 deletions

View File

@@ -50,7 +50,12 @@ class SourceCreate(BaseSource):
# required
name: str = Field(..., description="The name of the source.")
# TODO: @matt, make this required after shub makes the FE changes
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the source.")
embedding: Optional[str] = Field(None, description="The hande for the embedding config used by the source.")
embedding_chunk_size: Optional[int] = Field(None, description="The chunk size of the embedding.")
# TODO: remove (legacy config)
embedding_config: Optional[EmbeddingConfig] = Field(None, description="(Legacy) The embedding configuration used by the source.")
# optional
description: Optional[str] = Field(None, description="The description of the source.")

View File

@@ -76,8 +76,20 @@ def create_source(
Create a new data source.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
source = Source(**source_create.model_dump())
if not source_create.embedding_config:
if not source_create.embedding:
# TODO: modify error type
raise ValueError("Must specify either embedding or embedding_config in request")
source_create.embedding_config = server.get_embedding_config_from_handle(
handle=source_create.embedding,
embedding_chunk_size=source_create.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
)
source = Source(
name=source_create.name,
embedding_config=source_create.embedding_config,
description=source_create.description,
metadata=source_create.metadata,
)
return server.source_manager.create_source(source=source, actor=actor)
@@ -91,6 +103,7 @@ def modify_source(
"""
Update the name or documentation of an existing data source.
"""
# TODO: allow updating the handle/embedding config
actor = server.user_manager.get_user_or_default(user_id=actor_id)
if not server.source_manager.get_source_by_id(source_id=source_id, actor=actor):
raise HTTPException(status_code=404, detail=f"Source with id={source_id} does not exist.")