Source code for pahelix.model_zoo.pretrain_gnns_model

#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This is an implementation of pretrain gnns:
https://arxiv.org/abs/1905.12265
"""
import numpy as np

import paddle
import paddle.nn as nn
import pgl
from pgl.nn import GraphPool

from pahelix.networks.gnn_block import GIN
from pahelix.networks.compound_encoder import AtomEmbedding, BondEmbedding
from pahelix.utils.compound_tools import CompoundKit
from pahelix.networks.gnn_block import MeanPool, GraphNorm


[docs]class PretrainGNNModel(nn.Layer): """ The basic GNN Model used in pretrain gnns. Args: model_config(dict): a dict of model configurations. """ def __init__(self, model_config={}): super(PretrainGNNModel, self).__init__() self.embed_dim = model_config.get('embed_dim', 300) self.dropout_rate = model_config.get('dropout_rate', 0.5) self.norm_type = model_config.get('norm_type', 'batch_norm') self.graph_norm = model_config.get('graph_norm', False) self.residual = model_config.get('residual', False) self.layer_num = model_config.get('layer_num', 5) self.gnn_type = model_config.get('gnn_type', 'gin') self.JK = model_config.get('JK', 'last') self.readout = model_config.get('readout', 'mean') self.atom_names = model_config['atom_names'] self.bond_names = model_config['bond_names'] self.atom_embedding = AtomEmbedding(self.atom_names, self.embed_dim) self.bond_embedding_list = nn.LayerList() self.gnn_list = nn.LayerList() self.norm_list = nn.LayerList() self.graph_norm_list = nn.LayerList() self.dropout_list = nn.LayerList() for layer_id in range(self.layer_num): self.bond_embedding_list.append(BondEmbedding(self.bond_names, self.embed_dim)) if self.gnn_type == 'gin': self.gnn_list.append(GIN(self.embed_dim)) else: raise ValueError(self.gnn_type) if self.norm_type == 'batch_norm': self.norm_list.append(nn.BatchNorm1D(self.embed_dim)) elif self.norm_type == 'layer_norm': self.norm_list.append(nn.LayerNorm(self.embed_dim)) else: raise ValueError(self.norm_type) if self.graph_norm: self.graph_norm_list.append(GraphNorm()) # TODO: pgl.nn.GraphNorm not implemented in pgl==2.1.2 self.dropout_list.append(nn.Dropout(self.dropout_rate)) # TODO: use self-implemented MeanPool due to pgl bug. if self.readout == 'mean': self.graph_pool = MeanPool() else: self.graph_pool = pgl.nn.GraphPool(pool_type=self.readout) print('[PretrainGNNModel] embed_dim:%s' % self.embed_dim) print('[PretrainGNNModel] dropout_rate:%s' % self.dropout_rate) print('[PretrainGNNModel] norm_type:%s' % self.norm_type) print('[PretrainGNNModel] graph_norm:%s' % self.graph_norm) print('[PretrainGNNModel] residual:%s' % self.residual) print('[PretrainGNNModel] layer_num:%s' % self.layer_num) print('[PretrainGNNModel] gnn_type:%s' % self.gnn_type) print('[PretrainGNNModel] JK:%s' % self.JK) print('[PretrainGNNModel] readout:%s' % self.readout) print('[PretrainGNNModel] atom_names:%s' % str(self.atom_names)) print('[PretrainGNNModel] bond_names:%s' % str(self.bond_names)) @property def node_dim(self): """the out dim of graph_repr""" return self.embed_dim @property def graph_dim(self): """the out dim of graph_repr""" return self.embed_dim
[docs] def forward(self, graph): """ Build the network. """ node_feat = self.atom_embedding(graph.node_feat) node_feat_list = [node_feat] for layer_id in range(self.layer_num): edge_features = self.bond_embedding_list[layer_id](graph.edge_feat) node_feat = self.gnn_list[layer_id]( graph, node_feat_list[layer_id], edge_features) node_feat = self.norm_list[layer_id](node_feat) if self.graph_norm: node_feat = self.graph_norm_list[layer_id](graph, node_feat) if layer_id < self.layer_num - 1: node_feat = nn.functional.relu(node_feat) node_feat = self.dropout_list[layer_id](node_feat) if self.residual: node_feat = node_feat + node_feat_list[layer_id] node_feat_list.append(node_feat) if self.JK == "sum": node_repr = paddle.sum(node_feat_list, axis=0) elif self.JK == "mean": node_repr = paddle.mean(node_feat_list, axis=0) elif self.JK == "last": node_repr = node_feat_list[-1] else: raise ValueError(self.JK) graph_repr = self.graph_pool(graph, node_repr) return node_repr, graph_repr
[docs]class AttrmaskModel(nn.Layer): """ This is a pretraning model used by pretrain gnns for attribute mask training. Returns: loss: the loss variance of the model. """ def __init__(self, model_config, compound_encoder): super(AttrmaskModel, self).__init__() self.compound_encoder = compound_encoder out_size = CompoundKit.get_atom_feature_size('atomic_num') + 3 self.linear = nn.Linear(compound_encoder.node_dim, out_size) self.criterion = nn.CrossEntropyLoss()
[docs] def forward(self, graphs, masked_node_indice, masked_node_labels): """ Build the network. """ node_repr, graph_repr = self.compound_encoder(graphs) masked_node_repr = paddle.gather(node_repr, masked_node_indice) logits = self.linear(masked_node_repr) loss = self.criterion(logits, masked_node_labels) return loss
[docs]class SupervisedModel(nn.Layer): """ This is a pretraning model used by pretrain gnns for supervised training. Returns: self.loss: the loss variance of the model. """ def __init__(self, model_config, compound_encoder): super(SupervisedModel, self).__init__() self.task_num = model_config['task_num'] self.compound_encoder = compound_encoder self.linear = nn.Linear(compound_encoder.graph_dim, self.task_num) self.criterion = nn.BCEWithLogitsLoss(reduction='none')
[docs] def forward(self, graphs, labels, valids): """ Build the network. """ node_repr, graph_repr = self.compound_encoder(graphs) logits = self.linear(graph_repr) loss = self.criterion(logits, labels) loss = paddle.sum(loss * valids) / paddle.sum(valids) return loss