# Source code for pygod.nn.functional

# -*- coding: utf-8 -*-
"""Funtional Interface for PyGOD"""
# Author: Kay Liu <zliu234@uic.edu>

import torch
import torch.nn.functional as F

[docs]def double_recon_loss(x,
x_,
s,
s_,
weight=0.5,
pos_weight_a=0.5,
pos_weight_s=0.5,
bce_s=False):
r"""
Double reconstruction loss function for feature and structure.
The loss function is defined as :math:\alpha \symbf{E_a} +
(1-\alpha) \symbf{E_s}, where :math:\alpha is the weight between
0 and 1 inclusive, and :math:\symbf{E_a} and :math:\symbf{E_s}
are the reconstruction loss for feature and structure, respectively.
The first dimension is kept for outlier scores of each node.

For feature reconstruction, we use mean squared error loss:
:math:\symbf{E_a} = \|\symbf{X}-\symbf{X}'\odot H\|,
where :math:H=\begin{cases}1 - \eta &
\text{if }x_{ij}=0\\ \eta & \text{if }x_{ij}>0\end{cases}, and
:math:\eta is the positive weight for feature.

For structure reconstruction, we use mean squared error loss by
default: :math:\symbf{E_s} = \|\symbf{S}-\symbf{S}'\odot
\Theta\|, where :math:\Theta=\begin{cases}1 -
\theta & \text{if }s_{ij}=0\\ \theta & \text{if }s_{ij}>0
\end{cases}, and :math:\theta is the positive weight for
structure. Alternatively, we can use binary cross entropy loss
for structure reconstruction: :math:\symbf{E_s} =
\text{BCE}(\symbf{S}, \symbf{S}' \odot \Theta).

Parameters
----------
x : torch.Tensor
Ground truth node feature
x_ : torch.Tensor
Reconstructed node feature
s : torch.Tensor
Ground truth node structure
s_ : torch.Tensor
Reconstructed node structure
weight : float, optional
Balancing weight :math:\alpha between 0 and 1 inclusive between node feature
and graph structure. Default: 0.5.
pos_weight_a : float, optional
Positive weight for feature :math:\eta. Default: 0.5.
pos_weight_s : float, optional
Positive weight for structure :math:\theta. Default: 0.5.
bce_s : bool, optional
Use binary cross entropy for structure reconstruction loss.

Returns
-------
score : torch.tensor
Outlier scores of shape :math:N with gradients.
"""

assert 0 <= weight <= 1, "weight must be a float between 0 and 1."
assert 0 <= pos_weight_a <= 1 and 0 <= pos_weight_s <= 1, \
"positive weight must be a float between 0 and 1."

# attribute reconstruction loss
diff_attr = torch.pow(x - x_, 2)

if pos_weight_a != 0.5:
diff_attr = torch.where(x > 0,
diff_attr * pos_weight_a,
diff_attr * (1 - pos_weight_a))

attr_error = torch.sqrt(torch.sum(diff_attr, 1))

# structure reconstruction loss
if bce_s:
diff_stru = F.binary_cross_entropy(s_, s, reduction='none')
else:
diff_stru = torch.pow(s - s_, 2)

if pos_weight_s != 0.5:
diff_stru = torch.where(s > 0,
diff_stru * pos_weight_s,
diff_stru * (1 - pos_weight_s))

stru_error = torch.sqrt(torch.sum(diff_stru, 1))

score = weight * attr_error + (1 - weight) * stru_error

return score