Source code for pahelix.utils.splitters

#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
| Splitters
"""

import random
import numpy as np
from itertools import compress
from rdkit.Chem.Scaffolds import MurckoScaffold
from collections import defaultdict
from sklearn.model_selection import StratifiedKFold

__all__ = [
    'RandomSplitter',
    'IndexSplitter',
    'ScaffoldSplitter',
    'RandomScaffoldSplitter',
]


[docs]def generate_scaffold(smiles, include_chirality=False): """ Obtain Bemis-Murcko scaffold from smiles Args: smiles: smiles sequence include_chirality: Default=False Return: the scaffold of the given smiles. """ scaffold = MurckoScaffold.MurckoScaffoldSmiles( smiles=smiles, includeChirality=include_chirality) return scaffold
class Splitter(object): """ The abstract class of splitters which split up dataset into train/valid/test subsets. """ def __init__(self): super(Splitter, self).__init__()
[docs]class RandomSplitter(Splitter): """ Random splitter. """ def __init__(self): super(RandomSplitter, self).__init__()
[docs] def split(self, dataset, frac_train=None, frac_valid=None, frac_test=None, seed=None): """ Args: dataset(InMemoryDataset): the dataset to split. frac_train(float): the fraction of data to be used for the train split. frac_valid(float): the fraction of data to be used for the valid split. frac_test(float): the fraction of data to be used for the test split. seed(int|None): the random seed. """ np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0) N = len(dataset) indices = list(range(N)) rng = np.random.RandomState(seed) rng.shuffle(indices) train_cutoff = int(frac_train * N) valid_cutoff = int((frac_train + frac_valid) * N) train_dataset = dataset[indices[:train_cutoff]] valid_dataset = dataset[indices[train_cutoff:valid_cutoff]] test_dataset = dataset[indices[valid_cutoff:]] return train_dataset, valid_dataset, test_dataset
[docs]class IndexSplitter(Splitter): """ Split daatasets that has already been orderd. The first `frac_train` proportion is used for train set, the next `frac_valid` for valid set and the final `frac_test` for test set. """ def __init__(self): super(IndexSplitter, self).__init__()
[docs] def split(self, dataset, frac_train=None, frac_valid=None, frac_test=None): """ Args: dataset(InMemoryDataset): the dataset to split. frac_train(float): the fraction of data to be used for the train split. frac_valid(float): the fraction of data to be used for the valid split. frac_test(float): the fraction of data to be used for the test split. """ np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0) N = len(dataset) indices = list(range(N)) train_cutoff = int(frac_train * N) valid_cutoff = int((frac_train + frac_valid) * N) train_dataset = dataset[indices[:train_cutoff]] valid_dataset = dataset[indices[train_cutoff:valid_cutoff]] test_dataset = dataset[indices[valid_cutoff:]] return train_dataset, valid_dataset, test_dataset
[docs]class ScaffoldSplitter(Splitter): """ Adapted from https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py Split dataset by Bemis-Murcko scaffolds """ def __init__(self): super(ScaffoldSplitter, self).__init__()
[docs] def split(self, dataset, frac_train=None, frac_valid=None, frac_test=None): """ Args: dataset(InMemoryDataset): the dataset to split. Make sure each element in the dataset has key "smiles" which will be used to calculate the scaffold. frac_train(float): the fraction of data to be used for the train split. frac_valid(float): the fraction of data to be used for the valid split. frac_test(float): the fraction of data to be used for the test split. """ np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0) N = len(dataset) # create dict of the form {scaffold_i: [idx1, idx....]} all_scaffolds = {} for i in range(N): scaffold = generate_scaffold(dataset[i]['smiles'], include_chirality=True) if scaffold not in all_scaffolds: all_scaffolds[scaffold] = [i] else: all_scaffolds[scaffold].append(i) # sort from largest to smallest sets all_scaffolds = {key: sorted(value) for key, value in all_scaffolds.items()} all_scaffold_sets = [ scaffold_set for (scaffold, scaffold_set) in sorted( all_scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True) ] # get train, valid test indices train_cutoff = frac_train * N valid_cutoff = (frac_train + frac_valid) * N train_idx, valid_idx, test_idx = [], [], [] for scaffold_set in all_scaffold_sets: if len(train_idx) + len(scaffold_set) > train_cutoff: if len(train_idx) + len(valid_idx) + len(scaffold_set) > valid_cutoff: test_idx.extend(scaffold_set) else: valid_idx.extend(scaffold_set) else: train_idx.extend(scaffold_set) assert len(set(train_idx).intersection(set(valid_idx))) == 0 assert len(set(test_idx).intersection(set(valid_idx))) == 0 # get train, valid test indices train_cutoff = frac_train * N valid_cutoff = (frac_train + frac_valid) * N train_idx, valid_idx, test_idx = [], [], [] for scaffold_set in all_scaffold_sets: if len(train_idx) + len(scaffold_set) > train_cutoff: if len(train_idx) + len(valid_idx) + len(scaffold_set) > valid_cutoff: test_idx.extend(scaffold_set) else: valid_idx.extend(scaffold_set) else: train_idx.extend(scaffold_set) assert len(set(train_idx).intersection(set(valid_idx))) == 0 assert len(set(test_idx).intersection(set(valid_idx))) == 0 train_dataset = dataset[train_idx] valid_dataset = dataset[valid_idx] test_dataset = dataset[test_idx] return train_dataset, valid_dataset, test_dataset
[docs]class RandomScaffoldSplitter(Splitter): """ Adapted from https://github.com/pfnet-research/chainer-chemistry/blob/master/chainer_chemistry/dataset/splitters/scaffold_splitter.py Split dataset by Bemis-Murcko scaffolds """ def __init__(self): super(RandomScaffoldSplitter, self).__init__()
[docs] def split(self, dataset, frac_train=None, frac_valid=None, frac_test=None, seed=None): """ Args: dataset(InMemoryDataset): the dataset to split. Make sure each element in the dataset has key "smiles" which will be used to calculate the scaffold. frac_train(float): the fraction of data to be used for the train split. frac_valid(float): the fraction of data to be used for the valid split. frac_test(float): the fraction of data to be used for the test split. seed(int|None): the random seed. """ np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0) N = len(dataset) rng = np.random.RandomState(seed) scaffolds = defaultdict(list) for ind in range(N): scaffold = generate_scaffold(dataset[ind]['smiles'], include_chirality=True) scaffolds[scaffold].append(ind) scaffold_sets = rng.permutation(np.array(list(scaffolds.values()), dtype=object)) n_total_valid = int(np.floor(frac_valid * len(dataset))) n_total_test = int(np.floor(frac_test * len(dataset))) train_idx = [] valid_idx = [] test_idx = [] for scaffold_set in scaffold_sets: if len(valid_idx) + len(scaffold_set) <= n_total_valid: valid_idx.extend(scaffold_set) elif len(test_idx) + len(scaffold_set) <= n_total_test: test_idx.extend(scaffold_set) else: train_idx.extend(scaffold_set) train_dataset = dataset[train_idx] valid_dataset = dataset[valid_idx] test_dataset = dataset[test_idx] return train_dataset, valid_dataset, test_dataset