Source code for hdnnpy.dataset.dataset_generator

# coding: utf-8

"""Deal out datasets as needed."""

from sklearn.model_selection import KFold

from hdnnpy.dataset.hdnnp_dataset import HDNNPDataset


[docs]class DatasetGenerator(object): """Deal out datasets as needed.""" def __init__(self, *datasets): """ Args: *datasets (HDNNPDataset): What you want to unite. """ if not datasets: raise ValueError('No dataset are given') for dataset in datasets: assert isinstance(dataset, HDNNPDataset) self._datasets = list(datasets)
[docs] def all(self): """Pass all datasets an instance have. Returns: list [HDNNPDataset]: All stored datasets. """ return self._datasets
[docs] def foreach(self): """Pass all datasets an instance have one by one. Returns: Iterator [HDNNPDataset]: a stored dataset object. """ for dataset in self._datasets: yield dataset
[docs] def holdout(self, ratio): """Split each dataset at a certain rate and pass it Args: ratio (float): Specify the rate you want to use as training data. Remains are test data. Returns: list [tuple [HDNNPDataset, HDNNPDataset]]: All stored dataset split by specified ratio into training and test data. """ split = [] for dataset in self._datasets: s = int(dataset.partial_size * ratio) train = dataset.take(slice(None, s, None)) test = dataset.take(slice(s, None, None)) assert len(train) > 0 assert len(test) > 0 split.append((train, test)) return split
[docs] def kfold(self, kfold): """Split each dataset almost equally and pass it for cross validation. Args: kfold (int): Number of folds to split dataset. Returns: Iterator [list [tuple [HDNNPDataset, HDNNPDataset]]]: All stored dataset split into training and test data. It iterates k times while changing parts used for test data. """ kf = KFold(n_splits=kfold) kfold_indices = [kf.split(range(dataset.partial_size)) for dataset in self._datasets] for indices in zip(*kfold_indices): split = [] for dataset, (train_idx, test_idx) in zip(self._datasets, indices): train = dataset.take(train_idx) test = dataset.take(test_idx) assert len(train) > 0 assert len(test) > 0 split.append((train, test)) yield split