# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Basic Encoder for compound atom/bond features.
"""
import numpy as np
import paddle
import paddle.nn as nn
import pgl
from pahelix.utils.compound_tools import CompoundKit
from pahelix.networks.basic_block import RBF
[docs]class AtomEmbedding(nn.Layer):
"""
Atom Encoder
"""
def __init__(self, atom_names, embed_dim):
super(AtomEmbedding, self).__init__()
self.atom_names = atom_names
self.embed_list = nn.LayerList()
for name in self.atom_names:
embed = nn.Embedding(
CompoundKit.get_atom_feature_size(name) + 5,
embed_dim,
weight_attr=nn.initializer.XavierUniform())
self.embed_list.append(embed)
[docs] def forward(self, node_features):
"""
Args:
node_features(dict of tensor): node features.
"""
out_embed = 0
for i, name in enumerate(self.atom_names):
out_embed += self.embed_list[i](node_features[name])
return out_embed
[docs]class AtomFloatEmbedding(nn.Layer):
"""
Atom Float Encoder
"""
def __init__(self, atom_float_names, embed_dim, rbf_params=None):
super(AtomFloatEmbedding, self).__init__()
self.atom_float_names = atom_float_names
if rbf_params is None:
self.rbf_params = {
'van_der_waals_radis': (np.arange(1, 3, 0.2), 10.0), # (centers, gamma)
'partial_charge': (np.arange(-1, 4, 0.25), 10.0), # (centers, gamma)
'mass': (np.arange(0, 2, 0.1), 10.0), # (centers, gamma)
}
else:
self.rbf_params = rbf_params
self.linear_list = nn.LayerList()
self.rbf_list = nn.LayerList()
for name in self.atom_float_names:
centers, gamma = self.rbf_params[name]
rbf = RBF(centers, gamma)
self.rbf_list.append(rbf)
linear = nn.Linear(len(centers), embed_dim)
self.linear_list.append(linear)
[docs] def forward(self, feats):
"""
Args:
feats(dict of tensor): node float features.
"""
out_embed = 0
for i, name in enumerate(self.atom_float_names):
x = feats[name]
rbf_x = self.rbf_list[i](x)
out_embed += self.linear_list[i](rbf_x)
return out_embed
[docs]class BondEmbedding(nn.Layer):
"""
Bond Encoder
"""
def __init__(self, bond_names, embed_dim):
super(BondEmbedding, self).__init__()
self.bond_names = bond_names
self.embed_list = nn.LayerList()
for name in self.bond_names:
embed = nn.Embedding(
CompoundKit.get_bond_feature_size(name) + 5,
embed_dim,
weight_attr=nn.initializer.XavierUniform())
self.embed_list.append(embed)
[docs] def forward(self, edge_features):
"""
Args:
edge_features(dict of tensor): edge features.
"""
out_embed = 0
for i, name in enumerate(self.bond_names):
out_embed += self.embed_list[i](edge_features[name])
return out_embed
[docs]class BondFloatRBF(nn.Layer):
"""
Bond Float Encoder using Radial Basis Functions
"""
def __init__(self, bond_float_names, embed_dim, rbf_params=None):
super(BondFloatRBF, self).__init__()
self.bond_float_names = bond_float_names
if rbf_params is None:
self.rbf_params = {
'bond_length': (np.arange(0, 2, 0.1), 10.0), # (centers, gamma)
}
else:
self.rbf_params = rbf_params
self.linear_list = nn.LayerList()
self.rbf_list = nn.LayerList()
for name in self.bond_float_names:
centers, gamma = self.rbf_params[name]
rbf = RBF(centers, gamma)
self.rbf_list.append(rbf)
linear = nn.Linear(len(centers), embed_dim)
self.linear_list.append(linear)
[docs] def forward(self, bond_float_features):
"""
Args:
bond_float_features(dict of tensor): bond float features.
"""
out_embed = 0
for i, name in enumerate(self.bond_float_names):
x = bond_float_features[name]
rbf_x = self.rbf_list[i](x)
out_embed += self.linear_list[i](rbf_x)
return out_embed
[docs]class BondAngleFloatRBF(nn.Layer):
"""
Bond Angle Float Encoder using Radial Basis Functions
"""
def __init__(self, bond_angle_float_names, embed_dim, rbf_params=None):
super(BondAngleFloatRBF, self).__init__()
self.bond_angle_float_names = bond_angle_float_names
if rbf_params is None:
self.rbf_params = {
'bond_angle': (np.arange(0, np.pi, 0.1), 10.0), # (centers, gamma)
}
else:
self.rbf_params = rbf_params
self.linear_list = nn.LayerList()
self.rbf_list = nn.LayerList()
for name in self.bond_angle_float_names:
centers, gamma = self.rbf_params[name]
rbf = RBF(centers, gamma)
self.rbf_list.append(rbf)
linear = nn.Linear(len(centers), embed_dim)
self.linear_list.append(linear)
[docs] def forward(self, bond_angle_float_features):
"""
Args:
bond_angle_float_features(dict of tensor): bond angle float features.
"""
out_embed = 0
for i, name in enumerate(self.bond_angle_float_names):
x = bond_angle_float_features[name]
rbf_x = self.rbf_list[i](x)
out_embed += self.linear_list[i](rbf_x)
return out_embed