Source code for pygod.nn.gae

import math
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import MLP, GCN
from torch_geometric.utils import to_dense_adj

from ..nn.decoder import DotProductDecoder


[docs]class GAEBase(nn.Module): """ Graph Autoencoder See :cite:`kipf2016variational` for details. Parameters ---------- in_dim : int Input dimension of model. hid_dim : int Hidden dimension of model. Default: ``64``. num_layers : int, optional Total number of layers in model. Default: ``4``. dropout : float, optional Dropout rate. Default: ``0.``. act : callable activation function or None, optional Activation function if not None. Default: ``torch.nn.functional.relu``. backbone : torch.nn.Module, optional The backbone of the deep detector implemented in PyG. Default: ``torch_geometric.nn.GCN``. recon_s : bool, optional Reconstruct the structure instead of node feature . Default: ``False``. sigmoid_s : bool, optional Whether to use sigmoid function to scale the reconstructed structure. Default: ``False``. **kwargs : optional Other parameters for the backbone. """ def __init__(self, in_dim, hid_dim=64, num_layers=4, dropout=0., act=F.relu, backbone=GCN, recon_s=False, sigmoid_s=False, **kwargs): super(GAEBase, self).__init__() self.backbone = backbone # split the number of layers for the encoder and decoders assert num_layers >= 2, \ "Number of layers must be greater than or equal to 2." encoder_layers = math.floor(num_layers / 2) decoder_layers = math.ceil(num_layers / 2) self.encoder = self.backbone(in_channels=in_dim, hidden_channels=hid_dim, out_channels=hid_dim, num_layers=encoder_layers, dropout=dropout, act=act, **kwargs) self.recon_s = recon_s if self.recon_s: self.decoder = DotProductDecoder(in_dim=hid_dim, hid_dim=hid_dim, num_layers=decoder_layers, dropout=dropout, act=act, sigmoid_s=sigmoid_s, backbone=self.backbone, **kwargs) else: self.decoder = self.backbone(in_channels=hid_dim, hidden_channels=hid_dim, out_channels=in_dim, num_layers=decoder_layers, dropout=dropout, act=act, **kwargs) self.loss_func = F.mse_loss self.emb = None
[docs] def forward(self, x, edge_index): """ Forward computation. Parameters ---------- x : torch.Tensor Input attribute embeddings. edge_index : torch.Tensor Edge index. Returns ------- x_ : torch.Tensor Reconstructed embeddings. """ if self.backbone == MLP: self.emb = self.encoder(x, None) x_ = self.decoder(self.emb, None) else: self.emb = self.encoder(x, edge_index) x_ = self.decoder(self.emb, edge_index) return x_
[docs] @staticmethod def process_graph(data, recon_s=False): """ Obtain the dense adjacency matrix of the graph. Parameters ---------- data : torch_geometric.data.Data Input graph. recon_s : bool, optional Reconstruct the structure instead of node feature . """ if recon_s: data.s = to_dense_adj(data.edge_index)[0]