# -*- coding: utf-8 -*-
"""A set of utility functions to support outlier detection.
"""
# Author: Yue Zhao <zhaoy@cmu.edu>
# License: BSD 2 clause
import os
import torch
import shutil
import numbers
import requests
import warnings
import numpy as np
from importlib import import_module
from ..metric import *
MAX_INT = np.iinfo(np.int32).max
MIN_INT = np.iinfo(np.int32).min
def validate_device(gpu_id):
"""Validate the input GPU ID is valid on the given environment.
If no GPU is presented, return 'cpu'.
Parameters
----------
gpu_id : int
GPU ID to check.
Returns
-------
device : str
Valid device, e.g., 'cuda:0' or 'cpu'.
"""
# cast to int for checking
gpu_id = int(gpu_id)
# if it is cpu
if gpu_id == -1:
return 'cpu'
# if gpu is available
if torch.cuda.is_available():
# check if gpu id is between 0 and the total number of GPUs
check_parameter(gpu_id, 0, torch.cuda.device_count(),
param_name='gpu id', include_left=True,
include_right=False)
device = 'cuda:{}'.format(gpu_id)
else:
if gpu_id != 'cpu':
warnings.warn('The cuda is not available. Set to cpu.')
device = 'cpu'
return device
def check_parameter(param, low=MIN_INT, high=MAX_INT, param_name='',
include_left=False, include_right=False):
"""Check if an input is within the defined range.
Parameters
----------
param : int, float
The input parameter to check.
low : int, float
The lower bound of the range.
high : int, float
The higher bound of the range.
param_name : str, optional (default='')
The name of the parameter.
include_left : bool, optional (default=False)
Whether includes the lower bound (lower bound <=).
include_right : bool, optional (default=False)
Whether includes the higher bound (<= higher bound).
Returns
-------
within_range : bool or raise errors
Whether the parameter is within the range of (low, high)
"""
# param, low and high should all be numerical
if not isinstance(param, (numbers.Integral, int, float)):
raise TypeError('{param_name} is set to {param} Not numerical'.format(
param=param, param_name=param_name))
if not isinstance(low, (numbers.Integral, int, float)):
raise TypeError('low is set to {low}. Not numerical'.format(low=low))
if not isinstance(high, (numbers.Integral, int, float)):
raise TypeError('high is set to {high}. Not numerical'.format(
high=high))
# at least one of the bounds should be specified
if low is MIN_INT and high is MAX_INT:
raise ValueError('Neither low nor high bounds is undefined')
# if wrong bound values are used
if low > high:
raise ValueError(
'Lower bound > Higher bound')
# value check under different bound conditions
if (include_left and include_right) and (param < low or param > high):
raise ValueError(
'{param_name} is set to {param}. '
'Not in the range of [{low}, {high}].'.format(
param=param, low=low, high=high, param_name=param_name))
elif (include_left and not include_right) and (
param < low or param >= high):
raise ValueError(
'{param_name} is set to {param}. '
'Not in the range of [{low}, {high}).'.format(
param=param, low=low, high=high, param_name=param_name))
elif (not include_left and include_right) and (
param <= low or param > high):
raise ValueError(
'{param_name} is set to {param}. '
'Not in the range of ({low}, {high}].'.format(
param=param, low=low, high=high, param_name=param_name))
elif (not include_left and not include_right) and (
param <= low or param >= high):
raise ValueError(
'{param_name} is set to {param}. '
'Not in the range of ({low}, {high}).'.format(
param=param, low=low, high=high, param_name=param_name))
else:
return True
[docs]
def load_data(name, cache_dir=None):
"""
Data loading function. See `data repository
<https://github.com/pygod-team/data>`_ for supported datasets.
For injected/generated datasets, the labels meanings are as follows.
- 0: inlier
- 1: contextual outlier only
- 2: structural outlier only
- 3: both contextual outlier and structural outlier
Parameters
----------
name : str
The name of the dataset.
cache_dir : str, optional
The directory for dataset caching.
Default: ``None``.
Returns
-------
data : torch_geometric.data.Data
The outlier dataset.
Examples
--------
>>> from pygod.utils import load_data
>>> data = load_data(name='weibo') # in PyG format
>>> y = data.y.bool() # binary labels (inlier/outlier)
>>> yc = data.y >> 0 & 1 # contextual outliers
>>> ys = data.y >> 1 & 1 # structural outliers
"""
if cache_dir is None:
cache_dir = os.path.join(os.path.expanduser('~'), '.pygod/data')
file_path = os.path.join(cache_dir, name+'.pt')
zip_path = os.path.join(cache_dir, name+'.pt.zip')
if os.path.exists(file_path):
data = torch.load(file_path)
else:
url = "https://github.com/pygod-team/data/raw/main/" + name + ".pt.zip"
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
r = requests.get(url, stream=True)
if r.status_code != 200:
raise RuntimeError("Failed downloading url %s" % url)
with open(zip_path, 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
shutil.unpack_archive(zip_path, cache_dir)
data = torch.load(file_path)
return data
def logger(epoch=0,
loss=0,
score=None,
target=None,
time=None,
verbose=0,
train=True,
deep=True):
"""
Logger for detector.
Parameters
----------
epoch : int, optional
The current epoch.
loss : float, optional
The current epoch loss value.
score : torch.Tensor, optional
The current outlier scores.
target : torch.Tensor, optional
The ground truth labels.
time : float, optional
The current epoch time.
verbose : int, optional
Verbosity mode. Range in [0, 3]. Larger value for printing out
more log information. Default: ``0``.
train : bool, optional
Whether the logger is used for training.
deep : bool, optional
Whether the logger is used for deep detector.
"""
if verbose > 0:
if deep:
if train:
print("Epoch {:04d}: ".format(epoch), end='')
else:
print("Test: ", end='')
if isinstance(loss, tuple):
print("Loss I {:.4f} | Loss O {:.4f} | "
.format(loss[0], loss[1]), end='')
else:
print("Loss {:.4f} | ".format(loss), end='')
if verbose > 1:
if target is not None:
auc = eval_roc_auc(target, score)
print("AUC {:.4f}".format(auc), end='')
if verbose > 2:
if target is not None:
pos_size = target.nonzero().size(0)
rec = eval_recall_at_k(target, score, pos_size)
pre = eval_precision_at_k(target, score, pos_size)
ap = eval_average_precision(target, score)
contamination = sum(target) / len(target)
threshold = np.percentile(score,
100 * (1 - contamination))
pred = (score > threshold).long()
f1 = eval_f1(target, pred)
print(" | Recall {:.4f} | Precision {:.4f} "
"| AP {:.4f} | F1 {:.4f}"
.format(rec, pre, ap, f1), end='')
if time is not None:
print(" | Time {:.2f}".format(time), end='')
print()
def init_detector(name, **kwargs):
"""
Detector initialization function.
"""
module = import_module('pygod.detector')
assert name in module.__all__, "Detector {} not found".format(name)
return getattr(module, name)(**kwargs)
def init_nn(name, **kwargs):
"""
Neural network initialization function.
"""
module = import_module('pygod.nn')
assert name in module.__all__, "Neural network {} not found".format(name)
return getattr(module, name)(**kwargs)
def pprint(params, offset=0, printer=repr):
"""Pretty print the dictionary 'params'
Parameters
----------
params : dict
The dictionary to pretty print
offset : int, optional
The offset at the beginning of each line.
printer : callable, optional
The function to convert entries to strings, typically
the builtin str or repr.
"""
params_list = list()
this_line_length = offset
line_sep = ',\n' + (1 + offset) * ' '
for i, (k, v) in enumerate(sorted(params.items())):
if type(v) is float:
# use str for representing floating point numbers
# this way we get consistent representation across
# architectures and versions.
this_repr = '%s=%s' % (k, str(v))
else:
# use repr of the rest
this_repr = '%s=%s' % (k, printer(v))
if len(this_repr) > 500:
this_repr = this_repr[:300] + '...' + this_repr[-100:]
if i > 0:
if this_line_length + len(this_repr) >= 75 or '\n' in this_repr:
params_list.append(line_sep)
this_line_length = len(line_sep)
else:
params_list.append(', ')
this_line_length += 2
params_list.append(this_repr)
this_line_length += len(this_repr)
lines = ''.join(params_list)
# Strip trailing space to avoid nightmare in doctests
lines = '\n'.join(l.rstrip(' ') for l in lines.split('\n'))
return lines
def is_fitted(detector, attributes=None):
"""
Check if the detector is fitted.
Parameters
----------
detector : pygod.detector.Detector
The detector to check.
attributes : list, optional
The attributes to check.
Default: ``None``.
Returns
-------
is_fitted : bool
Whether the detector is fitted.
"""
if attributes is None:
attributes = ['model']
assert all(hasattr(detector, attr) and
eval('detector.%s' % attr) is not None
for attr in attributes), \
"The detector is not fitted yet"