diff --git a/scripts/gen_train_eval_nopretrained.ps1 b/scripts/gen_train_eval_nopretrained.ps1 new file mode 100644 index 0000000..33f2c6c --- /dev/null +++ b/scripts/gen_train_eval_nopretrained.ps1 @@ -0,0 +1,23 @@ +# Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below. +# This script is for networks that do NOT come with a pretrained checkpoint provided either by a model zoo or by the NBDT utility itself. + +$model="ResNet18" +$dataset="CIFAR10" +$weight=1 + +# 0. train the baseline neural network +python main.py --dataset=$dataset --arch=$model + +# 1. generate hieararchy -- for models without a pretrained checkpoint, use `checkpoint` +nbdt-hierarchy --dataset=$dataset --checkpoint="./checkpoint/ckpt-$dataset" + "-$model.pth" + +# 2. train with soft tree supervision loss -- for models without a pretrained checkpoint, use `path-resume` OR just train from scratch, without `path-resume` +# python main.py --lr=0.01 --dataset=${dataset} --model=${model} --hierarchy=induced-${model} --path-resume=./checkpoint/ckpt-${dataset}-${model}.pth --loss=SoftTreeSupLoss --tree-supervision-weight=${weight} # fine-tuning +python main.py --dataset=$dataset --arch=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --tree-supervision-weight=$weight # training from scratch + +# 3. evaluate with soft then hard inference +$analysisRules = @("SoftEmbeddedDecisionRules", "HardEmbeddedDecisionRules") + +foreach ($analysis in $analysisRules) { + python main.py --dataset=$dataset --model=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --eval --resume --analysis=$analysis --tree-supervision-weight=$weight +} diff --git a/scripts/gen_train_eval_pretrained.ps1 b/scripts/gen_train_eval_pretrained.ps1 new file mode 100644 index 0000000..b242504 --- /dev/null +++ b/scripts/gen_train_eval_pretrained.ps1 @@ -0,0 +1,19 @@ +# Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below. +# This script is for networks that DO come with a pretrained checkpoint provided either by a model zoo or by the NBDT utility itself. + +$model="wrn28_10_cifar10" +$dataset="CIFAR10" +$weight=1 + +# 1. generate hieararchy +nbdt-hierarchy --dataset=$dataset --arch=$model + +# 2. train with soft tree supervision loss +python main.py --lr=0.01 --dataset=$dataset --model=$model --hierarchy=induced-$model --pretrained --loss=SoftTreeSupLoss --tree-supervision-weight=$weight + +# 3. evaluate with soft then hard inference +$analysisRules = @("SoftEmbeddedDecisionRules", "HardEmbeddedDecisionRules") + +foreach ($analysis in $analysisRules) { + python main.py --dataset=$dataset --model=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --eval --resume --analysis=$analysis --tree-supervision-weight=$weight +} diff --git a/scripts/gen_train_eval_resnet.ps1 b/scripts/gen_train_eval_resnet.ps1 new file mode 100644 index 0000000..e8c5517 --- /dev/null +++ b/scripts/gen_train_eval_resnet.ps1 @@ -0,0 +1,27 @@ +# Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below. + +$MODELS = @("CIFAR10 1", "CIFAR100 1", "TinyImagenet200 10") + +foreach ($model in $MODELS) { + + $params = $model.split(" ") + + $dataset=$params[0] + $weight=$params[1] + + + + # 1. generate hieararchy + nbdt-hierarchy --dataset=$dataset --arch=ResNet18 + + # 2. train with soft tree supervision loss + python main.py --dataset=$dataset --arch=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --tree-supervision-weight=$weight + + # 3. evaluate with soft then hard inference + + $analysisRules = @("SoftEmbeddedDecisionRules", "HardEmbeddedDecisionRules") + + foreach ($analysis in $analysisRules) { + python main.py --dataset=$dataset --arch=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --tree-supervision-weight=$weight --eval --resume --analysis=$analysis + } +} diff --git a/scripts/gen_train_eval_wideresnet.ps1 b/scripts/gen_train_eval_wideresnet.ps1 new file mode 100644 index 0000000..4fa7221 --- /dev/null +++ b/scripts/gen_train_eval_wideresnet.ps1 @@ -0,0 +1,28 @@ +# Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below. + +$MODEL_NAME="wrn28_10" +$CIFAR100="CIFAR100" + " " + $MODEL_NAME + "_cifar100 1" +$CIFAR10="CIFAR10 $MODEL_NAME" + "_cifar10 1" +$MODELS=@($CIFAR10, $CIFAR100, "TinyImagenet200 $MODEL_NAME 10") + +foreach ($model in $MODELS) { + + $params = $model.split(" ") + + $dataset=$params[0] + $model=$params[1] + $weight=$params[2] + + # 1. generate hieararchy + nbdt-hierarchy --dataset=$dataset --arch=$model + + # 2. train with soft tree supervision loss + python main.py --lr=0.01 --dataset=$dataset --arch=$model --hierarchy=induced-$model --pretrained --loss=SoftTreeSupLoss --tree-supervision-weight=$weight + + # 3. evaluate with soft then hard inference + $analysisRules = @("SoftEmbeddedDecisionRules", "HardEmbeddedDecisionRules") + + foreach ($analysis in $analysisRules) { + python main.py --dataset=${dataset} --arch=${model} --hierarchy=induced-${model} --loss=SoftTreeSupLoss --eval --resume --analysis=${analysis} --tree-supervision-weight=${weight} + } +} \ No newline at end of file diff --git a/scripts/generate_hierarchies_wordnet.ps1 b/scripts/generate_hierarchies_wordnet.ps1 new file mode 100644 index 0000000..492b9f8 --- /dev/null +++ b/scripts/generate_hierarchies_wordnet.ps1 @@ -0,0 +1,13 @@ +python -c "import nltk;nltk.download('wordnet')" + +# Generate WNIDs +$DATASETS = @("CIFAR10", "CIFAR100") +foreach ($dataset in $DATASETS) { + nbdt-wnids --dataset=$dataset +} + +# Generate and test hierarchies +$MORE_DATASETS = @("CIFAR10", "CIFAR100", "TinyImagenet200") +foreach ($dataset in $MORE_DATASETS) { + nbdt-hierarchy --dataset=$dataset --method=wordnet; +} \ No newline at end of file