Skip to content

Commit

Permalink
update to the latest version
Browse files Browse the repository at this point in the history
  • Loading branch information
Minqi824 committed Aug 10, 2022
1 parent d3ff989 commit 666c5bb
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 19 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ from data_generator import DataGenerator
from myutils import Utils

# one can use our already included datasets
data_generator = DataGenerator(dataset='1_abalone.npz')
data_generator = DataGenerator(dataset='6_cardio.npz')
# specify the ratio of labeled anomalies to generate X and y
# la could be any float number in [0.0, 1.0]
data = data_generator.generator(la=0.1)
Expand All @@ -113,7 +113,7 @@ result = utils.metric(y_true=data['y_test'], y_score=score)
**_Angle II: Types of Anomalies_**
```python
# For Angle II, different types of anomalies are generated as the following
data_generator = DataGenerator(dataset='1_abalone.npz')
data_generator = DataGenerator(dataset='6_cardio.npz')
# the type of anomalies could be 'local', 'global', 'dependency' or 'cluster'.
data = data_generator.generator(realistic_synthetic_mode='local')
```
Expand All @@ -122,7 +122,7 @@ data = data_generator.generator(realistic_synthetic_mode='local')
**_Angle III: Model Robustness with Noisy and Corrupted Data_**
```python
# For Angle III, different data noises and corruptions are added as the following
data_generator = DataGenerator(dataset='1_abalone.npz')
data_generator = DataGenerator(dataset='6_cardio.npz')
# the type of anomalies could be 'duplicated_anomalies', 'irrelevant_features' or 'label_contamination'.
data = data_generator.generator(noise_type='duplicated_anomalies')
```
Expand All @@ -146,7 +146,7 @@ Pretrained models are applied to extract data embedding from NLP and CV datasets
Please see the [datasets](datasets) folder and our [paper]((https://arxiv.org/abs/2206.09426)) for detailed information.

- We organize the above 57 datasets into user-friendly format. All the datasets are named as "number_data.npz" in the
[datasets](datasets) folder. For example, one can evaluate AD algorithms on the abalone dataset by the following codes.
[datasets](datasets) folder. For example, one can evaluate AD algorithms on the cardio dataset by the following codes.
For multi-class dataset like CIFAR10, additional class numbers should be specified as "number_data_class.npz".
Please see the folder for more details.

Expand All @@ -157,7 +157,7 @@ reproduce our procedures via the free GPUs. We hope this could be helpful for th

```python
import numpy as np
data = np.load('1_abalone.npz', allow_pickle=True)
data = np.load('6_cardio.npz', allow_pickle=True)
X, y = data['X'], data['y']
```

Expand Down
28 changes: 18 additions & 10 deletions data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ def __init__(self, seed:int=42, dataset:str=None, test_size:float=0.3,
self.generate_duplicates = generate_duplicates
self.n_samples_threshold = n_samples_threshold

# dataset list
self.dataset_list_classical = [os.path.splitext(_)[0] for _ in os.listdir('datasets/Classical')
if os.path.splitext(_)[1] == '.npz'] # classical AD datasets
self.dataset_list_cv = [os.path.splitext(_)[0] for _ in os.listdir('datasets/CV(by ResNet-18)')
if os.path.splitext(_)[1] == '.npz'] # CV datasets
self.dataset_list_nlp = [os.path.splitext(_)[0] for _ in os.listdir('datasets/NLP(by BERT)')
if os.path.splitext(_)[1] == '.npz'] # NLP datasets

# myutils function
self.utils = Utils()

Expand Down Expand Up @@ -210,17 +218,16 @@ def generator(self, X=None, y=None, minmax=True,
# load dataset
if self.dataset is None:
assert X is not None and y is not None, "For customized dataset, you should provide the X and y!"
# datasets from https://github.com/GuansongPang/ADRepository-Anomaly-detection-datasets/tree/main/numerical%20data/DevNet%20datasets
elif self.dataset in ['bank-additional-full_normalised', 'celeba_baldvsnonbald_normalised',
'census-income-full-mixed-binarized', 'creditcardfraud_normalised',
'KDD2014_donors_10feat_nomissing_normalised', 'UNSW_NB15_traintest_backdoor']:
data = pd.read_csv(os.path.join('datasets', self.dataset+'.csv'))
X = data.drop(['class'], axis=1).values
y = data['class'].values

minmax = False
else:
data = np.load(os.path.join('datasets', self.dataset+'.npz'), allow_pickle=True)
if self.dataset in self.dataset_list_classical:
data = np.load(os.path.join('datasets', 'Classical', self.dataset + '.npz'), allow_pickle=True)
elif self.dataset in self.dataset_list_cv:
data = np.load(os.path.join('datasets', 'CV(by ResNet-18)', self.dataset + '.npz'), allow_pickle=True)
elif self.dataset in self.dataset_list_nlp:
data = np.load(os.path.join('datasets', 'NLP(by BERT)', self.dataset + '.npz'), allow_pickle=True)
else:
raise NotImplementedError

X = data['X']
y = data['y']

Expand Down Expand Up @@ -264,6 +271,7 @@ def generator(self, X=None, y=None, minmax=True,
X = data_dependency['X']; y = data_dependency['y']

except:
# raise NotImplementedError
print(f'Generating dependency anomalies...')
X, y = self.generate_realistic_synthetic(X, y,
realistic_synthetic_mode=realistic_synthetic_mode,
Expand Down
4 changes: 2 additions & 2 deletions other_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def filter_names(name):
linewidth=linewidth)
if labels:
# text(textspace + 0.3, chei - 0.075, format(ssums[i], '.4f'), ha="right", va="center", size=10)
text(textspace + 0.7, chei - 0.075, format(avmetrics[i], '.4f'), ha="right", va="center", size=12) # by bug, show the average metric
text(textspace + 0.6, chei - 0.075, format(avmetrics[i], '.2f'), ha="right", va="center", size=12) # by bug, show the average metric
text(textspace - 0.2, chei, filter_names(nnames[i]), ha="right", va="center", size=16)

for i in range(math.ceil(k / 2), k):
Expand All @@ -222,7 +222,7 @@ def filter_names(name):
linewidth=linewidth)
if labels:
# text(textspace + scalewidth - 0.3, chei - 0.075, format(ssums[i], '.4f'), ha="left", va="center", size=10)
text(textspace + scalewidth - 0.7, chei - 0.075, format(avmetrics[i], '.4f'), ha="left", va="center", size=12) # by bug, show the average metric
text(textspace + scalewidth - 0.6, chei - 0.075, format(avmetrics[i], '.2f'), ha="left", va="center", size=12) # by bug, show the average metric
text(textspace + scalewidth + 0.2, chei, filter_names(nnames[i]),
ha="left", va="center", size=16)

Expand Down
9 changes: 7 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,13 @@ def __init__(self, suffix:str=None, mode:str='rla', parallel:str=None,
# dataset filter for delelting those datasets that do not satisfy the experimental requirement
def dataset_filter(self):
# dataset list in the current folder
dataset_list_org = [os.path.splitext(_)[0] for _ in os.listdir('datasets') if os.path.splitext(_)[1] in ['.npz', '.csv']]
# dataset_list_org = [_ for _ in dataset_list_org if not _.split('_')[0].isdigit()]
dataset_list_org = [os.path.splitext(_)[0] for _ in os.listdir('datasets/Classical')
if os.path.splitext(_)[1] == '.npz'] # classical AD datasets
dataset_list_org.extend([os.path.splitext(_)[0] for _ in os.listdir('datasets/CV(by ResNet-18)')
if os.path.splitext(_)[1] == '.npz']) # CV datasets
dataset_list_org.extend([os.path.splitext(_)[0] for _ in os.listdir('datasets/NLP(by BERT)')
if os.path.splitext(_)[1] == '.npz']) # NLP datasets


dataset_list = []
dataset_size = []
Expand Down

0 comments on commit 666c5bb

Please sign in to comment.