Skip to content

Commit

Permalink
Multiple GPU Training through PyTorch Distributed Data Parallel (PyTo…
Browse files Browse the repository at this point in the history
…rch DDP)
  • Loading branch information
iSiddharth20 authored Jan 2, 2024
2 parents 222fee9 + 290f3a7 commit cbb3657
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 73 deletions.
4 changes: 2 additions & 2 deletions Code/autoencoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ class Grey2RGBAutoEncoder(nn.Module):
def __init__(self):
super(Grey2RGBAutoEncoder, self).__init__()
# Define the Encoder
self.encoder = self._make_layers([1, 64, 128, 256])
self.encoder = self._make_layers([1, 8, 16, 32])
# Define the Decoder
self.decoder = self._make_layers([256, 128, 64, 3], decoder=True)
self.decoder = self._make_layers([32, 16, 8, 3], decoder=True)

# Helper function to create the encoder or decoder layers.
def _make_layers(self, channels, decoder=False):
Expand Down
196 changes: 138 additions & 58 deletions Code/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,87 +10,138 @@
from losses import LossMSE, LossMEP, SSIMLoss
from training import Trainer


# Import Necessary Libraries
import os
import traceback
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
import platform

# Define Working Directories
grayscale_dir = '../Dataset/Greyscale'
rgb_dir = '../Dataset/RGB'

# Define Universal Parameters
image_height = 400
image_width = 600
image_height = 4000
image_width = 6000
batch_size = 2


def main():
def get_backend():
system_type = platform.system()
if system_type == "Linux":
return "nccl"
else:
return "gloo"

def main_worker(rank, world_size):
# Set environment variables
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
# Initialize the distributed environment.
torch.manual_seed(0)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
dist.init_process_group(backend=get_backend(), init_method="env://", world_size=world_size, rank=rank)
main(rank) # Call the existing main function.

def main(rank):
# Initialize Dataset Object (PyTorch Tensors)
try:
dataset = CustomDataset(grayscale_dir, rgb_dir, (image_height, image_width), batch_size)
print('Importing Dataset Complete.')
if rank == 0:
print('Importing Dataset Complete.')
except Exception as e:
print(f"Importing Dataset In-Complete : \n{e}")
if rank == 0:
print(f"Importing Dataset In-Complete : \n{e}")
if rank == 0:
print('-'*20) # Makes Output Readable
# Import Loss Functions
try:
loss_mse = LossMSE() # Mean Squared Error Loss
loss_mep = LossMEP(alpha=0.4) # Maximum Entropy Loss
loss_ssim = SSIMLoss() # Structural Similarity Index Measure Loss
print('Importing Loss Functions Complete.')
if rank == 0:
print('Importing Loss Functions Complete.')
except Exception as e:
print(f"Importing Loss Functions In-Complete : \n{e}")
print('-'*20) # Makes Output Readable
if rank == 0:
print(f"Importing Loss Functions In-Complete : \n{e}")
if rank == 0:
print('-'*20) # Makes Output Readable

# Initialize AutoEncoder Model and Import Dataloader (Training, Validation)
data_autoencoder_train, data_autoencoder_val = dataset.get_autoencoder_batches(val_split=0.2)
print('AutoEncoder Model Data Imported.')
if rank == 0:
print('AutoEncoder Model Data Imported.')
model_autoencoder = Grey2RGBAutoEncoder()
print('AutoEncoder Model Initialized.')
print('-'*20) # Makes Output Readable
if rank == 0:
print('AutoEncoder Model Initialized.')
print('-'*20) # Makes Output Readable

# Initialize LSTM Model and Import Dataloader (Training, Validation)
data_lstm_train, data_lstm_val = dataset.get_lstm_batches(val_split=0.25, sequence_length=2)
print('LSTM Model Data Imported.')
if rank == 0:
print('LSTM Model Data Imported.')
model_lstm = ConvLSTM(input_dim=1, hidden_dims=[1,1,1], kernel_size=(3, 3), num_layers=3, alpha=0.5)
print('LSTM Model Initialized.')
print('-'*20) # Makes Output Readable
if rank == 0:
print('LSTM Model Initialized.')
print('-'*20) # Makes Output Readable

