-
Notifications
You must be signed in to change notification settings - Fork 130
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #17 from MDutro/windows-version
Add Powershell scripts for Windows users
- Loading branch information
Showing
5 changed files
with
110 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below. | ||
# This script is for networks that do NOT come with a pretrained checkpoint provided either by a model zoo or by the NBDT utility itself. | ||
|
||
$model="ResNet18" | ||
$dataset="CIFAR10" | ||
$weight=1 | ||
|
||
# 0. train the baseline neural network | ||
python main.py --dataset=$dataset --arch=$model | ||
|
||
# 1. generate hieararchy -- for models without a pretrained checkpoint, use `checkpoint` | ||
nbdt-hierarchy --dataset=$dataset --checkpoint="./checkpoint/ckpt-$dataset" + "-$model.pth" | ||
|
||
# 2. train with soft tree supervision loss -- for models without a pretrained checkpoint, use `path-resume` OR just train from scratch, without `path-resume` | ||
# python main.py --lr=0.01 --dataset=${dataset} --model=${model} --hierarchy=induced-${model} --path-resume=./checkpoint/ckpt-${dataset}-${model}.pth --loss=SoftTreeSupLoss --tree-supervision-weight=${weight} # fine-tuning | ||
python main.py --dataset=$dataset --arch=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --tree-supervision-weight=$weight # training from scratch | ||
|
||
# 3. evaluate with soft then hard inference | ||
$analysisRules = @("SoftEmbeddedDecisionRules", "HardEmbeddedDecisionRules") | ||
|
||
foreach ($analysis in $analysisRules) { | ||
python main.py --dataset=$dataset --model=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --eval --resume --analysis=$analysis --tree-supervision-weight=$weight | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below. | ||
# This script is for networks that DO come with a pretrained checkpoint provided either by a model zoo or by the NBDT utility itself. | ||
|
||
$model="wrn28_10_cifar10" | ||
$dataset="CIFAR10" | ||
$weight=1 | ||
|
||
# 1. generate hieararchy | ||
nbdt-hierarchy --dataset=$dataset --arch=$model | ||
|
||
# 2. train with soft tree supervision loss | ||
python main.py --lr=0.01 --dataset=$dataset --model=$model --hierarchy=induced-$model --pretrained --loss=SoftTreeSupLoss --tree-supervision-weight=$weight | ||
|
||
# 3. evaluate with soft then hard inference | ||
$analysisRules = @("SoftEmbeddedDecisionRules", "HardEmbeddedDecisionRules") | ||
|
||
foreach ($analysis in $analysisRules) { | ||
python main.py --dataset=$dataset --model=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --eval --resume --analysis=$analysis --tree-supervision-weight=$weight | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below. | ||
|
||
$MODELS = @("CIFAR10 1", "CIFAR100 1", "TinyImagenet200 10") | ||
|
||
foreach ($model in $MODELS) { | ||
|
||
$params = $model.split(" ") | ||
|
||
$dataset=$params[0] | ||
$weight=$params[1] | ||
|
||
|
||
|
||
# 1. generate hieararchy | ||
nbdt-hierarchy --dataset=$dataset --arch=ResNet18 | ||
|
||
# 2. train with soft tree supervision loss | ||
python main.py --dataset=$dataset --arch=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --tree-supervision-weight=$weight | ||
|
||
# 3. evaluate with soft then hard inference | ||
|
||
$analysisRules = @("SoftEmbeddedDecisionRules", "HardEmbeddedDecisionRules") | ||
|
||
foreach ($analysis in $analysisRules) { | ||
python main.py --dataset=$dataset --arch=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --tree-supervision-weight=$weight --eval --resume --analysis=$analysis | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below. | ||
|
||
$MODEL_NAME="wrn28_10" | ||
$CIFAR100="CIFAR100" + " " + $MODEL_NAME + "_cifar100 1" | ||
$CIFAR10="CIFAR10 $MODEL_NAME" + "_cifar10 1" | ||
$MODELS=@($CIFAR10, $CIFAR100, "TinyImagenet200 $MODEL_NAME 10") | ||
|
||
foreach ($model in $MODELS) { | ||
|
||
$params = $model.split(" ") | ||
|
||
$dataset=$params[0] | ||
$model=$params[1] | ||
$weight=$params[2] | ||
|
||
# 1. generate hieararchy | ||
nbdt-hierarchy --dataset=$dataset --arch=$model | ||
|
||
# 2. train with soft tree supervision loss | ||
python main.py --lr=0.01 --dataset=$dataset --arch=$model --hierarchy=induced-$model --pretrained --loss=SoftTreeSupLoss --tree-supervision-weight=$weight | ||
|
||
# 3. evaluate with soft then hard inference | ||
$analysisRules = @("SoftEmbeddedDecisionRules", "HardEmbeddedDecisionRules") | ||
|
||
foreach ($analysis in $analysisRules) { | ||
python main.py --dataset=${dataset} --arch=${model} --hierarchy=induced-${model} --loss=SoftTreeSupLoss --eval --resume --analysis=${analysis} --tree-supervision-weight=${weight} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
python -c "import nltk;nltk.download('wordnet')" | ||
|
||
# Generate WNIDs | ||
$DATASETS = @("CIFAR10", "CIFAR100") | ||
foreach ($dataset in $DATASETS) { | ||
nbdt-wnids --dataset=$dataset | ||
} | ||
|
||
# Generate and test hierarchies | ||
$MORE_DATASETS = @("CIFAR10", "CIFAR100", "TinyImagenet200") | ||
foreach ($dataset in $MORE_DATASETS) { | ||
nbdt-hierarchy --dataset=$dataset --method=wordnet; | ||
} |