Source code for pygod.nn.card

import torch
import torch.nn as nn
from torch.nn.functional import binary_cross_entropy_with_logits, normalize

from torch_geometric.nn import GCN
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import to_dense_adj, to_undirected
from torch_geometric.transforms import GDC
from .functional import double_recon_loss


[docs] class CARDBase(nn.Module): """ Community-Guided Contrastive Learning with Anomaly-Aware Reconstruction for Anomaly Detection on Attributed Networks. CARD is a contrastive learning based method and utilizes mask reconstruction and community information to make anomalies more distinct. This model is train with contrastive loss and local and global attribute reconstruction loss. Random neighbor sampling instead of random walk sampling is used to sample the subgraph corresponding to each node. Since random neighbor sampling cannot accurately control the number of neighbors for each sampling, it may run slower compared to the method implementation in the original paper. See:cite:`Wang2024Card` for details. Parameters ---------- in_dim : int Input dimension of model. subgraph_num_neigh: int, optional Number of neighbors in subgraph sampling for each node, Values not exceeding 4 are recommended for efficiency. Default: ``4``. fp: float, optional The balance parameter between the mask autoencoder module and contrastive learning. Default: ``0.6`` gama: float, optional The proportion of the local reconstruction in contrastive learning module. Default: ``0.5`` alpha: float, optional The proprotion of the community embedding in the conbine_encoder. Default: ``0.1`` hid_dim : int, optional Hidden dimension of model. Default: ``64``. num_layers : int, optional Total number of layers in model. Default: ``4``. dropout : float, optional Dropout rate. Default: ``0.``. act : callable activation function or None, optional Activation function if not None. Default: ``torch.nn.functional.relu``. backbone : torch.nn.Module The backbone of the deep detector implemented in PyG. Default: ``torch_geometric.nn.GCN``. **kwargs Other parameters for the backbone. """ def __init__(self, in_dim, subgraph_num_neigh=4, fp=0.6, gama=0.4, alpha=0.1, hid_dim=64, num_layers=4, dropout=0., act=torch.nn.functional.relu, backbone=GCN, **kwargs): super(CARDBase, self).__init__() self.alpha = alpha self.hid_dim = hid_dim self.num_layers = num_layers self.subgraph_num_neigh = subgraph_num_neigh self.fp = fp self.gama = gama # subgraph encoder self.encoder = backbone(in_channels=in_dim, hidden_channels=hid_dim, num_layers=num_layers, out_channels=hid_dim, dropout=dropout, act=act, **kwargs) self.global_feat_encoder = backbone(in_channels=in_dim, hidden_channels=hid_dim, num_layers=num_layers, out_channels=hid_dim, dropout=dropout, act=act, **kwargs) self.global_feat_decoder = backbone(in_channels=hid_dim, hidden_channels=hid_dim, num_layers=num_layers, out_channels=in_dim, dropout=dropout, act=act, **kwargs) self.local_feat_decoder = nn.Sequential( nn.Linear(hid_dim, hid_dim), nn.PReLU(), nn.Linear(hid_dim, hid_dim), nn.PReLU(), nn.Linear(hid_dim, in_dim), nn.PReLU() ) self.community_encoder = nn.Sequential( nn.Linear(self.subgraph_num_neigh, int(hid_dim / 2)), nn.PReLU(), nn.Linear(int(hid_dim / 2), hid_dim), nn.PReLU(), ) self.combine_encoder = backbone(in_channels=hid_dim, hidden_channels=hid_dim, num_layers=num_layers, out_channels=hid_dim, dropout=dropout, act=act, **kwargs) self.discriminator = nn.Bilinear(hid_dim, hid_dim, 1) self.disc_loss_func = binary_cross_entropy_with_logits self.emb = None self.disc_emb = None
[docs] def forward(self, data): """ Forward computation. Parameters ---------- Data: torch_geometric.data.Data Input graph. Returns ------- logits : torch.Tensor Discriminator logits of positive examples. neg_logits : torch.Tensor Discriminator logits of negative examples. x_: torch.Tensor feature reconstract matrix local_x_: torch.Tensor subgraph feature reconstract matrix """ self.emb, self.disc_emb = self._train_subgraph_network(data) x = data.x logits = self.discriminator(self.disc_emb, self.emb) perm_idx = torch.randperm( self.disc_emb.shape[0]).to(self.disc_emb.device) neg_logits = self.discriminator( self.disc_emb[perm_idx], self.emb) local_x_ = self.local_feat_decoder(self.emb) attr_emb = self.global_feat_encoder(x, data.edge_index) x_ = self.global_feat_decoder(attr_emb, data.edge_index) return logits.squeeze(), neg_logits.squeeze(), x_, local_x_
[docs] def loss_func(self, logits, diff_logits, x_, local_x_, x, con_label): """ The loss function proposed in the CARD paper. This implementation ignores the KL-loss as it contributes little to the accuracy. Parameters ---------- logits : torch.Tensor Discriminator logits of positive subgraphs batch. diff_logits : torch.Tensor Discriminator logits of negative subgraphs batch. x_ : torch.Tensor Global reconstructed attribute embeddings. local_x_ : torch.Tensor Local reconstructed attribute embeddings. x : torch.Tensor Input attribute embeddings. con_label : torch.Tensor Contrastive learning pseudo label Returns ------- final_loss: torch.Tensor The total loss value used to backpropagate and update the model parameters. score: torch.Tensor The anomaly score for each node. """ ori_loss = self.disc_loss_func(logits, con_label) diff_loss = self.disc_loss_func(diff_logits, con_label) logit_loss = (ori_loss + diff_loss) / 2 batch_size = int(logits.shape[0] / 2) h_1 = normalize(logits[:batch_size], dim=0, p=2) h_2 = normalize(diff_logits[:batch_size], dim=0, p=2) inter_logit_loss = 2 - 2 * (h_1 * h_2).sum(dim=-1).mean() rec_loss = double_recon_loss(x, x_, x, x, 1) local_rec_loss = double_recon_loss(x, local_x_, x, x, 1) constra_loss = torch.mean(logit_loss) + \ inter_logit_loss + self.gama * torch.mean(local_rec_loss) final_loss = (1 - self.fp) * constra_loss + \ self.fp * torch.mean(rec_loss) # + 0.5 * kl constra_score = ((logits[batch_size:] - logits[:batch_size]) + (diff_logits[batch_size:] - diff_logits[:batch_size])) / 2 score = (1 - self.fp) * (constra_score + self.gama * local_rec_loss[:batch_size]) + self.fp * rec_loss[:batch_size] return final_loss, score
def _train_subgraph_network(self, data): """ Train the model subgraph encoder and community-guided module with each node and its corresponding subgraph as input. Parameters ---------- data : torch_geometric.data.Data Input graph. Returns ------- res_emb: torch.Tensor Subgraph embedding readout. disc_emb: torch.Tensor Target node embedding. """ res_emb = [] disc_emb = [] for index in range(data.num_nodes): subgraphs = NeighborLoader( data, num_neighbors=[self.subgraph_num_neigh] * self.num_layers) subgraph = subgraphs([index]) community_idx = [] i = 0 while len(community_idx) < self.subgraph_num_neigh: community_idx.append(subgraph.n_id[i]) i = (i + 1) % len(subgraph.n_id) community_adj = subgraph.community_adj[:, community_idx] subgraph.x[0, :] = 0 x = subgraph.x edge_index = subgraph.edge_index ori_emb = self.encoder(x, edge_index) community_emb = self.community_encoder(community_adj) combine_emb = self.combine_encoder( ori_emb + self.alpha * community_emb, edge_index) # avoid nan problem if combine_emb.shape[0] > 1: res_emb.append(torch.mean(combine_emb[1:, :], 0)) else: res_emb.append(combine_emb[0, :]) disc_emb.append(combine_emb[0, :]) return torch.stack(res_emb), torch.stack(disc_emb)
[docs] @staticmethod def process_graph(data): """ Obtain the community structure matrix and the diffusion graph data. Parameters ---------- data: torch_geometric.data.Data Input graph. Returns ------- community_adj: torch.Tensor Community structure matrix, corresponding to the B matrix in the paper. diff_data: torch_geometric.data.Data Diffusion graph Data """ # only support undirected graph if not data.is_undirected(): data.edge_index = to_undirected(data.edge_index) data.s = to_dense_adj(data.edge_index)[0] k1 = torch.sum(data.s, axis=1) k2 = k1.reshape(data.num_nodes, 1) e = k1 * k2 / (2 * data.num_edges) community_adj = (data.s - e).clone().detach() transform = GDC( self_loop_weight=1, normalization_in='sym', normalization_out='col', diffusion_kwargs=dict(method='ppr', alpha=0.01, eps=0.0001), sparsification_kwargs=dict(method='topk', k=128, dim=0), exact=True) diff_data = transform(data) return community_adj, diff_data