# -*- coding: utf-8 -*-
"""Functional Interface for PyGOD"""
# Author: Kay Liu <zliu234@uic.edu>
# License: BSD 2 clause
import torch
import torch.nn.functional as F
[docs]def double_recon_loss(x,
x_,
s,
s_,
weight=0.5,
pos_weight_a=0.5,
pos_weight_s=0.5,
bce_s=False):
r"""
Double reconstruction loss function for feature and structure.
The loss function is defined as :math:`\alpha \symbf{E_a} +
(1-\alpha) \symbf{E_s}`, where :math:`\alpha` is the weight between
0 and 1 inclusive, and :math:`\symbf{E_a}` and :math:`\symbf{E_s}`
are the reconstruction loss for feature and structure, respectively.
The first dimension is kept for outlier scores of each node.
For feature reconstruction, we use mean squared error loss:
:math:`\symbf{E_a} = \|\symbf{X}-\symbf{X}'\|\odot H`,
where :math:`H=\begin{cases}1 - \eta &
\text{if }x_{ij}=0\\ \eta & \text{if }x_{ij}>0\end{cases}`, and
:math:`\eta` is the positive weight for feature.
For structure reconstruction, we use mean squared error loss by
default: :math:`\symbf{E_s} = \|\symbf{S}-\symbf{S}'\|\odot
\Theta`, where :math:`\Theta=\begin{cases}1 -
\theta & \text{if }s_{ij}=0\\ \theta & \text{if }s_{ij}>0
\end{cases}`, and :math:`\theta` is the positive weight for
structure. Alternatively, we can use binary cross entropy loss
for structure reconstruction: :math:`\symbf{E_s} =
\text{BCE}(\symbf{S}, \symbf{S}' \odot \Theta)`.
Parameters
----------
x : torch.Tensor
Ground truth node feature
x_ : torch.Tensor
Reconstructed node feature
s : torch.Tensor
Ground truth node structure
s_ : torch.Tensor
Reconstructed node structure
weight : float, optional
Balancing weight :math:`\alpha` between 0 and 1 inclusive between node feature
and graph structure. Default: ``0.5``.
pos_weight_a : float, optional
Positive weight for feature :math:`\eta`. Default: ``0.5``.
pos_weight_s : float, optional
Positive weight for structure :math:`\theta`. Default: ``0.5``.
bce_s : bool, optional
Use binary cross entropy for structure reconstruction loss.
Returns
-------
score : torch.tensor
Outlier scores of shape :math:`N` with gradients.
"""
assert 0 <= weight <= 1, "weight must be a float between 0 and 1."
assert 0 <= pos_weight_a <= 1 and 0 <= pos_weight_s <= 1, \
"positive weight must be a float between 0 and 1."
# attribute reconstruction loss
diff_attr = torch.pow(x - x_, 2)
if pos_weight_a != 0.5:
diff_attr = torch.where(x > 0,
diff_attr * pos_weight_a,
diff_attr * (1 - pos_weight_a))
attr_error = torch.sqrt(torch.sum(diff_attr, 1))
# structure reconstruction loss
if bce_s:
diff_stru = F.binary_cross_entropy(s_, s, reduction='none')
else:
diff_stru = torch.pow(s - s_, 2)
if pos_weight_s != 0.5:
diff_stru = torch.where(s > 0,
diff_stru * pos_weight_s,
diff_stru * (1 - pos_weight_s))
stru_error = torch.sqrt(torch.sum(diff_stru, 1))
score = weight * attr_error + (1 - weight) * stru_error
return score