# -*- coding: utf-8 -*-
"""A set of utility functions to support outlier detection.
"""
# Author: Yue Zhao <zhaoy@cmu.edu>
# License: BSD 2 clause
from __future__ import division
from __future__ import print_function
import os
import requests
import warnings
import torch
import numpy as np
import numbers
import shutil
MAX_INT = np.iinfo(np.int32).max
MIN_INT = -1 * MAX_INT
def validate_device(gpu_id):
"""Validate the input device id (GPU id) is valid on the given
machine. If no GPU is presented, return 'cpu'.
Parameters
----------
gpu_id : int
GPU id to be used. The function will validate the usability
of the GPU. If failed, return device as 'cpu'.
Returns
-------
device_id : str
Valid device id, e.g., 'cuda:0' or 'cpu'
"""
# if it is cpu
if gpu_id == -1:
return 'cpu'
# cast to int for checking
gpu_id = int(gpu_id)
# if gpu is available
if torch.cuda.is_available():
# check if gpu id is between 0 and the total number of GPUs
check_parameter(gpu_id, 0, torch.cuda.device_count(), param_name='gpu id', include_left=True,
include_right=False)
device_id = 'cuda:{}'.format(gpu_id)
else:
if gpu_id != 'cpu':
warnings.warn('The cuda is not available. Set to cpu.')
device_id = 'cpu'
return device_id
def check_parameter(param, low=MIN_INT, high=MAX_INT, param_name='',
include_left=False, include_right=False):
"""Check if an input is within the defined range.
Parameters
----------
param : int, float
The input parameter to check.
low : int, float
The lower bound of the range.
high : int, float
The higher bound of the range.
param_name : str, optional (default='')
The name of the parameter.
include_left : bool, optional (default=False)
Whether includes the lower bound (lower bound <=).
include_right : bool, optional (default=False)
Whether includes the higher bound (<= higher bound).
Returns
-------
within_range : bool or raise errors
Whether the parameter is within the range of (low, high)
"""
# param, low and high should all be numerical
if not isinstance(param, (numbers.Integral, np.integer, float)):
raise TypeError('{param_name} is set to {param} Not numerical'.format(
param=param, param_name=param_name))
if not isinstance(low, (numbers.Integral, np.integer, float)):
raise TypeError('low is set to {low}. Not numerical'.format(low=low))
if not isinstance(high, (numbers.Integral, np.integer, float)):
raise TypeError('high is set to {high}. Not numerical'.format(
high=high))
# at least one of the bounds should be specified
if low is MIN_INT and high is MAX_INT:
raise ValueError('Neither low nor high bounds is undefined')
# if wrong bound values are used
if low > high:
raise ValueError(
'Lower bound > Higher bound')
# value check under different bound conditions
if (include_left and include_right) and (param < low or param > high):
raise ValueError(
'{param_name} is set to {param}. '
'Not in the range of [{low}, {high}].'.format(
param=param, low=low, high=high, param_name=param_name))
elif (include_left and not include_right) and (
param < low or param >= high):
raise ValueError(
'{param_name} is set to {param}. '
'Not in the range of [{low}, {high}).'.format(
param=param, low=low, high=high, param_name=param_name))
elif (not include_left and include_right) and (
param <= low or param > high):
raise ValueError(
'{param_name} is set to {param}. '
'Not in the range of ({low}, {high}].'.format(
param=param, low=low, high=high, param_name=param_name))
elif (not include_left and not include_right) and (
param <= low or param >= high):
raise ValueError(
'{param_name} is set to {param}. '
'Not in the range of ({low}, {high}).'.format(
param=param, low=low, high=high, param_name=param_name))
else:
return True
[docs]def load_data(name, cache_dir=None):
"""
Data loading function. See `data repository
<https://github.com/pygod-team/data>`_ for supported datasets.
For injected/generated datasets, the labels meanings are as follows.
- 0: inlier
- 1: contextual outlier only
- 2: structural outlier only
- 3: both contextual outlier and structural outlier
Parameters
----------
name : str
The name of the dataset.
cache_dir : str, optional
The directory for dataset caching.
Default: ``None``.
Returns
-------
data : torch_geometric.data.Data
The outlier dataset.
Examples
--------
>>> from pygod.utils import load_data
>>> data = load_data(name='weibo') # in PyG format
>>> y = data.y.bool() # binary labels (inlier/outlier)
>>> yc = data.y >> 0 & 1 # contextual outliers
>>> ys = data.y >> 1 & 1 # structural outliers
"""
if cache_dir is None:
cache_dir = os.path.join(os.path.expanduser('~'), '.pygod/data')
file_path = os.path.join(cache_dir, name+'.pt')
zip_path = os.path.join(cache_dir, name+'.pt.zip')
if os.path.exists(file_path):
data = torch.load(file_path)
else:
url = "https://github.com/pygod-team/data/raw/main/" + name + ".pt.zip"
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
r = requests.get(url, stream=True)
if r.status_code != 200:
raise RuntimeError("Failed downloading url %s" % url)
with open(zip_path, 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
shutil.unpack_archive(zip_path, cache_dir)
data = torch.load(file_path)
return data