Source code for pahelix.networks.compound_encoder

#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Basic Encoder for compound atom/bond features.
"""
import numpy as np

import paddle
import paddle.nn as nn
import pgl

from pahelix.utils.compound_tools import CompoundKit
from pahelix.networks.basic_block import RBF


[docs]class AtomEmbedding(nn.Layer): """ Atom Encoder """ def __init__(self, atom_names, embed_dim): super(AtomEmbedding, self).__init__() self.atom_names = atom_names self.embed_list = nn.LayerList() for name in self.atom_names: embed = nn.Embedding( CompoundKit.get_atom_feature_size(name) + 5, embed_dim, weight_attr=nn.initializer.XavierUniform()) self.embed_list.append(embed)
[docs] def forward(self, node_features): """ Args: node_features(dict of tensor): node features. """ out_embed = 0 for i, name in enumerate(self.atom_names): out_embed += self.embed_list[i](node_features[name]) return out_embed
[docs]class AtomFloatEmbedding(nn.Layer): """ Atom Float Encoder """ def __init__(self, atom_float_names, embed_dim, rbf_params=None): super(AtomFloatEmbedding, self).__init__() self.atom_float_names = atom_float_names if rbf_params is None: self.rbf_params = { 'van_der_waals_radis': (np.arange(1, 3, 0.2), 10.0), # (centers, gamma) 'partial_charge': (np.arange(-1, 4, 0.25), 10.0), # (centers, gamma) 'mass': (np.arange(0, 2, 0.1), 10.0), # (centers, gamma) } else: self.rbf_params = rbf_params self.linear_list = nn.LayerList() self.rbf_list = nn.LayerList() for name in self.atom_float_names: centers, gamma = self.rbf_params[name] rbf = RBF(centers, gamma) self.rbf_list.append(rbf) linear = nn.Linear(len(centers), embed_dim) self.linear_list.append(linear)
[docs] def forward(self, feats): """ Args: feats(dict of tensor): node float features. """ out_embed = 0 for i, name in enumerate(self.atom_float_names): x = feats[name] rbf_x = self.rbf_list[i](x) out_embed += self.linear_list[i](rbf_x) return out_embed
[docs]class BondEmbedding(nn.Layer): """ Bond Encoder """ def __init__(self, bond_names, embed_dim): super(BondEmbedding, self).__init__() self.bond_names = bond_names self.embed_list = nn.LayerList() for name in self.bond_names: embed = nn.Embedding( CompoundKit.get_bond_feature_size(name) + 5, embed_dim, weight_attr=nn.initializer.XavierUniform()) self.embed_list.append(embed)
[docs] def forward(self, edge_features): """ Args: edge_features(dict of tensor): edge features. """ out_embed = 0 for i, name in enumerate(self.bond_names): out_embed += self.embed_list[i](edge_features[name]) return out_embed
[docs]class BondFloatRBF(nn.Layer): """ Bond Float Encoder using Radial Basis Functions """ def __init__(self, bond_float_names, embed_dim, rbf_params=None): super(BondFloatRBF, self).__init__() self.bond_float_names = bond_float_names if rbf_params is None: self.rbf_params = { 'bond_length': (np.arange(0, 2, 0.1), 10.0), # (centers, gamma) } else: self.rbf_params = rbf_params self.linear_list = nn.LayerList() self.rbf_list = nn.LayerList() for name in self.bond_float_names: centers, gamma = self.rbf_params[name] rbf = RBF(centers, gamma) self.rbf_list.append(rbf) linear = nn.Linear(len(centers), embed_dim) self.linear_list.append(linear)
[docs] def forward(self, bond_float_features): """ Args: bond_float_features(dict of tensor): bond float features. """ out_embed = 0 for i, name in enumerate(self.bond_float_names): x = bond_float_features[name] rbf_x = self.rbf_list[i](x) out_embed += self.linear_list[i](rbf_x) return out_embed
[docs]class BondAngleFloatRBF(nn.Layer): """ Bond Angle Float Encoder using Radial Basis Functions """ def __init__(self, bond_angle_float_names, embed_dim, rbf_params=None): super(BondAngleFloatRBF, self).__init__() self.bond_angle_float_names = bond_angle_float_names if rbf_params is None: self.rbf_params = { 'bond_angle': (np.arange(0, np.pi, 0.1), 10.0), # (centers, gamma) } else: self.rbf_params = rbf_params self.linear_list = nn.LayerList() self.rbf_list = nn.LayerList() for name in self.bond_angle_float_names: centers, gamma = self.rbf_params[name] rbf = RBF(centers, gamma) self.rbf_list.append(rbf) linear = nn.Linear(len(centers), embed_dim) self.linear_list.append(linear)
[docs] def forward(self, bond_angle_float_features): """ Args: bond_angle_float_features(dict of tensor): bond angle float features. """ out_embed = 0 for i, name in enumerate(self.bond_angle_float_names): x = bond_angle_float_features[name] rbf_x = self.rbf_list[i](x) out_embed += self.linear_list[i](rbf_x) return out_embed