feat: Move Source to ORM model (#1979)

This commit is contained in:
Matthew Zhou
2024-11-12 09:57:40 -08:00
committed by GitHub
parent d29a0b2cc7
commit e40e60945a
18 changed files with 509 additions and 236 deletions

View File

@@ -238,7 +238,7 @@ class AbstractClient(object):
def delete_file_from_source(self, source_id: str, file_id: str) -> None:
raise NotImplementedError
def create_source(self, name: str) -> Source:
def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source:
raise NotImplementedError
def delete_source(self, source_id: str):
@@ -1188,7 +1188,7 @@ class RESTClient(AbstractClient):
if response.status_code not in [200, 204]:
raise ValueError(f"Failed to delete tool: {response.text}")
def create_source(self, name: str) -> Source:
def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source:
"""
Create a source
@@ -1198,7 +1198,8 @@ class RESTClient(AbstractClient):
Returns:
source (Source): Created source
"""
payload = {"name": name}
source_create = SourceCreate(name=name, embedding_config=embedding_config or self._default_embedding_config)
payload = source_create.model_dump()
response = requests.post(f"{self.base_url}/{self.api_prefix}/sources", json=payload, headers=self.headers)
response_json = response.json()
return Source(**response_json)
@@ -1253,7 +1254,7 @@ class RESTClient(AbstractClient):
Returns:
source (Source): Updated source
"""
request = SourceUpdate(id=source_id, name=name)
request = SourceUpdate(name=name)
response = requests.patch(f"{self.base_url}/{self.api_prefix}/sources/{source_id}", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update source: {response.text}")
@@ -2453,7 +2454,7 @@ class LocalClient(AbstractClient):
def list_active_jobs(self):
return self.server.list_active_jobs(user_id=self.user_id)
def create_source(self, name: str) -> Source:
def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source:
"""
Create a source
@@ -2463,8 +2464,10 @@ class LocalClient(AbstractClient):
Returns:
source (Source): Created source
"""
request = SourceCreate(name=name)
return self.server.create_source(request=request, user_id=self.user_id)
source = Source(
name=name, embedding_config=embedding_config or self._default_embedding_config, organization_id=self.user.organization_id
)
return self.server.source_manager.create_source(source=source, actor=self.user)
def delete_source(self, source_id: str):
"""
@@ -2475,7 +2478,7 @@ class LocalClient(AbstractClient):
"""
# TODO: delete source data
self.server.delete_source(source_id=source_id, user_id=self.user_id)
self.server.delete_source(source_id=source_id, actor=self.user)
def get_source(self, source_id: str) -> Source:
"""
@@ -2487,7 +2490,7 @@ class LocalClient(AbstractClient):
Returns:
source (Source): Source
"""
return self.server.get_source(source_id=source_id, user_id=self.user_id)
return self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user)
def get_source_id(self, source_name: str) -> str:
"""
@@ -2499,7 +2502,7 @@ class LocalClient(AbstractClient):
Returns:
source_id (str): ID of the source
"""
return self.server.get_source_id(source_name=source_name, user_id=self.user_id)
return self.server.source_manager.get_source_by_name(source_name=source_name, actor=self.user).id
def attach_source_to_agent(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None):
"""
@@ -2532,7 +2535,7 @@ class LocalClient(AbstractClient):
sources (List[Source]): List of sources
"""
return self.server.list_all_sources(user_id=self.user_id)
return self.server.list_all_sources(actor=self.user)
def list_attached_sources(self, agent_id: str) -> List[Source]:
"""
@@ -2572,8 +2575,8 @@ class LocalClient(AbstractClient):
source (Source): Updated source
"""
# TODO should the arg here just be "source_update: Source"?
request = SourceUpdate(id=source_id, name=name)
return self.server.update_source(request=request, user_id=self.user_id)
request = SourceUpdate(name=name)
return self.server.source_manager.update_source(source_id=source_id, source_update=request, actor=self.user)
# archival memory