diff --git a/pytorch/README.md b/pytorch/README.md new file mode 100644 index 0000000..0de4c73 --- /dev/null +++ b/pytorch/README.md @@ -0,0 +1,142 @@ +# comet-pytorch-example + +## Using Comet.ml to track PyTorch experiments + +The following code snippets shows how to use PyTorch with Comet.ml. Based on the tutorial from [Yunjey](https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/01-basics/feedforward_neural_network/main.py) this code trains an RNN to detect hand writted digits from the MNIST dataset. + +By initializing the `Experiment()` object, Comet.ml will log stdout and source code. To log hyper-parameters, metrics and visualizations, we add a few function calls such as `experiment.log_metric()` and `experiment.log_parameters()`. + +Find out more about how you can customize Comet.ml in our documentation: https://www.comet.ml/docs/ + +## Using Comet.ml with Pytorch Distributed Data Parallel (DDP) training + +### Setup +All machines on the network need to be able to talk to each-other on the unprivileged port range (1024-65535) in addition to the master port. + +You will need to define the IP, `master_addr`, of the "master" node that is reachable by all of the other machines and the port, `master_port`, on which the workers will connect. Make sure all of the machines have the same number of GPUS. + +### Logging Metrics from Multiple Machines +To capture model metrics and system metrics (GPU/CPU Usage, RAM etc) from each machine while running distributed training, we recommend creating an Experiment object per GPU process, and grouping these experiments under a user provided run ID. + +An example project can be found [here.](https://www.comet.ml/team-comet-ml/pytorch-ddp-cifar10/view/Tzf2pUfV5BWOVa36eWoW0HOO1) + +In order to reproduce the project, you will need to run the `comet-pytorch-ddp-cifar10.py` example. The example can be run in the following ways + +- Single Machine, Multiple GPUs +- Multiple Machines, Multiple GPUs + +##### Running the Example +We will run the example based on the assumption that we have 2 machines, with a single GPU each. We will use `192.168.1.1` as our `master_addr` and `8892` as our `master_port`. + +On the master node, start the script with the following command + +``` +python comet-pytorch-ddp-cifar10.py \ +--nodes 2 \ +--gpus 1 \ +--node_rank 0 \ +--master_addr 192.168.1.1 \ +--master_port 8892 \ +--epochs 5 \ +--replica_batch_size 32 \ +--run_id +``` + +On the worker node run + +``` +python comet-pytorch-ddp-cifar10.py \ +--nodes 2 \ +--gpus 1 \ +--node_rank 1 \ +--master_addr 192.168.1.1 \ +--master_port 8892 \ +--epochs 5 \ +--replica_batch_size 32 \ +--run_id +``` + +The command line arguments are: + +``` +nodes: The number of available compute nodes + +gpus: The number of GPUs available on each machine + +node_rank: The ranking of the machine within the nodes. It starts at 0 + +replica_batch_size: The batch size allocated to a single GPU process + +run_id: A user provided string that allows us to group the experiments from a single run. +``` + +As you add machines and GPUs, you will have to run the same command on each machine while incrementing the `node_rank`. For example in the case of N machines, we would run the script on each machine up until `node_rank = N-1` + +### Logging Metrics from Multiple Machines as a single Experiment + +If you would like to log the metrics from each worker as a single experiment, you will need to run the `comet-pytorch-ddp-mnist-single-experiment.py` example. Keep in mind, logging system metrics (CPU/GPU Usage, RAM, etc) from mutiple workers as a single experiment is not currently supported. We recommend using an Experiment per GPU process instead. + +An example project can be found [here.](https://www.comet.ml/team-comet-ml/pytorch-ddp-mnist-single/view/new) + +##### Running the Example +We will run the example based on the assumption that we have 2 machines, with a single GPU each. We will use `192.168.1.1` as our `master_addr` and `8892` as our `master_port`. + +On the master node, start the script with the following command + +``` +python comet-pytorch-ddp-mnist-single-experiment.py \ +--nodes 2 \ +--gpus 1 \ +--master_addr 192.168.1.1 \ +--master_port 8892 \ +--node_rank 0 \ +--local_rank 0 \ +--run_id +``` + +In this case, the sha256 hash of the `run_id` string will be used to create an experiment key for the Experiment that will be used to log the metrics from each worker. + +On the worker node run + +``` +python comet-pytorch-ddp-mnist-single-experiment.py \ +--nodes 2 \ +--gpus 1 \ +--master_addr 192.168.1.1 \ +--master_port 8892 \ +--node_rank 1 \ +--local_rank 0 \ +--run_id +``` + +The command line arguments are: + +``` +nodes: The number of available compute nodes + +gpus: The number of GPUs available on each machine + +node_rank: The ranking of the machine within the nodes. It starts at 0 + +local_rank: The rank of the process within the current node + +run_id: A user provided string +``` + +### Using Comet.ml with Horovod and Pytorch +In order to run the Horovod example in a distributed manner, you will need the following + +```bash +1. At least 2 machines 1 gpu each. +2. Ensure that the master node has ssh access to the worker machines without requiring a password +3. Horovod and Comet.ml installed on all machines +4. This script must be present in the same directory on all machines. +``` +To the run the example + +``` +horovodrun --gloo -np -H ,, python comet-pytorch-horovod-example.py +``` + +>If you're curious about learning more on parallelized training in Pytorch, checkout our report [here](https://www.comet.ml/team-comet-ml/parallelism/reports/advanced-ml-parallelism +) \ No newline at end of file diff --git a/pytorch/comet-pytorch-ddp-cifar10.py b/pytorch/comet-pytorch-ddp-cifar10.py new file mode 100644 index 0000000..9e0cd59 --- /dev/null +++ b/pytorch/comet-pytorch-ddp-cifar10.py @@ -0,0 +1,296 @@ +# coding: utf-8 +"""Pytorch Distributed Data Parallel Example with Learning Rate Scaling + +""" +import argparse +import os + +from comet_ml import Experiment + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn.functional as F +import torchvision +from torch import nn, optim +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import random_split +from torchvision import transforms +from tqdm import tqdm + +torch.manual_seed(0) + +# This is the batch size being used per GPU +LEARNING_RATE = 0.001 + +# Learning Rate scaling factor is computed relative to this batch size +MIN_BATCH_SIZE = 8 + + +def scale_lr(batch_size, lr): + return lr * (batch_size / MIN_BATCH_SIZE) + + +def setup(rank, world_size, backend): + # initialize the process group + dist.init_process_group(backend, rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +def load_data(data_dir="./data"): + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + + trainset = torchvision.datasets.CIFAR10( + root=data_dir, train=True, download=True, transform=transform + ) + + testset = torchvision.datasets.CIFAR10( + root=data_dir, train=False, download=True, transform=transform + ) + + return trainset, testset + + +def train(model, optimizer, criterion, trainloader, epoch, gpu_id, experiment): + model.train() + total_loss = 0 + epoch_steps = 0 + for batch_idx, (images, labels) in tqdm(enumerate(trainloader)): + optimizer.zero_grad() + images = images.cuda(gpu_id, non_blocking=True) + labels = labels.cuda(gpu_id, non_blocking=True) + + pred = model(images) + + loss = criterion(pred, labels) + loss.backward() + + experiment.log_metric("train_batch_loss", loss.item()) + + total_loss += loss.item() + epoch_steps += 1 + + optimizer.step() + + return total_loss / epoch_steps + + +def evaluate(model, criterion, valloader, epoch, local_rank): + # Validation loss + total_loss = 0.0 + epoch_steps = 0 + total = 0 + correct = 0 + + model.eval() + for i, data in enumerate(valloader, 0): + with torch.no_grad(): + inputs, labels = data + inputs, labels = inputs.cuda(local_rank), labels.cuda(local_rank) + outputs = model(inputs) + + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + loss = criterion(outputs, labels) + + total_loss += loss.item() + epoch_steps += 1 + + val_acc = correct / total + val_loss = total_loss / epoch_steps + + return val_loss, val_acc + + +def test_accuracy(net, testset, device="cpu"): + testloader = torch.utils.data.DataLoader( + testset, batch_size=4, shuffle=False, num_workers=2 + ) + + correct = 0 + total = 0 + with torch.no_grad(): + for data in testloader: + images, labels = data + images, labels = images.to(device), labels.to(device) + outputs = net(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + return correct / total + + +def run(local_rank, world_size, args): + """ + This is a single process that is linked to a single GPU + + :param local_rank: The id of the GPU on the current node + :param world_size: Total number of processes across nodes + :param args: + :return: + """ + torch.cuda.set_device(local_rank) + + # The overall rank of this GPU process across multiple nodes + global_process_rank = args.node_rank * args.gpus + local_rank + + experiment = Experiment(auto_output_logging="simple") + experiment.log_parameter("run_id", args.run_id) + experiment.log_parameter("global_process_rank", global_process_rank) + experiment.log_parameter("replica_batch_size", args.replica_batch_size) + experiment.log_parameter("batch_size", args.replica_batch_size * world_size) + + learning_rate = scale_lr(args.replica_batch_size * world_size, LEARNING_RATE) + experiment.log_parameter("learning_rate", learning_rate) + + print(f"Running DDP model on Global Process with Rank: {global_process_rank }.") + setup(global_process_rank, world_size, args.backend) + + model = Net() + model.cuda(local_rank) + ddp_model = DDP(model, device_ids=[local_rank]) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(ddp_model.parameters(), lr=learning_rate, momentum=0.9) + + # Load training data + trainset, testset = load_data() + test_abs = int(len(trainset) * 0.8) + train_subset, val_subset = random_split( + trainset, [test_abs, len(trainset) - test_abs] + ) + + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_subset, num_replicas=world_size, rank=global_process_rank + ) + + trainloader = torch.utils.data.DataLoader( + train_subset, + batch_size=args.replica_batch_size, + sampler=train_sampler, + num_workers=8, + ) + valloader = torch.utils.data.DataLoader( + val_subset, batch_size=args.replica_batch_size, shuffle=True, num_workers=8 + ) + + for epoch in range(args.epochs): + train_loss = train( + ddp_model, optimizer, criterion, trainloader, epoch, local_rank, experiment + ) + experiment.log_metric("train_loss", train_loss) + + val_loss, val_acc = evaluate(ddp_model, criterion, valloader, epoch, local_rank) + experiment.log_metric("val_loss", val_loss, epoch=epoch) + experiment.log_metric("val_acc", val_acc, epoch=epoch) + + test_acc = test_accuracy(model, testset, f"cuda:{local_rank}") + experiment.log_metric("test_acc", test_acc, epoch=args.epochs) + + cleanup() + + +class Net(nn.Module): + def __init__(self, l1=120, l2=84): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, l1) + self.fc2 = nn.Linear(l1, l2) + self.fc3 = nn.Linear(l2, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--run_id", type=str) + parser.add_argument("-b", "--backend", type=str, default="nccl") + parser.add_argument( + "-n", + "--nodes", + default=1, + type=int, + metavar="N", + help="total number of compute nodes", + ) + parser.add_argument( + "-g", "--gpus", default=1, type=int, help="number of gpus per node" + ) + parser.add_argument( + "-nr", + "--node_rank", + default=0, + type=int, + help="ranking within the nodes, starts at 0", + ) + parser.add_argument( + "--epochs", + default=2, + type=int, + metavar="N", + help="number of total epochs to run", + ) + parser.add_argument( + "--replica_batch_size", + default=32, + type=int, + metavar="N", + help="number of total epochs to run", + ) + parser.add_argument( + "--master_addr", + type=str, + default="localhost", + help="""Address of master, will default to localhost if not provided. + Master must be able to accept network traffic on the address + port.""", + ) + parser.add_argument( + "--master_port", + type=str, + default="8892", + help="""Port that master is listening on, will default to 29500 if not + provided. Master must be able to accept network traffic on the host and + port.""", + ) + return parser.parse_args() + + +def main(): + args = get_args() + world_size = args.gpus * args.nodes + + # Make sure all nodes can talk to each other on the unprivileged port range + # (1024-65535) in addition to the master port + os.environ["MASTER_ADDR"] = args.master_addr + os.environ["MASTER_PORT"] = args.master_port + + mp.spawn( + run, + args=( + world_size, + args, + ), + nprocs=args.gpus, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/pytorch/comet-pytorch-ddp-mnist-example.py b/pytorch/comet-pytorch-ddp-mnist-example.py new file mode 100644 index 0000000..a3b0519 --- /dev/null +++ b/pytorch/comet-pytorch-ddp-mnist-example.py @@ -0,0 +1,179 @@ +# coding: utf-8 +import argparse +import os +from collections import OrderedDict + +from comet_ml import Experiment + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torchvision +from torch import nn, optim +from torch.nn.parallel import DistributedDataParallel as DDP +from torchvision import transforms +from tqdm import tqdm + +torch.manual_seed(0) +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) + + +INPUT_SIZE = 784 +HIDDEN_SIZES = [128, 64] +OUTPUT_SIZE = 10 +BATCH_SIZE = 256 + + +def setup(rank, world_size, backend): + # initialize the process group + dist.init_process_group(backend, rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +def build_model(): + model = nn.Sequential( + OrderedDict( + [ + ("linear0", nn.Linear(INPUT_SIZE, HIDDEN_SIZES[0])), + ("activ0", nn.ReLU()), + ("linear1", nn.Linear(HIDDEN_SIZES[0], HIDDEN_SIZES[1])), + ("activ1", nn.ReLU()), + ("linear2", nn.Linear(HIDDEN_SIZES[1], OUTPUT_SIZE)), + ("output", nn.LogSoftmax(dim=1)), + ] + ) + ) + + return model + + +def train(model, optimizer, criterion, trainloader, epoch): + model.train() + for batch_idx, (images, labels) in tqdm(enumerate(trainloader)): + optimizer.zero_grad() + images = images.cuda(non_blocking=True) + labels = labels.cuda(non_blocking=True) + + images = images.view(images.size(0), -1) + pred = model(images) + + loss = criterion(pred, labels) + loss.backward() + + optimizer.step() + + +def run(gpu_id, world_size, args): + """ + This is a single process that is linked to a single GPU + + :param gpu_id: The id of the GPU on the current node + :param world_size: Total number of processes across nodes + :param args: + :return: + """ + torch.cuda.set_device(gpu_id) + + # The overall rank of this GPU process across multiple nodes + global_process_rank = args.nr * args.gpus + gpu_id + if global_process_rank == 0: + experiment = Experiment(auto_output_logging="simple") + + else: + experiment = Experiment(disabled=True) + + print(f"Running DDP model on Global Process with Rank: {global_process_rank }.") + setup(global_process_rank, world_size, args.backend) + + model = build_model() + model.cuda(gpu_id) + ddp_model = DDP(model, device_ids=[gpu_id]) + + criterion = nn.CrossEntropyLoss().cuda(gpu_id) + optimizer = optim.Adam(model.parameters()) + + # Load training data + train_dataset = torchvision.datasets.MNIST( + root="./data", train=True, transform=transforms.ToTensor(), download=True + ) + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, num_replicas=world_size, rank=global_process_rank + ) + trainloader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=BATCH_SIZE, + shuffle=False, + num_workers=0, + pin_memory=True, + sampler=train_sampler, + ) + + for epoch in range(1, args.epochs + 1): + train(ddp_model, optimizer, criterion, trainloader, epoch) + + cleanup() + experiment.end() + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-b", "--backend", type=str, default="nccl") + parser.add_argument( + "-n", + "--nodes", + default=1, + type=int, + metavar="N", + help="total number of compute nodes", + ) + parser.add_argument( + "-g", "--gpus", default=1, type=int, help="number of gpus per node" + ) + parser.add_argument( + "-nr", "--nr", default=0, type=int, help="ranking within the nodes, starts at 0" + ) + parser.add_argument( + "--epochs", + default=2, + type=int, + metavar="N", + help="number of total epochs to run", + ) + parser.add_argument( + "--master_addr", + type=str, + default="localhost", + help="""Address of master, will default to localhost if not provided. + Master must be able to accept network traffic on the address + port.""", + ) + parser.add_argument( + "--master_port", + type=str, + default="8892", + help="""Port that master is listening on, will default to 29500 if not + provided. Master must be able to accept network traffic on the host and + port.""", + ) + return parser.parse_args() + + +def main(): + + args = get_args() + world_size = args.gpus * args.nodes + + # Make sure all nodes can talk to each other on the unprivileged port range + # (1024-65535) in addition to the master port + os.environ["MASTER_ADDR"] = args.master_addr + os.environ["MASTER_PORT"] = args.master_port + + mp.spawn(run, args=(world_size, args), nprocs=args.gpus, join=True) + + +if __name__ == "__main__": + main() diff --git a/pytorch/comet-pytorch-ddp-mnist-single-experiment.py b/pytorch/comet-pytorch-ddp-mnist-single-experiment.py new file mode 100644 index 0000000..c5d1daf --- /dev/null +++ b/pytorch/comet-pytorch-ddp-mnist-single-experiment.py @@ -0,0 +1,219 @@ +# coding: utf-8 +import argparse +import hashlib +import os +from collections import OrderedDict + +import comet_ml + +import torch +import torch.distributed as dist +import torchvision +from torch import nn, optim +from torch.nn.parallel import DistributedDataParallel as DDP +from torchvision import transforms +from tqdm import tqdm + +torch.manual_seed(0) +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) + +PROJECT_NAME = os.environ.get("COMET_PROJECT_NAME", "pytorch-ddp-mnist-single") +INPUT_SIZE = 784 +HIDDEN_SIZES = [128, 64] +OUTPUT_SIZE = 10 +BATCH_SIZE = 256 + + +def get_experiment(run_id): + experiment_id = hashlib.sha1(run_id.encode("utf-8")).hexdigest() + os.environ["COMET_EXPERIMENT_KEY"] = experiment_id + + api = comet_ml.API() # Assumes API key is set in config/env + api_experiment = api.get_experiment_by_id(experiment_id) + + if api_experiment is None: + return comet_ml.Experiment(project_name=PROJECT_NAME) + + else: + return comet_ml.ExistingExperiment(project_name=PROJECT_NAME) + + +def setup(): + # initialize the process group + dist.init_process_group(backend="nccl", init_method="env://") + + +def cleanup(): + dist.destroy_process_group() + + +def build_model(): + model = nn.Sequential( + OrderedDict( + [ + ("linear0", nn.Linear(INPUT_SIZE, HIDDEN_SIZES[0])), + ("activ0", nn.ReLU()), + ("linear1", nn.Linear(HIDDEN_SIZES[0], HIDDEN_SIZES[1])), + ("activ1", nn.ReLU()), + ("linear2", nn.Linear(HIDDEN_SIZES[1], OUTPUT_SIZE)), + ("output", nn.LogSoftmax(dim=1)), + ] + ) + ) + + return model + + +def train(model, optimizer, criterion, trainloader, epoch, process_rank, experiment): + model.train() + for batch_idx, (images, labels) in tqdm(enumerate(trainloader)): + optimizer.zero_grad() + images = images.cuda(non_blocking=True) + labels = labels.cuda(non_blocking=True) + + images = images.view(images.size(0), -1) + pred = model(images) + + loss = criterion(pred, labels) + loss.backward() + + experiment.log_metric(f"{process_rank}_train_batch_loss", loss.item()) + + optimizer.step() + + +def run(local_rank, args): + """ + This is a single process that is linked to a single GPU + + :param local_rank: The id of the GPU on the current node + :param args: + :return: + """ + setup() + + # The overall rank of this GPU process across multiple nodes + world_size = dist.get_world_size() + global_process_rank = dist.get_rank() + print(f"Running DDP model on Global Process with Rank: {global_process_rank }.") + + # Set GPU to local device + torch.cuda.set_device(f"cuda:{local_rank}") + + experiment = args.experiment + + model = build_model() + model.cuda(local_rank) + ddp_model = DDP(model, device_ids=[local_rank]) + + criterion = nn.CrossEntropyLoss().cuda(local_rank) + optimizer = optim.Adam(ddp_model.parameters()) + + # Load training data + train_dataset = torchvision.datasets.MNIST( + root="./data", train=True, transform=transforms.ToTensor(), download=True + ) + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, num_replicas=world_size, rank=global_process_rank + ) + trainloader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=BATCH_SIZE, + shuffle=False, + num_workers=0, + pin_memory=True, + sampler=train_sampler, + ) + + for epoch in range(1, args.epochs + 1): + train( + ddp_model, + optimizer, + criterion, + trainloader, + epoch, + global_process_rank, + experiment, + ) + + cleanup() + experiment.end() + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--run_id", type=str) + parser.add_argument("-b", "--backend", type=str, default="nccl") + parser.add_argument( + "--master_addr", + type=str, + default="localhost", + help="""Address of master, will default to localhost if not provided. + Master must be able to accept network traffic on the address + port.""", + ) + parser.add_argument( + "--master_port", + type=str, + default="8892", + help="""Port that master is listening on, will default to 29500 if not + provided. Master must be able to accept network traffic on the host and + port.""", + ) + parser.add_argument( + "-n", + "--nodes", + default=1, + type=int, + metavar="N", + help="total number of compute nodes", + ) + parser.add_argument( + "-g", "--gpus", default=1, type=int, help="number of gpus per node" + ) + parser.add_argument( + "-nr", + "--node_rank", + default=0, + type=int, + help="ranking within the nodes, starts at 0", + ) + parser.add_argument( + "-lr", + "--local_rank", + default=0, + type=int, + help="rank of the process within the current node, starts at 0", + ) + parser.add_argument( + "--epochs", + default=2, + type=int, + metavar="N", + help="number of total epochs to run", + ) + return parser.parse_args() + + +def main(): + args = get_args() + + world_size = args.gpus * args.nodes + global_process_rank = args.node_rank * args.gpus + args.local_rank + + experiment = get_experiment(args.run_id) + + # Make sure all nodes can talk to each other on the unprivileged port range + # (1024-65535) in addition to the master port + os.environ["MASTER_ADDR"] = args.master_addr + os.environ["MASTER_PORT"] = args.master_port + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["RANK"] = str(global_process_rank) + + args.experiment = experiment + run(args.local_rank, args) + + +if __name__ == "__main__": + main() diff --git a/pytorch/comet-pytorch-horovod-mnist.py b/pytorch/comet-pytorch-horovod-mnist.py new file mode 100644 index 0000000..80ecffc --- /dev/null +++ b/pytorch/comet-pytorch-horovod-mnist.py @@ -0,0 +1,260 @@ +# coding: utf-8 +import argparse + +from comet_ml import Experiment + +import horovod.torch as hvd +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch.utils.data.distributed +from torchvision import datasets, transforms + +PROJECT_NAME = "pytorch-horovod" + +# Training settings +parser = argparse.ArgumentParser(description="PyTorch MNIST Example") +parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", +) +parser.add_argument( + "--test-batch-size", + type=int, + default=1000, + metavar="N", + help="input batch size for testing (default: 1000)", +) +parser.add_argument( + "--epochs", + type=int, + default=10, + metavar="N", + help="number of epochs to train (default: 10)", +) +parser.add_argument( + "--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)" +) +parser.add_argument( + "--momentum", + type=float, + default=0.5, + metavar="M", + help="SGD momentum (default: 0.5)", +) +parser.add_argument( + "--no-cuda", action="store_true", default=False, help="disables CUDA training" +) +parser.add_argument( + "--seed", type=int, default=42, metavar="S", help="random seed (default: 42)" +) +parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", +) +parser.add_argument( + "--fp16-allreduce", + action="store_true", + default=False, + help="use fp16 compression during allreduce", +) +parser.add_argument( + "--use-adasum", + action="store_true", + default=False, + help="use adasum algorithm to do reduction", +) + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x) + + +def train(epoch): + model.train() + # Horovod: set epoch to sampler for shuffling. + train_sampler.set_epoch(epoch) + for batch_idx, (data, target) in enumerate(train_loader): + if args.cuda: + data, target = data.cuda(), target.cuda() + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % args.log_interval == 0: + # Horovod: use train_sampler to determine the number of examples in + # this worker's partition. + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_sampler), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + + +def metric_average(val, name): + tensor = torch.tensor(val) + avg_tensor = hvd.allreduce(tensor, name=name) + return avg_tensor.item() + + +def test(): + model.eval() + test_loss = 0.0 + test_accuracy = 0.0 + for data, target in test_loader: + if args.cuda: + data, target = data.cuda(), target.cuda() + output = model(data) + # sum up batch loss + test_loss += F.nll_loss(output, target, size_average=False).item() + # get the index of the max log-probability + pred = output.data.max(1, keepdim=True)[1] + test_accuracy += pred.eq(target.data.view_as(pred)).cpu().float().sum() + + # Horovod: use test_sampler to determine the number of examples in + # this worker's partition. + test_loss /= len(test_sampler) + test_accuracy /= len(test_sampler) + + # Horovod: average metric values across workers. + test_loss = metric_average(test_loss, "avg_loss") + test_accuracy = metric_average(test_accuracy, "avg_accuracy") + + # Horovod: print output only on first rank. + if hvd.rank() == 0: + print( + "\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n".format( + test_loss, 100.0 * test_accuracy + ) + ) + + return test_loss, test_accuracy + + +if __name__ == "__main__": + args = parser.parse_args() + args.cuda = not args.no_cuda and torch.cuda.is_available() + + # Horovod: initialize library. + hvd.init() + torch.manual_seed(args.seed) + + if hvd.rank() == 0: + experiment = Experiment(project_name=PROJECT_NAME) + experiment.log_parameters(args) + + if args.cuda: + # Horovod: pin GPU to local rank. + torch.cuda.set_device(hvd.local_rank()) + torch.cuda.manual_seed(args.seed) + + # Horovod: limit # of CPU threads to be used per worker. + torch.set_num_threads(1) + + kwargs = {"num_workers": 1, "pin_memory": True} if args.cuda else {} + # When supported, use 'forkserver' to spawn dataloader workers instead of 'fork' to + # prevent issues with Infiniband implementations that are not fork-safe + if ( + kwargs.get("num_workers", 0) > 0 + and hasattr(mp, "_supports_context") + and mp._supports_context + and "forkserver" in mp.get_all_start_methods() + ): + kwargs["multiprocessing_context"] = "forkserver" + + train_dataset = datasets.MNIST( + "data-%d" % hvd.rank(), + train=True, + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) + # Horovod: use DistributedSampler to partition the training data. + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, num_replicas=hvd.size(), rank=hvd.rank() + ) + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, sampler=train_sampler, **kwargs + ) + + test_dataset = datasets.MNIST( + "data-%d" % hvd.rank(), + train=False, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) + # Horovod: use DistributedSampler to partition the test data. + test_sampler = torch.utils.data.distributed.DistributedSampler( + test_dataset, num_replicas=hvd.size(), rank=hvd.rank() + ) + test_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=args.test_batch_size, sampler=test_sampler, **kwargs + ) + + model = Net() + + # By default, Adasum doesn't need scaling up learning rate. + lr_scaler = hvd.size() if not args.use_adasum else 1 + + if args.cuda: + # Move model to GPU. + model.cuda() + # If using GPU Adasum allreduce, scale learning rate by local_size. + if args.use_adasum and hvd.nccl_built(): + lr_scaler = hvd.local_size() + + # Horovod: scale learning rate by lr_scaler. + optimizer = optim.SGD( + model.parameters(), lr=args.lr * lr_scaler, momentum=args.momentum + ) + + # Horovod: broadcast parameters & optimizer state. + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer, root_rank=0) + + # Horovod: (optional) compression algorithm. + compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none + + # Horovod: wrap optimizer with DistributedOptimizer. + optimizer = hvd.DistributedOptimizer( + optimizer, + named_parameters=model.named_parameters(), + compression=compression, + op=hvd.Adasum if args.use_adasum else hvd.Average, + ) + + for epoch in range(1, args.epochs + 1): + train(epoch) + test_loss, test_accuracy = test() + if hvd.rank() == 0: + with experiment.test(): + experiment.log_metrics({"loss": test_loss, "accuracy": test_accuracy}) diff --git a/pytorch/online-pytorch-lightning-apex-example.py b/pytorch/online-pytorch-lightning-apex-example.py new file mode 100644 index 0000000..9624715 --- /dev/null +++ b/pytorch/online-pytorch-lightning-apex-example.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2018-2020 Nvidia +# Released under BSD-3 license https://github.com/NVIDIA/apex/blob/master/LICENSE + +from comet_ml import Experiment + +import torch +from apex import amp + + +def run(): + experiment = Experiment() + + torch.cuda.set_device("cuda:0") + + torch.backends.cudnn.benchmark = True + + N, D_in, D_out = 64, 1024, 16 + + # Each process receives its own batch of "fake input data" and "fake target data." + # The "training loop" in each process just uses this fake batch over and over. + # https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more + # realistic example of distributed data sampling for both training and validation. + x = torch.randn(N, D_in, device="cuda") + y = torch.randn(N, D_out, device="cuda") + + model = torch.nn.Linear(D_in, D_out).cuda() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + + model, optimizer = amp.initialize(model, optimizer, opt_level="O1") + + loss_fn = torch.nn.MSELoss() + + for t in range(5000): + optimizer.zero_grad() + y_pred = model(x) + loss = loss_fn(y_pred, y) + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + optimizer.step() + + print("final loss = ", loss) + experiment.log_metric("final_loss", loss) + + +if __name__ == "__main__": + run()