'''
Initialize Trainer Objects
'''
# Method 1 : Baseline : Mean Squared Error Loss for AutoEncoder and LSTM
os.makedirs('../Models/Method1', exist_ok=True) # Creating Directory for Model Saving
model_save_path_ae = '../Models/Method1/model_autoencoder_m1.pth'
trainer_autoencoder_baseline = Trainer(model_autoencoder, loss_mse, optimizer=torch.optim.Adam(model_autoencoder.parameters(), lr=0.001), model_save_path=model_save_path_ae)
print('Method-1 AutoEncoder Trainer Initialized.')
trainer_autoencoder_baseline = Trainer(model=model_autoencoder,
loss_function=loss_mse,
optimizer=torch.optim.Adam(model_autoencoder.parameters(), lr=0.001),
model_save_path=model_save_path_ae,
rank=rank)
if rank == 0:
print('Method-1 AutoEncoder Trainer Initialized.')
model_save_path_lstm = '../Models/Method1/model_lstm_m1.pth'
trainer_lstm_baseline = Trainer(model_lstm, loss_mse, optimizer=torch.optim.Adam(model_lstm.parameters(), lr=0.001), model_save_path=model_save_path_lstm)
print('Method-1 LSTM Trainer Initialized.')
print('-'*10) # Makes Output Readable
trainer_lstm_baseline = Trainer(model=model_lstm,
loss_function=loss_mse,
optimizer=torch.optim.Adam(model_lstm.parameters(), lr=0.001),
model_save_path=model_save_path_lstm,
rank=rank)
if rank == 0:
print('Method-1 LSTM Trainer Initialized.')
print('-'*10) # Makes Output Readable

# Method 2 : Composite Loss (MSE + MaxEnt) for AutoEncoder and Mean Squared Error Loss for LSTM
os.makedirs('../Models/Method2', exist_ok=True) # Creating Directory for Model Saving
model_save_path_ae = '../Models/Method2/model_autoencoder_m2.pth'
trainer_autoencoder_m2 = Trainer(model=model_autoencoder, loss_function=loss_mep, optimizer=torch.optim.Adam(model_autoencoder.parameters(), lr=0.001), model_save_path=model_save_path_ae)
print('Method-2 AutoEncoder Trainer Initialized.')
print('Method-2 LSTM == Method-1 LSTM')
print('-'*10) # Makes Output Readable
trainer_autoencoder_m2 = Trainer(model=model_autoencoder,
loss_function=loss_mep,
optimizer=torch.optim.Adam(model_autoencoder.parameters(), lr=0.001),
model_save_path=model_save_path_ae,
rank=rank)
if rank == 0:
print('Method-2 AutoEncoder Trainer Initialized.')
print('Method-2 LSTM == Method-1 LSTM')
print('-'*10) # Makes Output Readable

# Method 3 : Mean Squared Error Loss for AutoEncoder and SSIM Loss for LSTM
os.makedirs('../Models/Method3', exist_ok=True) # Creating Directory for Model Saving
print('Method-3 AutoEncoder == Method-1 AutoEncoder')
if rank == 0:
print('Method-3 AutoEncoder == Method-1 AutoEncoder')
model_save_path_lstm = '../Models/Method3/model_lstm_m3.pth'
trainer_lstm_m3 = Trainer(model_lstm, loss_ssim, optimizer=torch.optim.Adam(model_lstm.parameters(), lr=0.001), model_save_path=model_save_path_lstm)
print('Method-3 LSTM Trainer Initialized.')
print('-'*10) # Makes Output Readable
trainer_lstm_m3 = Trainer(model=model_lstm,
loss_function=loss_ssim,
optimizer=torch.optim.Adam(model_lstm.parameters(), lr=0.001),
model_save_path=model_save_path_lstm,
rank=rank)
if rank == 0:
print('Method-3 LSTM Trainer Initialized.')
print('-'*10) # Makes Output Readable

# Method 4 : Proposed Method : Composite Loss (MSE + MaxEnt) for AutoEncoder and SSIM Loss for LSTM
print('Method-4 AutoEncoder == Method-2 AutoEncoder')
print('Method-4 LSTM == Method-3 LSTM')

