-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ADD: distributed torch Trainer and decorator
- Loading branch information
Matteo Bunino
committed
Aug 9, 2023
1 parent
c744f85
commit 57b415b
Showing
13 changed files
with
676 additions
and
110 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,8 @@ tmp* | |
*.txt | ||
checkpoints/ | ||
mamba* | ||
MNIST | ||
mllogs | ||
|
||
# Custom envs | ||
.venv* | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ dependencies: | |
- lightning=2.0.0 | ||
- torchmetrics | ||
- mlflow>=2 | ||
- wandb | ||
- typer | ||
- rich | ||
- pyyaml | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ dependencies: | |
- lightning=2.0.0 | ||
- torchmetrics | ||
- mlflow>=2 | ||
- wandb | ||
- typer | ||
- rich | ||
- pyyaml | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
""" | ||
Test Trainer class. To run this script, use the following command: | ||
>>> torchrun --nnodes=1 --nproc_per_node=2 --rdzv_id=100 --rdzv_backend=c10d \ | ||
--rdzv_endpoint=localhost:29400 test_decorator.py | ||
""" | ||
|
||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
from torch.utils.data import DataLoader | ||
from torchvision import transforms, datasets | ||
import torch.optim as optim | ||
from torch.optim.lr_scheduler import StepLR | ||
|
||
from itwinai.backend.torch.trainer import distributed | ||
|
||
|
||
class Net(nn.Module): | ||
|
||
def __init__(self): | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | ||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | ||
self.conv2_drop = nn.Dropout2d() | ||
self.fc1 = nn.Linear(320, 50) | ||
self.fc2 = nn.Linear(50, 10) | ||
|
||
def forward(self, x): | ||
x = F.relu(F.max_pool2d(self.conv1(x), 2)) | ||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) | ||
x = x.view(-1, 320) | ||
x = F.relu(self.fc1(x)) | ||
x = F.dropout(x, training=self.training) | ||
x = self.fc2(x) | ||
return F.log_softmax(x, dim=0) | ||
|
||
|
||
def train(model, device, train_loader, optimizer, epoch): | ||
model.train() | ||
for batch_idx, (data, target) in enumerate(train_loader): | ||
data, target = data.to(device), target.to(device) | ||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = F.nll_loss(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
if batch_idx % 100 == 0: | ||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | ||
epoch, batch_idx * len(data), len(train_loader.dataset), | ||
100. * batch_idx / len(train_loader), loss.item())) | ||
|
||
|
||
def test(model, device, test_loader): | ||
model.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
with torch.no_grad(): | ||
for data, target in test_loader: | ||
data, target = data.to(device), target.to(device) | ||
output = model(data) | ||
# sum up batch loss | ||
test_loss += F.nll_loss(output, target, reduction='sum').item() | ||
# get the index of the max log-probability | ||
pred = output.argmax(dim=1, keepdim=True) | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
|
||
test_loss /= len(test_loader.dataset) | ||
|
||
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( | ||
test_loss, correct, len(test_loader.dataset), | ||
100. * correct / len(test_loader.dataset))) | ||
|
||
|
||
@distributed | ||
def train_func( | ||
model, train_dataloader, validation_dataloader, device, | ||
optimizer, scheduler, epochs=10 | ||
): | ||
for epoch in range(1, epochs + 1): | ||
train(model, device, train_dataloader, optimizer, epoch) | ||
test(model, device, validation_dataloader) | ||
scheduler.step() | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
train_set = datasets.MNIST( | ||
'.tmp/', train=True, download=True, | ||
transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), (0.3081,)) | ||
])) | ||
val_set = datasets.MNIST( | ||
'.tmp/', train=False, download=True, | ||
transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), (0.3081,)) | ||
])) | ||
model = Net() | ||
train_dataloader = DataLoader(train_set, batch_size=32, pin_memory=True) | ||
validation_dataloader = DataLoader(val_set, batch_size=32, pin_memory=True) | ||
optimizer = optim.Adadelta(model.parameters(), lr=1e-3) | ||
scheduler = StepLR(optimizer, step_size=1, gamma=0.9) | ||
|
||
# Train distributed | ||
train_func(model, train_dataloader, validation_dataloader, 'cuda', | ||
optimizer, scheduler=scheduler) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
""" | ||
Test Trainer class. To run this script, use the following command: | ||
>>> torchrun --nnodes=1 --nproc_per_node=2 --rdzv_id=100 --rdzv_backend=c10d \ | ||
--rdzv_endpoint=localhost:29400 test_trainer.py | ||
""" | ||
|
||
from torch import nn | ||
import torch.nn.functional as F | ||
from torch.utils.data import DataLoader | ||
from torchvision import transforms, datasets | ||
|
||
from itwinai.backend.torch.trainer import TorchTrainer | ||
|
||
|
||
class Net(nn.Module): | ||
|
||
def __init__(self): | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | ||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | ||
self.conv2_drop = nn.Dropout2d() | ||
self.fc1 = nn.Linear(320, 50) | ||
self.fc2 = nn.Linear(50, 10) | ||
|
||
def forward(self, x): | ||
x = F.relu(F.max_pool2d(self.conv1(x), 2)) | ||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) | ||
x = x.view(-1, 320) | ||
x = F.relu(self.fc1(x)) | ||
x = F.dropout(x, training=self.training) | ||
x = self.fc2(x) | ||
return F.log_softmax(x, dim=0) | ||
|
||
|
||
if __name__ == '__main__': | ||
train_set = datasets.MNIST( | ||
'.tmp/', train=True, download=True, | ||
transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), (0.3081,)) | ||
])) | ||
val_set = datasets.MNIST( | ||
'.tmp/', train=False, download=True, | ||
transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), (0.3081,)) | ||
])) | ||
trainer = TorchTrainer( | ||
model=Net(), | ||
train_dataloader=DataLoader(train_set, batch_size=32, pin_memory=True), | ||
validation_dataloader=DataLoader( | ||
val_set, batch_size=32, pin_memory=True), | ||
strategy='ddp', | ||
backend='nccl', | ||
loss='NLLLoss', | ||
epochs=20, | ||
checkpoint_every=1 | ||
) | ||
trainer.train() |
Oops, something went wrong.