feat: Move Source to ORM model (#1979)
This commit is contained in:
@@ -238,7 +238,7 @@ class AbstractClient(object):
|
||||
def delete_file_from_source(self, source_id: str, file_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def create_source(self, name: str) -> Source:
|
||||
def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source:
|
||||
raise NotImplementedError
|
||||
|
||||
def delete_source(self, source_id: str):
|
||||
@@ -1188,7 +1188,7 @@ class RESTClient(AbstractClient):
|
||||
if response.status_code not in [200, 204]:
|
||||
raise ValueError(f"Failed to delete tool: {response.text}")
|
||||
|
||||
def create_source(self, name: str) -> Source:
|
||||
def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source:
|
||||
"""
|
||||
Create a source
|
||||
|
||||
@@ -1198,7 +1198,8 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
source (Source): Created source
|
||||
"""
|
||||
payload = {"name": name}
|
||||
source_create = SourceCreate(name=name, embedding_config=embedding_config or self._default_embedding_config)
|
||||
payload = source_create.model_dump()
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/sources", json=payload, headers=self.headers)
|
||||
response_json = response.json()
|
||||
return Source(**response_json)
|
||||
@@ -1253,7 +1254,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
source (Source): Updated source
|
||||
"""
|
||||
request = SourceUpdate(id=source_id, name=name)
|
||||
request = SourceUpdate(name=name)
|
||||
response = requests.patch(f"{self.base_url}/{self.api_prefix}/sources/{source_id}", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update source: {response.text}")
|
||||
@@ -2453,7 +2454,7 @@ class LocalClient(AbstractClient):
|
||||
def list_active_jobs(self):
|
||||
return self.server.list_active_jobs(user_id=self.user_id)
|
||||
|
||||
def create_source(self, name: str) -> Source:
|
||||
def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source:
|
||||
"""
|
||||
Create a source
|
||||
|
||||
@@ -2463,8 +2464,10 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
source (Source): Created source
|
||||
"""
|
||||
request = SourceCreate(name=name)
|
||||
return self.server.create_source(request=request, user_id=self.user_id)
|
||||
source = Source(
|
||||
name=name, embedding_config=embedding_config or self._default_embedding_config, organization_id=self.user.organization_id
|
||||
)
|
||||
return self.server.source_manager.create_source(source=source, actor=self.user)
|
||||
|
||||
def delete_source(self, source_id: str):
|
||||
"""
|
||||
@@ -2475,7 +2478,7 @@ class LocalClient(AbstractClient):
|
||||
"""
|
||||
|
||||
# TODO: delete source data
|
||||
self.server.delete_source(source_id=source_id, user_id=self.user_id)
|
||||
self.server.delete_source(source_id=source_id, actor=self.user)
|
||||
|
||||
def get_source(self, source_id: str) -> Source:
|
||||
"""
|
||||
@@ -2487,7 +2490,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
source (Source): Source
|
||||
"""
|
||||
return self.server.get_source(source_id=source_id, user_id=self.user_id)
|
||||
return self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user)
|
||||
|
||||
def get_source_id(self, source_name: str) -> str:
|
||||
"""
|
||||
@@ -2499,7 +2502,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
source_id (str): ID of the source
|
||||
"""
|
||||
return self.server.get_source_id(source_name=source_name, user_id=self.user_id)
|
||||
return self.server.source_manager.get_source_by_name(source_name=source_name, actor=self.user).id
|
||||
|
||||
def attach_source_to_agent(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None):
|
||||
"""
|
||||
@@ -2532,7 +2535,7 @@ class LocalClient(AbstractClient):
|
||||
sources (List[Source]): List of sources
|
||||
"""
|
||||
|
||||
return self.server.list_all_sources(user_id=self.user_id)
|
||||
return self.server.list_all_sources(actor=self.user)
|
||||
|
||||
def list_attached_sources(self, agent_id: str) -> List[Source]:
|
||||
"""
|
||||
@@ -2572,8 +2575,8 @@ class LocalClient(AbstractClient):
|
||||
source (Source): Updated source
|
||||
"""
|
||||
# TODO should the arg here just be "source_update: Source"?
|
||||
request = SourceUpdate(id=source_id, name=name)
|
||||
return self.server.update_source(request=request, user_id=self.user_id)
|
||||
request = SourceUpdate(name=name)
|
||||
return self.server.source_manager.update_source(source_id=source_id, source_update=request, actor=self.user)
|
||||
|
||||
# archival memory
|
||||
|
||||
|
||||
Reference in New Issue
Block a user