From 4ca28099d1271f78a7599746693a2295f0a2340a Mon Sep 17 00:00:00 2001 From: Robert S Lee Date: Tue, 27 Jul 2021 12:56:28 -0400 Subject: [PATCH 01/11] add grid.ai examples --- .gitignore | 2 ++ README.md | 23 ++++++++++++++++++++--- requirements.txt | 7 ++++--- 3 files changed, 26 insertions(+), 6 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3647e99 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +data/* +lightning_logs/* diff --git a/README.md b/README.md index 7d67f86..d1b3f26 100644 --- a/README.md +++ b/README.md @@ -5,21 +5,38 @@ Simple MNIST classifier written in PyTorch, PyTorch Lightning, and Keras. ## Install Dependencies ```bash +conda create --yes --name mnist python=3.8 +conda activate mnist pip install -r requirements.txt +pip install lightning-grid --upgrade ``` -## PyTorch / Lightning +## PyTorch ```bash -# pytorch +# Local Run python pytorch.py -# lightning +# Grid.ai Run +grid run pytorch.py +``` + +# PyTorch Lightning + +```bash +# Local Run python pl_mnist.py + +# Grid.ai Run +grid run pl_mnist.py ``` ## Keras ```bash +# Local Run python keras.py + +# Grid.ai Run +grid run keras.py ``` diff --git a/requirements.txt b/requirements.txt index f66081b..eddbaa3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ -torch==1.7.1 -pytorch-lightning==1.1.2 -torchvision==0.8.2 +torch +pytorch-lightning +torchvision +tensorflow From 5bd4db06d1cf7ed3989c39e4756a82d5fb2c94c1 Mon Sep 17 00:00:00 2001 From: Robert S Lee Date: Tue, 27 Jul 2021 13:04:45 -0400 Subject: [PATCH 02/11] Update README.md --- README.md | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index d1b3f26..b8cb60a 100644 --- a/README.md +++ b/README.md @@ -2,23 +2,29 @@ Simple MNIST classifier written in PyTorch, PyTorch Lightning, and Keras. -## Install Dependencies +# Install Dependencies ```bash conda create --yes --name mnist python=3.8 conda activate mnist pip install -r requirements.txt -pip install lightning-grid --upgrade +pip install lightning-grid ``` -## PyTorch +# PyTorch ```bash # Local Run python pytorch.py # Grid.ai Run -grid run pytorch.py +grid run pytorch.py | tee /tmp/grid.run.log + +# Grid.ai Run Monitor +RUN_NAME=$(grep /tmp/grid_name grid.run.log | cut -d':' -f 2 | sed -e 's/^[[:space:]]*//') +watch grid status --details $RUN_NAME +grid history | grep $RUN_NAME +grid logs ${RUN_NAME}-exp0 | tee /tmp/grid.exp0.log ``` # PyTorch Lightning @@ -28,15 +34,27 @@ grid run pytorch.py python pl_mnist.py # Grid.ai Run -grid run pl_mnist.py +grid run pl_mnist.py | tee /tmp/grid.run.log + +# Grid.ai Run Monitor +RUN_NAME=$(grep /tmp/grid_name grid.run.log | cut -d':' -f 2 | sed -e 's/^[[:space:]]*//') +watch grid status --details $RUN_NAME +grid history | grep $RUN_NAME +grid logs ${RUN_NAME}-exp0 | tee /tmp/grid.exp0.log ``` -## Keras +# Keras ```bash # Local Run python keras.py # Grid.ai Run -grid run keras.py +grid run keras.py | tee /tmp/grid.run.log + +# Grid.ai Run Monitor +RUN_NAME=$(grep /tmp/grid_name grid.run.log | cut -d':' -f 2 | sed -e 's/^[[:space:]]*//') +watch grid status --details $RUN_NAME +grid history | grep $RUN_NAME +grid logs ${RUN_NAME}-exp0 | tee /tmp/grid.exp0.log ``` From bb6aa8d5295458d324f170d316e037f22da76efd Mon Sep 17 00:00:00 2001 From: Robert S Lee Date: Tue, 27 Jul 2021 13:29:57 -0400 Subject: [PATCH 03/11] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 3647e99..4ccda91 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ data/* lightning_logs/* +runs/* From a0d62bf5c5faebfb93e29909472930482bcb9146 Mon Sep 17 00:00:00 2001 From: Robert S Lee Date: Tue, 27 Jul 2021 16:05:38 -0400 Subject: [PATCH 04/11] add num_workers --- .gitignore | 2 ++ pl_cifar10.py | 6 ++++-- pl_mnist.py | 5 +++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 4ccda91..c10e9d0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ data/* lightning_logs/* runs/* +MNIST/* +cifar-10-* diff --git a/pl_cifar10.py b/pl_cifar10.py index 7236114..5ec30ac 100644 --- a/pl_cifar10.py +++ b/pl_cifar10.py @@ -40,15 +40,17 @@ def training_step(self, batch, batch_idx): from argparse import ArgumentParser parser = ArgumentParser() - parser.add_argument('--gpus', type=int, default=None) + parser.add_argument('--gpus', type=int, default=0) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--max_epochs', type=int, default=10) parser.add_argument('--data_dir', type=str, default=os.getcwd()) + parser.add_argument('--num_workers', type=int, default=8) + args = parser.parse_args() dataset = CIFAR10(args.data_dir, download=True, transform=transforms.ToTensor()) - train_loader = DataLoader(dataset, batch_size=args.batch_size) + train_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers) # init model model = LitModel(lr=args.lr) diff --git a/pl_mnist.py b/pl_mnist.py index cbb54c0..c5152f1 100644 --- a/pl_mnist.py +++ b/pl_mnist.py @@ -39,15 +39,16 @@ def training_step(self, batch, batch_idx): from argparse import ArgumentParser parser = ArgumentParser() - parser.add_argument('--gpus', type=int, default=None) + parser.add_argument('--gpus', type=int, default=0) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--max_epochs', type=int, default=10) parser.add_argument('--data_dir', type=str, default=os.getcwd()) + parser.add_argument('--num_workers', type=int, default=8) args = parser.parse_args() dataset = MNIST(args.data_dir, download=True, transform=transforms.ToTensor()) - train_loader = DataLoader(dataset, batch_size=args.batch_size) + train_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers) # init model model = LitModel(lr=args.lr) From 8f5f4b9c837af589f1f0e6bc67548391700a6039 Mon Sep 17 00:00:00 2001 From: Robert S Lee Date: Tue, 27 Jul 2021 16:11:18 -0400 Subject: [PATCH 05/11] Update README.md --- README.md | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index b8cb60a..b900ab4 100644 --- a/README.md +++ b/README.md @@ -19,12 +19,6 @@ python pytorch.py # Grid.ai Run grid run pytorch.py | tee /tmp/grid.run.log - -# Grid.ai Run Monitor -RUN_NAME=$(grep /tmp/grid_name grid.run.log | cut -d':' -f 2 | sed -e 's/^[[:space:]]*//') -watch grid status --details $RUN_NAME -grid history | grep $RUN_NAME -grid logs ${RUN_NAME}-exp0 | tee /tmp/grid.exp0.log ``` # PyTorch Lightning @@ -35,12 +29,6 @@ python pl_mnist.py # Grid.ai Run grid run pl_mnist.py | tee /tmp/grid.run.log - -# Grid.ai Run Monitor -RUN_NAME=$(grep /tmp/grid_name grid.run.log | cut -d':' -f 2 | sed -e 's/^[[:space:]]*//') -watch grid status --details $RUN_NAME -grid history | grep $RUN_NAME -grid logs ${RUN_NAME}-exp0 | tee /tmp/grid.exp0.log ``` # Keras @@ -51,9 +39,11 @@ python keras.py # Grid.ai Run grid run keras.py | tee /tmp/grid.run.log +``` # Grid.ai Run Monitor -RUN_NAME=$(grep /tmp/grid_name grid.run.log | cut -d':' -f 2 | sed -e 's/^[[:space:]]*//') +```bash +RUN_NAME=$(grep grid_name /tmp/grid.run.log | cut -d':' -f 2 | sed -e 's/^[[:space:]]*//') watch grid status --details $RUN_NAME grid history | grep $RUN_NAME grid logs ${RUN_NAME}-exp0 | tee /tmp/grid.exp0.log From 925c9347d00a2c64c029f13be27afcb9dbe273ce Mon Sep 17 00:00:00 2001 From: Robert S Lee Date: Tue, 27 Jul 2021 20:38:45 -0400 Subject: [PATCH 06/11] use same arg across the models --- README.md | 12 ++++++++++++ pytorch.py | 14 +++++++------- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index b900ab4..5722cc5 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ conda create --yes --name mnist python=3.8 conda activate mnist pip install -r requirements.txt pip install lightning-grid +grid login ``` # PyTorch @@ -41,7 +42,18 @@ python keras.py grid run keras.py | tee /tmp/grid.run.log ``` +# CIFAR-10 Bonus + +```bash +# Local Run +python pl_cifar10.py + +# Grid.ai Run +grid run pl_cifar10.py | tee /tmp/grid.run.log +``` + # Grid.ai Run Monitor + ```bash RUN_NAME=$(grep grid_name /tmp/grid.run.log | cut -d':' -f 2 | sed -e 's/^[[:space:]]*//') watch grid status --details $RUN_NAME diff --git a/pytorch.py b/pytorch.py index 862e56a..600381a 100644 --- a/pytorch.py +++ b/pytorch.py @@ -88,11 +88,11 @@ def test(model, device, test_loader, epoch): def main(): # Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', + parser.add_argument('--batch_size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + parser.add_argument('--test_batch_size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=14, metavar='N', + parser.add_argument('--max_epochs', type=int, default=14, metavar='N', help='number of epochs to train (default: 14)') parser.add_argument('--lr', type=float, default=1.0, metavar='LR', help='learning rate (default: 1.0)') @@ -100,13 +100,13 @@ def main(): help='Learning rate step gamma (default: 0.7)') parser.add_argument('--cuda', action='store_true', default=False, help='disables CUDA training') - parser.add_argument('--dry-run', action='store_true', default=False, + parser.add_argument('--dry_run', action='store_true', default=False, help='quickly check a single pass') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') - parser.add_argument('--log-interval', type=int, default=10, metavar='N', + parser.add_argument('--log_interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') - parser.add_argument('--save-model', action='store_true', default=False, + parser.add_argument('--save_model', action='store_true', default=False, help='For Saving the current Model') args = parser.parse_args() use_cuda = not args.cuda and torch.cuda.is_available() @@ -137,7 +137,7 @@ def main(): optimizer = optim.Adadelta(model.parameters(), lr=args.lr) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) - for epoch in range(1, args.epochs + 1): + for epoch in range(1, args.max_epochs + 1): train(args, model, device, train_loader, optimizer, epoch) test(model, device, test_loader, epoch) scheduler.step() From 6cbeebc74035cc802ddcacb852e8c284e243e4cf Mon Sep 17 00:00:00 2001 From: Robert S Lee Date: Tue, 27 Jul 2021 21:57:56 -0400 Subject: [PATCH 07/11] Update README.md --- README.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/README.md b/README.md index 5722cc5..b63be36 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,24 @@ python pl_cifar10.py grid run pl_cifar10.py | tee /tmp/grid.run.log ``` +# Default Command Line Argument Values per Script + +| Argument Name  | keras.py | pl_cifar10.py | pl_mnist.py | pytorch.py| +| --:| :--: | :--: | :--: | :--: | +| --max_epochs | 5 | 10 | 10 | 14| +| --lr | 1.00E-03 | 1.00E-03 | 1.00E-03 | 1 | +| --batch_size | 32 | 32 | 32 | 64 | +| --data_dir | ./data/ | os.getcwd | os.getcwd |  | +| --num_workers |   | 8 | 8 |  | +| --gpus |   | 0 | 0 |  | +| --test_batch_size |   |   |   | 1000| +| --seed |   |   |   | 1| +| --save_model |   |   |   | FALSE| +| --log_interval |   |   |   | 10| +| --gamma |   |   |   | 0.7| +| --dry_run |   |   |   | FALSE| +| --cuda |   |   |   | FALSE | + # Grid.ai Run Monitor ```bash From da1a2faca0c372eb2c7617935381b2301e8f3b22 Mon Sep 17 00:00:00 2001 From: Robert S Lee Date: Sun, 1 Aug 2021 14:09:05 -0400 Subject: [PATCH 08/11] num_workers --- README.md | 16 ++++++++++++++++ keras.py | 1 + requirements.txt | 6 ++++-- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b63be36..1c754d2 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,11 @@ grid login # PyTorch +Use the CLI commands below or click +[![PyTorch](https://img.shields.io/badge/rid_AI-run-78FF96.svg?labelColor=black&logo=data:image/svg%2bxml;base64,PHN2ZyB3aWR0aD0iNDgiIGhlaWdodD0iNDgiIGZpbGw9Im5vbmUiIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyI+PHBhdGggZD0iTTEgMTR2MjBhMTQgMTQgMCAwMDE0IDE0aDlWMzYuOEgxMi42VjExaDIyLjV2N2gxMS4yVjE0QTE0IDE0IDAgMDAzMi40IDBIMTVBMTQgMTQgMCAwMDEgMTR6IiBmaWxsPSIjZmZmIi8+PHBhdGggZD0iTTM1LjIgNDhoMTEuMlYyNS41SDIzLjl2MTEuM2gxMS4zVjQ4eiIgZmlsbD0iI2ZmZiIvPjwvc3ZnPg==)]( +https://platform.grid.ai/#/runs?script=https://github.com/robert-s-lee/hello_mnists/blob/6cbeebc74035cc802ddcacb852e8c284e243e4cf/pytorch.py&cloud=grid&instance=t2.medium&accelerators=1&disk_size=200&framework=lightning +) + ```bash # Local Run python pytorch.py @@ -24,6 +29,11 @@ grid run pytorch.py | tee /tmp/grid.run.log # PyTorch Lightning +Use the CLI commands below or click +[![Lightning](https://img.shields.io/badge/rid_AI-run-78FF96.svg?labelColor=black&logo=data:image/svg%2bxml;base64,PHN2ZyB3aWR0aD0iNDgiIGhlaWdodD0iNDgiIGZpbGw9Im5vbmUiIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyI+PHBhdGggZD0iTTEgMTR2MjBhMTQgMTQgMCAwMDE0IDE0aDlWMzYuOEgxMi42VjExaDIyLjV2N2gxMS4yVjE0QTE0IDE0IDAgMDAzMi40IDBIMTVBMTQgMTQgMCAwMDEgMTR6IiBmaWxsPSIjZmZmIi8+PHBhdGggZD0iTTM1LjIgNDhoMTEuMlYyNS41SDIzLjl2MTEuM2gxMS4zVjQ4eiIgZmlsbD0iI2ZmZiIvPjwvc3ZnPg==)]( +https://platform.grid.ai/#/runs?script=https://github.com/robert-s-lee/hello_mnists/blob/6cbeebc74035cc802ddcacb852e8c284e243e4cf/pl_mnist.py&cloud=grid&instance=t2.medium&accelerators=1&disk_size=200&framework=lightning +) + ```bash # Local Run python pl_mnist.py @@ -34,6 +44,11 @@ grid run pl_mnist.py | tee /tmp/grid.run.log # Keras +Use the CLI commands below or click +[![Keras](https://img.shields.io/badge/rid_AI-run-78FF96.svg?labelColor=black&logo=data:image/svg%2bxml;base64,PHN2ZyB3aWR0aD0iNDgiIGhlaWdodD0iNDgiIGZpbGw9Im5vbmUiIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyI+PHBhdGggZD0iTTEgMTR2MjBhMTQgMTQgMCAwMDE0IDE0aDlWMzYuOEgxMi42VjExaDIyLjV2N2gxMS4yVjE0QTE0IDE0IDAgMDAzMi40IDBIMTVBMTQgMTQgMCAwMDEgMTR6IiBmaWxsPSIjZmZmIi8+PHBhdGggZD0iTTM1LjIgNDhoMTEuMlYyNS41SDIzLjl2MTEuM2gxMS4zVjQ4eiIgZmlsbD0iI2ZmZiIvPjwvc3ZnPg==)]( +https://platform.grid.ai/#/runs?script=https://github.com/robert-s-lee/hello_mnists/blob/6cbeebc74035cc802ddcacb852e8c284e243e4cf/keras.py&cloud=grid&instance=t2.medium&accelerators=1&disk_size=200&framework=lightning&script_args=keras.py +) + ```bash # Local Run python keras.py @@ -44,6 +59,7 @@ grid run keras.py | tee /tmp/grid.run.log # CIFAR-10 Bonus + ```bash # Local Run python pl_cifar10.py diff --git a/keras.py b/keras.py index b156a94..3959dd4 100644 --- a/keras.py +++ b/keras.py @@ -9,6 +9,7 @@ parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--max_epochs', type=int, default=5) parser.add_argument('--data_dir', type=str, default="./data/") +parser.add_argument('--num_workers', type=int, default=8) args = parser.parse_args() # Make sure data_dir is absolute + create it if it doesn't exist diff --git a/requirements.txt b/requirements.txt index eddbaa3..09fc7ba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ -torch +# tensorflow version matches Grid.ai grid run --framework tensorflow 2.2.0 +tensorflow +# lightning version matches Grid.ai grid run --framework lightning 1.2.1 pytorch-lightning +torch torchvision -tensorflow From fb6d50d8e171d5d2e01ff7f436434daf8badaf32 Mon Sep 17 00:00:00 2001 From: Robert S Lee Date: Thu, 5 Aug 2021 10:57:16 -0400 Subject: [PATCH 09/11] add dockerfile --- Dockerfile | 10 ++++++++++ requirements.txt | 5 ++++- 2 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 Dockerfile diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..2fa4a81 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,10 @@ +# test locally for syntax error by running +# `docker build -t gridray:latest -f gridray.dockerfile .` + +FROM python:3.8 + +# mandatory for Grid.ai +WORKDIR /gridai/project +COPY . . +# mandatory for Grid. +RUN pip install --ignore-requires-python -v -r requirements.txt \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 09fc7ba..f11fae9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,9 @@ # tensorflow version matches Grid.ai grid run --framework tensorflow 2.2.0 tensorflow # lightning version matches Grid.ai grid run --framework lightning 1.2.1 -pytorch-lightning +# raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.") +# ValueError: The `preds` should be probabilities, but values were detected outside of [0,1] range. +# https://github.com/PyTorchLightning/lightning-bolts/issues/551 +pytorch-lightning<1.2 torch torchvision From 58b311ff32fbfeb3e427fbe5be5464b26a5d16a4 Mon Sep 17 00:00:00 2001 From: Robert S Lee Date: Thu, 5 Aug 2021 11:17:00 -0400 Subject: [PATCH 10/11] remove dockerfile --- Dockerfile | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100644 Dockerfile diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 2fa4a81..0000000 --- a/Dockerfile +++ /dev/null @@ -1,10 +0,0 @@ -# test locally for syntax error by running -# `docker build -t gridray:latest -f gridray.dockerfile .` - -FROM python:3.8 - -# mandatory for Grid.ai -WORKDIR /gridai/project -COPY . . -# mandatory for Grid. -RUN pip install --ignore-requires-python -v -r requirements.txt \ No newline at end of file From 9561562efba6b00ad13622cd5236372d66ae8444 Mon Sep 17 00:00:00 2001 From: Robert S Lee Date: Thu, 5 Aug 2021 11:47:16 -0400 Subject: [PATCH 11/11] use configargparser --- keras.py | 8 ++++---- pl_cifar10.py | 8 ++++---- pl_mnist.py | 8 ++++---- pytorch.py | 10 +++++----- requirements.txt | 1 + 5 files changed, 18 insertions(+), 17 deletions(-) diff --git a/keras.py b/keras.py index 3959dd4..f6a654f 100644 --- a/keras.py +++ b/keras.py @@ -1,13 +1,13 @@ -from argparse import ArgumentParser +from configargparse import ArgumentParser from pathlib import Path from tensorflow import keras # Define this script's flags parser = ArgumentParser() -parser.add_argument('--lr', type=float, default=1e-3) -parser.add_argument('--batch_size', type=int, default=32) -parser.add_argument('--max_epochs', type=int, default=5) +parser.add_argument('--lr', type=float, default=1e-3, env_var="MNIST_LR") +parser.add_argument('--batch_size', type=int, default=32, env_var="MNIST_BATCH_SIZE") +parser.add_argument('--max_epochs', type=int, default=5, env_var="MNIST_MAX_EPOCHS") parser.add_argument('--data_dir', type=str, default="./data/") parser.add_argument('--num_workers', type=int, default=8) args = parser.parse_args() diff --git a/pl_cifar10.py b/pl_cifar10.py index 5ec30ac..aecec4b 100644 --- a/pl_cifar10.py +++ b/pl_cifar10.py @@ -37,13 +37,13 @@ def training_step(self, batch, batch_idx): return loss if __name__ == '__main__': - from argparse import ArgumentParser + from configargparse import ArgumentParser parser = ArgumentParser() parser.add_argument('--gpus', type=int, default=0) - parser.add_argument('--lr', type=float, default=1e-3) - parser.add_argument('--batch_size', type=int, default=32) - parser.add_argument('--max_epochs', type=int, default=10) + parser.add_argument('--lr', type=float, default=1e-3, env_var="CIFAR_LR") + parser.add_argument('--batch_size', type=int, default=32, env_var="CIFAR_BATCH_SIZE") + parser.add_argument('--max_epochs', type=int, default=10, env_var="CIFAR_MAX_EPOCHS") parser.add_argument('--data_dir', type=str, default=os.getcwd()) parser.add_argument('--num_workers', type=int, default=8) diff --git a/pl_mnist.py b/pl_mnist.py index c5152f1..9f8eba1 100644 --- a/pl_mnist.py +++ b/pl_mnist.py @@ -36,13 +36,13 @@ def training_step(self, batch, batch_idx): return loss if __name__ == '__main__': - from argparse import ArgumentParser + from configargparse import ArgumentParser parser = ArgumentParser() parser.add_argument('--gpus', type=int, default=0) - parser.add_argument('--lr', type=float, default=1e-3) - parser.add_argument('--batch_size', type=int, default=32) - parser.add_argument('--max_epochs', type=int, default=10) + parser.add_argument('--lr', type=float, default=1e-3, env_var="MNIST_LR") + parser.add_argument('--batch_size', type=int, default=32, env_var="MNIST_BATCH_SIZE") + parser.add_argument('--max_epochs', type=int, default=10, env_var="MNIST_MAX_EPOCHS") parser.add_argument('--data_dir', type=str, default=os.getcwd()) parser.add_argument('--num_workers', type=int, default=8) args = parser.parse_args() diff --git a/pytorch.py b/pytorch.py index 600381a..5aad4b2 100644 --- a/pytorch.py +++ b/pytorch.py @@ -4,7 +4,7 @@ """ from __future__ import print_function -import argparse +from configargparse import ArgumentParser import torch import torch.nn as nn import torch.nn.functional as F @@ -87,15 +87,15 @@ def test(model, device, test_loader, epoch): def main(): # Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser = ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch_size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') + help='input batch size for training (default: 64)', env_var="MNIST_BATCH_SIZE") parser.add_argument('--test_batch_size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--max_epochs', type=int, default=14, metavar='N', - help='number of epochs to train (default: 14)') + help='number of epochs to train (default: 14)', env_var="MNIST_MAX_EPOCHS") parser.add_argument('--lr', type=float, default=1.0, metavar='LR', - help='learning rate (default: 1.0)') + help='learning rate (default: 1.0)', env_var="MNIST_LR") parser.add_argument('--gamma', type=float, default=0.7, metavar='M', help='Learning rate step gamma (default: 0.7)') parser.add_argument('--cuda', action='store_true', default=False, diff --git a/requirements.txt b/requirements.txt index f11fae9..23d7b2d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ tensorflow pytorch-lightning<1.2 torch torchvision +configargparser