Skip to content

Commit

Permalink
Add Powershell scripts for Windows users
Browse files Browse the repository at this point in the history
  • Loading branch information
MDutro committed Mar 19, 2021
1 parent 21dd005 commit 23b51d1
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 0 deletions.
23 changes: 23 additions & 0 deletions scripts/gen_train_eval_nopretrained.ps1
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below.
# This script is for networks that do NOT come with a pretrained checkpoint provided either by a model zoo or by the NBDT utility itself.

$model="ResNet18"
$dataset="CIFAR10"
$weight=1

# 0. train the baseline neural network
python main.py --dataset=$dataset --arch=$model

# 1. generate hieararchy -- for models without a pretrained checkpoint, use `checkpoint`
nbdt-hierarchy --dataset=$dataset --checkpoint="./checkpoint/ckpt-$dataset" + "-$model.pth"

# 2. train with soft tree supervision loss -- for models without a pretrained checkpoint, use `path-resume` OR just train from scratch, without `path-resume`
# python main.py --lr=0.01 --dataset=${dataset} --model=${model} --hierarchy=induced-${model} --path-resume=./checkpoint/ckpt-${dataset}-${model}.pth --loss=SoftTreeSupLoss --tree-supervision-weight=${weight} # fine-tuning
python main.py --dataset=$dataset --arch=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --tree-supervision-weight=$weight # training from scratch

# 3. evaluate with soft then hard inference
$analysisRules = @("SoftEmbeddedDecisionRules", "HardEmbeddedDecisionRules")

foreach ($analysis in $analysisRules) {
python main.py --dataset=$dataset --model=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --eval --resume --analysis=$analysis --tree-supervision-weight=$weight
}
19 changes: 19 additions & 0 deletions scripts/gen_train_eval_pretrained.ps1
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below.
# This script is for networks that DO come with a pretrained checkpoint provided either by a model zoo or by the NBDT utility itself.

$model="wrn28_10_cifar10"
$dataset="CIFAR10"
$weight=1

# 1. generate hieararchy
nbdt-hierarchy --dataset=$dataset --arch=$model

# 2. train with soft tree supervision loss
python main.py --lr=0.01 --dataset=$dataset --model=$model --hierarchy=induced-$model --pretrained --loss=SoftTreeSupLoss --tree-supervision-weight=$weight

# 3. evaluate with soft then hard inference
$analysisRules = @("SoftEmbeddedDecisionRules", "HardEmbeddedDecisionRules")

foreach ($analysis in $analysisRules) {
python main.py --dataset=$dataset --model=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --eval --resume --analysis=$analysis --tree-supervision-weight=$weight
}
27 changes: 27 additions & 0 deletions scripts/gen_train_eval_resnet.ps1
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below.

$MODELS = @("CIFAR10 1", "CIFAR100 1", "TinyImagenet200 10")

foreach ($model in $MODELS) {

$params = $model.split(" ")

$dataset=$params[0]
$weight=$params[1]



# 1. generate hieararchy
nbdt-hierarchy --dataset=$dataset --arch=ResNet18

# 2. train with soft tree supervision loss
python main.py --dataset=$dataset --arch=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --tree-supervision-weight=$weight

# 3. evaluate with soft then hard inference

$analysisRules = @("SoftEmbeddedDecisionRules", "HardEmbeddedDecisionRules")

foreach ($analysis in $analysisRules) {
python main.py --dataset=$dataset --arch=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --tree-supervision-weight=$weight --eval --resume --analysis=$analysis
}
}
28 changes: 28 additions & 0 deletions scripts/gen_train_eval_wideresnet.ps1
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below.

$MODEL_NAME="wrn28_10"
$CIFAR100="CIFAR100" + " " + $MODEL_NAME + "_cifar100 1"
$CIFAR10="CIFAR10 $MODEL_NAME" + "_cifar10 1"
$MODELS=@($CIFAR10, $CIFAR100, "TinyImagenet200 $MODEL_NAME 10")

foreach ($model in $MODELS) {

$params = $model.split(" ")

$dataset=$params[0]
$model=$params[1]
$weight=$params[2]

# 1. generate hieararchy
nbdt-hierarchy --dataset=$dataset --arch=$model

# 2. train with soft tree supervision loss
python main.py --lr=0.01 --dataset=$dataset --arch=$model --hierarchy=induced-$model --pretrained --loss=SoftTreeSupLoss --tree-supervision-weight=$weight

# 3. evaluate with soft then hard inference
$analysisRules = @("SoftEmbeddedDecisionRules", "HardEmbeddedDecisionRules")

foreach ($analysis in $analysisRules) {
python main.py --dataset=${dataset} --arch=${model} --hierarchy=induced-${model} --loss=SoftTreeSupLoss --eval --resume --analysis=${analysis} --tree-supervision-weight=${weight}
}
}
13 changes: 13 additions & 0 deletions scripts/generate_hierarchies_wordnet.ps1
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
python -c "import nltk;nltk.download('wordnet')"

# Generate WNIDs
$DATASETS = @("CIFAR10", "CIFAR100")
foreach ($dataset in $DATASETS) {
nbdt-wnids --dataset=$dataset
}

# Generate and test hierarchies
$MORE_DATASETS = @("CIFAR10", "CIFAR100", "TinyImagenet200")
foreach ($dataset in $MORE_DATASETS) {
nbdt-hierarchy --dataset=$dataset --method=wordnet;
}

0 comments on commit 23b51d1

Please sign in to comment.