Source code for pygho.hodata.ParallelPreprocess

import torch
from torch_geometric.data import InMemoryDataset, Data as PygData
from typing import Callable, Optional, Iterable
from multiprocessing import Pool
from pqdm.processes import pqdm
from tqdm import tqdm
from .Wrapper import _repr
import os.path as osp


[docs] class ParallelPreprocessDataset(InMemoryDataset): ''' Parallelly transform a PyG dataset. This dataset class allows parallel preprocessing of a list of PyGData or PyGDataset instances. Args: - root (str): The directory to save processed data. - data_list (Iterable[PygData]): A list of PygData or PygDataset instances. - pre_transform (Callable[[PygData], PygData]): A function that maps PygData to PygData. It is executed only once for all data and is typically a tuple sampler. - num_worker (int): The number of processes for parallel preprocessing. It can be set to the number of available CPU cores. - processedname (Optional[str]): The name to save the processed data. If None, the name will be a hash of the pre_transform function. - transform (Optional[Callable[[PygData], PygData]]): A function to dynamically transform data during data loading. ''' def __init__(self, root: str, data_list: Iterable[PygData], pre_transform: Callable[[PygData], PygData], num_worker: int, processedname: Optional[str] = None, transform: Optional[Callable[[PygData], PygData]] = None): self.tmp_data_list = list(data_list) self.num_worker = num_worker self.processedname = processedname super().__init__(root, pre_transform=pre_transform, transform=transform) self.data, self.slices = torch.load(self.processed_paths[0]) @property def processed_file_names(self): return 'data.pt' @property def processed_dir(self) -> str: if self.processedname is None: return osp.join( self.root, f'processed__{_repr(self.pre_transform)}__{_repr(self.pre_filter)}' ) else: return osp.join(self.root, f'processed__{self.processedname}')
[docs] def process(self): if self.num_worker > 0: data_list = pqdm(self.tmp_data_list, self.pre_transform, n_jobs=self.num_worker) else: data_list = [ self.pre_transform(_) for _ in tqdm(self.tmp_data_list) ] torch.save(self.collate(data_list), self.processed_paths[0])