Source code for pygod.nn.encoder
# -*- coding: utf-8 -*-
"""Graph Neural Networks Encoders"""
# Author: Kay Liu <zliu234@uic.edu>
# License: BSD 2 clause
import torch
import torch.nn.functional as F
from .conv import GNAConv
[docs]class GNA(torch.nn.Module):
"""
Graph Node Attention Network (GNA). See :cite:`yuan2021higher` for
more details.
"""
def __init__(self,
in_channels,
hidden_channels,
num_layers,
out_channels,
dropout,
act):
super().__init__()
self.layers = torch.nn.ModuleList()
self.layers.append(GNAConv(in_channels, hidden_channels))
for layer in range(num_layers - 2):
self.layers.append(GNAConv(hidden_channels,
hidden_channels))
self.layers.append(GNAConv(hidden_channels, out_channels))
self.dropout = dropout
self.act = act
[docs] def forward(self, s, edge_index):
"""
Forward computation.
Parameters
----------
s : torch.Tensor
Input node embeddings.
edge_index : torch.Tensor
Edge index.
Returns
-------
s : torch.Tensor
Updated node embeddings.
"""
for layer in self.layers:
s = layer(s, edge_index)
s = F.dropout(s, self.dropout, training=self.training)
if self.act is not None:
s = self.act(s)
return s