print('-'*20) # Makes Output Readable
if rank == 0:
print('Method-4 AutoEncoder == Method-2 AutoEncoder')
print('Method-4 LSTM == Method-3 LSTM')
print('-'*20) # Makes Output Readable


'''
Expand All @@ -99,55 +150,84 @@ def main():
# Method-1
try:
epochs = 1
print('Method-1 AutoEncoder Training Start')
if rank == 0:
print('Method-1 AutoEncoder Training Start')
model_autoencoder_m1 = trainer_autoencoder_baseline.train_autoencoder(epochs, data_autoencoder_train, data_autoencoder_val)
print('Method-1 AutoEncoder Training Complete.')
if rank == 0:
print('Method-1 AutoEncoder Training Complete.')
except Exception as e:
print(f"Method-1 AutoEncoder Training Error : \n{e}")
if rank == 0:
print(f"Method-1 AutoEncoder Training Error : \n{e}")
traceback.print_exc()
print('-'*10) # Makes Output Readable
finally:
if rank == 0:
trainer_autoencoder_baseline.cleanup_ddp()
if rank == 0:
print('-'*10) # Makes Output Readable
try:
epochs = 1
print('Method-1 LSTM Training Start')
if rank == 0:
print('Method-1 LSTM Training Start')
model_lstm_m1 = trainer_lstm_baseline.train_lstm(epochs, data_lstm_train, data_lstm_val)
print('Method-1 LSTM Training Complete.')
if rank == 0:
print('Method-1 LSTM Training Complete.')
except Exception as e:
print(f"Method-1 LSTM Training Error : \n{e}")
if rank == 0:
print(f"Method-1 LSTM Training Error : \n{e}")
traceback.print_exc()
print('-'*20) # Makes Output Readable
finally:
if rank == 0:
trainer_lstm_baseline.cleanup_ddp()
if rank == 0:
print('-'*20) # Makes Output Readable

# Method-2
try:
epochs = 1
print('Method-2 AutoEncoder Training Start')
if rank == 0:
print('Method-2 AutoEncoder Training Start')
model_autoencoder_m2 = trainer_autoencoder_m2.train_autoencoder(epochs, data_autoencoder_train, data_autoencoder_val)
print('Method-2 AutoEncoder Training Complete.')
if rank == 0:
print('Method-2 AutoEncoder Training Complete.')
except Exception as e:
print(f"Method-2 AutoEncoder Training Error : \n{e}")
if rank == 0:
print(f"Method-2 AutoEncoder Training Error : \n{e}")
traceback.print_exc()
print('-'*10) # Makes Output Readable
print("Method-2 LSTM == Method-1 LSTM, No Need To Train Again.")
print('-'*20) # Makes Output Readable
finally:
trainer_autoencoder_m2.cleanup_ddp()
if rank == 0:
print('-'*10) # Makes Output Readable
print("Method-2 LSTM == Method-1 LSTM, No Need To Train Again.")
print('-'*20) # Makes Output Readable

# Method-3
print("Method-3 AutoEncoder == Method-1 AutoEncoder, No Need To Train Again.")
print('-'*10) # Makes Output Readable
if rank == 0:
print("Method-3 AutoEncoder == Method-1 AutoEncoder, No Need To Train Again.")
print('-'*10) # Makes Output Readable
try:
epochs = 1
print('Method-3 LSTM Training Start.')
if rank == 0:
print('Method-3 LSTM Training Start.')
model_lstm_m3 = trainer_lstm_m3.train_lstm(epochs, data_lstm_train, data_lstm_val)
print('Method-3 LSTM Training Complete.')
if rank == 0:
print('Method-3 LSTM Training Complete.')
except Exception as e:
print(f"Method-3 LSTM Training Error : \n{e}")
if rank == 0:
print(f"Method-3 LSTM Training Error : \n{e}")
traceback.print_exc()
print('-'*20) # Makes Output Readable
finally:
trainer_lstm_m3.cleanup_ddp()
if rank == 0:
print('-'*20) # Makes Output Readable

# Method-4
print("Method-4 AutoEncoder == Method-2 AutoEncoder, No Need To Train Again.")
print('-'*10) # Makes Output Readable
print("Method-4 LSTM == Method-3 LSTM, No Need To Train Again.")
print('-'*20) # Makes Output Readable
if rank == 0:
print("Method-4 AutoEncoder == Method-2 AutoEncoder, No Need To Train Again.")
print('-'*10) # Makes Output Readable
print("Method-4 LSTM == Method-3 LSTM, No Need To Train Again.")
print('-'*20) # Makes Output Readable


if __name__ == '__main__':
main()
world_size = torch.cuda.device_count() # Number of available GPUs
mp.spawn(main_worker, args=(world_size,), nprocs=world_size, join=True)
35 changes: 22 additions & 13 deletions Code/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,37 @@

# Import Necessary Libraries
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

# Define Training Class
class Trainer():
def __init__(self, model, loss_function, optimizer=None, model_save_path=None):
# Use All Available CUDA GPUs for Training (if Available)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
def __init__(self, model, loss_function, optimizer=None, model_save_path=None, rank=None):
self.rank = rank # Rank of the current process
self.device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')
self.model = model.to(self.device)
# Define the loss function
self.loss_function = loss_function
# Define the optimizer
self.optimizer = optimizer if optimizer is not None else torch.optim.Adam(self.model.parameters(), lr=0.001)
# Wrap model with DDP
if torch.cuda.device_count() > 1 and rank is not None:
self.model = DDP(self.model, device_ids=[rank], find_unused_parameters=True)
# Define the path to save the model
self.model_save_path = model_save_path
self.model_save_path = model_save_path if rank == 0 else None # Only save on master process

def cleanup_ddp(self):
if dist.is_initialized():
dist.destroy_process_group()

def save_model(self):
# Save the model
torch.save(self.model.state_dict(), self.model_save_path)
if self.rank == 0:
# Save the model
torch.save(self.model.state_dict(), self.model_save_path)

def train_autoencoder(self, epochs, train_loader, val_loader):
# Print Names of All Available GPUs (if any) to Train the Model
if torch.cuda.device_count() > 0:
if torch.cuda.device_count() > 0 and self.rank == 0:
gpu_names = ', '.join([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])
print("\tGPUs being used for Training : ",gpu_names)
best_val_loss = float('inf')
Expand All @@ -54,7 +61,8 @@ def train_autoencoder(self, epochs, train_loader, val_loader):
val_loss = sum(self.loss_function(self.model(input.to(self.device)), target.to(self.device)).item() for input, target in val_loader) # Compute Total Validation Loss
val_loss /= len(val_loader) # Compute Average Validation Loss
# Print epochs and losses
print(f'\tAutoEncoder Epoch {epoch+1}/{epochs} --- Training Loss: {loss.item()} --- Validation Loss: {val_loss}')
if self.rank == 0:
print(f'\tAutoEncoder Epoch {epoch+1}/{epochs} --- Training Loss: {loss.item()} --- Validation Loss: {val_loss}')
# If the current validation loss is lower than the best validation loss, save the model
if val_loss < best_val_loss:
best_val_loss = val_loss # Update the best validation loss
Expand All @@ -64,7 +72,7 @@ def train_autoencoder(self, epochs, train_loader, val_loader):

def train_lstm(self, epochs, train_loader, val_loader):
# Print Names of All Available GPUs (if any) to Train the Model
if torch.cuda.device_count() > 0:
if torch.cuda.device_count() > 0 and self.rank == 0:
gpu_names = ', '.join([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])
print("\tGPUs being used for Training : ",gpu_names)
best_val_loss = float('inf')
Expand All @@ -88,7 +96,8 @@ def train_lstm(self, epochs, train_loader, val_loader):
val_loss += self.loss_function(output_sequence, target_sequence).item() # Accumulate loss
val_loss /= len(val_loader) # Average validation loss
# Print epochs and losses
print(f'\tLSTM Epoch {epoch+1}/{epochs} --- Training Loss: {loss.item()} --- Validation Loss: {val_loss}')
if self.rank == 0:
print(f'\tLSTM Epoch {epoch+1}/{epochs} --- Training Loss: {loss.item()} --- Validation Loss: {val_loss}')
# Model saving based on validation loss
if val_loss < best_val_loss:
best_val_loss = val_loss
Expand Down

0 comments on commit cbb3657

Please sign in to comment.