# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
| Splitters
"""
import random
import numpy as np
from itertools import compress
from rdkit.Chem.Scaffolds import MurckoScaffold
from collections import defaultdict
from sklearn.model_selection import StratifiedKFold
__all__ = [
'RandomSplitter',
'IndexSplitter',
'ScaffoldSplitter',
'RandomScaffoldSplitter',
]
[docs]def generate_scaffold(smiles, include_chirality=False):
"""
Obtain Bemis-Murcko scaffold from smiles
Args:
smiles: smiles sequence
include_chirality: Default=False
Return:
the scaffold of the given smiles.
"""
scaffold = MurckoScaffold.MurckoScaffoldSmiles(
smiles=smiles, includeChirality=include_chirality)
return scaffold
class Splitter(object):
"""
The abstract class of splitters which split up dataset into train/valid/test
subsets.
"""
def __init__(self):
super(Splitter, self).__init__()
[docs]class RandomSplitter(Splitter):
"""
Random splitter.
"""
def __init__(self):
super(RandomSplitter, self).__init__()
[docs] def split(self,
dataset,
frac_train=None,
frac_valid=None,
frac_test=None,
seed=None):
"""
Args:
dataset(InMemoryDataset): the dataset to split.
frac_train(float): the fraction of data to be used for the train split.
frac_valid(float): the fraction of data to be used for the valid split.
frac_test(float): the fraction of data to be used for the test split.
seed(int|None): the random seed.
"""
np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0)
N = len(dataset)
indices = list(range(N))
rng = np.random.RandomState(seed)
rng.shuffle(indices)
train_cutoff = int(frac_train * N)
valid_cutoff = int((frac_train + frac_valid) * N)
train_dataset = dataset[indices[:train_cutoff]]
valid_dataset = dataset[indices[train_cutoff:valid_cutoff]]
test_dataset = dataset[indices[valid_cutoff:]]
return train_dataset, valid_dataset, test_dataset
[docs]class IndexSplitter(Splitter):
"""
Split daatasets that has already been orderd. The first `frac_train` proportion
is used for train set, the next `frac_valid` for valid set and the final `frac_test`
for test set.
"""
def __init__(self):
super(IndexSplitter, self).__init__()
[docs] def split(self,
dataset,
frac_train=None,
frac_valid=None,
frac_test=None):
"""
Args:
dataset(InMemoryDataset): the dataset to split.
frac_train(float): the fraction of data to be used for the train split.
frac_valid(float): the fraction of data to be used for the valid split.
frac_test(float): the fraction of data to be used for the test split.
"""
np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0)
N = len(dataset)
indices = list(range(N))
train_cutoff = int(frac_train * N)
valid_cutoff = int((frac_train + frac_valid) * N)
train_dataset = dataset[indices[:train_cutoff]]
valid_dataset = dataset[indices[train_cutoff:valid_cutoff]]
test_dataset = dataset[indices[valid_cutoff:]]
return train_dataset, valid_dataset, test_dataset
[docs]class ScaffoldSplitter(Splitter):
"""
Adapted from https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py
Split dataset by Bemis-Murcko scaffolds
"""
def __init__(self):
super(ScaffoldSplitter, self).__init__()
[docs] def split(self,
dataset,
frac_train=None,
frac_valid=None,
frac_test=None):
"""
Args:
dataset(InMemoryDataset): the dataset to split. Make sure each element in
the dataset has key "smiles" which will be used to calculate the
scaffold.
frac_train(float): the fraction of data to be used for the train split.
frac_valid(float): the fraction of data to be used for the valid split.
frac_test(float): the fraction of data to be used for the test split.
"""
np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0)
N = len(dataset)
# create dict of the form {scaffold_i: [idx1, idx....]}
all_scaffolds = {}
for i in range(N):
scaffold = generate_scaffold(dataset[i]['smiles'], include_chirality=True)
if scaffold not in all_scaffolds:
all_scaffolds[scaffold] = [i]
else:
all_scaffolds[scaffold].append(i)
# sort from largest to smallest sets
all_scaffolds = {key: sorted(value) for key, value in all_scaffolds.items()}
all_scaffold_sets = [
scaffold_set for (scaffold, scaffold_set) in sorted(
all_scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
]
# get train, valid test indices
train_cutoff = frac_train * N
valid_cutoff = (frac_train + frac_valid) * N
train_idx, valid_idx, test_idx = [], [], []
for scaffold_set in all_scaffold_sets:
if len(train_idx) + len(scaffold_set) > train_cutoff:
if len(train_idx) + len(valid_idx) + len(scaffold_set) > valid_cutoff:
test_idx.extend(scaffold_set)
else:
valid_idx.extend(scaffold_set)
else:
train_idx.extend(scaffold_set)
assert len(set(train_idx).intersection(set(valid_idx))) == 0
assert len(set(test_idx).intersection(set(valid_idx))) == 0
# get train, valid test indices
train_cutoff = frac_train * N
valid_cutoff = (frac_train + frac_valid) * N
train_idx, valid_idx, test_idx = [], [], []
for scaffold_set in all_scaffold_sets:
if len(train_idx) + len(scaffold_set) > train_cutoff:
if len(train_idx) + len(valid_idx) + len(scaffold_set) > valid_cutoff:
test_idx.extend(scaffold_set)
else:
valid_idx.extend(scaffold_set)
else:
train_idx.extend(scaffold_set)
assert len(set(train_idx).intersection(set(valid_idx))) == 0
assert len(set(test_idx).intersection(set(valid_idx))) == 0
train_dataset = dataset[train_idx]
valid_dataset = dataset[valid_idx]
test_dataset = dataset[test_idx]
return train_dataset, valid_dataset, test_dataset
[docs]class RandomScaffoldSplitter(Splitter):
"""
Adapted from https://github.com/pfnet-research/chainer-chemistry/blob/master/chainer_chemistry/dataset/splitters/scaffold_splitter.py
Split dataset by Bemis-Murcko scaffolds
"""
def __init__(self):
super(RandomScaffoldSplitter, self).__init__()
[docs] def split(self,
dataset,
frac_train=None,
frac_valid=None,
frac_test=None,
seed=None):
"""
Args:
dataset(InMemoryDataset): the dataset to split. Make sure each element in
the dataset has key "smiles" which will be used to calculate the
scaffold.
frac_train(float): the fraction of data to be used for the train split.
frac_valid(float): the fraction of data to be used for the valid split.
frac_test(float): the fraction of data to be used for the test split.
seed(int|None): the random seed.
"""
np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0)
N = len(dataset)
rng = np.random.RandomState(seed)
scaffolds = defaultdict(list)
for ind in range(N):
scaffold = generate_scaffold(dataset[ind]['smiles'], include_chirality=True)
scaffolds[scaffold].append(ind)
scaffold_sets = rng.permutation(np.array(list(scaffolds.values()), dtype=object))
n_total_valid = int(np.floor(frac_valid * len(dataset)))
n_total_test = int(np.floor(frac_test * len(dataset)))
train_idx = []
valid_idx = []
test_idx = []
for scaffold_set in scaffold_sets:
if len(valid_idx) + len(scaffold_set) <= n_total_valid:
valid_idx.extend(scaffold_set)
elif len(test_idx) + len(scaffold_set) <= n_total_test:
test_idx.extend(scaffold_set)
else:
train_idx.extend(scaffold_set)
train_dataset = dataset[train_idx]
valid_dataset = dataset[valid_idx]
test_dataset = dataset[test_idx]
return train_dataset, valid_dataset, test_dataset